FEATURE: Make tool support polymorphic (#798)

Polymorphic RAG means that we will be able to access RAG fragments both from AiPersona and AiCustomTool

In turn this gives us support for richer RAG implementations.
This commit is contained in:
Sam 2024-09-16 08:17:17 +10:00 committed by GitHub
parent b16390ae2a
commit 03eccbe392
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 132 additions and 61 deletions

View File

@ -41,7 +41,7 @@ module DiscourseAi
def create def create
ai_persona = AiPersona.new(ai_persona_params.except(:rag_uploads)) ai_persona = AiPersona.new(ai_persona_params.except(:rag_uploads))
if ai_persona.save if ai_persona.save
RagDocumentFragment.link_persona_and_uploads(ai_persona, attached_upload_ids) RagDocumentFragment.link_target_and_uploads(ai_persona, attached_upload_ids)
render json: { render json: {
ai_persona: LocalizedAiPersonaSerializer.new(ai_persona, root: false), ai_persona: LocalizedAiPersonaSerializer.new(ai_persona, root: false),
@ -59,7 +59,7 @@ module DiscourseAi
def update def update
if @ai_persona.update(ai_persona_params.except(:rag_uploads)) if @ai_persona.update(ai_persona_params.except(:rag_uploads))
RagDocumentFragment.update_persona_uploads(@ai_persona, attached_upload_ids) RagDocumentFragment.update_target_uploads(@ai_persona, attached_upload_ids)
render json: LocalizedAiPersonaSerializer.new(@ai_persona, root: false) render json: LocalizedAiPersonaSerializer.new(@ai_persona, root: false)
else else

View File

@ -9,17 +9,24 @@ module ::Jobs
# TODO(roman): Add a way to automatically recover from errors, resulting in unindexed uploads. # TODO(roman): Add a way to automatically recover from errors, resulting in unindexed uploads.
def execute(args) def execute(args)
return if (upload = Upload.find_by(id: args[:upload_id])).nil? return if (upload = Upload.find_by(id: args[:upload_id])).nil?
return if (ai_persona = AiPersona.find_by(id: args[:ai_persona_id])).nil?
target_type = args[:target_type]
target_id = args[:target_id]
return if !target_type || !target_id
target = target_type.constantize.find_by(id: target_id)
return if !target
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep = vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
tokenizer = vector_rep.tokenizer tokenizer = vector_rep.tokenizer
chunk_tokens = ai_persona.rag_chunk_tokens chunk_tokens = target.rag_chunk_tokens
overlap_tokens = ai_persona.rag_chunk_overlap_tokens overlap_tokens = target.rag_chunk_overlap_tokens
fragment_ids = RagDocumentFragment.where(ai_persona: ai_persona, upload: upload).pluck(:id) fragment_ids = RagDocumentFragment.where(target: target, upload: upload).pluck(:id)
# Check if this is the first time we process this upload. # Check if this is the first time we process this upload.
if fragment_ids.empty? if fragment_ids.empty?
@ -39,7 +46,7 @@ module ::Jobs
overlap_tokens: overlap_tokens, overlap_tokens: overlap_tokens,
) do |chunk, metadata| ) do |chunk, metadata|
fragment_ids << RagDocumentFragment.create!( fragment_ids << RagDocumentFragment.create!(
ai_persona: ai_persona, target: target,
fragment: chunk, fragment: chunk,
fragment_number: idx + 1, fragment_number: idx + 1,
upload: upload, upload: upload,

View File

@ -17,10 +17,10 @@ module ::Jobs
fragments.map { |fragment| vector_rep.generate_representation_from(fragment) } fragments.map { |fragment| vector_rep.generate_representation_from(fragment) }
last_fragment = fragments.last last_fragment = fragments.last
ai_persona = last_fragment.ai_persona target = last_fragment.target
upload = last_fragment.upload upload = last_fragment.upload
indexing_status = RagDocumentFragment.indexing_status(ai_persona, [upload])[upload.id] indexing_status = RagDocumentFragment.indexing_status(target, [upload])[upload.id]
RagDocumentFragment.publish_status(upload, indexing_status) RagDocumentFragment.publish_status(upload, indexing_status)
end end
end end

View File

@ -20,6 +20,7 @@ class AiPersona < ActiveRecord::Base
validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 } validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 }
validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 } validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 }
validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 } validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 }
has_many :rag_document_fragments, dependent: :destroy, as: :target
belongs_to :created_by, class_name: "User" belongs_to :created_by, class_name: "User"
belongs_to :user belongs_to :user
@ -27,12 +28,7 @@ class AiPersona < ActiveRecord::Base
has_many :upload_references, as: :target, dependent: :destroy has_many :upload_references, as: :target, dependent: :destroy
has_many :uploads, through: :upload_references has_many :uploads, through: :upload_references
has_many :rag_document_fragment, dependent: :destroy
has_many :rag_document_fragments, through: :ai_persona_rag_document_fragments
before_destroy :ensure_not_system before_destroy :ensure_not_system
before_update :regenerate_rag_fragments before_update :regenerate_rag_fragments
def self.persona_cache def self.persona_cache
@ -230,7 +226,7 @@ class AiPersona < ActiveRecord::Base
def regenerate_rag_fragments def regenerate_rag_fragments
if rag_chunk_tokens_changed? || rag_chunk_overlap_tokens_changed? if rag_chunk_tokens_changed? || rag_chunk_overlap_tokens_changed?
RagDocumentFragment.where(ai_persona: self).delete_all RagDocumentFragment.where(target: self).delete_all
end end
end end

