FEATURE: add support for new OpenAI embedding models (#445)
* FEATURE: add support for new OpenAI embedding models This adds support for just released text_embedding_3_small and large Note, we have not yet implemented truncation support which is a new API feature. (triggered using dimensions) * Tiny side fix, recalc bots when ai is enabled or disabled * FIX: downsample to 2000 items per vector which is a pgvector limitation
This commit is contained in:
parent
4c4b418cff
commit
b2b01185f2
|
@ -236,6 +236,8 @@ discourse_ai:
|
||||||
choices:
|
choices:
|
||||||
- all-mpnet-base-v2
|
- all-mpnet-base-v2
|
||||||
- text-embedding-ada-002
|
- text-embedding-ada-002
|
||||||
|
- text-embedding-3-small
|
||||||
|
- text-embedding-3-large
|
||||||
- multilingual-e5-large
|
- multilingual-e5-large
|
||||||
- bge-large-en
|
- bge-large-en
|
||||||
- gemini
|
- gemini
|
||||||
|
|
|
@ -0,0 +1,49 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class CreateOpenaiTextEmbeddingTables < ActiveRecord::Migration[7.0]
|
||||||
|
def change
|
||||||
|
create_table :ai_topic_embeddings_6_1, id: false do |t|
|
||||||
|
t.integer :topic_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(1536)", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :topic_id, unique: true
|
||||||
|
end
|
||||||
|
|
||||||
|
create_table :ai_topic_embeddings_7_1, id: false do |t|
|
||||||
|
t.integer :topic_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(2000)", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :topic_id, unique: true
|
||||||
|
end
|
||||||
|
|
||||||
|
create_table :ai_post_embeddings_6_1, id: false do |t|
|
||||||
|
t.integer :post_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(1536)", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :post_id, unique: true
|
||||||
|
end
|
||||||
|
|
||||||
|
create_table :ai_post_embeddings_7_1, id: false do |t|
|
||||||
|
t.integer :post_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(2000)", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :post_id, unique: true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -46,7 +46,8 @@ module DiscourseAi
|
||||||
|
|
||||||
def inject_into(plugin)
|
def inject_into(plugin)
|
||||||
plugin.on(:site_setting_changed) do |name, _old_value, _new_value|
|
plugin.on(:site_setting_changed) do |name, _old_value, _new_value|
|
||||||
if name == :ai_bot_enabled_chat_bots || name == :ai_bot_enabled
|
if name == :ai_bot_enabled_chat_bots || name == :ai_bot_enabled ||
|
||||||
|
name == :discourse_ai_enabled
|
||||||
DiscourseAi::AiBot::SiteSettingsExtension.enable_or_disable_ai_bots
|
DiscourseAi::AiBot::SiteSettingsExtension.enable_or_disable_ai_bots
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -18,7 +18,6 @@ module DiscourseAi
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
cache_for = results_ttl(topic)
|
cache_for = results_ttl(topic)
|
||||||
|
|
||||||
asd =
|
|
||||||
Discourse
|
Discourse
|
||||||
.cache
|
.cache
|
||||||
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
||||||
|
@ -71,7 +70,6 @@ module DiscourseAi
|
||||||
|
|
||||||
return "" if related_topics.empty?
|
return "" if related_topics.empty?
|
||||||
|
|
||||||
render_result =
|
|
||||||
ApplicationController.render(
|
ApplicationController.render(
|
||||||
template: "list/related_topics",
|
template: "list/related_topics",
|
||||||
layout: false,
|
layout: false,
|
||||||
|
|
|
@ -13,6 +13,8 @@ module DiscourseAi
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
|
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
|
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
|
||||||
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
|
||||||
].map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model }
|
].map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model }
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,53 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module VectorRepresentations
|
||||||
|
class TextEmbedding3Large < Base
|
||||||
|
def id
|
||||||
|
7
|
||||||
|
end
|
||||||
|
|
||||||
|
def version
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def name
|
||||||
|
"text-embedding-3-large"
|
||||||
|
end
|
||||||
|
|
||||||
|
def dimensions
|
||||||
|
# real dimentions are 3072, but we only support up to 2000 in the
|
||||||
|
# indexes, so we downsample to 2000 via API
|
||||||
|
2000
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_sequence_length
|
||||||
|
8191
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_function
|
||||||
|
"<=>"
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_index_type
|
||||||
|
"vector_cosine_ops"
|
||||||
|
end
|
||||||
|
|
||||||
|
def vector_from(text)
|
||||||
|
response =
|
||||||
|
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
|
||||||
|
text,
|
||||||
|
model: name,
|
||||||
|
dimensions: dimensions,
|
||||||
|
)
|
||||||
|
response[:data].first[:embedding]
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,46 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Embeddings
|
||||||
|
module VectorRepresentations
|
||||||
|
class TextEmbedding3Small < Base
|
||||||
|
def id
|
||||||
|
6
|
||||||
|
end
|
||||||
|
|
||||||
|
def version
|
||||||
|
1
|
||||||
|
end
|
||||||
|
|
||||||
|
def name
|
||||||
|
"text-embedding-3-small"
|
||||||
|
end
|
||||||
|
|
||||||
|
def dimensions
|
||||||
|
1536
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_sequence_length
|
||||||
|
8191
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_function
|
||||||
|
"<=>"
|
||||||
|
end
|
||||||
|
|
||||||
|
def pg_index_type
|
||||||
|
"vector_cosine_ops"
|
||||||
|
end
|
||||||
|
|
||||||
|
def vector_from(text)
|
||||||
|
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name)
|
||||||
|
response[:data].first[:embedding]
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -33,7 +33,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def vector_from(text)
|
def vector_from(text)
|
||||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text)
|
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name)
|
||||||
response[:data].first[:embedding]
|
response[:data].first[:embedding]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
module ::DiscourseAi
|
module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class OpenAiEmbeddings
|
class OpenAiEmbeddings
|
||||||
def self.perform!(content, model = nil)
|
def self.perform!(content, model:, dimensions: nil)
|
||||||
headers = { "Content-Type" => "application/json" }
|
headers = { "Content-Type" => "application/json" }
|
||||||
|
|
||||||
if SiteSetting.ai_openai_embeddings_url.include?("azure")
|
if SiteSetting.ai_openai_embeddings_url.include?("azure")
|
||||||
|
@ -12,14 +12,10 @@ module ::DiscourseAi
|
||||||
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
||||||
end
|
end
|
||||||
|
|
||||||
model ||= "text-embedding-ada-002"
|
payload = { model: model, input: content }
|
||||||
|
payload[:dimensions] = dimensions if dimensions.present?
|
||||||
|
|
||||||
response =
|
response = Faraday.post(SiteSetting.ai_openai_embeddings_url, payload.to_json, headers)
|
||||||
Faraday.post(
|
|
||||||
SiteSetting.ai_openai_embeddings_url,
|
|
||||||
{ model: model, input: content }.to_json,
|
|
||||||
headers,
|
|
||||||
)
|
|
||||||
|
|
||||||
case response.status
|
case response.status
|
||||||
when 200
|
when 200
|
||||||
|
|
|
@ -25,7 +25,8 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
|
||||||
},
|
},
|
||||||
).to_return(status: 200, body: body_json, headers: {})
|
).to_return(status: 200, body: body_json, headers: {})
|
||||||
|
|
||||||
result = DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello")
|
result =
|
||||||
|
DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello", model: "text-embedding-ada-002")
|
||||||
|
|
||||||
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
|
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
|
||||||
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
|
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
|
||||||
|
@ -42,15 +43,22 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
|
||||||
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
|
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
|
||||||
}.to_json
|
}.to_json
|
||||||
|
|
||||||
|
body = { model: "text-embedding-ada-002", input: "hello", dimensions: 1000 }.to_json
|
||||||
|
|
||||||
stub_request(:post, "https://api.openai.com/v1/embeddings").with(
|
stub_request(:post, "https://api.openai.com/v1/embeddings").with(
|
||||||
body: "{\"model\":\"text-embedding-ada-002\",\"input\":\"hello\"}",
|
body: body,
|
||||||
headers: {
|
headers: {
|
||||||
"Authorization" => "Bearer 123456",
|
"Authorization" => "Bearer 123456",
|
||||||
"Content-Type" => "application/json",
|
"Content-Type" => "application/json",
|
||||||
},
|
},
|
||||||
).to_return(status: 200, body: body_json, headers: {})
|
).to_return(status: 200, body: body_json, headers: {})
|
||||||
|
|
||||||
result = DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello")
|
result =
|
||||||
|
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
|
||||||
|
"hello",
|
||||||
|
model: "text-embedding-ada-002",
|
||||||
|
dimensions: 1000,
|
||||||
|
)
|
||||||
|
|
||||||
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
|
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
|
||||||
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
|
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
|
||||||
|
|
Loading…
Reference in New Issue