Rafael dos Santos Silva 791fad1e6a
FEATURE: Index embeddings using bit vectors (#824)
On very large sites, the rare cache misses for Related Topics can take around 200ms, which affects our p99 metric on the topic page. In order to mitigate this impact, we now have several tools at our disposal.

First, one is to migrate the index embedding type from halfvec to bit and change the related topic query to leverage the new bit index by changing the search algorithm from inner product to Hamming distance. This will reduce our index sizes by 90%, severely reducing the impact of embeddings on our storage. By making the related query a bit smarter, we can have zero impact on recall by using the index to over-capture N*2 results, then re-ordering those N*2 using the full halfvec vectors and taking the top N. The expected impact is to go from 200ms to <20ms for cache misses and from a 2.5GB index to a 250MB index on a large site.

Another tool is migrating our index type from IVFFLAT to HNSW, which can increase the cache misses performance even further, eventually putting us in the under 5ms territory. 

Co-authored-by: Roman Rizzi <roman@discourse.org>
2024-10-14 13:26:03 -03:00

441 lines
15 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class Base
class << self
def find_representation(model_name)
# we are explicit here cause the loader may have not
# loaded the subclasses yet
[
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
].find { _1.name == model_name }
end
def current_representation(strategy)
find_representation(SiteSetting.ai_embeddings_model).new(strategy)
end
def correctly_configured?
raise NotImplementedError
end
def dependant_setting_names
raise NotImplementedError
end
def configuration_hint
settings = dependant_setting_names
I18n.t(
"discourse_ai.embeddings.configuration.hint",
settings: settings.join(", "),
count: settings.length,
)
end
end
def initialize(strategy)
@strategy = strategy
end
def vector_from(text, asymetric: false)
raise NotImplementedError
end
def generate_representation_from(target, persist: true)
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
return if text.blank?
target_column =
case target
when Topic
"topic_id"
when Post
"post_id"
when RagDocumentFragment
"rag_document_fragment_id"
else
raise ArgumentError, "Invalid target type"
end
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
current_digest = DB.query_single(<<~SQL, target_id: target.id).first
SELECT
digest
FROM
#{table_name(target)}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
#{target_column} = :target_id
LIMIT 1
SQL
return if current_digest == new_digest
vector = vector_from(text)
save_to_db(target, vector, new_digest) if persist
end
def topic_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
topic_id
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
end
def post_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
post_id
FROM
#{post_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT 1
SQL
end
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
WITH candidates AS (
SELECT
topic_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :limit * 2
)
SELECT
topic_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
if return_distance
results.map { |r| [r.topic_id, r.distance] }
else
results.map(&:topic_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
WITH candidates AS (
SELECT
post_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{post_table_name}
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :limit * 2
)
SELECT
post_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
if return_distance
results.map { |r| [r.post_id, r.distance] }
else
results.map(&:post_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def asymmetric_rag_fragment_similarity_search(
raw_vector,
target_id:,
target_type:,
limit:,
offset:,
return_distance: false
)
# A too low limit exacerbates the the recall loss of binary quantization
binary_search_limit = [limit * 2, 100].max
results =
DB.query(
<<~SQL,
WITH candidates AS (
SELECT
rag_document_fragment_id,
embeddings::halfvec(#{dimensions}) AS embeddings
FROM
#{rag_fragments_table_name}
INNER JOIN
rag_document_fragments ON rag_document_fragments.id = rag_document_fragment_id
WHERE
model_id = #{id} AND strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
LIMIT :binary_search_limit
)
SELECT
rag_document_fragment_id,
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
FROM
candidates
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
query_embedding: raw_vector,
target_id: target_id,
target_type: target_type,
limit: limit,
offset: offset,
binary_search_limit: binary_search_limit,
)
if return_distance
results.map { |r| [r.rag_document_fragment_id, r.distance] }
else
results.map(&:rag_document_fragment_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def symmetric_topics_similarity_search(topic)
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
WITH le_target AS (
SELECT
embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
topic_id = :topic_id
LIMIT 1
)
SELECT topic_id FROM (
SELECT
topic_id, embeddings
FROM
#{topic_table_name}
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id}
ORDER BY
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
SELECT
binary_quantize(embeddings)::bit(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 200
) AS widenet
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} (
SELECT
embeddings::halfvec(#{dimensions})
FROM
le_target
LIMIT 1
)
LIMIT 100;
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
)
raise MissingEmbeddingError
end
def topic_table_name
"ai_topic_embeddings"
end
def post_table_name
"ai_post_embeddings"
end
def rag_fragments_table_name
"ai_document_fragment_embeddings"
end
def table_name(target)
case target
when Topic
topic_table_name
when Post
post_table_name
when RagDocumentFragment
rag_fragments_table_name
else
raise ArgumentError, "Invalid target type"
end
end
def index_name(table_name)
"#{table_name}_#{id}_#{@strategy.id}_search"
end
def name
raise NotImplementedError
end
def dimensions
raise NotImplementedError
end
def max_sequence_length
raise NotImplementedError
end
def id
raise NotImplementedError
end
def pg_function
raise NotImplementedError
end
def version
raise NotImplementedError
end
def tokenizer
raise NotImplementedError
end
def asymmetric_query_prefix
raise NotImplementedError
end
protected
def save_to_db(target, vector, digest)
if target.is_a?(Topic)
DB.exec(
<<~SQL,
INSERT INTO #{topic_table_name} (topic_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:topic_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (strategy_id, model_id, topic_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
topic_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
elsif target.is_a?(Post)
DB.exec(
<<~SQL,
INSERT INTO #{post_table_name} (post_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:post_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (model_id, strategy_id, post_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
post_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
elsif target.is_a?(RagDocumentFragment)
DB.exec(
<<~SQL,
INSERT INTO #{rag_fragments_table_name} (rag_document_fragment_id, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:fragment_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
ON CONFLICT (model_id, strategy_id, rag_document_fragment_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = :now
SQL
fragment_id: target.id,
model_id: id,
model_version: version,
strategy_id: @strategy.id,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
now: Time.zone.now,
)
else
raise ArgumentError, "Invalid target type"
end
end
def discourse_embeddings_endpoint
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_embeddings_discourse_service_api_endpoint
end
end
end
end
end
end