mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-07-07 23:12:34 +00:00
FEATURE: Seamless embedding model upgrades (#1486)
This commit is contained in:
parent
ab5edae121
commit
6247906c13
@ -18,105 +18,115 @@ module Jobs
|
|||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
rebaked = 0
|
production_vector = DiscourseAi::Embeddings::Vector.instance
|
||||||
|
|
||||||
vector = DiscourseAi::Embeddings::Vector.instance
|
if SiteSetting.ai_embeddings_backfill_model.present? &&
|
||||||
vector_def = vector.vdef
|
SiteSetting.ai_embeddings_backfill_model != SiteSetting.ai_embeddings_selected_model
|
||||||
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
|
backfill_vector =
|
||||||
|
DiscourseAi::Embeddings::Vector.new(
|
||||||
topics =
|
EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_backfill_model),
|
||||||
Topic
|
|
||||||
.joins(
|
|
||||||
"LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id AND #{table_name}.model_id = #{vector_def.id}",
|
|
||||||
)
|
)
|
||||||
.where(archetype: Archetype.default)
|
end
|
||||||
.where(deleted_at: nil)
|
|
||||||
.order("topics.bumped_at DESC")
|
|
||||||
|
|
||||||
rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))
|
topic_work_list = []
|
||||||
|
topic_work_list << production_vector
|
||||||
|
topic_work_list << backfill_vector if backfill_vector
|
||||||
|
|
||||||
return if rebaked >= limit
|
topic_work_list.each do |vector|
|
||||||
|
rebaked = 0
|
||||||
|
table_name = DiscourseAi::Embeddings::Schema::TOPICS_TABLE
|
||||||
|
vector_def = vector.vdef
|
||||||
|
|
||||||
# Then, we'll try to backfill embeddings for topics that have outdated
|
topics =
|
||||||
# embeddings, be it model or strategy version
|
Topic
|
||||||
relation = topics.where(<<~SQL).limit(limit - rebaked)
|
.joins(
|
||||||
#{table_name}.model_version < #{vector_def.version}
|
"LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id AND #{table_name}.model_id = #{vector_def.id}",
|
||||||
OR
|
)
|
||||||
#{table_name}.strategy_version < #{vector_def.strategy_version}
|
.where(archetype: Archetype.default)
|
||||||
SQL
|
.where(deleted_at: nil)
|
||||||
|
.order("topics.bumped_at DESC")
|
||||||
|
|
||||||
rebaked += populate_topic_embeddings(vector, relation, force: true)
|
rebaked += populate_topic_embeddings(vector, topics.limit(limit - rebaked))
|
||||||
|
|
||||||
return if rebaked >= limit
|
next if rebaked >= limit
|
||||||
|
|
||||||
# Finally, we'll try to backfill embeddings for topics that have outdated
|
# Then, we'll try to backfill embeddings for topics that have outdated
|
||||||
# embeddings due to edits or new replies. Here we only do 10% of the limit
|
# embeddings, be it model or strategy version
|
||||||
relation =
|
relation = topics.where(<<~SQL).limit(limit - rebaked)
|
||||||
topics
|
#{table_name}.model_version < #{vector_def.version}
|
||||||
.where("#{table_name}.updated_at < ?", 6.hours.ago)
|
OR
|
||||||
.where("#{table_name}.updated_at < topics.updated_at")
|
#{table_name}.strategy_version < #{vector_def.strategy_version}
|
||||||
|
SQL
|
||||||
|
|
||||||
|
rebaked += populate_topic_embeddings(vector, relation, force: true)
|
||||||
|
|
||||||
|
next if rebaked >= limit
|
||||||
|
|
||||||
|
# Finally, we'll try to backfill embeddings for topics that have outdated
|
||||||
|
# embeddings due to edits or new replies. Here we only do 10% of the limit
|
||||||
|
relation =
|
||||||
|
topics
|
||||||
|
.where("#{table_name}.updated_at < ?", 6.hours.ago)
|
||||||
|
.where("#{table_name}.updated_at < topics.updated_at")
|
||||||
|
.limit((limit - rebaked) / 10)
|
||||||
|
|
||||||
|
populate_topic_embeddings(vector, relation, force: true)
|
||||||
|
|
||||||
|
next unless SiteSetting.ai_embeddings_per_post_enabled
|
||||||
|
|
||||||
|
# Now for posts
|
||||||
|
table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE
|
||||||
|
posts_batch_size = 1000
|
||||||
|
|
||||||
|
posts =
|
||||||
|
Post
|
||||||
|
.joins(
|
||||||
|
"LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id AND #{table_name}.model_id = #{vector_def.id}",
|
||||||
|
)
|
||||||
|
.where(deleted_at: nil)
|
||||||
|
.where(post_type: Post.types[:regular])
|
||||||
|
|
||||||
|
# First, we'll try to backfill embeddings for posts that have none
|
||||||
|
posts
|
||||||
|
.where("#{table_name}.post_id IS NULL")
|
||||||
|
.limit(limit - rebaked)
|
||||||
|
.pluck(:id)
|
||||||
|
.each_slice(posts_batch_size) do |batch|
|
||||||
|
vector.gen_bulk_reprensentations(Post.where(id: batch))
|
||||||
|
rebaked += batch.length
|
||||||
|
end
|
||||||
|
|
||||||
|
next if rebaked >= limit
|
||||||
|
|
||||||
|
# Then, we'll try to backfill embeddings for posts that have outdated
|
||||||
|
# embeddings, be it model or strategy version
|
||||||
|
posts
|
||||||
|
.where(<<~SQL)
|
||||||
|
#{table_name}.model_version < #{vector_def.version}
|
||||||
|
OR
|
||||||
|
#{table_name}.strategy_version < #{vector_def.strategy_version}
|
||||||
|
SQL
|
||||||
|
.limit(limit - rebaked)
|
||||||
|
.pluck(:id)
|
||||||
|
.each_slice(posts_batch_size) do |batch|
|
||||||
|
vector.gen_bulk_reprensentations(Post.where(id: batch))
|
||||||
|
rebaked += batch.length
|
||||||
|
end
|
||||||
|
|
||||||
|
next if rebaked >= limit
|
||||||
|
|
||||||
|
# Finally, we'll try to backfill embeddings for posts that have outdated
|
||||||
|
# embeddings due to edits. Here we only do 10% of the limit
|
||||||
|
posts
|
||||||
|
.where("#{table_name}.updated_at < ?", 7.days.ago)
|
||||||
|
.order("random()")
|
||||||
.limit((limit - rebaked) / 10)
|
.limit((limit - rebaked) / 10)
|
||||||
|
.pluck(:id)
|
||||||
populate_topic_embeddings(vector, relation, force: true)
|
.each_slice(posts_batch_size) do |batch|
|
||||||
|
vector.gen_bulk_reprensentations(Post.where(id: batch))
|
||||||
return if rebaked >= limit
|
rebaked += batch.length
|
||||||
|
end
|
||||||
return unless SiteSetting.ai_embeddings_per_post_enabled
|
end
|
||||||
|
|
||||||
# Now for posts
|
|
||||||
table_name = DiscourseAi::Embeddings::Schema::POSTS_TABLE
|
|
||||||
posts_batch_size = 1000
|
|
||||||
|
|
||||||
posts =
|
|
||||||
Post
|
|
||||||
.joins(
|
|
||||||
"LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id AND #{table_name}.model_id = #{vector_def.id}",
|
|
||||||
)
|
|
||||||
.where(deleted_at: nil)
|
|
||||||
.where(post_type: Post.types[:regular])
|
|
||||||
|
|
||||||
# First, we'll try to backfill embeddings for posts that have none
|
|
||||||
posts
|
|
||||||
.where("#{table_name}.post_id IS NULL")
|
|
||||||
.limit(limit - rebaked)
|
|
||||||
.pluck(:id)
|
|
||||||
.each_slice(posts_batch_size) do |batch|
|
|
||||||
vector.gen_bulk_reprensentations(Post.where(id: batch))
|
|
||||||
rebaked += batch.length
|
|
||||||
end
|
|
||||||
|
|
||||||
return if rebaked >= limit
|
|
||||||
|
|
||||||
# Then, we'll try to backfill embeddings for posts that have outdated
|
|
||||||
# embeddings, be it model or strategy version
|
|
||||||
posts
|
|
||||||
.where(<<~SQL)
|
|
||||||
#{table_name}.model_version < #{vector_def.version}
|
|
||||||
OR
|
|
||||||
#{table_name}.strategy_version < #{vector_def.strategy_version}
|
|
||||||
SQL
|
|
||||||
.limit(limit - rebaked)
|
|
||||||
.pluck(:id)
|
|
||||||
.each_slice(posts_batch_size) do |batch|
|
|
||||||
vector.gen_bulk_reprensentations(Post.where(id: batch))
|
|
||||||
rebaked += batch.length
|
|
||||||
end
|
|
||||||
|
|
||||||
return if rebaked >= limit
|
|
||||||
|
|
||||||
# Finally, we'll try to backfill embeddings for posts that have outdated
|
|
||||||
# embeddings due to edits. Here we only do 10% of the limit
|
|
||||||
posts
|
|
||||||
.where("#{table_name}.updated_at < ?", 7.days.ago)
|
|
||||||
.order("random()")
|
|
||||||
.limit((limit - rebaked) / 10)
|
|
||||||
.pluck(:id)
|
|
||||||
.each_slice(posts_batch_size) do |batch|
|
|
||||||
vector.gen_bulk_reprensentations(Post.where(id: batch))
|
|
||||||
rebaked += batch.length
|
|
||||||
end
|
|
||||||
|
|
||||||
rebaked
|
|
||||||
end
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
@ -230,20 +230,26 @@ discourse_ai:
|
|||||||
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
|
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
|
||||||
validator: "DiscourseAi::Configuration::EmbeddingDefsValidator"
|
validator: "DiscourseAi::Configuration::EmbeddingDefsValidator"
|
||||||
area: "ai-features/embeddings"
|
area: "ai-features/embeddings"
|
||||||
|
ai_embeddings_backfill_model:
|
||||||
|
type: enum
|
||||||
|
default: ""
|
||||||
|
allow_any: false
|
||||||
|
enum: "DiscourseAi::Configuration::EmbeddingDefsEnumerator"
|
||||||
|
hidden: true
|
||||||
ai_embeddings_per_post_enabled:
|
ai_embeddings_per_post_enabled:
|
||||||
default: false
|
default: false
|
||||||
hidden: true
|
hidden: true
|
||||||
ai_embeddings_generate_for_pms:
|
ai_embeddings_generate_for_pms:
|
||||||
default: false
|
default: false
|
||||||
area: "ai-features/embeddings"
|
area: "ai-features/embeddings"
|
||||||
ai_embeddings_semantic_related_topics_enabled:
|
ai_embeddings_semantic_related_topics_enabled:
|
||||||
default: false
|
default: false
|
||||||
client: true
|
client: true
|
||||||
area: "ai-features/embeddings"
|
area: "ai-features/embeddings"
|
||||||
ai_embeddings_semantic_related_topics:
|
ai_embeddings_semantic_related_topics:
|
||||||
default: 5
|
default: 5
|
||||||
area: "ai-features/embeddings"
|
area: "ai-features/embeddings"
|
||||||
ai_embeddings_semantic_related_include_closed_topics:
|
ai_embeddings_semantic_related_include_closed_topics:
|
||||||
default: true
|
default: true
|
||||||
area: "ai-features/embeddings"
|
area: "ai-features/embeddings"
|
||||||
ai_embeddings_backfill_batch_size:
|
ai_embeddings_backfill_batch_size:
|
||||||
|
@ -20,8 +20,11 @@ module DiscourseAi
|
|||||||
MissingEmbeddingError = Class.new(StandardError)
|
MissingEmbeddingError = Class.new(StandardError)
|
||||||
|
|
||||||
class << self
|
class << self
|
||||||
def for(target_klass)
|
def for(target_klass, vector_def: nil)
|
||||||
vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model)
|
vector_def =
|
||||||
|
EmbeddingDefinition.find_by(
|
||||||
|
id: SiteSetting.ai_embeddings_selected_model,
|
||||||
|
) if vector_def.nil?
|
||||||
raise "Invalid embeddings selected model" if vector_def.nil?
|
raise "Invalid embeddings selected model" if vector_def.nil?
|
||||||
|
|
||||||
case target_klass&.name
|
case target_klass&.name
|
||||||
|
@ -3,6 +3,8 @@
|
|||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Embeddings
|
module Embeddings
|
||||||
class SemanticRelated
|
class SemanticRelated
|
||||||
|
CACHE_PREFIX = "semantic-suggested-topic-"
|
||||||
|
|
||||||
def self.clear_cache_for(topic)
|
def self.clear_cache_for(topic)
|
||||||
Discourse.cache.delete("semantic-suggested-topic-#{topic.id}")
|
Discourse.cache.delete("semantic-suggested-topic-#{topic.id}")
|
||||||
Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}")
|
Discourse.redis.del("build-semantic-suggested-topic-#{topic.id}")
|
||||||
@ -79,14 +81,21 @@ module DiscourseAi
|
|||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def self.clear_cache!
|
||||||
|
Discourse
|
||||||
|
.cache
|
||||||
|
.keys("#{CACHE_PREFIX}*")
|
||||||
|
.each { |key| Discourse.cache.delete(key.split(":").last) }
|
||||||
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
def semantic_suggested_key(topic_id)
|
def semantic_suggested_key(topic_id)
|
||||||
"semantic-suggested-topic-#{topic_id}"
|
"#{CACHE_PREFIX}#{topic_id}"
|
||||||
end
|
end
|
||||||
|
|
||||||
def build_semantic_suggested_key(topic_id)
|
def build_semantic_suggested_key(topic_id)
|
||||||
"build-semantic-suggested-topic-#{topic_id}"
|
"build-#{CACHE_PREFIX}#{topic_id}"
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -25,7 +25,7 @@ module DiscourseAi
|
|||||||
idletime: 30,
|
idletime: 30,
|
||||||
)
|
)
|
||||||
|
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class)
|
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector_def: @vdef)
|
||||||
|
|
||||||
embedding_gen = vdef.inference_client
|
embedding_gen = vdef.inference_client
|
||||||
promised_embeddings =
|
promised_embeddings =
|
||||||
@ -58,7 +58,7 @@ module DiscourseAi
|
|||||||
text = vdef.prepare_target_text(target)
|
text = vdef.prepare_target_text(target)
|
||||||
return if text.blank?
|
return if text.blank?
|
||||||
|
|
||||||
schema = DiscourseAi::Embeddings::Schema.for(target.class)
|
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector_def: @vdef)
|
||||||
|
|
||||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||||
return if schema.find_by_target(target)&.digest == new_digest
|
return if schema.find_by_target(target)&.digest == new_digest
|
||||||
|
@ -20,6 +20,8 @@ RSpec.describe Jobs::EmbeddingsBackfill do
|
|||||||
end
|
end
|
||||||
|
|
||||||
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
fab!(:vector_def) { Fabricate(:embedding_definition) }
|
||||||
|
fab!(:vector_def2) { Fabricate(:embedding_definition) }
|
||||||
|
fab!(:embedding_array) { Array.new(1024) { 1 } }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
SiteSetting.ai_embeddings_selected_model = vector_def.id
|
||||||
@ -27,16 +29,14 @@ RSpec.describe Jobs::EmbeddingsBackfill do
|
|||||||
SiteSetting.ai_embeddings_backfill_batch_size = 1
|
SiteSetting.ai_embeddings_backfill_batch_size = 1
|
||||||
SiteSetting.ai_embeddings_per_post_enabled = true
|
SiteSetting.ai_embeddings_per_post_enabled = true
|
||||||
Jobs.run_immediately!
|
Jobs.run_immediately!
|
||||||
end
|
|
||||||
|
|
||||||
it "backfills topics based on bumped_at date" do
|
|
||||||
embedding = Array.new(1024) { 1 }
|
|
||||||
|
|
||||||
WebMock.stub_request(:post, "https://test.com/embeddings").to_return(
|
WebMock.stub_request(:post, "https://test.com/embeddings").to_return(
|
||||||
status: 200,
|
status: 200,
|
||||||
body: JSON.dump(embedding),
|
body: JSON.dump(embedding_array),
|
||||||
)
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "backfills topics based on bumped_at date" do
|
||||||
Jobs::EmbeddingsBackfill.new.execute({})
|
Jobs::EmbeddingsBackfill.new.execute({})
|
||||||
|
|
||||||
topic_ids =
|
topic_ids =
|
||||||
@ -68,4 +68,19 @@ RSpec.describe Jobs::EmbeddingsBackfill do
|
|||||||
|
|
||||||
expect(index_date).to be_within_one_second_of(Time.zone.now)
|
expect(index_date).to be_within_one_second_of(Time.zone.now)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "backfills embeddings for the ai_embeddings_backfill_model" do
|
||||||
|
SiteSetting.ai_embeddings_backfill_model = vector_def2.id
|
||||||
|
SiteSetting.ai_embeddings_backfill_batch_size = 100
|
||||||
|
|
||||||
|
Jobs::EmbeddingsBackfill.new.execute({})
|
||||||
|
|
||||||
|
topic_ids =
|
||||||
|
DB.query_single(
|
||||||
|
"SELECT topic_id from #{DiscourseAi::Embeddings::Schema::TOPICS_TABLE} WHERE model_id = ?",
|
||||||
|
vector_def2.id,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(topic_ids).to contain_exactly(first_topic.id, second_topic.id, third_topic.id)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
x
Reference in New Issue
Block a user