mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-02-20 10:25:33 +00:00
Previous to this change we relied on explicit loading for a files in Discourse AI. This had a few downsides: - Busywork whenever you add a file (an extra require relative) - We were not keeping to conventions internally ... some places were OpenAI others are OpenAi - Autoloader did not work which lead to lots of full application broken reloads when developing. This moves all of DiscourseAI into a Zeitwerk compatible structure. It also leaves some minimal amount of manual loading (automation - which is loading into an existing namespace that may or may not be there) To avoid needing /lib/discourse_ai/... we mount a namespace thus we are able to keep /lib pointed at ::DiscourseAi Various files were renamed to get around zeitwerk rules and minimize usage of custom inflections Though we can get custom inflections to work it is not worth it, will require a Discourse core patch which means we create a hard dependency.
239 lines
7.5 KiB
Ruby
239 lines
7.5 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Embeddings
|
|
module VectorRepresentations
|
|
class Base
|
|
def self.current_representation(strategy)
|
|
subclasses.map { _1.new(strategy) }.find { _1.name == SiteSetting.ai_embeddings_model }
|
|
end
|
|
|
|
def initialize(strategy)
|
|
@strategy = strategy
|
|
end
|
|
|
|
def consider_indexing(memory: "100MB")
|
|
# Using extension maintainer's recommendation for ivfflat indexes
|
|
# Results are not as good as without indexes, but it's much faster
|
|
# Disk usage is ~1x the size of the table, so this doubles table total size
|
|
count = DB.query_single("SELECT count(*) FROM #{table_name};").first
|
|
lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max
|
|
probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max
|
|
|
|
existing_index = DB.query_single(<<~SQL, index_name: index_name).first
|
|
SELECT
|
|
indexdef
|
|
FROM
|
|
pg_indexes
|
|
WHERE
|
|
indexname = :index_name
|
|
LIMIT 1
|
|
SQL
|
|
|
|
if !existing_index.present?
|
|
Rails.logger.info("Index #{index_name} does not exist, creating...")
|
|
return create_index!(memory, lists, probes)
|
|
end
|
|
|
|
existing_index_age =
|
|
DB
|
|
.query_single(
|
|
"SELECT pg_catalog.obj_description((:index_name)::regclass, 'pg_class');",
|
|
index_name: index_name,
|
|
)
|
|
.first
|
|
.to_i || 0
|
|
new_rows =
|
|
DB.query_single(
|
|
"SELECT count(*) FROM #{table_name} WHERE created_at > '#{Time.at(existing_index_age)}';",
|
|
).first
|
|
existing_lists = existing_index.match(/lists='(\d+)'/)&.captures&.first&.to_i
|
|
|
|
if existing_index_age > 0 && existing_index_age < 1.hour.ago.to_i
|
|
if new_rows > 10_000
|
|
Rails.logger.info(
|
|
"Index #{index_name} is #{existing_index_age} seconds old, and there are #{new_rows} new rows, updating...",
|
|
)
|
|
return create_index!(memory, lists, probes)
|
|
elsif existing_lists != lists
|
|
Rails.logger.info(
|
|
"Index #{index_name} already exists, but lists is #{existing_lists} instead of #{lists}, updating...",
|
|
)
|
|
return create_index!(memory, lists, probes)
|
|
end
|
|
end
|
|
|
|
Rails.logger.info(
|
|
"Index #{index_name} kept. #{Time.now.to_i - existing_index_age} seconds old, #{new_rows} new rows, #{existing_lists} lists, #{probes} probes.",
|
|
)
|
|
end
|
|
|
|
def create_index!(memory, lists, probes)
|
|
DB.exec("SET work_mem TO '#{memory}';")
|
|
DB.exec("SET maintenance_work_mem TO '#{memory}';")
|
|
DB.exec(<<~SQL)
|
|
DROP INDEX IF EXISTS #{index_name};
|
|
CREATE INDEX IF NOT EXISTS
|
|
#{index_name}
|
|
ON
|
|
#{table_name}
|
|
USING
|
|
ivfflat (embeddings #{pg_index_type})
|
|
WITH
|
|
(lists = #{lists});
|
|
SQL
|
|
DB.exec("COMMENT ON INDEX #{index_name} IS '#{Time.now.to_i}';")
|
|
DB.exec("RESET work_mem;")
|
|
DB.exec("RESET maintenance_work_mem;")
|
|
|
|
database = DB.query_single("SELECT current_database();").first
|
|
DB.exec("ALTER DATABASE #{database} SET ivfflat.probes = #{probes};")
|
|
end
|
|
|
|
def vector_from(text)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def generate_topic_representation_from(target, persist: true)
|
|
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
|
|
|
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
|
current_digest = DB.query_single(<<~SQL, topic_id: target.id).first
|
|
SELECT
|
|
digest
|
|
FROM
|
|
#{table_name}
|
|
WHERE
|
|
topic_id = :topic_id
|
|
LIMIT 1
|
|
SQL
|
|
return if current_digest == new_digest
|
|
|
|
vector = vector_from(text)
|
|
|
|
save_to_db(target, vector, new_digest) if persist
|
|
end
|
|
|
|
def topic_id_from_representation(raw_vector)
|
|
DB.query_single(<<~SQL, query_embedding: raw_vector).first
|
|
SELECT
|
|
topic_id
|
|
FROM
|
|
#{table_name}
|
|
ORDER BY
|
|
embeddings #{pg_function} '[:query_embedding]'
|
|
LIMIT 1
|
|
SQL
|
|
end
|
|
|
|
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
|
|
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
|
|
SELECT
|
|
topic_id,
|
|
embeddings #{pg_function} '[:query_embedding]' AS distance
|
|
FROM
|
|
#{table_name}
|
|
ORDER BY
|
|
embeddings #{pg_function} '[:query_embedding]'
|
|
LIMIT :limit
|
|
OFFSET :offset
|
|
SQL
|
|
|
|
if return_distance
|
|
results.map { |r| [r.topic_id, r.distance] }
|
|
else
|
|
results.map(&:topic_id)
|
|
end
|
|
rescue PG::Error => e
|
|
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
|
|
raise MissingEmbeddingError
|
|
end
|
|
|
|
def symmetric_topics_similarity_search(topic)
|
|
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
|
|
SELECT
|
|
topic_id
|
|
FROM
|
|
#{table_name}
|
|
ORDER BY
|
|
embeddings #{pg_function} (
|
|
SELECT
|
|
embeddings
|
|
FROM
|
|
#{table_name}
|
|
WHERE
|
|
topic_id = :topic_id
|
|
LIMIT 1
|
|
)
|
|
LIMIT 100
|
|
SQL
|
|
rescue PG::Error => e
|
|
Rails.logger.error(
|
|
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
|
|
)
|
|
raise MissingEmbeddingError
|
|
end
|
|
|
|
def table_name
|
|
"ai_topic_embeddings_#{id}_#{@strategy.id}"
|
|
end
|
|
|
|
def index_name
|
|
"#{table_name}_search"
|
|
end
|
|
|
|
def name
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def dimensions
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def max_sequence_length
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def id
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def pg_function
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def version
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def tokenizer
|
|
raise NotImplementedError
|
|
end
|
|
|
|
protected
|
|
|
|
def save_to_db(target, vector, digest)
|
|
DB.exec(
|
|
<<~SQL,
|
|
INSERT INTO #{table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
|
|
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
|
|
ON CONFLICT (topic_id)
|
|
DO UPDATE SET
|
|
model_version = :model_version,
|
|
strategy_version = :strategy_version,
|
|
digest = :digest,
|
|
embeddings = '[:embeddings]',
|
|
updated_at = CURRENT_TIMESTAMP
|
|
SQL
|
|
topic_id: target.id,
|
|
model_version: version,
|
|
strategy_version: @strategy.version,
|
|
digest: digest,
|
|
embeddings: vector,
|
|
)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|