View File

@ -1,32 +1,40 @@
# frozen_string_literal: true # frozen_string_literal: true
class RagDocumentFragment < ActiveRecord::Base class RagDocumentFragment < ActiveRecord::Base
# TODO Jan 2025 - remove
self.ignored_columns = %i[ai_persona_id]
belongs_to :upload belongs_to :upload
belongs_to :ai_persona belongs_to :target, polymorphic: true
class << self class << self
def link_persona_and_uploads(persona, upload_ids) def link_target_and_uploads(target, upload_ids)
return if persona.blank? return if target.blank?
return if upload_ids.blank? return if upload_ids.blank?
return if !SiteSetting.ai_embeddings_enabled? return if !SiteSetting.ai_embeddings_enabled?
UploadReference.ensure_exist!(upload_ids: upload_ids, target: persona) UploadReference.ensure_exist!(upload_ids: upload_ids, target: target)
upload_ids.each do |upload_id| upload_ids.each do |upload_id|
Jobs.enqueue(:digest_rag_upload, ai_persona_id: persona.id, upload_id: upload_id) Jobs.enqueue(
:digest_rag_upload,
target_id: target.id,
target_type: target.class.to_s,
upload_id: upload_id,
)
end end
end end
def update_persona_uploads(persona, upload_ids) def update_target_uploads(target, upload_ids)
return if persona.blank? return if target.blank?
return if !SiteSetting.ai_embeddings_enabled? return if !SiteSetting.ai_embeddings_enabled?
if upload_ids.blank? if upload_ids.blank?
RagDocumentFragment.where(ai_persona: persona).destroy_all RagDocumentFragment.where(target: target).destroy_all
UploadReference.where(target: persona).destroy_all UploadReference.where(target: target).destroy_all
else else
RagDocumentFragment.where(ai_persona: persona).where.not(upload_id: upload_ids).destroy_all RagDocumentFragment.where(target: target).where.not(upload_id: upload_ids).destroy_all
link_persona_and_uploads(persona, upload_ids) link_target_and_uploads(target, upload_ids)
end end
end end
@ -37,18 +45,25 @@ class RagDocumentFragment < ActiveRecord::Base
embeddings_table = vector_rep.rag_fragments_table_name embeddings_table = vector_rep.rag_fragments_table_name
results = DB.query(<<~SQL, persona_id: persona.id, upload_ids: uploads.map(&:id)) results =
DB.query(
<<~SQL,
SELECT SELECT
uploads.id, uploads.id,
SUM(CASE WHEN (rdf.upload_id IS NOT NULL) THEN 1 ELSE 0 END) AS total, SUM(CASE WHEN (rdf.upload_id IS NOT NULL) THEN 1 ELSE 0 END) AS total,
SUM(CASE WHEN (eft.rag_document_fragment_id IS NOT NULL) THEN 1 ELSE 0 END) as indexed, SUM(CASE WHEN (eft.rag_document_fragment_id IS NOT NULL) THEN 1 ELSE 0 END) as indexed,
SUM(CASE WHEN (rdf.upload_id IS NOT NULL AND eft.rag_document_fragment_id IS NULL) THEN 1 ELSE 0 END) as left SUM(CASE WHEN (rdf.upload_id IS NOT NULL AND eft.rag_document_fragment_id IS NULL) THEN 1 ELSE 0 END) as left
FROM uploads FROM uploads
LEFT OUTER JOIN rag_document_fragments rdf ON uploads.id = rdf.upload_id AND rdf.ai_persona_id = :persona_id LEFT OUTER JOIN rag_document_fragments rdf ON uploads.id = rdf.upload_id AND rdf.target_id = :target_id
AND rdf.target_type = :target_type
LEFT OUTER JOIN #{embeddings_table} eft ON rdf.id = eft.rag_document_fragment_id LEFT OUTER JOIN #{embeddings_table} eft ON rdf.id = eft.rag_document_fragment_id
WHERE uploads.id IN (:upload_ids) WHERE uploads.id IN (:upload_ids)
GROUP BY uploads.id GROUP BY uploads.id
SQL SQL
target_id: persona.id,
target_type: persona.class.to_s,
upload_ids: uploads.map(&:id),
)
results.reduce({}) do |acc, r| results.reduce({}) do |acc, r|
acc[r.id] = { total: r.total, indexed: r.indexed, left: r.left } acc[r.id] = { total: r.total, indexed: r.indexed, left: r.left }
@ -78,4 +93,10 @@ end
# created_at :datetime not null # created_at :datetime not null
# updated_at :datetime not null # updated_at :datetime not null
# metadata :text # metadata :text
# target_id :integer
# target_type :string(800)
#
# Indexes
#
# index_rag_document_fragments_on_target_type_and_target_id (target_type,target_id)
# #

