mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-02-20 10:25:33 +00:00
On very large sites, the rare cache misses for Related Topics can take around 200ms, which affects our p99 metric on the topic page. In order to mitigate this impact, we now have several tools at our disposal. First, one is to migrate the index embedding type from halfvec to bit and change the related topic query to leverage the new bit index by changing the search algorithm from inner product to Hamming distance. This will reduce our index sizes by 90%, severely reducing the impact of embeddings on our storage. By making the related query a bit smarter, we can have zero impact on recall by using the index to over-capture N*2 results, then re-ordering those N*2 using the full halfvec vectors and taking the top N. The expected impact is to go from 200ms to <20ms for cache misses and from a 2.5GB index to a 250MB index on a large site. Another tool is migrating our index type from IVFFLAT to HNSW, which can increase the cache misses performance even further, eventually putting us in the under 5ms territory. Co-authored-by: Roman Rizzi <roman@discourse.org>
441 lines
15 KiB
Ruby
441 lines
15 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Embeddings
|
|
module VectorRepresentations
|
|
class Base
|
|
class << self
|
|
def find_representation(model_name)
|
|
# we are explicit here cause the loader may have not
|
|
# loaded the subclasses yet
|
|
[
|
|
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
|
|
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
|
|
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
|
|
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
|
|
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
|
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
|
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
|
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
|
].find { _1.name == model_name }
|
|
end
|
|
|
|
def current_representation(strategy)
|
|
find_representation(SiteSetting.ai_embeddings_model).new(strategy)
|
|
end
|
|
|
|
def correctly_configured?
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def dependant_setting_names
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def configuration_hint
|
|
settings = dependant_setting_names
|
|
I18n.t(
|
|
"discourse_ai.embeddings.configuration.hint",
|
|
settings: settings.join(", "),
|
|
count: settings.length,
|
|
)
|
|
end
|
|
end
|
|
|
|
def initialize(strategy)
|
|
@strategy = strategy
|
|
end
|
|
|
|
def vector_from(text, asymetric: false)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def generate_representation_from(target, persist: true)
|
|
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
|
|
return if text.blank?
|
|
|
|
target_column =
|
|
case target
|
|
when Topic
|
|
"topic_id"
|
|
when Post
|
|
"post_id"
|
|
when RagDocumentFragment
|
|
"rag_document_fragment_id"
|
|
else
|
|
raise ArgumentError, "Invalid target type"
|
|
end
|
|
|
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
|
current_digest = DB.query_single(<<~SQL, target_id: target.id).first
|
|
SELECT
|
|
digest
|
|
FROM
|
|
#{table_name(target)}
|
|
WHERE
|
|
model_id = #{id} AND
|
|
strategy_id = #{@strategy.id} AND
|
|
#{target_column} = :target_id
|
|
LIMIT 1
|
|
SQL
|
|
return if current_digest == new_digest
|
|
|
|
vector = vector_from(text)
|
|
|
|
save_to_db(target, vector, new_digest) if persist
|
|
end
|
|
|
|
def topic_id_from_representation(raw_vector)
|
|
DB.query_single(<<~SQL, query_embedding: raw_vector).first
|
|
SELECT
|
|
topic_id
|
|
FROM
|
|
#{topic_table_name}
|
|
WHERE
|
|
model_id = #{id} AND
|
|
strategy_id = #{@strategy.id}
|
|
ORDER BY
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
|
LIMIT 1
|
|
SQL
|
|
end
|
|
|
|
def post_id_from_representation(raw_vector)
|
|
DB.query_single(<<~SQL, query_embedding: raw_vector).first
|
|
SELECT
|
|
post_id
|
|
FROM
|
|
#{post_table_name}
|
|
WHERE
|
|
model_id = #{id} AND
|
|
strategy_id = #{@strategy.id}
|
|
ORDER BY
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
|
LIMIT 1
|
|
SQL
|
|
end
|
|
|
|
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
|
|
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
|
|
WITH candidates AS (
|
|
SELECT
|
|
topic_id,
|
|
embeddings::halfvec(#{dimensions}) AS embeddings
|
|
FROM
|
|
#{topic_table_name}
|
|
WHERE
|
|
model_id = #{id} AND strategy_id = #{@strategy.id}
|
|
ORDER BY
|
|
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
|
|
LIMIT :limit * 2
|
|
)
|
|
SELECT
|
|
topic_id,
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
|
|
FROM
|
|
candidates
|
|
ORDER BY
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
|
LIMIT :limit
|
|
OFFSET :offset
|
|
SQL
|
|
|
|
if return_distance
|
|
results.map { |r| [r.topic_id, r.distance] }
|
|
else
|
|
results.map(&:topic_id)
|
|
end
|
|
rescue PG::Error => e
|
|
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
|
raise MissingEmbeddingError
|
|
end
|
|
|
|
def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_distance: false)
|
|
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
|
|
WITH candidates AS (
|
|
SELECT
|
|
post_id,
|
|
embeddings::halfvec(#{dimensions}) AS embeddings
|
|
FROM
|
|
#{post_table_name}
|
|
WHERE
|
|
model_id = #{id} AND strategy_id = #{@strategy.id}
|
|
ORDER BY
|
|
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
|
|
LIMIT :limit * 2
|
|
)
|
|
SELECT
|
|
post_id,
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
|
|
FROM
|
|
candidates
|
|
ORDER BY
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
|
LIMIT :limit
|
|
OFFSET :offset
|
|
SQL
|
|
|
|
if return_distance
|
|
results.map { |r| [r.post_id, r.distance] }
|
|
else
|
|
results.map(&:post_id)
|
|
end
|
|
rescue PG::Error => e
|
|
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
|
raise MissingEmbeddingError
|
|
end
|
|
|
|
def asymmetric_rag_fragment_similarity_search(
|
|
raw_vector,
|
|
target_id:,
|
|
target_type:,
|
|
limit:,
|
|
offset:,
|
|
return_distance: false
|
|
)
|
|
# A too low limit exacerbates the the recall loss of binary quantization
|
|
binary_search_limit = [limit * 2, 100].max
|
|
results =
|
|
DB.query(
|
|
<<~SQL,
|
|
WITH candidates AS (
|
|
SELECT
|
|
rag_document_fragment_id,
|
|
embeddings::halfvec(#{dimensions}) AS embeddings
|
|
FROM
|
|
#{rag_fragments_table_name}
|
|
INNER JOIN
|
|
rag_document_fragments ON rag_document_fragments.id = rag_document_fragment_id
|
|
WHERE
|
|
model_id = #{id} AND strategy_id = #{@strategy.id}
|
|
ORDER BY
|
|
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
|
|
LIMIT :binary_search_limit
|
|
)
|
|
SELECT
|
|
rag_document_fragment_id,
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
|
|
FROM
|
|
candidates
|
|
ORDER BY
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
|
LIMIT :limit
|
|
OFFSET :offset
|
|
SQL
|
|
query_embedding: raw_vector,
|
|
target_id: target_id,
|
|
target_type: target_type,
|
|
limit: limit,
|
|
offset: offset,
|
|
binary_search_limit: binary_search_limit,
|
|
)
|
|
|
|
if return_distance
|
|
results.map { |r| [r.rag_document_fragment_id, r.distance] }
|
|
else
|
|
results.map(&:rag_document_fragment_id)
|
|
end
|
|
rescue PG::Error => e
|
|
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
|
raise MissingEmbeddingError
|
|
end
|
|
|
|
def symmetric_topics_similarity_search(topic)
|
|
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
|
WITH le_target AS (
|
|
SELECT
|
|
embeddings
|
|
FROM
|
|
#{topic_table_name}
|
|
WHERE
|
|
model_id = #{id} AND
|
|
strategy_id = #{@strategy.id} AND
|
|
topic_id = :topic_id
|
|
LIMIT 1
|
|
)
|
|
SELECT topic_id FROM (
|
|
SELECT
|
|
topic_id, embeddings
|
|
FROM
|
|
#{topic_table_name}
|
|
WHERE
|
|
model_id = #{id} AND
|
|
strategy_id = #{@strategy.id}
|
|
ORDER BY
|
|
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
|
|
SELECT
|
|
binary_quantize(embeddings)::bit(#{dimensions})
|
|
FROM
|
|
le_target
|
|
LIMIT 1
|
|
)
|
|
LIMIT 200
|
|
) AS widenet
|
|
ORDER BY
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} (
|
|
SELECT
|
|
embeddings::halfvec(#{dimensions})
|
|
FROM
|
|
le_target
|
|
LIMIT 1
|
|
)
|
|
LIMIT 100;
|
|
SQL
|
|
rescue PG::Error => e
|
|
Rails.logger.error(
|
|
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
|
|
)
|
|
raise MissingEmbeddingError
|
|
end
|
|
|
|
def topic_table_name
|
|
"ai_topic_embeddings"
|
|
end
|
|
|
|
def post_table_name
|
|
"ai_post_embeddings"
|
|
end
|
|
|
|
def rag_fragments_table_name
|
|
"ai_document_fragment_embeddings"
|
|
end
|
|
|
|
def table_name(target)
|
|
case target
|
|
when Topic
|
|
topic_table_name
|
|
when Post
|
|
post_table_name
|
|
when RagDocumentFragment
|
|
rag_fragments_table_name
|
|
else
|
|
raise ArgumentError, "Invalid target type"
|
|
end
|
|
end
|
|
|
|
def index_name(table_name)
|
|
"#{table_name}_#{id}_#{@strategy.id}_search"
|
|
end
|
|
|
|
def name
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def dimensions
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def max_sequence_length
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def id
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def pg_function
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def version
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def tokenizer
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def asymmetric_query_prefix
|
|
raise NotImplementedError
|
|
end
|
|
|
|
protected
|
|
|
|
def save_to_db(target, vector, digest)
|
|
if target.is_a?(Topic)
|
|
DB.exec(
|
|
<<~SQL,
|
|
INSERT INTO #{topic_table_name} (topic_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
|
|
VALUES (:topic_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
|
|
ON CONFLICT (strategy_id, model_id, topic_id)
|
|
DO UPDATE SET
|
|
model_version = :model_version,
|
|
strategy_version = :strategy_version,
|
|
digest = :digest,
|
|
embeddings = '[:embeddings]',
|
|
updated_at = :now
|
|
SQL
|
|
topic_id: target.id,
|
|
model_id: id,
|
|
model_version: version,
|
|
strategy_id: @strategy.id,
|
|
strategy_version: @strategy.version,
|
|
digest: digest,
|
|
embeddings: vector,
|
|
now: Time.zone.now,
|
|
)
|
|
elsif target.is_a?(Post)
|
|
DB.exec(
|
|
<<~SQL,
|
|
INSERT INTO #{post_table_name} (post_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
|
|
VALUES (:post_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
|
|
ON CONFLICT (model_id, strategy_id, post_id)
|
|
DO UPDATE SET
|
|
model_version = :model_version,
|
|
strategy_version = :strategy_version,
|
|
digest = :digest,
|
|
embeddings = '[:embeddings]',
|
|
updated_at = :now
|
|
SQL
|
|
post_id: target.id,
|
|
model_id: id,
|
|
model_version: version,
|
|
strategy_id: @strategy.id,
|
|
strategy_version: @strategy.version,
|
|
digest: digest,
|
|
embeddings: vector,
|
|
now: Time.zone.now,
|
|
)
|
|
elsif target.is_a?(RagDocumentFragment)
|
|
DB.exec(
|
|
<<~SQL,
|
|
INSERT INTO #{rag_fragments_table_name} (rag_document_fragment_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
|
|
VALUES (:fragment_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
|
|
ON CONFLICT (model_id, strategy_id, rag_document_fragment_id)
|
|
DO UPDATE SET
|
|
model_version = :model_version,
|
|
strategy_version = :strategy_version,
|
|
digest = :digest,
|
|
embeddings = '[:embeddings]',
|
|
updated_at = :now
|
|
SQL
|
|
fragment_id: target.id,
|
|
model_id: id,
|
|
model_version: version,
|
|
strategy_id: @strategy.id,
|
|
strategy_version: @strategy.version,
|
|
digest: digest,
|
|
embeddings: vector,
|
|
now: Time.zone.now,
|
|
)
|
|
else
|
|
raise ArgumentError, "Invalid target type"
|
|
end
|
|
end
|
|
|
|
def discourse_embeddings_endpoint
|
|
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
|
|
service =
|
|
DiscourseAi::Utils::DnsSrv.lookup(
|
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
|
|
)
|
|
"https://#{service.target}:#{service.port}"
|
|
else
|
|
SiteSetting.ai_embeddings_discourse_service_api_endpoint
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|