mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-09 11:48:47 +00:00
* REFACTOR: A Simpler way of interacting with embeddings' tables. This change adds a new abstraction called `Schema`, which acts as a repository that supports the same DB features `VectorRepresentation::Base` has, with the exception that removes the need to have duplicated methods per embeddings table. It is also a bit more flexible when performing a similarity search because you can pass it a block that gives you access to the builder, allowing you to add multiple joins/where conditions.
166 lines
4.6 KiB
Ruby
166 lines
4.6 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Embeddings
|
|
module VectorRepresentations
|
|
class Base
|
|
class << self
|
|
def find_representation(model_name)
|
|
# we are explicit here cause the loader may have not
|
|
# loaded the subclasses yet
|
|
[
|
|
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
|
|
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
|
|
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
|
|
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
|
|
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
|
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
|
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
|
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
|
].find { _1.name == model_name }
|
|
end
|
|
|
|
def current_representation
|
|
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
|
|
find_representation(SiteSetting.ai_embeddings_model).new(truncation)
|
|
end
|
|
|
|
def correctly_configured?
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def dependant_setting_names
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def configuration_hint
|
|
settings = dependant_setting_names
|
|
I18n.t(
|
|
"discourse_ai.embeddings.configuration.hint",
|
|
settings: settings.join(", "),
|
|
count: settings.length,
|
|
)
|
|
end
|
|
end
|
|
|
|
def initialize(strategy)
|
|
@strategy = strategy
|
|
end
|
|
|
|
def vector_from(text, asymetric: false)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def gen_bulk_reprensentations(relation)
|
|
http_pool_size = 100
|
|
pool =
|
|
Concurrent::CachedThreadPool.new(
|
|
min_threads: 0,
|
|
max_threads: http_pool_size,
|
|
idletime: 30,
|
|
)
|
|
|
|
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector: self)
|
|
|
|
embedding_gen = inference_client
|
|
promised_embeddings =
|
|
relation
|
|
.map do |record|
|
|
prepared_text = prepare_text(record)
|
|
next if prepared_text.blank?
|
|
|
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
|
|
next if schema.find_by_target(record)&.digest == new_digest
|
|
|
|
Concurrent::Promises
|
|
.fulfilled_future(
|
|
{ target: record, text: prepared_text, digest: new_digest },
|
|
pool,
|
|
)
|
|
.then_on(pool) do |w_prepared_text|
|
|
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
|
|
end
|
|
end
|
|
.compact
|
|
|
|
Concurrent::Promises
|
|
.zip(*promised_embeddings)
|
|
.value!
|
|
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
|
|
|
|
pool.shutdown
|
|
pool.wait_for_termination
|
|
end
|
|
|
|
def generate_representation_from(target, persist: true)
|
|
text = prepare_text(target)
|
|
return if text.blank?
|
|
|
|
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector: self)
|
|
|
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
|
return if schema.find_by_target(target)&.digest == new_digest
|
|
|
|
vector = vector_from(text)
|
|
|
|
schema.store(target, vector, new_digest) if persist
|
|
end
|
|
|
|
def index_name(table_name)
|
|
"#{table_name}_#{id}_#{@strategy.id}_search"
|
|
end
|
|
|
|
def name
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def dimensions
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def max_sequence_length
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def id
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def pg_function
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def version
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def tokenizer
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def asymmetric_query_prefix
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def strategy_id
|
|
@strategy.id
|
|
end
|
|
|
|
def strategy_version
|
|
@strategy.version
|
|
end
|
|
|
|
protected
|
|
|
|
def inference_client
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prepare_text(record)
|
|
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|