Roman Rizzi 1f1c94e5c6
FEATURE: AI Bot RAG support. (#537)
This PR lets you associate uploads to an AI persona, which we'll split and generate embeddings from. When building the system prompt to get a bot reply, we'll do a similarity search followed by a re-ranking (if available). This will let us find the most relevant fragments from the body of knowledge you associated with the persona, resulting in better, more informed responses.

For now, we'll only allow plain-text files, but this will change in the future.

Commits:

* FEATURE: RAG embeddings for the AI Bot

This first commit introduces a UI where admins can upload text files, which we'll store, split into fragments,
and generate embeddings of. In a next commit, we'll use those to give the bot additional information during
conversations.

* Basic asymmetric similarity search to provide guidance in system prompt

* Fix tests and lint

* Apply reranker to fragments

* Uploads filter, css adjustments and file validations

* Add placeholder for rag fragments

* Update annotations
2024-04-01 13:43:34 -03:00

476 lines
16 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Embeddings
module VectorRepresentations
class Base
class << self
def find_representation(model_name)
# we are explicit here cause the loader may have not
# loaded the subclasses yet
[
DiscourseAi::Embeddings::VectorRepresentations::AllMpnetBaseV2,
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn,
DiscourseAi::Embeddings::VectorRepresentations::Gemini,
DiscourseAi::Embeddings::VectorRepresentations::MultilingualE5Large,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbeddingAda002,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Small,
DiscourseAi::Embeddings::VectorRepresentations::TextEmbedding3Large,
].find { _1.name == model_name }
end
def current_representation(strategy)
find_representation(SiteSetting.ai_embeddings_model).new(strategy)
end
def correctly_configured?
raise NotImplementedError
end
def dependant_setting_names
raise NotImplementedError
end
def configuration_hint
settings = dependant_setting_names
I18n.t(
"discourse_ai.embeddings.configuration.hint",
settings: settings.join(", "),
count: settings.length,
)
end
end
def initialize(strategy)
@strategy = strategy
end
def consider_indexing(memory: "100MB")
[topic_table_name, post_table_name].each do |table_name|
index_name = index_name(table_name)
# Using extension maintainer's recommendation for ivfflat indexes
# Results are not as good as without indexes, but it's much faster
# Disk usage is ~1x the size of the table, so this doubles table total size
count = DB.query_single("SELECT count(*) FROM #{table_name};").first
lists = [count < 1_000_000 ? count / 1000 : Math.sqrt(count).to_i, 10].max
probes = [count < 1_000_000 ? lists / 10 : Math.sqrt(lists).to_i, 1].max
Discourse.cache.write("#{table_name}-probes", probes)
existing_index = DB.query_single(<<~SQL, index_name: index_name).first
SELECT
indexdef
FROM
pg_indexes
WHERE
indexname = :index_name
AND schemaname = 'public'
LIMIT 1
SQL
if !existing_index.present?
Rails.logger.info("Index #{index_name} does not exist, creating...")
return create_index!(table_name, memory, lists, probes)
end
existing_index_age =
DB
.query_single(
"SELECT pg_catalog.obj_description((:index_name)::regclass, 'pg_class');",
index_name: index_name,
)
.first
.to_i || 0
new_rows =
DB.query_single(
"SELECT count(*) FROM #{table_name} WHERE created_at > '#{Time.at(existing_index_age)}';",
).first
existing_lists = existing_index.match(/lists='(\d+)'/)&.captures&.first&.to_i
if existing_index_age > 0 &&
existing_index_age <
(
if SiteSetting.ai_embeddings_semantic_related_topics_enabled
1.hour.ago.to_i
else
1.day.ago.to_i
end
)
if new_rows > 10_000
Rails.logger.info(
"Index #{index_name} is #{existing_index_age} seconds old, and there are #{new_rows} new rows, updating...",
)
return create_index!(table_name, memory, lists, probes)
elsif existing_lists != lists
Rails.logger.info(
"Index #{index_name} already exists, but lists is #{existing_lists} instead of #{lists}, updating...",
)
return create_index!(table_name, memory, lists, probes)
end
end
Rails.logger.info(
"Index #{index_name} kept. #{Time.now.to_i - existing_index_age} seconds old, #{new_rows} new rows, #{existing_lists} lists, #{probes} probes.",
)
end
end
def create_index!(table_name, memory, lists, probes)
tries = 0
index_name = index_name(table_name)
DB.exec("SET work_mem TO '#{memory}';")
DB.exec("SET maintenance_work_mem TO '#{memory}';")
begin
DB.exec(<<~SQL)
DROP INDEX IF EXISTS #{index_name};
CREATE INDEX IF NOT EXISTS
#{index_name}
ON
#{table_name}
USING
ivfflat (embeddings #{pg_index_type})
WITH
(lists = #{lists});
SQL
rescue PG::ProgramLimitExceeded => e
parsed_error = e.message.match(/memory required is (\d+ [A-Z]{2}), ([a-z_]+)/)
if parsed_error[1].present? && parsed_error[2].present?
DB.exec("SET #{parsed_error[2]} TO '#{parsed_error[1].tr(" ", "")}';")
tries += 1
retry if tries < 3
else
raise e
end
end
DB.exec("COMMENT ON INDEX #{index_name} IS '#{Time.now.to_i}';")
DB.exec("RESET work_mem;")
DB.exec("RESET maintenance_work_mem;")
end
def vector_from(text, asymetric: false)
raise NotImplementedError
end
def generate_representation_from(target, persist: true)
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
return if text.blank?
target_column =
case target
when Topic
"topic_id"
when Post
"post_id"
when RagDocumentFragment
"rag_document_fragment_id"
else
raise ArgumentError, "Invalid target type"
end
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
current_digest = DB.query_single(<<~SQL, target_id: target.id).first
SELECT
digest
FROM
#{table_name(target)}
WHERE
#{target_column} = :target_id
LIMIT 1
SQL
return if current_digest == new_digest
vector = vector_from(text)
save_to_db(target, vector, new_digest) if persist
end
def topic_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
topic_id
FROM
#{topic_table_name}
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT 1
SQL
end
def post_id_from_representation(raw_vector)
DB.query_single(<<~SQL, query_embedding: raw_vector).first
SELECT
post_id
FROM
#{post_table_name}
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT 1
SQL
end
def asymmetric_topics_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
#{probes_sql(topic_table_name)}
SELECT
topic_id,
embeddings #{pg_function} '[:query_embedding]' AS distance
FROM
#{topic_table_name}
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT :limit
OFFSET :offset
SQL
if return_distance
results.map { |r| [r.topic_id, r.distance] }
else
results.map(&:topic_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def asymmetric_posts_similarity_search(raw_vector, limit:, offset:, return_distance: false)
results = DB.query(<<~SQL, query_embedding: raw_vector, limit: limit, offset: offset)
#{probes_sql(post_table_name)}
SELECT
post_id,
embeddings #{pg_function} '[:query_embedding]' AS distance
FROM
#{post_table_name}
INNER JOIN
posts AS p ON p.id = post_id
INNER JOIN
topics AS t ON t.id = p.topic_id AND t.archetype = 'regular'
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT :limit
OFFSET :offset
SQL
if return_distance
results.map { |r| [r.post_id, r.distance] }
else
results.map(&:post_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def asymmetric_rag_fragment_similarity_search(
raw_vector,
persona_id:,
limit:,
offset:,
return_distance: false
)
results =
DB.query(
<<~SQL,
#{probes_sql(post_table_name)}
SELECT
rag_document_fragment_id,
embeddings #{pg_function} '[:query_embedding]' AS distance
FROM
#{rag_fragments_table_name}
INNER JOIN
rag_document_fragments AS rdf ON rdf.id = rag_document_fragment_id
WHERE
rdf.ai_persona_id = :persona_id
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT :limit
OFFSET :offset
SQL
query_embedding: raw_vector,
persona_id: persona_id,
limit: limit,
offset: offset,
)
if return_distance
results.map { |r| [r.rag_document_fragment_id, r.distance] }
else
results.map(&:rag_document_fragment_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def symmetric_topics_similarity_search(topic)
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
#{probes_sql(topic_table_name)}
SELECT
topic_id
FROM
#{topic_table_name}
ORDER BY
embeddings #{pg_function} (
SELECT
embeddings
FROM
#{topic_table_name}
WHERE
topic_id = :topic_id
LIMIT 1
)
LIMIT 100
SQL
rescue PG::Error => e
Rails.logger.error(
"Error #{e} querying embeddings for topic #{topic.id} and model #{name}",
)
raise MissingEmbeddingError
end
def topic_table_name
"ai_topic_embeddings_#{id}_#{@strategy.id}"
end
def post_table_name
"ai_post_embeddings_#{id}_#{@strategy.id}"
end
def rag_fragments_table_name
"ai_document_fragment_embeddings_#{id}_#{@strategy.id}"
end
def table_name(target)
case target
when Topic
topic_table_name
when Post
post_table_name
when RagDocumentFragment
rag_fragments_table_name
else
raise ArgumentError, "Invalid target type"
end
end
def index_name(table_name)
"#{table_name}_search"
end
def probes_sql(table_name)
probes = Discourse.cache.read("#{table_name}-probes")
probes.present? ? "SET LOCAL ivfflat.probes TO #{probes};" : ""
end
def name
raise NotImplementedError
end
def dimensions
raise NotImplementedError
end
def max_sequence_length
raise NotImplementedError
end
def id
raise NotImplementedError
end
def pg_function
raise NotImplementedError
end
def version
raise NotImplementedError
end
def tokenizer
raise NotImplementedError
end
def asymmetric_query_prefix
raise NotImplementedError
end
protected
def save_to_db(target, vector, digest)
if target.is_a?(Topic)
DB.exec(
<<~SQL,
INSERT INTO #{topic_table_name} (topic_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:topic_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (topic_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = CURRENT_TIMESTAMP
SQL
topic_id: target.id,
model_version: version,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
)
elsif target.is_a?(Post)
DB.exec(
<<~SQL,
INSERT INTO #{post_table_name} (post_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:post_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (post_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = CURRENT_TIMESTAMP
SQL
post_id: target.id,
model_version: version,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
)
elsif target.is_a?(RagDocumentFragment)
DB.exec(
<<~SQL,
INSERT INTO #{rag_fragments_table_name} (rag_document_fragment_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:fragment_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (rag_document_fragment_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = CURRENT_TIMESTAMP
SQL
fragment_id: target.id,
model_version: version,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
)
else
raise ArgumentError, "Invalid target type"
end
end
def discourse_embeddings_endpoint
if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present?
service =
DiscourseAi::Utils::DnsSrv.lookup(
SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv,
)
"https://#{service.target}:#{service.port}"
else
SiteSetting.ai_embeddings_discourse_service_api_endpoint
end
end
end
end
end
end