discourse-ai/lib/embeddings/vector_representations/base.rb

359 lines
12 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class Base
class << self
def find_representation(model_name)
# we are explicit here cause the loader may have not
# loaded the subclasses yet
[
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
].find { _1.name == model_name }
end
def current_representation(strategy)
find_representation(SiteSetting.ai_embeddings_model).new(strategy)
end
def correctly_configured?
raise NotImplementedError
end
def dependant_setting_names
raise NotImplementedError
end
def configuration_hint
settings = dependant_setting_names
I18n.t(
"discourse_ai.embeddings.configuration.hint",
settings: settings.join(", "),
count: settings.length,
)
end
end
def initialize(strategy)
@strategy = strategy
end
def consider_indexing(memory: "100MB")
[topic_table_name, post_table_name].each do |table_name|
index_name = index_name(table_name)
# Using extension maintainer's recommendation for ivfflat indexes
# Results are not as good as without indexes, but it's much faster
# Disk usage is ~1x the size of the table, so this doubles table total size
count = DB.query_single("SELECT count(*) FROM #{table_name};").first
lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max
probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max
existing_index = DB.query_single(<<~SQL, index_name: index_name).first
SELECT
indexdef
FROM
pg_indexes
WHERE
indexname = :index_name
AND schemaname = 'public'
LIMIT 1
SQL
if !existing_index.present?
Rails.logger.info("Index #{index_name} does not exist, creating...")
return create_index!(table_name, memory, lists, probes)
end
existing_index_age =
DB
.query_single(
"SELECT pg_catalog.obj_description((:index_name)::regclass, 'pg_class');",
index_name: index_name,
)
.first
.to_i || 0
new_rows =
DB.query_single(
"SELECT count(*) FROM #{table_name} WHERE created_at > '#{Time.at(existing_index_age)}';",
).first
existing_lists = existing_index.match(/lists='(\d+)'/)&.captures&.first&.to_i
if existing_index_age > 0 && existing_index_age < 1.hour.ago.to_i
if new_rows > 10_000
Rails.logger.info(
"Index #{index_name} is #{existing_index_age} seconds old, and there are #{new_rows} new rows, updating...",
)
return create_index!(table_name, memory, lists, probes)
elsif existing_lists != lists
Rails.logger.info(
"Index #{index_name} already exists, but lists is #{existing_lists} instead of #{lists}, updating...",
)
return create_index!(table_name, memory, lists, probes)
end
end
Rails.logger.info(
"Index #{index_name} kept. #{Time.now.to_i - existing_index_age} seconds old, #{new_rows} new rows, #{existing_lists} lists, #{probes} probes.",
)
end
end
def create_index!(table_name, memory, lists, probes)
tries = 0
index_name = index_name(table_name)
DB.exec("SET work_mem TO '#{memory}';")
DB.exec("SET maintenance_work_mem TO '#{memory}';")
begin
DB.exec(<<~SQL)
DROP INDEX IF EXISTS #{index_name};
CREATE INDEX IF NOT EXISTS
#{index_name}
ON
#{table_name}
USING
ivfflat (embeddings #{pg_index_type})
WITH
(lists = #{lists});
SQL
rescue PG::ProgramLimitExceeded => e
parsed_error = e.message.match(/memory required is (\d+ [A-Z]{2}), ([a-z_]+)/)
if parsed_error[1].present? && parsed_error[2].present?
DB.exec("SET #{parsed_error[2]} TO '#{parsed_error[1].tr(" ", "")}';")
tries += 1
retry if tries < 3
else
raise e
end
end
DB.exec("COMMENT ON INDEX #{index_name} IS '#{Time.now.to_i}';")
DB.exec("RESET work_mem;")
DB.exec("RESET maintenance_work_mem;")
database = DB.query_single("SELECT current_database();").first
# This is a global setting, if we set it based on post count
# we will be unable to use the index for topics
# Hopefully https://github.com/pgvector/pgvector/issues/235 will make this better
if table_name == topic_table_name
DB.exec("ALTER DATABASE #{database} SET ivfflat.probes = #{probes};")
end
end
def vector_from(text)
raise NotImplementedError
end
def generate_representation_from(target, persist: true)
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
return if text.blank?
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
current_digest = DB.query_single(<<~SQL, target_id: target.id).first
SELECT
digest
FROM
#{table_name(target)}
WHERE
#{target.is_a?(Topic) ? "topic_id" : "post_id"} = :target_id
LIMIT 1
SQL
return if current_digest == new_digest
vector = vector_from(text)
save_to_db(target, vector, new_digest) if persist
end
def topic_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
topic_id
FROM
#{topic_table_name}
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT 1
SQL
end
def post_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
post_id
FROM
#{post_table_name}
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT 1
SQL
end
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
SELECT
topic_id,
embeddings #{pg_function} '[:query_embedding]' AS distance
FROM
#{topic_table_name}
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT :limit
OFFSET :offset
SQL
if return_distance
results.map { |r| [r.topic_id, r.distance] }
else
results.map(&:topic_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def symmetric_topics_similarity_search(topic)
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
SELECT
topic_id
FROM
#{topic_table_name}
ORDER BY
embeddings #{pg_function} (
SELECT
embeddings
FROM
#{topic_table_name}
WHERE
topic_id = :topic_id
LIMIT 1
)
LIMIT 100
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
)
raise MissingEmbeddingError
end
def topic_table_name
"ai_topic_embeddings_#{id}_#{@strategy.id}"
end
def post_table_name
"ai_post_embeddings_#{id}_#{@strategy.id}"
end
def table_name(target)
case target
when Topic
topic_table_name
when Post
post_table_name
else
raise ArgumentError, "Invalid target type"
end
end
def index_name(table_name)
"#{table_name}_search"
end
def name
raise NotImplementedError
end
def dimensions
raise NotImplementedError
end
def max_sequence_length
raise NotImplementedError
end
def id
raise NotImplementedError
end
def pg_function
raise NotImplementedError
end
def version
raise NotImplementedError
end
def tokenizer
raise NotImplementedError
end
protected
def save_to_db(target, vector, digest)
if target.is_a?(Topic)
DB.exec(
<<~SQL,
INSERT INTO #{topic_table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (topic_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = CURRENT_TIMESTAMP
SQL
topic_id: target.id,
model_version: version,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
)
elsif target.is_a?(Post)
DB.exec(
<<~SQL,
INSERT INTO #{post_table_name} (post_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:post_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (post_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = CURRENT_TIMESTAMP
SQL
post_id: target.id,
model_version: version,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
)
else
raise ArgumentError, "Invalid target type"
end
end
def discourse_embeddings_endpoint
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_embeddings_discourse_service_api_endpoint
end
end
end
end
end
end