discourse-ai/lib/embeddings/vector_representations/base.rb

504 lines
17 KiB
Ruby
Raw Normal View History

# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class Base
class << self
def find_representation(model_name)
# we are explicit here cause the loader may have not
# loaded the subclasses yet
[
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
].find { _1.name == model_name }
end
def current_representation(strategy)
find_representation(SiteSetting.ai_embeddings_model).new(strategy)
end
def correctly_configured?
raise NotImplementedError
end
def dependant_setting_names
raise NotImplementedError
end
def configuration_hint
settings = dependant_setting_names
I18n.t(
"discourse_ai.embeddings.configuration.hint",
settings: settings.join(", "),
count: settings.length,
)
end
end
def initialize(strategy)
@strategy = strategy
end
def vector_from(text, asymetric: false)
raise NotImplementedError
end
def gen_bulk_reprensentations(relation)
http_pool_size = 100
pool =
Concurrent::CachedThreadPool.new(
min_threads: 0,
max_threads: http_pool_size,
idletime: 30,
)
embedding_gen = inference_client
promised_embeddings =
relation
.map do |record|
prepared_text = prepare_text(record)
next if prepared_text.blank?
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
next if find_digest_of(record) == new_digest
Concurrent::Promises
.fulfilled_future(
{ target: record, text: prepared_text, digest: new_digest },
pool,
)
.then_on(pool) do |w_prepared_text|
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
end
end
.compact
Concurrent::Promises
.zip(*promised_embeddings)
.value!
.each { |e| save_to_db(e[:target], e[:embedding], e[:digest]) }
pool.shutdown
pool.wait_for_termination
end
2023-12-29 10:28:45 -05:00
def generate_representation_from(target, persist: true)
text = prepare_text(target)
return if text.blank?
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
return if find_digest_of(target) == new_digest
vector = vector_from(text)
save_to_db(target, vector, new_digest) if persist
end
def topic_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
topic_id
FROM
2023-12-29 10:28:45 -05:00
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
2023-12-29 10:28:45 -05:00
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
2023-12-29 10:28:45 -05:00
LIMIT 1
SQL
end
def post_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
post_id
FROM
#{post_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
end
def asymmetric_topics_similarity_search(
raw_vector,
limit:,
offset:,
return_distance: false,
exclude_category_ids: nil
)
builder = DB.build(<<~SQL)
WITH candidates AS (
SELECT
topic_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{topic_table_name}
/*join*/
/*where*/
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :limit * 2
)
SELECT
topic_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
builder.where(
"model_id = :model_id AND strategy_id = :strategy_id",
model_id: id,
strategy_id: @strategy.id,
)
if exclude_category_ids.present?
builder.join("topics t on t.id = topic_id")
builder.where(<<~SQL, exclude_category_ids: exclude_category_ids.map(&:to_i))
t.category_id NOT IN (:exclude_category_ids) AND
t.category_id NOT IN (SELECT categories.id FROM categories WHERE categories.parent_category_id IN (:exclude_category_ids))
SQL
end
results = builder.query(query_embedding: raw_vector, limit: limit, offset: offset)
if return_distance
results.map { |r| [r.topic_id, r.distance] }
else
results.map(&:topic_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
WITH candidates AS (
SELECT
post_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{post_table_name}
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :limit * 2
)
SELECT
post_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
if return_distance
results.map { |r| [r.post_id, r.distance] }
else
results.map(&:post_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def asymmetric_rag_fragment_similarity_search(
raw_vector,
target_id:,
target_type:,
limit:,
offset:,
return_distance: false
)
# A too low limit exacerbates the the recall loss of binary quantization
binary_search_limit = [limit * 2, 100].max
results =
DB.query(
<<~SQL,
WITH candidates AS (
SELECT
rag_document_fragment_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{rag_fragments_table_name}
INNER JOIN
rag_document_fragments ON
rag_document_fragments.id = rag_document_fragment_id AND
rag_document_fragments.target_id = :target_id AND
rag_document_fragments.target_type = :target_type
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :binary_search_limit
)
SELECT
rag_document_fragment_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
query_embedding: raw_vector,
target_id: target_id,
target_type: target_type,
limit: limit,
offset: offset,
binary_search_limit: binary_search_limit,
)
if return_distance
results.map { |r| [r.rag_document_fragment_id, r.distance] }
else
results.map(&:rag_document_fragment_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def symmetric_topics_similarity_search(topic)
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
WITH le_target AS (
SELECT
embeddings
FROM
2023-12-29 10:28:45 -05:00
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
topic_id = :topic_id
LIMIT 1
)
SELECT topic_id FROM (
SELECT
topic_id, embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
SELECT
binary_quantize(embeddings)::bit(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 200
) AS widenet
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} (
SELECT
embeddings::halfvec(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 100;
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
)
raise MissingEmbeddingError
end
2023-12-29 10:28:45 -05:00
def topic_table_name
"ai_topic_embeddings"
end
2023-12-29 10:28:45 -05:00
def post_table_name
"ai_post_embeddings"
2023-12-29 10:28:45 -05:00
end
def rag_fragments_table_name
"ai_document_fragment_embeddings"
end
2023-12-29 10:28:45 -05:00
def table_name(target)
case target
when Topic
topic_table_name
when Post
post_table_name
when RagDocumentFragment
rag_fragments_table_name
2023-12-29 10:28:45 -05:00
else
raise ArgumentError, "Invalid target type"
end
end
def index_name(table_name)
"#{table_name}_#{id}_#{@strategy.id}_search"
end
def name
raise NotImplementedError
end
def dimensions
raise NotImplementedError
end
def max_sequence_length
raise NotImplementedError
end
def id
raise NotImplementedError
end
def pg_function
raise NotImplementedError
end
def version
raise NotImplementedError
end
def tokenizer
raise NotImplementedError
end
def asymmetric_query_prefix
raise NotImplementedError
end
protected
def find_digest_of(target)
target_column =
case target
when Topic
"topic_id"
when Post
"post_id"
when RagDocumentFragment
"rag_document_fragment_id"
else
raise ArgumentError, "Invalid target type"
end
DB.query_single(<<~SQL, target_id: target.id).first
SELECT
digest
FROM
#{table_name(target)}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
#{target_column} = :target_id
LIMIT 1
SQL
end
def save_to_db(target, vector, digest)
2023-12-29 10:28:45 -05:00
if target.is_a?(Topic)
DB.exec(
<<~SQL,
INSERT INTO #{topic_table_name} (topic_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:topic_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (strategy_id, model_id, topic_id)
2023-12-29 10:28:45 -05:00
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
2023-12-29 10:28:45 -05:00
SQL
topic_id: target.id,
model_id: id,
2023-12-29 10:28:45 -05:00
model_version: version,
strategy_id: @strategy.id,
2023-12-29 10:28:45 -05:00
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
2023-12-29 10:28:45 -05:00
)
elsif target.is_a?(Post)
DB.exec(
<<~SQL,
INSERT INTO #{post_table_name} (post_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:post_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (model_id, strategy_id, post_id)
2023-12-29 10:28:45 -05:00
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
2023-12-29 10:28:45 -05:00
SQL
post_id: target.id,
model_id: id,
2023-12-29 10:28:45 -05:00
model_version: version,
strategy_id: @strategy.id,
2023-12-29 10:28:45 -05:00
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
2023-12-29 10:28:45 -05:00
)
elsif target.is_a?(RagDocumentFragment)
DB.exec(
<<~SQL,
INSERT INTO #{rag_fragments_table_name} (rag_document_fragment_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:fragment_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (model_id, strategy_id, rag_document_fragment_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
fragment_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
2023-12-29 10:28:45 -05:00
else
raise ArgumentError, "Invalid target type"
end
end
def inference_client
raise NotImplementedError
end
def prepare_text(record)
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
end
end
end
end
end