FEATURE: Per post embeddings (#387)

This commit is contained in:
Rafael dos Santos Silva 2023-12-29 12:28:45 -03:00 committed by GitHub
parent c3af27571b
commit 140359c2ef
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
11 changed files with 314 additions and 105 deletions

View File

@ -6,18 +6,20 @@ module Jobs
def execute(args) def execute(args)
return unless SiteSetting.ai_embeddings_enabled return unless SiteSetting.ai_embeddings_enabled
return if (topic_id = args[:topic_id]).blank? return if args[:target_type].blank? || args[:target_id].blank?
target = args[:target_type].constantize.find_by_id(args[:target_id])
return if target.nil? || target.deleted_at.present?
topic = Topic.find_by_id(topic_id) topic = target.is_a?(Topic) ? target : target.topic
return if topic.nil? || topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms post = target.is_a?(Post) ? target : target.first_post
post = topic.first_post return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
return if post.nil? || post.raw.blank? return if post.raw.blank?
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep = vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
vector_rep.generate_topic_representation_from(topic) vector_rep.generate_representation_from(target)
end end
end end
end end

View File

@ -15,7 +15,7 @@ module Jobs
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep = vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
table_name = vector_rep.table_name table_name = vector_rep.topic_table_name
topics = topics =
Topic Topic
@ -28,7 +28,7 @@ module Jobs
topics topics
.where("#{table_name}.topic_id IS NULL") .where("#{table_name}.topic_id IS NULL")
.find_each do |t| .find_each do |t|
vector_rep.generate_topic_representation_from(t) vector_rep.generate_representation_from(t)
rebaked += 1 rebaked += 1
end end
@ -45,7 +45,7 @@ module Jobs
#{table_name}.strategy_version < #{strategy.version} #{table_name}.strategy_version < #{strategy.version}
SQL SQL
.find_each do |t| .find_each do |t|
vector_rep.generate_topic_representation_from(t) vector_rep.generate_representation_from(t)
rebaked += 1 rebaked += 1
end end
@ -59,7 +59,57 @@ module Jobs
.limit((limit - rebaked) / 10) .limit((limit - rebaked) / 10)
.pluck(:id) .pluck(:id)
.each do |id| .each do |id|
vector_rep.generate_topic_representation_from(Topic.find_by(id: id)) vector_rep.generate_representation_from(Topic.find_by(id: id))
rebaked += 1
end
return if rebaked >= limit
# Now for posts
table_name = vector_rep.post_table_name
posts =
Post
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
.where(deleted_at: nil)
.limit(limit - rebaked)
# First, we'll try to backfill embeddings for posts that have none
posts
.where("#{table_name}.post_id IS NULL")
.find_each do |t|
vector_rep.generate_representation_from(t)
rebaked += 1
end
vector_rep.consider_indexing
return if rebaked >= limit
# Then, we'll try to backfill embeddings for posts that have outdated
# embeddings, be it model or strategy version
posts
.where(<<~SQL)
#{table_name}.model_version < #{vector_rep.version}
OR
#{table_name}.strategy_version < #{strategy.version}
SQL
.find_each do |t|
vector_rep.generate_representation_from(t)
rebaked += 1
end
return if rebaked >= limit
# Finally, we'll try to backfill embeddings for posts that have outdated
# embeddings due to edits. Here we only do 10% of the limit
posts
.where("#{table_name}.updated_at < ?", 7.days.ago)
.order("random()")
.limit((limit - rebaked) / 10)
.pluck(:id)
.each do |id|
vector_rep.generate_representation_from(Post.find_by(id: id))
rebaked += 1 rebaked += 1
end end

View File

@ -14,7 +14,7 @@ class MigrateEmbeddingsFromDedicatedDatabase < ActiveRecord::Migration[7.0]
].map { |k| k.new(truncation) } ].map { |k| k.new(truncation) }
vector_reps.each do |vector_rep| vector_reps.each do |vector_rep|
new_table_name = vector_rep.table_name new_table_name = vector_rep.topic_table_name
old_table_name = "topic_embeddings_#{vector_rep.name.underscore}" old_table_name = "topic_embeddings_#{vector_rep.name.underscore}"
begin begin

View File

@ -0,0 +1,60 @@
# frozen_string_literal: true
class CreateAiPostEmbeddingsTables < ActiveRecord::Migration[7.0]
def change
create_table :ai_post_embeddings_1_1, id: false do |t|
t.integer :post_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(768)", null: false
t.timestamps
t.index :post_id, unique: true
end
create_table :ai_post_embeddings_2_1, id: false do |t|
t.integer :post_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(1536)", null: false
t.timestamps
t.index :post_id, unique: true
end
create_table :ai_post_embeddings_3_1, id: false do |t|
t.integer :post_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(1024)", null: false
t.timestamps
t.index :post_id, unique: true
end
create_table :ai_post_embeddings_4_1, id: false do |t|
t.integer :post_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(1024)", null: false
t.timestamps
t.index :post_id, unique: true
end
create_table :ai_post_embeddings_5_1, id: false do |t|
t.integer :post_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(768)", null: false
t.timestamps
t.index :post_id, unique: true
end
end
end

