Roman Rizzi ddf2bf7034
DEV: Backfill embeddings concurrently. (#941)
We are adding a new method for generating and storing embeddings in bulk, which relies on `Concurrent::Promises::Future`. Generating an embedding consists of three steps:

Prepare text
HTTP call to retrieve the vector
Save to DB.
Each one is independently executed on whatever thread the pool gives us.

We are bringing a custom thread pool instead of the global executor since we want control over how many threads we spawn to limit concurrency. We also avoid firing thousands of HTTP requests when working with large batches.
2024-11-26 14:12:32 -03:00

470 lines
16 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class Base
class << self
def find_representation(model_name)
# we are explicit here cause the loader may have not
# loaded the subclasses yet
[
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
].find { _1.name == model_name }
end
def current_representation(strategy)
find_representation(SiteSetting.ai_embeddings_model).new(strategy)
end
def correctly_configured?
raise NotImplementedError
end
def dependant_setting_names
raise NotImplementedError
end
def configuration_hint
settings = dependant_setting_names
I18n.t(
"discourse_ai.embeddings.configuration.hint",
settings: settings.join(", "),
count: settings.length,
)
end
end
def initialize(strategy)
@strategy = strategy
end
def vector_from(text, asymetric: false)
raise NotImplementedError
end
def gen_bulk_reprensentations(relation)
http_pool_size = 100
pool =
Concurrent::CachedThreadPool.new(
min_threads: 0,
max_threads: http_pool_size,
idletime: 30,
)
embedding_gen = inference_client
promised_embeddings =
relation.map do |record|
materials = { target: record, text: prepare_text(record) }
Concurrent::Promises
.fulfilled_future(materials, pool)
.then_on(pool) do |w_prepared_text|
w_prepared_text.merge(
embedding: embedding_gen.perform!(w_prepared_text[:text]),
digest: OpenSSL::Digest::SHA1.hexdigest(w_prepared_text[:text]),
)
end
end
Concurrent::Promises
.zip(*promised_embeddings)
.value!
.each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) }
end
def generate_representation_from(target, persist: true)
text = prepare_text(target)
return if text.blank?
target_column =
case target
when Topic
"topic_id"
when Post
"post_id"
when RagDocumentFragment
"rag_document_fragment_id"
else
raise ArgumentError, "Invalid target type"
end
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
current_digest = DB.query_single(<<~SQL, target_id: target.id).first
SELECT
digest
FROM
#{table_name(target)}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
#{target_column} = :target_id
LIMIT 1
SQL
return if current_digest == new_digest
vector = vector_from(text)
save_to_db(target, vector, new_digest) if persist
end
def topic_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
topic_id
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
end
def post_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
post_id
FROM
#{post_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
end
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
WITH candidates AS (
SELECT
topic_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :limit * 2
)
SELECT
topic_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
if return_distance
results.map { |r| [r.topic_id, r.distance] }
else
results.map(&:topic_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
WITH candidates AS (
SELECT
post_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{post_table_name}
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :limit * 2
)
SELECT
post_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
if return_distance
results.map { |r| [r.post_id, r.distance] }
else
results.map(&:post_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def asymmetric_rag_fragment_similarity_search(
raw_vector,
target_id:,
target_type:,
limit:,
offset:,
return_distance: false
)
# A too low limit exacerbates the the recall loss of binary quantization
binary_search_limit = [limit * 2, 100].max
results =
DB.query(
<<~SQL,
WITH candidates AS (
SELECT
rag_document_fragment_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{rag_fragments_table_name}
INNER JOIN
rag_document_fragments ON
rag_document_fragments.id = rag_document_fragment_id AND
rag_document_fragments.target_id = :target_id AND
rag_document_fragments.target_type = :target_type
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :binary_search_limit
)
SELECT
rag_document_fragment_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
query_embedding: raw_vector,
target_id: target_id,
target_type: target_type,
limit: limit,
offset: offset,
binary_search_limit: binary_search_limit,
)
if return_distance
results.map { |r| [r.rag_document_fragment_id, r.distance] }
else
results.map(&:rag_document_fragment_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def symmetric_topics_similarity_search(topic)
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
WITH le_target AS (
SELECT
embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
topic_id = :topic_id
LIMIT 1
)
SELECT topic_id FROM (
SELECT
topic_id, embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
SELECT
binary_quantize(embeddings)::bit(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 200
) AS widenet
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} (
SELECT
embeddings::halfvec(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 100;
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
)
raise MissingEmbeddingError
end
def topic_table_name
"ai_topic_embeddings"
end
def post_table_name
"ai_post_embeddings"
end
def rag_fragments_table_name
"ai_document_fragment_embeddings"
end
def table_name(target)
case target
when Topic
topic_table_name
when Post
post_table_name
when RagDocumentFragment
rag_fragments_table_name
else
raise ArgumentError, "Invalid target type"
end
end
def index_name(table_name)
"#{table_name}_#{id}_#{@strategy.id}_search"
end
def name
raise NotImplementedError
end
def dimensions
raise NotImplementedError
end
def max_sequence_length
raise NotImplementedError
end
def id
raise NotImplementedError
end
def pg_function
raise NotImplementedError
end
def version
raise NotImplementedError
end
def tokenizer
raise NotImplementedError
end
def asymmetric_query_prefix
raise NotImplementedError
end
protected
def save_to_db(target, vector, digest)
if target.is_a?(Topic)
DB.exec(
<<~SQL,
INSERT INTO #{topic_table_name} (topic_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:topic_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (strategy_id, model_id, topic_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
topic_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
elsif target.is_a?(Post)
DB.exec(
<<~SQL,
INSERT INTO #{post_table_name} (post_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:post_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (model_id, strategy_id, post_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
post_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
elsif target.is_a?(RagDocumentFragment)
DB.exec(
<<~SQL,
INSERT INTO #{rag_fragments_table_name} (rag_document_fragment_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:fragment_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (model_id, strategy_id, rag_document_fragment_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
fragment_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
else
raise ArgumentError, "Invalid target type"
end
end
def inference_client
raise NotImplementedError
end
def prepare_text(record)
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
end
end
end
end
end