View File

@ -0,0 +1,9 @@
# frozen_string_literal: true
class AddTargetToRagDocumentFragment < ActiveRecord::Migration[7.1]
def change
add_column :rag_document_fragments, :target_id, :integer, null: true
add_column :rag_document_fragments, :target_type, :string, limit: 800, null: true
add_index :rag_document_fragments, %i[target_type target_id]
end
end

View File

@ -0,0 +1,22 @@
# frozen_string_literal: true
class DropPersonaIdFromRagDocumentFragments < ActiveRecord::Migration[7.1]
def change
execute <<~SQL
UPDATE rag_document_fragments
SET
target_type = 'AiPersona',
target_id = ai_persona_id
WHERE ai_persona_id IS NOT NULL
SQL
# unlikely but lets be safe
execute <<~SQL
DELETE FROM rag_document_fragments
WHERE target_id IS NULL OR target_type IS NULL
SQL
remove_column :rag_document_fragments, :ai_persona_id
change_column_null :rag_document_fragments, :target_id, false
change_column_null :rag_document_fragments, :target_type, false
end
end

View File

@ -288,7 +288,8 @@ module DiscourseAi
candidate_fragment_ids = candidate_fragment_ids =
vector_rep.asymmetric_rag_fragment_similarity_search( vector_rep.asymmetric_rag_fragment_similarity_search(
interactions_vector, interactions_vector,
persona_id: id, target_type: "AiPersona",
target_id: id,
limit: limit:
( (
if reranker.reranker_configured? if reranker.reranker_configured?

View File

@ -280,7 +280,8 @@ module DiscourseAi
def asymmetric_rag_fragment_similarity_search( def asymmetric_rag_fragment_similarity_search(
raw_vector, raw_vector,
persona_id:, target_id:,
target_type:,
limit:, limit:,
offset:, offset:,
return_distance: false return_distance: false
@ -299,14 +300,16 @@ module DiscourseAi
WHERE WHERE
model_id = #{id} AND model_id = #{id} AND
strategy_id = #{@strategy.id} AND strategy_id = #{@strategy.id} AND
rdf.ai_persona_id = :persona_id rdf.target_id = :target_id AND
rdf.target_type = :target_type
ORDER BY ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit LIMIT :limit
OFFSET :offset OFFSET :offset
SQL SQL
query_embedding: raw_vector, query_embedding: raw_vector,
persona_id: persona_id, target_id: target_id,
target_type: target_type,
limit: limit, limit: limit,
offset: offset, offset: offset,
) )

View File

@ -41,7 +41,11 @@ RSpec.describe Jobs::DigestRagUpload do
# be explicit here about chunking strategy # be explicit here about chunking strategy
persona.update!(rag_chunk_tokens: 100, rag_chunk_overlap_tokens: 10) persona.update!(rag_chunk_tokens: 100, rag_chunk_overlap_tokens: 10)
described_class.new.execute(upload_id: upload_with_metadata.id, ai_persona_id: persona.id) described_class.new.execute(
upload_id: upload_with_metadata.id,
target_id: persona.id,
target_type: persona.class.to_s,
)
parsed = +"" parsed = +""
first = true first = true
@ -66,7 +70,11 @@ RSpec.describe Jobs::DigestRagUpload do
before { File.expects(:open).returns(document_file) } before { File.expects(:open).returns(document_file) }
it "splits an upload into chunks" do it "splits an upload into chunks" do
subject.execute(upload_id: upload.id, ai_persona_id: persona.id) subject.execute(
upload_id: upload.id,
target_id: persona.id,
target_type: persona.class.to_s,
)
created_fragment = RagDocumentFragment.last created_fragment = RagDocumentFragment.last
@ -76,19 +84,23 @@ RSpec.describe Jobs::DigestRagUpload do
end end
it "queue jobs to generate embeddings for each fragment" do it "queue jobs to generate embeddings for each fragment" do
expect { subject.execute(upload_id: upload.id, ai_persona_id: persona.id) }.to change( expect {
Jobs::GenerateRagEmbeddings.jobs, subject.execute(
:size, upload_id: upload.id,
).by(1) target_id: persona.id,
target_type: persona.class.to_s,
)
}.to change(Jobs::GenerateRagEmbeddings.jobs, :size).by(1)
end end
end end
it "doesn't generate new fragments if we already processed the upload" do it "doesn't generate new fragments if we already processed the upload" do
Fabricate(:rag_document_fragment, upload: upload, ai_persona: persona) Fabricate(:rag_document_fragment, upload: upload, target: persona)
previous_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count
subject.execute(upload_id: upload.id, ai_persona_id: persona.id) previous_count = RagDocumentFragment.where(upload: upload, target: persona).count
updated_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count
subject.execute(upload_id: upload.id, target_id: persona.id, target_type: persona.class.to_s)
updated_count = RagDocumentFragment.where(upload: upload, target: persona).count
expect(updated_count).to eq(previous_count) expect(updated_count).to eq(previous_count)
end end

View File

@ -11,8 +11,8 @@ RSpec.describe Jobs::GenerateRagEmbeddings do
fab!(:ai_persona) fab!(:ai_persona)
fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, ai_persona: ai_persona) } fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, target: ai_persona) }
fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, ai_persona: ai_persona) } fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, target: ai_persona) }
before do before do
SiteSetting.ai_embeddings_enabled = true SiteSetting.ai_embeddings_enabled = true

