2023-09-05 11:08:23 -03:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module DiscourseAi
|
|
|
|
module Embeddings
|
|
|
|
module VectorRepresentations
|
|
|
|
class Base
|
2024-02-01 16:54:09 -03:00
|
|
|
class << self
|
|
|
|
def find_representation(model_name)
|
|
|
|
# we are explicit here cause the loader may have not
|
|
|
|
# loaded the subclasses yet
|
|
|
|
[
|
|
|
|
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
|
|
|
|
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
|
2024-04-10 17:24:01 -03:00
|
|
|
DiscourseAi::Embeddings::VectorRepresentations::BgeM3,
|
2024-02-01 16:54:09 -03:00
|
|
|
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
|
|
|
|
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
|
|
|
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
|
2024-04-10 17:24:01 -03:00
|
|
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
|
|
|
|
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
|
2024-02-01 16:54:09 -03:00
|
|
|
].find { _1.name == model_name }
|
|
|
|
end
|
|
|
|
|
2024-12-13 10:15:21 -03:00
|
|
|
def current_representation
|
|
|
|
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
|
|
|
|
find_representation(SiteSetting.ai_embeddings_model).new(truncation)
|
2024-02-01 16:54:09 -03:00
|
|
|
end
|
|
|
|
|
|
|
|
def correctly_configured?
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def dependant_setting_names
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def configuration_hint
|
|
|
|
settings = dependant_setting_names
|
|
|
|
I18n.t(
|
|
|
|
"discourse_ai.embeddings.configuration.hint",
|
|
|
|
settings: settings.join(", "),
|
|
|
|
count: settings.length,
|
|
|
|
)
|
|
|
|
end
|
2023-09-05 11:08:23 -03:00
|
|
|
end
|
|
|
|
|
|
|
|
def initialize(strategy)
|
|
|
|
@strategy = strategy
|
|
|
|
end
|
|
|
|
|
2024-03-08 08:02:50 -08:00
|
|
|
def vector_from(text, asymetric: false)
|
2023-09-05 11:08:23 -03:00
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
2024-11-26 14:12:32 -03:00
|
|
|
def gen_bulk_reprensentations(relation)
|
|
|
|
http_pool_size = 100
|
|
|
|
pool =
|
|
|
|
Concurrent::CachedThreadPool.new(
|
|
|
|
min_threads: 0,
|
|
|
|
max_threads: http_pool_size,
|
|
|
|
idletime: 30,
|
|
|
|
)
|
|
|
|
|
2024-12-13 10:15:21 -03:00
|
|
|
schema = DiscourseAi::Embeddings::Schema.for(relation.first.class, vector: self)
|
|
|
|
|
2024-11-26 14:12:32 -03:00
|
|
|
embedding_gen = inference_client
|
|
|
|
promised_embeddings =
|
2024-11-26 15:54:20 -03:00
|
|
|
relation
|
|
|
|
.map do |record|
|
|
|
|
prepared_text = prepare_text(record)
|
|
|
|
next if prepared_text.blank?
|
2024-11-26 14:12:32 -03:00
|
|
|
|
2024-12-04 17:47:28 -03:00
|
|
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(prepared_text)
|
2024-12-13 10:15:21 -03:00
|
|
|
next if schema.find_by_target(record)&.digest == new_digest
|
2024-12-04 17:47:28 -03:00
|
|
|
|
2024-11-26 15:54:20 -03:00
|
|
|
Concurrent::Promises
|
2024-12-04 17:47:28 -03:00
|
|
|
.fulfilled_future(
|
|
|
|
{ target: record, text: prepared_text, digest: new_digest },
|
|
|
|
pool,
|
|
|
|
)
|
2024-11-26 15:54:20 -03:00
|
|
|
.then_on(pool) do |w_prepared_text|
|
2024-12-04 17:47:28 -03:00
|
|
|
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
|
2024-11-26 15:54:20 -03:00
|
|
|
end
|
|
|
|
end
|
|
|
|
.compact
|
2024-11-26 14:12:32 -03:00
|
|
|
|
|
|
|
Concurrent::Promises
|
|
|
|
.zip(*promised_embeddings)
|
|
|
|
.value!
|
2024-12-13 10:15:21 -03:00
|
|
|
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
|
2024-11-26 18:12:03 -03:00
|
|
|
|
|
|
|
pool.shutdown
|
|
|
|
pool.wait_for_termination
|
2024-11-26 14:12:32 -03:00
|
|
|
end
|
|
|
|
|
2023-12-29 12:28:45 -03:00
|
|
|
def generate_representation_from(target, persist: true)
|
2024-11-26 14:12:32 -03:00
|
|
|
text = prepare_text(target)
|
2023-12-29 14:59:08 -03:00
|
|
|
return if text.blank?
|
2023-09-05 11:08:23 -03:00
|
|
|
|
2024-12-13 10:15:21 -03:00
|
|
|
schema = DiscourseAi::Embeddings::Schema.for(target.class, vector: self)
|
|
|
|
|
2023-10-26 12:07:37 -03:00
|
|
|
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
|
2024-12-13 10:15:21 -03:00
|
|
|
return if schema.find_by_target(target)&.digest == new_digest
|
2023-10-26 12:07:37 -03:00
|
|
|
|
|
|
|
vector = vector_from(text)
|
|
|
|
|
2024-12-13 10:15:21 -03:00
|
|
|
schema.store(target, vector, new_digest) if persist
|
2023-12-29 12:28:45 -03:00
|
|
|
end
|
|
|
|
|
|
|
|
def index_name(table_name)
|
2024-08-08 11:55:20 -03:00
|
|
|
"#{table_name}_#{id}_#{@strategy.id}_search"
|
2023-10-26 12:07:37 -03:00
|
|
|
end
|
|
|
|
|
2023-09-05 11:08:23 -03:00
|
|
|
def name
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def dimensions
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def max_sequence_length
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def id
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def pg_function
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def version
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def tokenizer
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
2024-03-08 08:02:50 -08:00
|
|
|
def asymmetric_query_prefix
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
2024-12-13 10:15:21 -03:00
|
|
|
def strategy_id
|
|
|
|
@strategy.id
|
2024-12-04 17:47:28 -03:00
|
|
|
end
|
|
|
|
|
2024-12-13 10:15:21 -03:00
|
|
|
def strategy_version
|
|
|
|
@strategy.version
|
2023-09-05 11:08:23 -03:00
|
|
|
end
|
2024-01-10 19:23:07 -03:00
|
|
|
|
2024-12-13 10:15:21 -03:00
|
|
|
protected
|
|
|
|
|
2024-11-25 13:12:43 -03:00
|
|
|
def inference_client
|
|
|
|
raise NotImplementedError
|
2024-01-10 19:23:07 -03:00
|
|
|
end
|
2024-11-26 14:12:32 -03:00
|
|
|
|
|
|
|
def prepare_text(record)
|
|
|
|
@strategy.prepare_text_from(record, tokenizer, max_sequence_length - 2)
|
|
|
|
end
|
2023-09-05 11:08:23 -03:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|