mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-03 15:59:59 +00:00
From [pgvector/pgvector](https://github.com/pgvector/pgvector) README > With approximate indexes, filtering is applied after the index is scanned. If a condition matches 10% of rows, with HNSW and the default hnsw.ef_search of 40, only 4 rows will match on average. For more rows, increase hnsw.ef_search. > > Starting with 0.8.0, you can enable [iterative index scans](https://github.com/pgvector/pgvector#iterative-index-scans), which will automatically scan more of the index when needed. Since we are stuck on 0.7.0 we are going the first option for now.
286 lines
8.9 KiB
Ruby
286 lines
8.9 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
# We don't have AR objects for our embeddings, so this class
|
|
# acts as an intermediary between us and the DB.
|
|
# It lets us retrieve embeddings either symmetrically and asymmetrically,
|
|
# and also store them.
|
|
|
|
module DiscourseAi
|
|
module Embeddings
|
|
class Schema
|
|
TOPICS_TABLE = "ai_topics_embeddings"
|
|
POSTS_TABLE = "ai_posts_embeddings"
|
|
RAG_DOCS_TABLE = "ai_document_fragments_embeddings"
|
|
|
|
EMBEDDING_TARGETS = %w[topics posts document_fragments]
|
|
EMBEDDING_TABLES = [TOPICS_TABLE, POSTS_TABLE, RAG_DOCS_TABLE]
|
|
|
|
DEFAULT_HNSW_EF_SEARCH = 40
|
|
|
|
MissingEmbeddingError = Class.new(StandardError)
|
|
|
|
class << self
|
|
def for(target_klass)
|
|
vector_def = EmbeddingDefinition.find_by(id: SiteSetting.ai_embeddings_selected_model)
|
|
raise "Invalid embeddings selected model" if vector_def.nil?
|
|
|
|
case target_klass&.name
|
|
when "Topic"
|
|
new(TOPICS_TABLE, "topic_id", vector_def)
|
|
when "Post"
|
|
new(POSTS_TABLE, "post_id", vector_def)
|
|
when "RagDocumentFragment"
|
|
new(RAG_DOCS_TABLE, "rag_document_fragment_id", vector_def)
|
|
else
|
|
raise ArgumentError, "Invalid target type for embeddings"
|
|
end
|
|
end
|
|
|
|
def search_index_name(table, def_id)
|
|
"ai_#{table}_embeddings_#{def_id}_1_search_bit"
|
|
end
|
|
|
|
def prepare_search_indexes(vector_def)
|
|
EMBEDDING_TARGETS.each { |target| DB.exec <<~SQL }
|
|
CREATE INDEX IF NOT EXISTS #{search_index_name(target, vector_def.id)} ON ai_#{target}_embeddings
|
|
USING hnsw ((binary_quantize(embeddings)::bit(#{vector_def.dimensions})) bit_hamming_ops)
|
|
WHERE model_id = #{vector_def.id} AND strategy_id = 1;
|
|
SQL
|
|
end
|
|
|
|
def correctly_indexed?(vector_def)
|
|
index_names = EMBEDDING_TARGETS.map { |t| search_index_name(t, vector_def.id) }
|
|
indexdefs =
|
|
DB.query_single(
|
|
"SELECT indexdef FROM pg_indexes WHERE indexname IN (:names)",
|
|
names: index_names,
|
|
)
|
|
|
|
return false if indexdefs.length < index_names.length
|
|
|
|
indexdefs.all? do |defs|
|
|
defs.include? "(binary_quantize(embeddings))::bit(#{vector_def.dimensions})"
|
|
end
|
|
end
|
|
|
|
def remove_orphaned_data
|
|
removed_defs_ids =
|
|
DB.query_single(
|
|
"SELECT DISTINCT(model_id) FROM #{TOPICS_TABLE} te LEFT JOIN embedding_definitions ed ON te.model_id = ed.id WHERE ed.id IS NULL",
|
|
)
|
|
|
|
EMBEDDING_TABLES.each do |t|
|
|
DB.exec(
|
|
"DELETE FROM #{t} WHERE model_id IN (:removed_defs)",
|
|
removed_defs: removed_defs_ids,
|
|
)
|
|
end
|
|
|
|
drop_index_statement =
|
|
EMBEDDING_TARGETS
|
|
.reduce([]) do |memo, et|
|
|
removed_defs_ids.each do |rdi|
|
|
memo << "DROP INDEX IF EXISTS #{search_index_name(et, rdi)};"
|
|
end
|
|
|
|
memo
|
|
end
|
|
.join("\n")
|
|
|
|
DB.exec(drop_index_statement)
|
|
end
|
|
end
|
|
|
|
def initialize(table, target_column, vector_def)
|
|
@table = table
|
|
@target_column = target_column
|
|
@vector_def = vector_def
|
|
end
|
|
|
|
attr_reader :table, :target_column, :vector_def
|
|
|
|
def find_by_embedding(embedding)
|
|
DB.query(
|
|
<<~SQL,
|
|
SELECT *
|
|
FROM #{table}
|
|
WHERE
|
|
model_id = :vid AND strategy_id = :vsid
|
|
ORDER BY
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
|
LIMIT 1
|
|
SQL
|
|
query_embedding: embedding,
|
|
vid: vector_def.id,
|
|
vsid: vector_def.strategy_id,
|
|
).first
|
|
end
|
|
|
|
def find_by_target(target)
|
|
DB.query(
|
|
<<~SQL,
|
|
SELECT *
|
|
FROM #{table}
|
|
WHERE
|
|
model_id = :vid AND
|
|
strategy_id = :vsid AND
|
|
#{target_column} = :target_id
|
|
LIMIT 1
|
|
SQL
|
|
target_id: target.id,
|
|
vid: vector_def.id,
|
|
vsid: vector_def.strategy_id,
|
|
).first
|
|
end
|
|
|
|
def asymmetric_similarity_search(embedding, limit:, offset:)
|
|
before_query = hnsw_search_workaround(limit)
|
|
|
|
builder = DB.build(<<~SQL)
|
|
WITH candidates AS (
|
|
SELECT
|
|
#{target_column},
|
|
embeddings::halfvec(#{dimensions}) AS embeddings
|
|
FROM
|
|
#{table}
|
|
/*join*/
|
|
/*where*/
|
|
ORDER BY
|
|
binary_quantize(embeddings)::bit(#{dimensions}) <~> binary_quantize('[:query_embedding]'::halfvec(#{dimensions}))
|
|
LIMIT :candidates_limit
|
|
)
|
|
SELECT
|
|
#{target_column},
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) AS distance
|
|
FROM
|
|
candidates
|
|
ORDER BY
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
|
LIMIT :limit
|
|
OFFSET :offset;
|
|
SQL
|
|
|
|
builder.where(
|
|
"model_id = :model_id AND strategy_id = :strategy_id",
|
|
model_id: vector_def.id,
|
|
strategy_id: vector_def.strategy_id,
|
|
)
|
|
|
|
yield(builder) if block_given?
|
|
|
|
if table == RAG_DOCS_TABLE
|
|
# A too low limit exacerbates the the recall loss of binary quantization
|
|
candidates_limit = [limit * 2, 100].max
|
|
else
|
|
candidates_limit = limit * 2
|
|
end
|
|
|
|
ActiveRecord::Base.transaction do
|
|
DB.exec(before_query) if before_query.present?
|
|
builder.query(
|
|
query_embedding: embedding,
|
|
candidates_limit: candidates_limit,
|
|
limit: limit,
|
|
offset: offset,
|
|
)
|
|
end
|
|
rescue PG::Error => e
|
|
Rails.logger.error("Error #{e} querying embeddings for model #{vector_def.display_name}")
|
|
raise MissingEmbeddingError
|
|
end
|
|
|
|
def symmetric_similarity_search(record)
|
|
limit = 200
|
|
before_query = hnsw_search_workaround(limit)
|
|
|
|
builder = DB.build(<<~SQL)
|
|
WITH le_target AS (
|
|
SELECT
|
|
embeddings
|
|
FROM
|
|
#{table}
|
|
WHERE
|
|
model_id = :vid AND
|
|
strategy_id = :vsid AND
|
|
#{target_column} = :target_id
|
|
LIMIT 1
|
|
)
|
|
SELECT #{target_column} FROM (
|
|
SELECT
|
|
#{target_column}, embeddings
|
|
FROM
|
|
#{table}
|
|
/*join*/
|
|
/*where*/
|
|
ORDER BY
|
|
binary_quantize(embeddings)::bit(#{dimensions}) <~> (
|
|
SELECT
|
|
binary_quantize(embeddings)::bit(#{dimensions})
|
|
FROM
|
|
le_target
|
|
LIMIT 1
|
|
)
|
|
LIMIT #{limit}
|
|
) AS widenet
|
|
ORDER BY
|
|
embeddings::halfvec(#{dimensions}) #{pg_function} (
|
|
SELECT
|
|
embeddings::halfvec(#{dimensions})
|
|
FROM
|
|
le_target
|
|
LIMIT 1
|
|
)
|
|
LIMIT #{limit / 2};
|
|
SQL
|
|
|
|
builder.where("model_id = :vid AND strategy_id = :vsid")
|
|
|
|
yield(builder) if block_given?
|
|
|
|
ActiveRecord::Base.transaction do
|
|
DB.exec(before_query) if before_query.present?
|
|
builder.query(vid: vector_def.id, vsid: vector_def.strategy_id, target_id: record.id)
|
|
end
|
|
rescue PG::Error => e
|
|
Rails.logger.error("Error #{e} querying embeddings for model #{vector_def.display_name}")
|
|
raise MissingEmbeddingError
|
|
end
|
|
|
|
def store(record, embedding, digest)
|
|
DB.exec(
|
|
<<~SQL,
|
|
INSERT INTO #{table} (#{target_column}, model_id, model_version, strategy_id, strategy_version, digest, embeddings, created_at, updated_at)
|
|
VALUES (:target_id, :model_id, :model_version, :strategy_id, :strategy_version, :digest, '[:embeddings]', :now, :now)
|
|
ON CONFLICT (model_id, strategy_id, #{target_column})
|
|
DO UPDATE SET
|
|
model_version = :model_version,
|
|
strategy_version = :strategy_version,
|
|
digest = :digest,
|
|
embeddings = '[:embeddings]',
|
|
updated_at = :now
|
|
SQL
|
|
target_id: record.id,
|
|
model_id: vector_def.id,
|
|
model_version: vector_def.version,
|
|
strategy_id: vector_def.strategy_id,
|
|
strategy_version: vector_def.strategy_version,
|
|
digest: digest,
|
|
embeddings: embedding,
|
|
now: Time.zone.now,
|
|
)
|
|
end
|
|
|
|
private
|
|
|
|
def hnsw_search_workaround(limit)
|
|
threshold = limit * 2
|
|
|
|
return "" if threshold < DEFAULT_HNSW_EF_SEARCH
|
|
"SET LOCAL hnsw.ef_search = #{threshold};"
|
|
end
|
|
|
|
delegate :dimensions, :pg_function, to: :vector_def
|
|
end
|
|
end
|
|
end
|