View File

@ -43,14 +43,20 @@ module DiscourseAi
# embeddings generation. # embeddings generation.
callback = callback =
Proc.new do |topic| Proc.new do |target|
if SiteSetting.ai_embeddings_enabled if SiteSetting.ai_embeddings_enabled
Jobs.enqueue(:generate_embeddings, topic_id: topic.id) Jobs.enqueue(
:generate_embeddings,
target_id: target.id,
target_type: target.class.name,
)
end end
end end
plugin.on(:topic_created, &callback) plugin.on(:topic_created, &callback)
plugin.on(:topic_edited, &callback) plugin.on(:topic_edited, &callback)
plugin.on(:post_created, &callback)
plugin.on(:post_edited, &callback)
end end
end end
end end

View File

@ -50,9 +50,9 @@ module DiscourseAi
tokenizer.truncate(text, max_length) tokenizer.truncate(text, max_length)
end end
def post_truncation(topic, tokenizer, max_length) def post_truncation(post, tokenizer, max_length)
text = +topic_information(post.topic) text = +topic_information(post.topic)
text << post.raw text << Nokogiri::HTML5.fragment(post.cooked).text
tokenizer.truncate(text, max_length) tokenizer.truncate(text, max_length)
end end

View File

@ -21,62 +21,66 @@ module DiscourseAi
end end
def consider_indexing(memory: "100MB") def consider_indexing(memory: "100MB")
# Using extension maintainer's recommendation for ivfflat indexes [topic_table_name, post_table_name].each do |table_name|
# Results are not as good as without indexes, but it's much faster index_name = index_name(table_name)
# Disk usage is ~1x the size of the table, so this doubles table total size # Using extension maintainer's recommendation for ivfflat indexes
count = DB.query_single("SELECT count(*) FROM #{table_name};").first # Results are not as good as without indexes, but it's much faster
lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max # Disk usage is ~1x the size of the table, so this doubles table total size
probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max count = DB.query_single("SELECT count(*) FROM #{table_name};").first
lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max
probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max
existing_index = DB.query_single(<<~SQL, index_name: index_name).first existing_index = DB.query_single(<<~SQL, index_name: index_name).first
SELECT SELECT
indexdef indexdef
FROM FROM
pg_indexes pg_indexes
WHERE WHERE
indexname = :index_name indexname = :index_name
LIMIT 1 LIMIT 1
SQL SQL
if !existing_index.present? if !existing_index.present?
Rails.logger.info("Index #{index_name} does not exist, creating...") Rails.logger.info("Index #{index_name} does not exist, creating...")
return create_index!(memory, lists, probes) return create_index!(table_name, memory, lists, probes)
end
existing_index_age =
DB
.query_single(
"SELECT pg_catalog.obj_description((:index_name)::regclass, 'pg_class');",
index_name: index_name,
)
.first
.to_i || 0
new_rows =
DB.query_single(
"SELECT count(*) FROM #{table_name} WHERE created_at > '#{Time.at(existing_index_age)}';",
).first
existing_lists = existing_index.match(/lists='(\d+)'/)&.captures&.first&.to_i
if existing_index_age > 0 && existing_index_age < 1.hour.ago.to_i
if new_rows > 10_000
Rails.logger.info(
"Index #{index_name} is #{existing_index_age} seconds old, and there are #{new_rows} new rows, updating...",
)
return create_index!(memory, lists, probes)
elsif existing_lists != lists
Rails.logger.info(
"Index #{index_name} already exists, but lists is #{existing_lists} instead of #{lists}, updating...",
)
return create_index!(memory, lists, probes)
end end
end
Rails.logger.info( existing_index_age =
"Index #{index_name} kept. #{Time.now.to_i - existing_index_age} seconds old, #{new_rows} new rows, #{existing_lists} lists, #{probes} probes.", DB
) .query_single(
"SELECT pg_catalog.obj_description((:index_name)::regclass, 'pg_class');",
index_name: index_name,
)
.first
.to_i || 0
new_rows =
DB.query_single(
"SELECT count(*) FROM #{table_name} WHERE created_at > '#{Time.at(existing_index_age)}';",
).first
existing_lists = existing_index.match(/lists='(\d+)'/)&.captures&.first&.to_i
if existing_index_age > 0 && existing_index_age < 1.hour.ago.to_i
if new_rows > 10_000
Rails.logger.info(
"Index #{index_name} is #{existing_index_age} seconds old, and there are #{new_rows} new rows, updating...",
)
return create_index!(table_name, memory, lists, probes)
elsif existing_lists != lists
Rails.logger.info(
"Index #{index_name} already exists, but lists is #{existing_lists} instead of #{lists}, updating...",
)
return create_index!(table_name, memory, lists, probes)
end
end
Rails.logger.info(
"Index #{index_name} kept. #{Time.now.to_i - existing_index_age} seconds old, #{new_rows} new rows, #{existing_lists} lists, #{probes} probes.",
)
end
end end
def create_index!(memory, lists, probes) def create_index!(table_name, memory, lists, probes)
index_name = index_name(table_name)
DB.exec("SET work_mem TO '#{memory}';") DB.exec("SET work_mem TO '#{memory}';")
DB.exec("SET maintenance_work_mem TO '#{memory}';") DB.exec("SET maintenance_work_mem TO '#{memory}';")
DB.exec(<<~SQL) DB.exec(<<~SQL)
@ -102,17 +106,17 @@ module DiscourseAi
raise NotImplementedError raise NotImplementedError
end end
def generate_topic_representation_from(target, persist: true) def generate_representation_from(target, persist: true)
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2) text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
new_digest = OpenSSL::Digest::SHA1.hexdigest(text) new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
current_digest = DB.query_single(<<~SQL, topic_id: target.id).first current_digest = DB.query_single(<<~SQL, target_id: target.id).first
SELECT SELECT
digest digest
FROM FROM
#{table_name} #{table_name(target)}
WHERE WHERE
topic_id = :topic_id #{target.is_a?(Topic) ? "topic_id" : "post_id"} = :target_id
LIMIT 1 LIMIT 1
SQL SQL
return if current_digest == new_digest return if current_digest == new_digest
@ -127,7 +131,19 @@ module DiscourseAi
SELECT SELECT
topic_id topic_id
FROM FROM
#{table_name} #{topic_table_name}
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT 1
SQL
end
def post_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
post_id
FROM
#{post_table_name}
ORDER BY ORDER BY
embeddings #{pg_function} '[:query_embedding]' embeddings #{pg_function} '[:query_embedding]'
LIMIT 1 LIMIT 1
@ -140,7 +156,7 @@ module DiscourseAi
topic_id, topic_id,
embeddings #{pg_function} '[:query_embedding]' AS distance embeddings #{pg_function} '[:query_embedding]' AS distance
FROM FROM
#{table_name} #{topic_table_name}
ORDER BY ORDER BY
embeddings #{pg_function} '[:query_embedding]' embeddings #{pg_function} '[:query_embedding]'
LIMIT :limit LIMIT :limit
@ -162,13 +178,13 @@ module DiscourseAi
SELECT SELECT
topic_id topic_id
FROM FROM
#{table_name} #{topic_table_name}
ORDER BY ORDER BY
embeddings #{pg_function} ( embeddings #{pg_function} (
SELECT SELECT
embeddings embeddings
FROM FROM
#{table_name} #{topic_table_name}
WHERE WHERE
topic_id = :topic_id topic_id = :topic_id
LIMIT 1 LIMIT 1
@ -182,11 +198,26 @@ module DiscourseAi
raise MissingEmbeddingError raise MissingEmbeddingError
end end
def table_name def topic_table_name
"ai_topic_embeddings_#{id}_#{@strategy.id}" "ai_topic_embeddings_#{id}_#{@strategy.id}"
end end
def index_name def post_table_name
"ai_post_embeddings_#{id}_#{@strategy.id}"
end
def table_name(target)
case target
when Topic
topic_table_name
when Post
post_table_name
else
raise ArgumentError, "Invalid target type"
end
end
def index_name(table_name)
"#{table_name}_search" "#{table_name}_search"
end end
@ -221,24 +252,47 @@ module DiscourseAi
protected protected
def save_to_db(target, vector, digest) def save_to_db(target, vector, digest)
DB.exec( if target.is_a?(Topic)
<<~SQL, DB.exec(
INSERT INTO #{table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at) <<~SQL,
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP) INSERT INTO #{topic_table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
ON CONFLICT (topic_id) VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
DO UPDATE SET ON CONFLICT (topic_id)
model_version = :model_version, DO UPDATE SET
strategy_version = :strategy_version, model_version = :model_version,
digest = :digest, strategy_version = :strategy_version,
embeddings = '[:embeddings]', digest = :digest,
updated_at = CURRENT_TIMESTAMP embeddings = '[:embeddings]',
SQL updated_at = CURRENT_TIMESTAMP
topic_id: target.id, SQL
model_version: version, topic_id: target.id,
strategy_version: @strategy.version, model_version: version,
digest: digest, strategy_version: @strategy.version,
embeddings: vector, digest: digest,
) embeddings: vector,
)
elsif target.is_a?(Post)
DB.exec(
<<~SQL,
INSERT INTO #{post_table_name} (post_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:post_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (post_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = CURRENT_TIMESTAMP
SQL
post_id: target.id,
model_version: version,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
)
else
raise ArgumentError, "Invalid target type"
end
end end
end end
end end

