FEATURE: add support for new OpenAI embedding models (#445)

* FEATURE: add support for new OpenAI embedding models

This adds support for just released text_embedding_3_small and large

Note, we have not yet implemented truncation support which is a
new API feature. (triggered using dimensions)

* Tiny side fix, recalc bots when ai is enabled or disabled

* FIX: downsample to 2000 items per vector which is a pgvector limitation
This commit is contained in:
Sam 2024-01-30 03:24:30 +11:00 committed by GitHub
parent 4c4b418cff
commit b2b01185f2
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 190 additions and 35 deletions

View File

@ -236,6 +236,8 @@ discourse_ai:
choices: choices:
- all-mpnet-base-v2 - all-mpnet-base-v2
- text-embedding-ada-002 - text-embedding-ada-002
- text-embedding-3-small
- text-embedding-3-large
- multilingual-e5-large - multilingual-e5-large
- bge-large-en - bge-large-en
- gemini - gemini

View File

@ -0,0 +1,49 @@
# frozen_string_literal: true
class CreateOpenaiTextEmbeddingTables < ActiveRecord::Migration[7.0]
def change
create_table :ai_topic_embeddings_6_1, id: false do |t|
t.integer :topic_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(1536)", null: false
t.timestamps
t.index :topic_id, unique: true
end
create_table :ai_topic_embeddings_7_1, id: false do |t|
t.integer :topic_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(2000)", null: false
t.timestamps
t.index :topic_id, unique: true
end
create_table :ai_post_embeddings_6_1, id: false do |t|
t.integer :post_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(1536)", null: false
t.timestamps
t.index :post_id, unique: true
end
create_table :ai_post_embeddings_7_1, id: false do |t|
t.integer :post_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(2000)", null: false
t.timestamps
t.index :post_id, unique: true
end
end
end

View File

@ -46,7 +46,8 @@ module DiscourseAi
def inject_into(plugin) def inject_into(plugin)
plugin.on(:site_setting_changed) do |name, _old_value, _new_value| plugin.on(:site_setting_changed) do |name, _old_value, _new_value|
if name == :ai_bot_enabled_chat_bots || name == :ai_bot_enabled if name == :ai_bot_enabled_chat_bots || name == :ai_bot_enabled ||
name == :discourse_ai_enabled
DiscourseAi::AiBot::SiteSettingsExtension.enable_or_disable_ai_bots DiscourseAi::AiBot::SiteSettingsExtension.enable_or_disable_ai_bots
end end
end end

View File

@ -18,7 +18,6 @@ module DiscourseAi
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
cache_for = results_ttl(topic) cache_for = results_ttl(topic)
asd =
Discourse Discourse
.cache .cache
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do .fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
@ -71,7 +70,6 @@ module DiscourseAi
return "" if related_topics.empty? return "" if related_topics.empty?
render_result =
ApplicationController.render( ApplicationController.render(
template: "list/related_topics", template: "list/related_topics",
layout: false, layout: false,

View File

@ -13,6 +13,8 @@ module DiscourseAi
DiscourseAi::Embeddings::VectorRepresentations::Gemini, DiscourseAi::Embeddings::VectorRepresentations::Gemini,
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large, DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002, DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
].map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model } ].map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model }
end end

View File

@ -0,0 +1,53 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class TextEmbedding3Large < Base
def id
7
end
def version
1
end
def name
"text-embedding-3-large"
end
def dimensions
# real dimentions are 3072, but we only support up to 2000 in the
# indexes, so we downsample to 2000 via API
2000
end
def max_sequence_length
8191
end
def pg_function
"<=>"
end
def pg_index_type
"vector_cosine_ops"
end
def vector_from(text)
response =
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
text,
model: name,
dimensions: dimensions,
)
response[:data].first[:embedding]
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
end
end
end

View File

@ -0,0 +1,46 @@
# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class TextEmbedding3Small < Base
def id
6
end
def version
1
end
def name
"text-embedding-3-small"
end
def dimensions
1536
end
def max_sequence_length
8191
end
def pg_function
"<=>"
end
def pg_index_type
"vector_cosine_ops"
end
def vector_from(text)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name)
response[:data].first[:embedding]
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
end
end
end

View File

@ -33,7 +33,7 @@ module DiscourseAi
end end
def vector_from(text) def vector_from(text)
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text) response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name)
response[:data].first[:embedding] response[:data].first[:embedding]
end end

View File

@ -3,7 +3,7 @@
module ::DiscourseAi module ::DiscourseAi
module Inference module Inference
class OpenAiEmbeddings class OpenAiEmbeddings
def self.perform!(content, model = nil) def self.perform!(content, model:, dimensions: nil)
headers = { "Content-Type" => "application/json" } headers = { "Content-Type" => "application/json" }
if SiteSetting.ai_openai_embeddings_url.include?("azure") if SiteSetting.ai_openai_embeddings_url.include?("azure")
@ -12,14 +12,10 @@ module ::DiscourseAi
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}" headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
end end
model ||= "text-embedding-ada-002" payload = { model: model, input: content }
payload[:dimensions] = dimensions if dimensions.present?
response = response = Faraday.post(SiteSetting.ai_openai_embeddings_url, payload.to_json, headers)
Faraday.post(
SiteSetting.ai_openai_embeddings_url,
{ model: model, input: content }.to_json,
headers,
)
case response.status case response.status
when 200 when 200

View File

@ -25,7 +25,8 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
}, },
).to_return(status: 200, body: body_json, headers: {}) ).to_return(status: 200, body: body_json, headers: {})
result = DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello") result =
DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello", model: "text-embedding-ada-002")
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 }) expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] }) expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
@ -42,15 +43,22 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
data: [{ object: "embedding", embedding: [0.0, 0.1] }], data: [{ object: "embedding", embedding: [0.0, 0.1] }],
}.to_json }.to_json
body = { model: "text-embedding-ada-002", input: "hello", dimensions: 1000 }.to_json
stub_request(:post, "https://api.openai.com/v1/embeddings").with( stub_request(:post, "https://api.openai.com/v1/embeddings").with(
body: "{\"model\":\"text-embedding-ada-002\",\"input\":\"hello\"}", body: body,
headers: { headers: {
"Authorization" => "Bearer 123456", "Authorization" => "Bearer 123456",
"Content-Type" => "application/json", "Content-Type" => "application/json",
}, },
).to_return(status: 200, body: body_json, headers: {}) ).to_return(status: 200, body: body_json, headers: {})
result = DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello") result =
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
"hello",
model: "text-embedding-ada-002",
dimensions: 1000,
)
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 }) expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] }) expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })