FEATURE: add support for new OpenAI embedding models (#445)
* FEATURE: add support for new OpenAI embedding models This adds support for just released text_embedding_3_small and large Note, we have not yet implemented truncation support which is a new API feature. (triggered using dimensions) * Tiny side fix, recalc bots when ai is enabled or disabled * FIX: downsample to 2000 items per vector which is a pgvector limitation
This commit is contained in:
parent
4c4b418cff
commit
b2b01185f2
|
@ -236,6 +236,8 @@ discourse_ai:
|
|||
choices:
|
||||
- all-mpnet-base-v2
|
||||
- text-embedding-ada-002
|
||||
- text-embedding-3-small
|
||||
- text-embedding-3-large
|
||||
- multilingual-e5-large
|
||||
- bge-large-en
|
||||
- gemini
|
||||
|
|
|
@ -0,0 +1,49 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class CreateOpenaiTextEmbeddingTables < ActiveRecord::Migration[7.0]
|
||||
def change
|
||||
create_table :ai_topic_embeddings_6_1, id: false do |t|
|
||||
t.integer :topic_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(1536)", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :topic_id, unique: true
|
||||
end
|
||||
|
||||
create_table :ai_topic_embeddings_7_1, id: false do |t|
|
||||
t.integer :topic_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(2000)", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :topic_id, unique: true
|
||||
end
|
||||
|
||||
create_table :ai_post_embeddings_6_1, id: false do |t|
|
||||
t.integer :post_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(1536)", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :post_id, unique: true
|
||||
end
|
||||
|
||||
create_table :ai_post_embeddings_7_1, id: false do |t|
|
||||
t.integer :post_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(2000)", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :post_id, unique: true
|
||||
end
|
||||
end
|
||||
end
|
|
@ -46,7 +46,8 @@ module DiscourseAi
|
|||
|
||||
def inject_into(plugin)
|
||||
plugin.on(:site_setting_changed) do |name, _old_value, _new_value|
|
||||
if name == :ai_bot_enabled_chat_bots || name == :ai_bot_enabled
|
||||
if name == :ai_bot_enabled_chat_bots || name == :ai_bot_enabled ||
|
||||
name == :discourse_ai_enabled
|
||||
DiscourseAi::AiBot::SiteSettingsExtension.enable_or_disable_ai_bots
|
||||
end
|
||||
end
|
||||
|
|
|
@ -18,20 +18,19 @@ module DiscourseAi
|
|||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
cache_for = results_ttl(topic)
|
||||
|
||||
asd =
|
||||
Discourse
|
||||
.cache
|
||||
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
||||
vector_rep
|
||||
.symmetric_topics_similarity_search(topic)
|
||||
.tap do |candidate_ids|
|
||||
# Happens when the topic doesn't have any embeddings
|
||||
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
||||
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
|
||||
raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
|
||||
end
|
||||
Discourse
|
||||
.cache
|
||||
.fetch(semantic_suggested_key(topic.id), expires_in: cache_for) do
|
||||
vector_rep
|
||||
.symmetric_topics_similarity_search(topic)
|
||||
.tap do |candidate_ids|
|
||||
# Happens when the topic doesn't have any embeddings
|
||||
# I'd rather not use Exceptions to control the flow, so this should be refactored soon
|
||||
if candidate_ids.empty? || !candidate_ids.include?(topic.id)
|
||||
raise MissingEmbeddingError, "No embeddings found for topic #{topic.id}"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
rescue MissingEmbeddingError
|
||||
# avoid a flood of jobs when visiting topic
|
||||
if Discourse.redis.set(
|
||||
|
@ -71,15 +70,14 @@ module DiscourseAi
|
|||
|
||||
return "" if related_topics.empty?
|
||||
|
||||
render_result =
|
||||
ApplicationController.render(
|
||||
template: "list/related_topics",
|
||||
layout: false,
|
||||
assigns: {
|
||||
list: related_topics,
|
||||
topic: topic,
|
||||
},
|
||||
)
|
||||
ApplicationController.render(
|
||||
template: "list/related_topics",
|
||||
layout: false,
|
||||
assigns: {
|
||||
list: related_topics,
|
||||
topic: topic,
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
private
|
||||
|
|
|
@ -13,6 +13,8 @@ module DiscourseAi
|
|||
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
|
||||
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
|
||||
].map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model }
|
||||
end
|
||||
|
||||
|
|
|
@ -0,0 +1,53 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class TextEmbedding3Large < Base
|
||||
def id
|
||||
7
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def name
|
||||
"text-embedding-3-large"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
# real dimentions are 3072, but we only support up to 2000 in the
|
||||
# indexes, so we downsample to 2000 via API
|
||||
2000
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
8191
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<=>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"vector_cosine_ops"
|
||||
end
|
||||
|
||||
def vector_from(text)
|
||||
response =
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
|
||||
text,
|
||||
model: name,
|
||||
dimensions: dimensions,
|
||||
)
|
||||
response[:data].first[:embedding]
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,46 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Embeddings
|
||||
module VectorRepresentations
|
||||
class TextEmbedding3Small < Base
|
||||
def id
|
||||
6
|
||||
end
|
||||
|
||||
def version
|
||||
1
|
||||
end
|
||||
|
||||
def name
|
||||
"text-embedding-3-small"
|
||||
end
|
||||
|
||||
def dimensions
|
||||
1536
|
||||
end
|
||||
|
||||
def max_sequence_length
|
||||
8191
|
||||
end
|
||||
|
||||
def pg_function
|
||||
"<=>"
|
||||
end
|
||||
|
||||
def pg_index_type
|
||||
"vector_cosine_ops"
|
||||
end
|
||||
|
||||
def vector_from(text)
|
||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name)
|
||||
response[:data].first[:embedding]
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -33,7 +33,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def vector_from(text)
|
||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text)
|
||||
response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: name)
|
||||
response[:data].first[:embedding]
|
||||
end
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
module ::DiscourseAi
|
||||
module Inference
|
||||
class OpenAiEmbeddings
|
||||
def self.perform!(content, model = nil)
|
||||
def self.perform!(content, model:, dimensions: nil)
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
|
||||
if SiteSetting.ai_openai_embeddings_url.include?("azure")
|
||||
|
@ -12,14 +12,10 @@ module ::DiscourseAi
|
|||
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
||||
end
|
||||
|
||||
model ||= "text-embedding-ada-002"
|
||||
payload = { model: model, input: content }
|
||||
payload[:dimensions] = dimensions if dimensions.present?
|
||||
|
||||
response =
|
||||
Faraday.post(
|
||||
SiteSetting.ai_openai_embeddings_url,
|
||||
{ model: model, input: content }.to_json,
|
||||
headers,
|
||||
)
|
||||
response = Faraday.post(SiteSetting.ai_openai_embeddings_url, payload.to_json, headers)
|
||||
|
||||
case response.status
|
||||
when 200
|
||||
|
|
|
@ -25,7 +25,8 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
|
|||
},
|
||||
).to_return(status: 200, body: body_json, headers: {})
|
||||
|
||||
result = DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello")
|
||||
result =
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello", model: "text-embedding-ada-002")
|
||||
|
||||
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
|
||||
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
|
||||
|
@ -42,15 +43,22 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do
|
|||
data: [{ object: "embedding", embedding: [0.0, 0.1] }],
|
||||
}.to_json
|
||||
|
||||
body = { model: "text-embedding-ada-002", input: "hello", dimensions: 1000 }.to_json
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/embeddings").with(
|
||||
body: "{\"model\":\"text-embedding-ada-002\",\"input\":\"hello\"}",
|
||||
body: body,
|
||||
headers: {
|
||||
"Authorization" => "Bearer 123456",
|
||||
"Content-Type" => "application/json",
|
||||
},
|
||||
).to_return(status: 200, body: body_json, headers: {})
|
||||
|
||||
result = DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello")
|
||||
result =
|
||||
DiscourseAi::Inference::OpenAiEmbeddings.perform!(
|
||||
"hello",
|
||||
model: "text-embedding-ada-002",
|
||||
dimensions: 1000,
|
||||
)
|
||||
|
||||
expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 })
|
||||
expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] })
|
||||
|
|
Loading…
Reference in New Issue