View File

@ -395,7 +395,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
candidate_ids << Fabricate( candidate_ids << Fabricate(
:rag_document_fragment, :rag_document_fragment,
fragment: "fragment-n#{i}", fragment: "fragment-n#{i}",
ai_persona_id: ai_persona.id, target_id: ai_persona.id,
target_type: "AiPersona",
upload: upload, upload: upload,
).id ).id
end end

View File

@ -55,7 +55,7 @@ RSpec.describe AiPersona do
id = id =
RagDocumentFragment.create!( RagDocumentFragment.create!(
ai_persona: persona, target: persona,
fragment: "test", fragment: "test",
fragment_number: 1, fragment_number: 1,
upload: Fabricate(:upload), upload: Fabricate(:upload),

View File

@ -12,14 +12,14 @@ RSpec.describe RagDocumentFragment do
describe ".link_uploads_and_persona" do describe ".link_uploads_and_persona" do
it "does nothing if there is no persona" do it "does nothing if there is no persona" do
expect { described_class.link_persona_and_uploads(nil, [upload_1.id]) }.not_to change( expect { described_class.link_target_and_uploads(nil, [upload_1.id]) }.not_to change(
Jobs::DigestRagUpload.jobs, Jobs::DigestRagUpload.jobs,
:size, :size,
) )
end end
it "does nothing if there are no uploads" do it "does nothing if there are no uploads" do
expect { described_class.link_persona_and_uploads(persona, []) }.not_to change( expect { described_class.link_target_and_uploads(persona, []) }.not_to change(
Jobs::DigestRagUpload.jobs, Jobs::DigestRagUpload.jobs,
:size, :size,
) )
@ -27,12 +27,12 @@ RSpec.describe RagDocumentFragment do
it "queues a job for each upload to generate fragments" do it "queues a job for each upload to generate fragments" do
expect { expect {
described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id]) described_class.link_target_and_uploads(persona, [upload_1.id, upload_2.id])
}.to change(Jobs::DigestRagUpload.jobs, :size).by(2) }.to change(Jobs::DigestRagUpload.jobs, :size).by(2)
end end
it "creates references between the persona an each upload" do it "creates references between the persona an each upload" do
described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id]) described_class.link_target_and_uploads(persona, [upload_1.id, upload_2.id])
refs = UploadReference.where(target: persona).pluck(:upload_id) refs = UploadReference.where(target: persona).pluck(:upload_id)
@ -40,26 +40,25 @@ RSpec.describe RagDocumentFragment do
end end
end end
describe ".update_persona_uploads" do describe ".update_target_uploads" do
it "does nothing if there is no persona" do it "does nothing if there is no persona" do
expect { described_class.update_persona_uploads(nil, [upload_1.id]) }.not_to change( expect { described_class.update_target_uploads(nil, [upload_1.id]) }.not_to change(
Jobs::DigestRagUpload.jobs, Jobs::DigestRagUpload.jobs,
:size, :size,
) )
end end
it "deletes the fragment if its not present in the uploads list" do it "deletes the fragment if its not present in the uploads list" do
fragment = Fabricate(:rag_document_fragment, ai_persona: persona) fragment = Fabricate(:rag_document_fragment, target: persona)
described_class.update_persona_uploads(persona, []) described_class.update_target_uploads(persona, [])
expect { fragment.reload }.to raise_error(ActiveRecord::RecordNotFound) expect { fragment.reload }.to raise_error(ActiveRecord::RecordNotFound)
end end
it "delete references between the upload and the persona" do it "delete references between the upload and the persona" do
described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id]) described_class.link_target_and_uploads(persona, [upload_1.id, upload_2.id])
described_class.update_target_uploads(persona, [upload_2.id])
described_class.update_persona_uploads(persona, [upload_2.id])
refs = UploadReference.where(target: persona).pluck(:upload_id) refs = UploadReference.where(target: persona).pluck(:upload_id)
@ -67,7 +66,7 @@ RSpec.describe RagDocumentFragment do
end end
it "queues jobs to generate new fragments" do it "queues jobs to generate new fragments" do
expect { described_class.update_persona_uploads(persona, [upload_1.id]) }.to change( expect { described_class.update_target_uploads(persona, [upload_1.id]) }.to change(
Jobs::DigestRagUpload.jobs, Jobs::DigestRagUpload.jobs,
:size, :size,
).by(1) ).by(1)
@ -81,11 +80,11 @@ RSpec.describe RagDocumentFragment do
end end
fab!(:rag_document_fragment_1) do fab!(:rag_document_fragment_1) do
Fabricate(:rag_document_fragment, upload: upload_1, ai_persona: persona) Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
end end
fab!(:rag_document_fragment_2) do fab!(:rag_document_fragment_2) do
Fabricate(:rag_document_fragment, upload: upload_1, ai_persona: persona) Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
end end
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }

View File

@ -139,7 +139,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
it "includes rag uploads for each persona" do it "includes rag uploads for each persona" do
upload = Fabricate(:upload) upload = Fabricate(:upload)
RagDocumentFragment.link_persona_and_uploads(ai_persona, [upload.id]) RagDocumentFragment.link_target_and_uploads(ai_persona, [upload.id])
get "/admin/plugins/discourse-ai/ai-personas/#{ai_persona.id}.json" get "/admin/plugins/discourse-ai/ai-personas/#{ai_persona.id}.json"
expect(response).to be_successful expect(response).to be_successful