View File

@ -1,23 +1,33 @@
# frozen_string_literal: true # frozen_string_literal: true
desc "Backfill embeddings for all topics" desc "Backfill embeddings for all topics and posts"
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args| task "ai:embeddings:backfill" => [:environment] do
public_categories = Category.where(read_restricted: false).pluck(:id) public_categories = Category.where(read_restricted: false).pluck(:id)
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
table_name = vector_rep.table_name table_name = vector_rep.topic_table_name
Topic Topic
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id") .joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
.where("#{table_name}.topic_id IS NULL") .where("#{table_name}.topic_id IS NULL")
.where("topics.id >= ?", args[:start_topic].to_i || 0)
.where("category_id IN (?)", public_categories) .where("category_id IN (?)", public_categories)
.where(deleted_at: nil) .where(deleted_at: nil)
.order("topics.id ASC") .order("topics.id DESC")
.find_each do |t| .find_each do |t|
print "." print "."
vector_rep.generate_topic_representation_from(t) vector_rep.generate_representation_from(t)
end
table_name = vector_rep.post_table_name
Post
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
.where("#{table_name}.post_id IS NULL")
.where(deleted_at: nil)
.order("posts.id DESC")
.find_each do |t|
print "."
vector_rep.generate_representation_from(t)
end end
end end

