FEATURE: Per post embeddings (#387)
This commit is contained in:
parent
c3af27571b
commit
140359c2ef
|
@ -6,18 +6,20 @@ module Jobs
|
||||||
|
|
||||||
def execute(args)
|
def execute(args)
|
||||||
return unless SiteSetting.ai_embeddings_enabled
|
return unless SiteSetting.ai_embeddings_enabled
|
||||||
return if (topic_id = args[:topic_id]).blank?
|
return if args[:target_type].blank? || args[:target_id].blank?
|
||||||
|
target = args[:target_type].constantize.find_by_id(args[:target_id])
|
||||||
|
return if target.nil? || target.deleted_at.present?
|
||||||
|
|
||||||
topic = Topic.find_by_id(topic_id)
|
topic = target.is_a?(Topic) ? target : target.topic
|
||||||
return if topic.nil? || topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
|
post = target.is_a?(Post) ? target : target.first_post
|
||||||
post = topic.first_post
|
return if topic.private_message? && !SiteSetting.ai_embeddings_generate_for_pms
|
||||||
return if post.nil? || post.raw.blank?
|
return if post.raw.blank?
|
||||||
|
|
||||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
vector_rep =
|
vector_rep =
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
|
|
||||||
vector_rep.generate_topic_representation_from(topic)
|
vector_rep.generate_representation_from(target)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -15,7 +15,7 @@ module Jobs
|
||||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
vector_rep =
|
vector_rep =
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
table_name = vector_rep.table_name
|
table_name = vector_rep.topic_table_name
|
||||||
|
|
||||||
topics =
|
topics =
|
||||||
Topic
|
Topic
|
||||||
|
@ -28,7 +28,7 @@ module Jobs
|
||||||
topics
|
topics
|
||||||
.where("#{table_name}.topic_id IS NULL")
|
.where("#{table_name}.topic_id IS NULL")
|
||||||
.find_each do |t|
|
.find_each do |t|
|
||||||
vector_rep.generate_topic_representation_from(t)
|
vector_rep.generate_representation_from(t)
|
||||||
rebaked += 1
|
rebaked += 1
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -45,7 +45,7 @@ module Jobs
|
||||||
#{table_name}.strategy_version < #{strategy.version}
|
#{table_name}.strategy_version < #{strategy.version}
|
||||||
SQL
|
SQL
|
||||||
.find_each do |t|
|
.find_each do |t|
|
||||||
vector_rep.generate_topic_representation_from(t)
|
vector_rep.generate_representation_from(t)
|
||||||
rebaked += 1
|
rebaked += 1
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -59,7 +59,57 @@ module Jobs
|
||||||
.limit((limit - rebaked) / 10)
|
.limit((limit - rebaked) / 10)
|
||||||
.pluck(:id)
|
.pluck(:id)
|
||||||
.each do |id|
|
.each do |id|
|
||||||
vector_rep.generate_topic_representation_from(Topic.find_by(id: id))
|
vector_rep.generate_representation_from(Topic.find_by(id: id))
|
||||||
|
rebaked += 1
|
||||||
|
end
|
||||||
|
|
||||||
|
return if rebaked >= limit
|
||||||
|
|
||||||
|
# Now for posts
|
||||||
|
table_name = vector_rep.post_table_name
|
||||||
|
|
||||||
|
posts =
|
||||||
|
Post
|
||||||
|
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
|
||||||
|
.where(deleted_at: nil)
|
||||||
|
.limit(limit - rebaked)
|
||||||
|
|
||||||
|
# First, we'll try to backfill embeddings for posts that have none
|
||||||
|
posts
|
||||||
|
.where("#{table_name}.post_id IS NULL")
|
||||||
|
.find_each do |t|
|
||||||
|
vector_rep.generate_representation_from(t)
|
||||||
|
rebaked += 1
|
||||||
|
end
|
||||||
|
|
||||||
|
vector_rep.consider_indexing
|
||||||
|
|
||||||
|
return if rebaked >= limit
|
||||||
|
|
||||||
|
# Then, we'll try to backfill embeddings for posts that have outdated
|
||||||
|
# embeddings, be it model or strategy version
|
||||||
|
posts
|
||||||
|
.where(<<~SQL)
|
||||||
|
#{table_name}.model_version < #{vector_rep.version}
|
||||||
|
OR
|
||||||
|
#{table_name}.strategy_version < #{strategy.version}
|
||||||
|
SQL
|
||||||
|
.find_each do |t|
|
||||||
|
vector_rep.generate_representation_from(t)
|
||||||
|
rebaked += 1
|
||||||
|
end
|
||||||
|
|
||||||
|
return if rebaked >= limit
|
||||||
|
|
||||||
|
# Finally, we'll try to backfill embeddings for posts that have outdated
|
||||||
|
# embeddings due to edits. Here we only do 10% of the limit
|
||||||
|
posts
|
||||||
|
.where("#{table_name}.updated_at < ?", 7.days.ago)
|
||||||
|
.order("random()")
|
||||||
|
.limit((limit - rebaked) / 10)
|
||||||
|
.pluck(:id)
|
||||||
|
.each do |id|
|
||||||
|
vector_rep.generate_representation_from(Post.find_by(id: id))
|
||||||
rebaked += 1
|
rebaked += 1
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ class MigrateEmbeddingsFromDedicatedDatabase < ActiveRecord::Migration[7.0]
|
||||||
].map { |k| k.new(truncation) }
|
].map { |k| k.new(truncation) }
|
||||||
|
|
||||||
vector_reps.each do |vector_rep|
|
vector_reps.each do |vector_rep|
|
||||||
new_table_name = vector_rep.table_name
|
new_table_name = vector_rep.topic_table_name
|
||||||
old_table_name = "topic_embeddings_#{vector_rep.name.underscore}"
|
old_table_name = "topic_embeddings_#{vector_rep.name.underscore}"
|
||||||
|
|
||||||
begin
|
begin
|
||||||
|
|
|
@ -0,0 +1,60 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class CreateAiPostEmbeddingsTables < ActiveRecord::Migration[7.0]
|
||||||
|
def change
|
||||||
|
create_table :ai_post_embeddings_1_1, id: false do |t|
|
||||||
|
t.integer :post_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(768)", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :post_id, unique: true
|
||||||
|
end
|
||||||
|
|
||||||
|
create_table :ai_post_embeddings_2_1, id: false do |t|
|
||||||
|
t.integer :post_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(1536)", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :post_id, unique: true
|
||||||
|
end
|
||||||
|
|
||||||
|
create_table :ai_post_embeddings_3_1, id: false do |t|
|
||||||
|
t.integer :post_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(1024)", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :post_id, unique: true
|
||||||
|
end
|
||||||
|
|
||||||
|
create_table :ai_post_embeddings_4_1, id: false do |t|
|
||||||
|
t.integer :post_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(1024)", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :post_id, unique: true
|
||||||
|
end
|
||||||
|
|
||||||
|
create_table :ai_post_embeddings_5_1, id: false do |t|
|
||||||
|
t.integer :post_id, null: false
|
||||||
|
t.integer :model_version, null: false
|
||||||
|
t.integer :strategy_version, null: false
|
||||||
|
t.text :digest, null: false
|
||||||
|
t.column :embeddings, "vector(768)", null: false
|
||||||
|
t.timestamps
|
||||||
|
|
||||||
|
t.index :post_id, unique: true
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -43,14 +43,20 @@ module DiscourseAi
|
||||||
|
|
||||||
# embeddings generation.
|
# embeddings generation.
|
||||||
callback =
|
callback =
|
||||||
Proc.new do |topic|
|
Proc.new do |target|
|
||||||
if SiteSetting.ai_embeddings_enabled
|
if SiteSetting.ai_embeddings_enabled
|
||||||
Jobs.enqueue(:generate_embeddings, topic_id: topic.id)
|
Jobs.enqueue(
|
||||||
|
:generate_embeddings,
|
||||||
|
target_id: target.id,
|
||||||
|
target_type: target.class.name,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
plugin.on(:topic_created, &callback)
|
plugin.on(:topic_created, &callback)
|
||||||
plugin.on(:topic_edited, &callback)
|
plugin.on(:topic_edited, &callback)
|
||||||
|
plugin.on(:post_created, &callback)
|
||||||
|
plugin.on(:post_edited, &callback)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -50,9 +50,9 @@ module DiscourseAi
|
||||||
tokenizer.truncate(text, max_length)
|
tokenizer.truncate(text, max_length)
|
||||||
end
|
end
|
||||||
|
|
||||||
def post_truncation(topic, tokenizer, max_length)
|
def post_truncation(post, tokenizer, max_length)
|
||||||
text = +topic_information(post.topic)
|
text = +topic_information(post.topic)
|
||||||
text << post.raw
|
text << Nokogiri::HTML5.fragment(post.cooked).text
|
||||||
|
|
||||||
tokenizer.truncate(text, max_length)
|
tokenizer.truncate(text, max_length)
|
||||||
end
|
end
|
||||||
|
|
|
@ -21,62 +21,66 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def consider_indexing(memory: "100MB")
|
def consider_indexing(memory: "100MB")
|
||||||
# Using extension maintainer's recommendation for ivfflat indexes
|
[topic_table_name, post_table_name].each do |table_name|
|
||||||
# Results are not as good as without indexes, but it's much faster
|
index_name = index_name(table_name)
|
||||||
# Disk usage is ~1x the size of the table, so this doubles table total size
|
# Using extension maintainer's recommendation for ivfflat indexes
|
||||||
count = DB.query_single("SELECT count(*) FROM #{table_name};").first
|
# Results are not as good as without indexes, but it's much faster
|
||||||
lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max
|
# Disk usage is ~1x the size of the table, so this doubles table total size
|
||||||
probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max
|
count = DB.query_single("SELECT count(*) FROM #{table_name};").first
|
||||||
|
lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max
|
||||||
|
probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max
|
||||||
|
|
||||||
existing_index = DB.query_single(<<~SQL, index_name: index_name).first
|
existing_index = DB.query_single(<<~SQL, index_name: index_name).first
|
||||||
SELECT
|
SELECT
|
||||||
indexdef
|
indexdef
|
||||||
FROM
|
FROM
|
||||||
pg_indexes
|
pg_indexes
|
||||||
WHERE
|
WHERE
|
||||||
indexname = :index_name
|
indexname = :index_name
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
SQL
|
SQL
|
||||||
|
|
||||||
if !existing_index.present?
|
if !existing_index.present?
|
||||||
Rails.logger.info("Index #{index_name} does not exist, creating...")
|
Rails.logger.info("Index #{index_name} does not exist, creating...")
|
||||||
return create_index!(memory, lists, probes)
|
return create_index!(table_name, memory, lists, probes)
|
||||||
end
|
|
||||||
|
|
||||||
existing_index_age =
|
|
||||||
DB
|
|
||||||
.query_single(
|
|
||||||
"SELECT pg_catalog.obj_description((:index_name)::regclass, 'pg_class');",
|
|
||||||
index_name: index_name,
|
|
||||||
)
|
|
||||||
.first
|
|
||||||
.to_i || 0
|
|
||||||
new_rows =
|
|
||||||
DB.query_single(
|
|
||||||
"SELECT count(*) FROM #{table_name} WHERE created_at > '#{Time.at(existing_index_age)}';",
|
|
||||||
).first
|
|
||||||
existing_lists = existing_index.match(/lists='(\d+)'/)&.captures&.first&.to_i
|
|
||||||
|
|
||||||
if existing_index_age > 0 && existing_index_age < 1.hour.ago.to_i
|
|
||||||
if new_rows > 10_000
|
|
||||||
Rails.logger.info(
|
|
||||||
"Index #{index_name} is #{existing_index_age} seconds old, and there are #{new_rows} new rows, updating...",
|
|
||||||
)
|
|
||||||
return create_index!(memory, lists, probes)
|
|
||||||
elsif existing_lists != lists
|
|
||||||
Rails.logger.info(
|
|
||||||
"Index #{index_name} already exists, but lists is #{existing_lists} instead of #{lists}, updating...",
|
|
||||||
)
|
|
||||||
return create_index!(memory, lists, probes)
|
|
||||||
end
|
end
|
||||||
end
|
|
||||||
|
|
||||||
Rails.logger.info(
|
existing_index_age =
|
||||||
"Index #{index_name} kept. #{Time.now.to_i - existing_index_age} seconds old, #{new_rows} new rows, #{existing_lists} lists, #{probes} probes.",
|
DB
|
||||||
)
|
.query_single(
|
||||||
|
"SELECT pg_catalog.obj_description((:index_name)::regclass, 'pg_class');",
|
||||||
|
index_name: index_name,
|
||||||
|
)
|
||||||
|
.first
|
||||||
|
.to_i || 0
|
||||||
|
new_rows =
|
||||||
|
DB.query_single(
|
||||||
|
"SELECT count(*) FROM #{table_name} WHERE created_at > '#{Time.at(existing_index_age)}';",
|
||||||
|
).first
|
||||||
|
existing_lists = existing_index.match(/lists='(\d+)'/)&.captures&.first&.to_i
|
||||||
|
|
||||||
|
if existing_index_age > 0 && existing_index_age < 1.hour.ago.to_i
|
||||||
|
if new_rows > 10_000
|
||||||
|
Rails.logger.info(
|
||||||
|
"Index #{index_name} is #{existing_index_age} seconds old, and there are #{new_rows} new rows, updating...",
|
||||||
|
)
|
||||||
|
return create_index!(table_name, memory, lists, probes)
|
||||||
|
elsif existing_lists != lists
|
||||||
|
Rails.logger.info(
|
||||||
|
"Index #{index_name} already exists, but lists is #{existing_lists} instead of #{lists}, updating...",
|
||||||
|
)
|
||||||
|
return create_index!(table_name, memory, lists, probes)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
Rails.logger.info(
|
||||||
|
"Index #{index_name} kept. #{Time.now.to_i - existing_index_age} seconds old, #{new_rows} new rows, #{existing_lists} lists, #{probes} probes.",
|
||||||
|
)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def create_index!(memory, lists, probes)
|
def create_index!(table_name, memory, lists, probes)
|
||||||
|
index_name = index_name(table_name)
|
||||||
DB.exec("SET work_mem TO '#{memory}';")
|
DB.exec("SET work_mem TO '#{memory}';")
|
||||||
DB.exec("SET maintenance_work_mem TO '#{memory}';")
|
DB.exec("SET maintenance_work_mem TO '#{memory}';")
|
||||||
DB.exec(<<~SQL)
|
DB.exec(<<~SQL)
|
||||||
|
@ -102,17 +106,17 @@ module DiscourseAi
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
|
||||||
def generate_topic_representation_from(target, persist: true)
|
def generate_representation_from(target, persist: true)
|
||||||
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
|
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
|
||||||
|
|
||||||
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
||||||
current_digest = DB.query_single(<<~SQL, topic_id: target.id).first
|
current_digest = DB.query_single(<<~SQL, target_id: target.id).first
|
||||||
SELECT
|
SELECT
|
||||||
digest
|
digest
|
||||||
FROM
|
FROM
|
||||||
#{table_name}
|
#{table_name(target)}
|
||||||
WHERE
|
WHERE
|
||||||
topic_id = :topic_id
|
#{target.is_a?(Topic) ? "topic_id" : "post_id"} = :target_id
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
SQL
|
SQL
|
||||||
return if current_digest == new_digest
|
return if current_digest == new_digest
|
||||||
|
@ -127,7 +131,19 @@ module DiscourseAi
|
||||||
SELECT
|
SELECT
|
||||||
topic_id
|
topic_id
|
||||||
FROM
|
FROM
|
||||||
#{table_name}
|
#{topic_table_name}
|
||||||
|
ORDER BY
|
||||||
|
embeddings #{pg_function} '[:query_embedding]'
|
||||||
|
LIMIT 1
|
||||||
|
SQL
|
||||||
|
end
|
||||||
|
|
||||||
|
def post_id_from_representation(raw_vector)
|
||||||
|
DB.query_single(<<~SQL, query_embedding: raw_vector).first
|
||||||
|
SELECT
|
||||||
|
post_id
|
||||||
|
FROM
|
||||||
|
#{post_table_name}
|
||||||
ORDER BY
|
ORDER BY
|
||||||
embeddings #{pg_function} '[:query_embedding]'
|
embeddings #{pg_function} '[:query_embedding]'
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
|
@ -140,7 +156,7 @@ module DiscourseAi
|
||||||
topic_id,
|
topic_id,
|
||||||
embeddings #{pg_function} '[:query_embedding]' AS distance
|
embeddings #{pg_function} '[:query_embedding]' AS distance
|
||||||
FROM
|
FROM
|
||||||
#{table_name}
|
#{topic_table_name}
|
||||||
ORDER BY
|
ORDER BY
|
||||||
embeddings #{pg_function} '[:query_embedding]'
|
embeddings #{pg_function} '[:query_embedding]'
|
||||||
LIMIT :limit
|
LIMIT :limit
|
||||||
|
@ -162,13 +178,13 @@ module DiscourseAi
|
||||||
SELECT
|
SELECT
|
||||||
topic_id
|
topic_id
|
||||||
FROM
|
FROM
|
||||||
#{table_name}
|
#{topic_table_name}
|
||||||
ORDER BY
|
ORDER BY
|
||||||
embeddings #{pg_function} (
|
embeddings #{pg_function} (
|
||||||
SELECT
|
SELECT
|
||||||
embeddings
|
embeddings
|
||||||
FROM
|
FROM
|
||||||
#{table_name}
|
#{topic_table_name}
|
||||||
WHERE
|
WHERE
|
||||||
topic_id = :topic_id
|
topic_id = :topic_id
|
||||||
LIMIT 1
|
LIMIT 1
|
||||||
|
@ -182,11 +198,26 @@ module DiscourseAi
|
||||||
raise MissingEmbeddingError
|
raise MissingEmbeddingError
|
||||||
end
|
end
|
||||||
|
|
||||||
def table_name
|
def topic_table_name
|
||||||
"ai_topic_embeddings_#{id}_#{@strategy.id}"
|
"ai_topic_embeddings_#{id}_#{@strategy.id}"
|
||||||
end
|
end
|
||||||
|
|
||||||
def index_name
|
def post_table_name
|
||||||
|
"ai_post_embeddings_#{id}_#{@strategy.id}"
|
||||||
|
end
|
||||||
|
|
||||||
|
def table_name(target)
|
||||||
|
case target
|
||||||
|
when Topic
|
||||||
|
topic_table_name
|
||||||
|
when Post
|
||||||
|
post_table_name
|
||||||
|
else
|
||||||
|
raise ArgumentError, "Invalid target type"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def index_name(table_name)
|
||||||
"#{table_name}_search"
|
"#{table_name}_search"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -221,24 +252,47 @@ module DiscourseAi
|
||||||
protected
|
protected
|
||||||
|
|
||||||
def save_to_db(target, vector, digest)
|
def save_to_db(target, vector, digest)
|
||||||
DB.exec(
|
if target.is_a?(Topic)
|
||||||
<<~SQL,
|
DB.exec(
|
||||||
INSERT INTO #{table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
<<~SQL,
|
||||||
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
INSERT INTO #{topic_table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
||||||
ON CONFLICT (topic_id)
|
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||||
DO UPDATE SET
|
ON CONFLICT (topic_id)
|
||||||
model_version = :model_version,
|
DO UPDATE SET
|
||||||
strategy_version = :strategy_version,
|
model_version = :model_version,
|
||||||
digest = :digest,
|
strategy_version = :strategy_version,
|
||||||
embeddings = '[:embeddings]',
|
digest = :digest,
|
||||||
updated_at = CURRENT_TIMESTAMP
|
embeddings = '[:embeddings]',
|
||||||
SQL
|
updated_at = CURRENT_TIMESTAMP
|
||||||
topic_id: target.id,
|
SQL
|
||||||
model_version: version,
|
topic_id: target.id,
|
||||||
strategy_version: @strategy.version,
|
model_version: version,
|
||||||
digest: digest,
|
strategy_version: @strategy.version,
|
||||||
embeddings: vector,
|
digest: digest,
|
||||||
)
|
embeddings: vector,
|
||||||
|
)
|
||||||
|
elsif target.is_a?(Post)
|
||||||
|
DB.exec(
|
||||||
|
<<~SQL,
|
||||||
|
INSERT INTO #{post_table_name} (post_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
||||||
|
VALUES (:post_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
||||||
|
ON CONFLICT (post_id)
|
||||||
|
DO UPDATE SET
|
||||||
|
model_version = :model_version,
|
||||||
|
strategy_version = :strategy_version,
|
||||||
|
digest = :digest,
|
||||||
|
embeddings = '[:embeddings]',
|
||||||
|
updated_at = CURRENT_TIMESTAMP
|
||||||
|
SQL
|
||||||
|
post_id: target.id,
|
||||||
|
model_version: version,
|
||||||
|
strategy_version: @strategy.version,
|
||||||
|
digest: digest,
|
||||||
|
embeddings: vector,
|
||||||
|
)
|
||||||
|
else
|
||||||
|
raise ArgumentError, "Invalid target type"
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,23 +1,33 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
desc "Backfill embeddings for all topics"
|
desc "Backfill embeddings for all topics and posts"
|
||||||
task "ai:embeddings:backfill", [:start_topic] => [:environment] do |_, args|
|
task "ai:embeddings:backfill" => [:environment] do
|
||||||
public_categories = Category.where(read_restricted: false).pluck(:id)
|
public_categories = Category.where(read_restricted: false).pluck(:id)
|
||||||
|
|
||||||
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||||
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
|
||||||
table_name = vector_rep.table_name
|
table_name = vector_rep.topic_table_name
|
||||||
|
|
||||||
Topic
|
Topic
|
||||||
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
.joins("LEFT JOIN #{table_name} ON #{table_name}.topic_id = topics.id")
|
||||||
.where("#{table_name}.topic_id IS NULL")
|
.where("#{table_name}.topic_id IS NULL")
|
||||||
.where("topics.id >= ?", args[:start_topic].to_i || 0)
|
|
||||||
.where("category_id IN (?)", public_categories)
|
.where("category_id IN (?)", public_categories)
|
||||||
.where(deleted_at: nil)
|
.where(deleted_at: nil)
|
||||||
.order("topics.id ASC")
|
.order("topics.id DESC")
|
||||||
.find_each do |t|
|
.find_each do |t|
|
||||||
print "."
|
print "."
|
||||||
vector_rep.generate_topic_representation_from(t)
|
vector_rep.generate_representation_from(t)
|
||||||
|
end
|
||||||
|
|
||||||
|
table_name = vector_rep.post_table_name
|
||||||
|
Post
|
||||||
|
.joins("LEFT JOIN #{table_name} ON #{table_name}.post_id = posts.id")
|
||||||
|
.where("#{table_name}.post_id IS NULL")
|
||||||
|
.where(deleted_at: nil)
|
||||||
|
.order("posts.id DESC")
|
||||||
|
.find_each do |t|
|
||||||
|
print "."
|
||||||
|
vector_rep.generate_representation_from(t)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -18,7 +18,7 @@ describe DiscourseAi::Embeddings::EntryPoint do
|
||||||
it "queues a job on create if embeddings is enabled" do
|
it "queues a job on create if embeddings is enabled" do
|
||||||
SiteSetting.ai_embeddings_enabled = true
|
SiteSetting.ai_embeddings_enabled = true
|
||||||
|
|
||||||
expect { creator.create }.to change(Jobs::GenerateEmbeddings.jobs, :size).by(1)
|
expect { creator.create }.to change(Jobs::GenerateEmbeddings.jobs, :size).by(2) # topic_created and post_created
|
||||||
end
|
end
|
||||||
|
|
||||||
it "does nothing if sentiment analysis is disabled" do
|
it "does nothing if sentiment analysis is disabled" do
|
||||||
|
|
|
@ -18,7 +18,7 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
||||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "works" do
|
it "works for topics" do
|
||||||
expected_embedding = [0.0038493] * vector_rep.dimensions
|
expected_embedding = [0.0038493] * vector_rep.dimensions
|
||||||
|
|
||||||
text =
|
text =
|
||||||
|
@ -29,9 +29,21 @@ RSpec.describe Jobs::GenerateEmbeddings do
|
||||||
)
|
)
|
||||||
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
|
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
|
||||||
|
|
||||||
job.execute(topic_id: topic.id)
|
job.execute(target_id: topic.id, target_type: "Topic")
|
||||||
|
|
||||||
expect(vector_rep.topic_id_from_representation(expected_embedding)).to eq(topic.id)
|
expect(vector_rep.topic_id_from_representation(expected_embedding)).to eq(topic.id)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "works for posts" do
|
||||||
|
expected_embedding = [0.0038493] * vector_rep.dimensions
|
||||||
|
|
||||||
|
text =
|
||||||
|
truncation.prepare_text_from(post, vector_rep.tokenizer, vector_rep.max_sequence_length - 2)
|
||||||
|
EmbeddingsGenerationStubs.discourse_service(vector_rep.name, text, expected_embedding)
|
||||||
|
|
||||||
|
job.execute(target_id: post.id, target_type: "Post")
|
||||||
|
|
||||||
|
expect(vector_rep.post_id_from_representation(expected_embedding)).to eq(post.id)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -12,9 +12,10 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "#generate_topic_representation_from" do
|
describe "#generate_representation_from" do
|
||||||
fab!(:topic) { Fabricate(:topic) }
|
fab!(:topic) { Fabricate(:topic) }
|
||||||
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
fab!(:post) { Fabricate(:post, post_number: 1, topic: topic) }
|
||||||
|
fab!(:post2) { Fabricate(:post, post_number: 2, topic: topic) }
|
||||||
|
|
||||||
it "creates a vector from a topic and stores it in the database" do
|
it "creates a vector from a topic and stores it in the database" do
|
||||||
text =
|
text =
|
||||||
|
@ -25,10 +26,24 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
||||||
)
|
)
|
||||||
stub_vector_mapping(text, @expected_embedding)
|
stub_vector_mapping(text, @expected_embedding)
|
||||||
|
|
||||||
vector_rep.generate_topic_representation_from(topic)
|
vector_rep.generate_representation_from(topic)
|
||||||
|
|
||||||
expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
|
expect(vector_rep.topic_id_from_representation(@expected_embedding)).to eq(topic.id)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "creates a vector from a post and stores it in the database" do
|
||||||
|
text =
|
||||||
|
truncation.prepare_text_from(
|
||||||
|
post2,
|
||||||
|
vector_rep.tokenizer,
|
||||||
|
vector_rep.max_sequence_length - 2,
|
||||||
|
)
|
||||||
|
stub_vector_mapping(text, @expected_embedding)
|
||||||
|
|
||||||
|
vector_rep.generate_representation_from(post)
|
||||||
|
|
||||||
|
expect(vector_rep.post_id_from_representation(@expected_embedding)).to eq(post.id)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "#asymmetric_topics_similarity_search" do
|
describe "#asymmetric_topics_similarity_search" do
|
||||||
|
@ -44,7 +59,7 @@ RSpec.shared_examples "generates and store embedding using with vector represent
|
||||||
vector_rep.max_sequence_length - 2,
|
vector_rep.max_sequence_length - 2,
|
||||||
)
|
)
|
||||||
stub_vector_mapping(text, @expected_embedding)
|
stub_vector_mapping(text, @expected_embedding)
|
||||||
vector_rep.generate_topic_representation_from(topic)
|
vector_rep.generate_representation_from(topic)
|
||||||
|
|
||||||
expect(
|
expect(
|
||||||
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
|
vector_rep.asymmetric_topics_similarity_search(similar_vector, limit: 1, offset: 0),
|
||||||
|
|
Loading…
Reference in New Issue