FEATURE: Per post embeddings (#387)
This commit is contained in:
parent
c3af27571b
commit
140359c2ef
|
@ -6,18 +6,20 @@ module Jobs
|
|||
|
||||
def execute(args)
|
||||
return unless SiteSetting.ai_embeddings_enabled
|
||||
return if (topic_id = args[:topic_id]).blank?
|
||||
return if args[:target_type].blank? || args[:target_id].blank?
|
||||
target = args[:target_type].constantize.find_by_id(args[:target_id])
|
||||
return if target.nil? || target.deleted_at.present?
|
||||
|
||||
topic = Topic.find_by_id(topic_id)
|
||||
return if topic.nil? || topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
|
||||
post = topic.first_post
|
||||
return if post.nil? || post.raw.blank?
|
||||
topic = target.is_a?(Topic) ? target : target.topic
|
||||
post = target.is_a?(Post) ? target : target.first_post
|
||||
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
|
||||
return if post.raw.blank?
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
|
||||
vector_rep.generate_topic_representation_from(topic)
|
||||
vector_rep.generate_representation_from(target)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -15,7 +15,7 @@ module Jobs
|
|||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
table_name = vector_rep.table_name
|
||||
table_name = vector_rep.topic_table_name
|
||||
|
||||
topics =
|
||||
Topic
|
||||
|
@ -28,7 +28,7 @@ module Jobs
|
|||
topics
|
||||
.where("#{table_name}.topic_id IS NULL")
|
||||
.find_each do |t|
|
||||
vector_rep.generate_topic_representation_from(t)
|
||||
vector_rep.generate_representation_from(t)
|
||||
rebaked += 1
|
||||
end
|
||||
|
||||
|
@ -45,7 +45,7 @@ module Jobs
|
|||
#{table_name}.strategy_version < #{strategy.version}
|
||||
SQL
|
||||
.find_each do |t|
|
||||
vector_rep.generate_topic_representation_from(t)
|
||||
vector_rep.generate_representation_from(t)
|
||||
rebaked += 1
|
||||
end
|
||||
|
||||
|
@ -59,7 +59,57 @@ module Jobs
|
|||
.limit((limit - rebaked) / 10)
|
||||
.pluck(:id)
|
||||
.each do |id|
|
||||
vector_rep.generate_topic_representation_from(Topic.find_by(id: id))
|
||||
vector_rep.generate_representation_from(Topic.find_by(id: id))
|
||||
rebaked += 1
|
||||
end
|
||||
|
||||
return if rebaked >= limit
|
||||
|
||||
# Now for posts
|
||||
table_name = vector_rep.post_table_name
|
||||
|
||||
posts =
|
||||
Post
|
||||
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
|
||||
.where(deleted_at: nil)
|
||||
.limit(limit - rebaked)
|
||||
|
||||
# First, we'll try to backfill embeddings for posts that have none
|
||||
posts
|
||||
.where("#{table_name}.post_id IS NULL")
|
||||
.find_each do |t|
|
||||
vector_rep.generate_representation_from(t)
|
||||
rebaked += 1
|
||||
end
|
||||
|
||||
vector_rep.consider_indexing
|
||||
|
||||
return if rebaked >= limit
|
||||
|
||||
# Then, we'll try to backfill embeddings for posts that have outdated
|
||||
# embeddings, be it model or strategy version
|
||||
posts
|
||||
.where(<<~SQL)
|
||||
#{table_name}.model_version < #{vector_rep.version}
|
||||
OR
|
||||
#{table_name}.strategy_version < #{strategy.version}
|
||||
SQL
|
||||
.find_each do |t|
|
||||
vector_rep.generate_representation_from(t)
|
||||
rebaked += 1
|
||||
end
|
||||
|
||||
return if rebaked >= limit
|
||||
|
||||
# Finally, we'll try to backfill embeddings for posts that have outdated
|
||||
# embeddings due to edits. Here we only do 10% of the limit
|
||||
posts
|
||||
.where("#{table_name}.updated_at < ?", 7.days.ago)
|
||||
.order("random()")
|
||||
.limit((limit - rebaked) / 10)
|
||||
.pluck(:id)
|
||||
.each do |id|
|
||||
vector_rep.generate_representation_from(Post.find_by(id: id))
|
||||
rebaked += 1
|
||||
end
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ class MigrateEmbeddingsFromDedicatedDatabase < ActiveRecord::Migration[7.0]
|
|||
].map { |k| k.new(truncation) }
|
||||
|
||||
vector_reps.each do |vector_rep|
|
||||
new_table_name = vector_rep.table_name
|
||||
new_table_name = vector_rep.topic_table_name
|
||||
old_table_name = "topic_embeddings_#{vector_rep.name.underscore}"
|
||||
|
||||
begin
|
||||
|
|
|
@ -0,0 +1,60 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class CreateAiPostEmbeddingsTables < ActiveRecord::Migration[7.0]
|
||||
def change
|
||||
create_table :ai_post_embeddings_1_1, id: false do |t|
|
||||
t.integer :post_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(768)", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :post_id, unique: true
|
||||
end
|
||||
|
||||
create_table :ai_post_embeddings_2_1, id: false do |t|
|
||||
t.integer :post_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(1536)", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :post_id, unique: true
|
||||
end
|
||||
|
||||
create_table :ai_post_embeddings_3_1, id: false do |t|
|
||||
t.integer :post_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(1024)", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :post_id, unique: true
|
||||
end
|
||||
|
||||
create_table :ai_post_embeddings_4_1, id: false do |t|
|
||||
t.integer :post_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(1024)", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :post_id, unique: true
|
||||
end
|
||||
|
||||
create_table :ai_post_embeddings_5_1, id: false do |t|
|
||||
t.integer :post_id, null: false
|
||||
t.integer :model_version, null: false
|
||||
t.integer :strategy_version, null: false
|
||||
t.text :digest, null: false
|
||||
t.column :embeddings, "vector(768)", null: false
|
||||
t.timestamps
|
||||
|
||||
t.index :post_id, unique: true
|
||||
end
|
||||
end
|
||||
end
|
|
@ -43,14 +43,20 @@ module DiscourseAi
|
|||
|
||||
# embeddings generation.
|
||||
callback =
|
||||
Proc.new do |topic|
|
||||
Proc.new do |target|
|
||||
if SiteSetting.ai_embeddings_enabled
|
||||
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
|
||||
Jobs.enqueue(
|
||||
:generate_embeddings,
|
||||
target_id: target.id,
|
||||
target_type: target.class.name,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
plugin.on(:topic_created, &callback)
|
||||
plugin.on(:topic_edited, &callback)
|
||||
plugin.on(:post_created, &callback)
|
||||
plugin.on(:post_edited, &callback)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -50,9 +50,9 @@ module DiscourseAi
|
|||
tokenizer.truncate(text, max_length)
|
||||
end
|
||||
|
||||
def post_truncation(topic, tokenizer, max_length)
|
||||
def post_truncation(post, tokenizer, max_length)
|
||||
text = +topic_information(post.topic)
|
||||
text << post.raw
|
||||
text << Nokogiri::HTML5.fragment(post.cooked).text
|
||||
|
||||
tokenizer.truncate(text, max_length)
|
||||
end
|
||||
|
|
|
@ -21,6 +21,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def consider_indexing(memory: "100MB")
|
||||
[topic_table_name, post_table_name].each do |table_name|
|
||||
index_name = index_name(table_name)
|
||||
# Using extension maintainer's recommendation for ivfflat indexes
|
||||
# Results are not as good as without indexes, but it's much faster
|
||||
# Disk usage is ~1x the size of the table, so this doubles table total size
|
||||
|
@ -40,7 +42,7 @@ module DiscourseAi
|
|||
|
||||
if !existing_index.present?
|
||||
Rails.logger.info("Index #{index_name} does not exist, creating...")
|
||||
return create_index!(memory, lists, probes)
|
||||
return create_index!(table_name, memory, lists, probes)
|
||||
end
|
||||
|
||||
existing_index_age =
|
||||
|
@ -62,12 +64,12 @@ module DiscourseAi
|
|||
Rails.logger.info(
|
||||
"Index #{index_name} is #{existing_index_age} seconds old, and there are #{new_rows} new rows, updating...",
|
||||
)
|
||||
return create_index!(memory, lists, probes)
|
||||
return create_index!(table_name, memory, lists, probes)
|
||||
elsif existing_lists != lists
|
||||
Rails.logger.info(
|
||||
"Index #{index_name} already exists, but lists is #{existing_lists} instead of #{lists}, updating...",
|
||||
)
|
||||
return create_index!(memory, lists, probes)
|
||||
return create_index!(table_name, memory, lists, probes)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -75,8 +77,10 @@ module DiscourseAi
|
|||
"Index #{index_name} kept. #{Time.now.to_i - existing_index_age} seconds old, #{new_rows} new rows, #{existing_lists} lists, #{probes} probes.",
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def create_index!(memory, lists, probes)
|
||||
def create_index!(table_name, memory, lists, probes)
|
||||
index_name = index_name(table_name)
|
||||
DB.exec("SET work_mem TO '#{memory}';")
|
||||
DB.exec("SET maintenance_work_mem TO '#{memory}';")
|
||||
DB.exec(<<~SQL)
|
||||
|
@ -102,17 +106,17 @@ module DiscourseAi
|
|||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def generate_topic_representation_from(target, persist: true)
|
||||
def generate_representation_from(target, persist: true)
|
||||
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
|
||||
|
||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||
current_digest = DB.query_single(<<~SQL, topic_id: target.id).first
|
||||
current_digest = DB.query_single(<<~SQL, target_id: target.id).first
|
||||
SELECT
|
||||
digest
|
||||
FROM
|
||||
#{table_name}
|
||||
#{table_name(target)}
|
||||
WHERE
|
||||
topic_id = :topic_id
|
||||
#{target.is_a?(Topic) ? "topic_id" : "post_id"} = :target_id
|
||||
LIMIT 1
|
||||
SQL
|
||||
return if current_digest == new_digest
|
||||
|
@ -127,7 +131,19 @@ module DiscourseAi
|
|||
SELECT
|
||||
topic_id
|
||||
FROM
|
||||
#{table_name}
|
||||
#{topic_table_name}
|
||||
ORDER BY
|
||||
embeddings #{pg_function} '[:query_embedding]'
|
||||
LIMIT 1
|
||||
SQL
|
||||
end
|
||||
|
||||
def post_id_from_representation(raw_vector)
|
||||
DB.query_single(<<~SQL, query_embedding: raw_vector).first
|
||||
SELECT
|
||||
post_id
|
||||
FROM
|
||||
#{post_table_name}
|
||||
ORDER BY
|
||||
embeddings #{pg_function} '[:query_embedding]'
|
||||
LIMIT 1
|
||||
|
@ -140,7 +156,7 @@ module DiscourseAi
|
|||
topic_id,
|
||||
embeddings #{pg_function} '[:query_embedding]' AS distance
|
||||
FROM
|
||||
#{table_name}
|
||||
#{topic_table_name}
|
||||
ORDER BY
|
||||
embeddings #{pg_function} '[:query_embedding]'
|
||||
LIMIT :limit
|
||||
|
@ -162,13 +178,13 @@ module DiscourseAi
|
|||
SELECT
|
||||
topic_id
|
||||
FROM
|
||||
#{table_name}
|
||||
#{topic_table_name}
|
||||
ORDER BY
|
||||
embeddings #{pg_function} (
|
||||
SELECT
|
||||
embeddings
|
||||
FROM
|
||||
#{table_name}
|
||||
#{topic_table_name}
|
||||
WHERE
|
||||
topic_id = :topic_id
|
||||
LIMIT 1
|
||||
|
@ -182,11 +198,26 @@ module DiscourseAi
|
|||
raise MissingEmbeddingError
|
||||
end
|
||||
|
||||
def table_name
|
||||
def topic_table_name
|
||||
"ai_topic_embeddings_#{id}_#{@strategy.id}"
|
||||
end
|
||||
|
||||
def index_name
|
||||
def post_table_name
|
||||
"ai_post_embeddings_#{id}_#{@strategy.id}"
|
||||
end
|
||||
|
||||
def table_name(target)
|
||||
case target
|
||||
when Topic
|
||||
topic_table_name
|
||||
when Post
|
||||
post_table_name
|
||||
else
|
||||
raise ArgumentError, "Invalid target type"
|
||||
end
|
||||
end
|
||||
|
||||
def index_name(table_name)
|
||||
"#{table_name}_search"
|
||||
end
|
||||
|
||||
|
@ -221,9 +252,10 @@ module DiscourseAi
|
|||
protected
|
||||
|
||||
def save_to_db(target, vector, digest)
|
||||
if target.is_a?(Topic)
|
||||
DB.exec(
|
||||
<<~SQL,
|
||||
INSERT INTO #{table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
||||
INSERT INTO #{topic_table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
||||
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT (topic_id)
|
||||
DO UPDATE SET
|
||||
|
@ -239,6 +271,28 @@ module DiscourseAi
|
|||
digest: digest,
|
||||
embeddings: vector,
|
||||
)
|
||||
elsif target.is_a?(Post)
|
||||
DB.exec(
|
||||
<<~SQL,
|
||||
INSERT INTO #{post_table_name} (post_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
||||
VALUES (:post_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||
ON CONFLICT (post_id)
|
||||
DO UPDATE SET
|
||||
model_version = :model_version,
|
||||
strategy_version = :strategy_version,
|
||||
digest = :digest,
|
||||
embeddings = '[:embeddings]',
|
||||
updated_at = CURRENT_TIMESTAMP
|
||||
SQL
|
||||
post_id: target.id,
|
||||
model_version: version,
|
||||
strategy_version: @strategy.version,
|
||||
digest: digest,
|
||||
embeddings: vector,
|
||||
)
|
||||
else
|
||||
raise ArgumentError, "Invalid target type"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,23 +1,33 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
desc "Backfill embeddings for all topics"
|
||||
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args|
|
||||
desc "Backfill embeddings for all topics and posts"
|
||||
task "ai:embeddings:backfill" => [:environment] do
|
||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||
|
||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||
table_name = vector_rep.table_name
|
||||
table_name = vector_rep.topic_table_name
|
||||
|
||||
Topic
|
||||
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
||||
.where("#{table_name}.topic_id IS NULL")
|
||||
.where("topics.id >= ?", args[:start_topic].to_i || 0)
|
||||
.where("category_id IN (?)", public_categories)
|
||||
.where(deleted_at: nil)
|
||||
.order("topics.id ASC")
|
||||
.order("topics.id DESC")
|
||||
.find_each do |t|
|
||||
print "."
|
||||
vector_rep.generate_topic_representation_from(t)
|
||||
vector_rep.generate_representation_from(t)
|
||||
end
|
||||
|
||||
table_name = vector_rep.post_table_name
|
||||
Post
|
||||
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
|
||||
.where("#{table_name}.post_id IS NULL")
|
||||
.where(deleted_at: nil)
|
||||
.order("posts.id DESC")
|
||||
.find_each do |t|
|
||||
print "."
|
||||
vector_rep.generate_representation_from(t)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -18,7 +18,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
|||
it "queues a job on create if embeddings is enabled" do
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
|
||||
expect { creator.create }.to change(Jobs::GenerateEmbeddings.jobs, :size).by(1)
|
||||
expect { creator.create }.to change(Jobs::GenerateEmbeddings.jobs, :size).by(2) # topic_created and post_created
|
||||
end
|
||||
|
||||
it "does nothing if sentiment analysis is disabled" do
|
||||
|
|
|
@ -18,7 +18,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
|||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
||||
end
|
||||
|
||||
it "works" do
|
||||
it "works for topics" do
|
||||
expected_embedding = [0.0038493] * vector_rep.dimensions
|
||||
|
||||
text =
|
||||
|
@ -29,9 +29,21 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
|||
)
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
|
||||
|
||||
job.execute(topic_id: topic.id)
|
||||
job.execute(target_id: topic.id, target_type: "Topic")
|
||||
|
||||
expect(vector_rep.topic_id_from_representation(expected_embedding)).to eq(topic.id)
|
||||
end
|
||||
|
||||
it "works for posts" do
|
||||
expected_embedding = [0.0038493] * vector_rep.dimensions
|
||||
|
||||
text =
|
||||
truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2)
|
||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
|
||||
|
||||
job.execute(target_id: post.id, target_type: "Post")
|
||||
|
||||
expect(vector_rep.post_id_from_representation(expected_embedding)).to eq(post.id)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -12,9 +12,10 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
|||
end
|
||||
end
|
||||
|
||||
describe "#generate_topic_representation_from" do
|
||||
describe "#generate_representation_from" do
|
||||
fab!(:topic) { Fabricate(:topic) }
|
||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
|
||||
|
||||
it "creates a vector from a topic and stores it in the database" do
|
||||
text =
|
||||
|
@ -25,10 +26,24 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
|||
)
|
||||
stub_vector_mapping(text, @expected_embedding)
|
||||
|
||||
vector_rep.generate_topic_representation_from(topic)
|
||||
vector_rep.generate_representation_from(topic)
|
||||
|
||||
expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
|
||||
end
|
||||
|
||||
it "creates a vector from a post and stores it in the database" do
|
||||
text =
|
||||
truncation.prepare_text_from(
|
||||
post2,
|
||||
vector_rep.tokenizer,
|
||||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
stub_vector_mapping(text, @expected_embedding)
|
||||
|
||||
vector_rep.generate_representation_from(post)
|
||||
|
||||
expect(vector_rep.post_id_from_representation(@expected_embedding)).to eq(post.id)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#asymmetric_topics_similarity_search" do
|
||||
|
@ -44,7 +59,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
|||
vector_rep.max_sequence_length - 2,
|
||||
)
|
||||
stub_vector_mapping(text, @expected_embedding)
|
||||
vector_rep.generate_topic_representation_from(topic)
|
||||
vector_rep.generate_representation_from(topic)
|
||||
|
||||
expect(
|
||||
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
|
||||
|
|
Loading…
Reference in New Issue