View File

@ -18,7 +18,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
it "queues a job on create if embeddings is enabled" do it "queues a job on create if embeddings is enabled" do
SiteSetting.ai_embeddings_enabled = true SiteSetting.ai_embeddings_enabled = true
expect { creator.create }.to change(Jobs::GenerateEmbeddings.jobs, :size).by(1) expect { creator.create }.to change(Jobs::GenerateEmbeddings.jobs, :size).by(2) # topic_created and post_created
end end
it "does nothing if sentiment analysis is disabled" do it "does nothing if sentiment analysis is disabled" do

View File

@ -18,7 +18,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end end
it "works" do it "works for topics" do
expected_embedding = [0.0038493] * vector_rep.dimensions expected_embedding = [0.0038493] * vector_rep.dimensions
text = text =
@ -29,9 +29,21 @@ RSpec.describe Jobs::GenerateEmbeddings do
) )
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding) EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
job.execute(topic_id: topic.id) job.execute(target_id: topic.id, target_type: "Topic")
expect(vector_rep.topic_id_from_representation(expected_embedding)).to eq(topic.id) expect(vector_rep.topic_id_from_representation(expected_embedding)).to eq(topic.id)
end end
it "works for posts" do
expected_embedding = [0.0038493] * vector_rep.dimensions
text =
truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2)
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
job.execute(target_id: post.id, target_type: "Post")
expect(vector_rep.post_id_from_representation(expected_embedding)).to eq(post.id)
end
end end
end end

View File

@ -12,9 +12,10 @@ RSpec.shared_examples "generates and store embedding using with vector represent
end end
end end
describe "#generate_topic_representation_from" do describe "#generate_representation_from" do
fab!(:topic) { Fabricate(:topic) } fab!(:topic) { Fabricate(:topic) }
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) } fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
it "creates a vector from a topic and stores it in the database" do it "creates a vector from a topic and stores it in the database" do
text = text =
@ -25,10 +26,24 @@ RSpec.shared_examples "generates and store embedding using with vector represent
) )
stub_vector_mapping(text, @expected_embedding) stub_vector_mapping(text, @expected_embedding)
vector_rep.generate_topic_representation_from(topic) vector_rep.generate_representation_from(topic)
expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id) expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
end end
it "creates a vector from a post and stores it in the database" do
text =
truncation.prepare_text_from(
post2,
vector_rep.tokenizer,
vector_rep.max_sequence_length - 2,
)
stub_vector_mapping(text, @expected_embedding)
vector_rep.generate_representation_from(post)
expect(vector_rep.post_id_from_representation(@expected_embedding)).to eq(post.id)
end
end end
describe "#asymmetric_topics_similarity_search" do describe "#asymmetric_topics_similarity_search" do
@ -44,7 +59,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
vector_rep.max_sequence_length - 2, vector_rep.max_sequence_length - 2,
) )
stub_vector_mapping(text, @expected_embedding) stub_vector_mapping(text, @expected_embedding)
vector_rep.generate_topic_representation_from(topic) vector_rep.generate_representation_from(topic)
expect( expect(
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0), vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),