FEATURE: Make tool support polymorphic (#798)

Polymorphic RAG means that we will be able to access RAG fragments both from AiPersona and AiCustomTool

In turn this gives us support for richer RAG implementations.
This commit is contained in:
Sam 2024-09-16 08:17:17 +10:00 committed by GitHub
parent b16390ae2a
commit 03eccbe392
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 132 additions and 61 deletions

View File

@ -41,7 +41,7 @@ module DiscourseAi
def create
ai_persona = AiPersona.new(ai_persona_params.except(:rag_uploads))
if ai_persona.save
RagDocumentFragment.link_persona_and_uploads(ai_persona, attached_upload_ids)
RagDocumentFragment.link_target_and_uploads(ai_persona, attached_upload_ids)
render json: {
ai_persona: LocalizedAiPersonaSerializer.new(ai_persona, root: false),
@ -59,7 +59,7 @@ module DiscourseAi
def update
if @ai_persona.update(ai_persona_params.except(:rag_uploads))
RagDocumentFragment.update_persona_uploads(@ai_persona, attached_upload_ids)
RagDocumentFragment.update_target_uploads(@ai_persona, attached_upload_ids)
render json: LocalizedAiPersonaSerializer.new(@ai_persona, root: false)
else

View File

@ -9,17 +9,24 @@ module ::Jobs
# TODO(roman): Add a way to automatically recover from errors, resulting in unindexed uploads.
def execute(args)
return if (upload = Upload.find_by(id: args[:upload_id])).nil?
return if (ai_persona = AiPersona.find_by(id: args[:ai_persona_id])).nil?
target_type = args[:target_type]
target_id = args[:target_id]
return if !target_type || !target_id
target = target_type.constantize.find_by(id: target_id)
return if !target
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
tokenizer = vector_rep.tokenizer
chunk_tokens = ai_persona.rag_chunk_tokens
overlap_tokens = ai_persona.rag_chunk_overlap_tokens
chunk_tokens = target.rag_chunk_tokens
overlap_tokens = target.rag_chunk_overlap_tokens
fragment_ids = RagDocumentFragment.where(ai_persona: ai_persona, upload: upload).pluck(:id)
fragment_ids = RagDocumentFragment.where(target: target, upload: upload).pluck(:id)
# Check if this is the first time we process this upload.
if fragment_ids.empty?
@ -39,7 +46,7 @@ module ::Jobs
overlap_tokens: overlap_tokens,
) do |chunk, metadata|
fragment_ids << RagDocumentFragment.create!(
ai_persona: ai_persona,
target: target,
fragment: chunk,
fragment_number: idx + 1,
upload: upload,

View File

@ -17,10 +17,10 @@ module ::Jobs
fragments.map { |fragment| vector_rep.generate_representation_from(fragment) }
last_fragment = fragments.last
ai_persona = last_fragment.ai_persona
target = last_fragment.target
upload = last_fragment.upload
indexing_status = RagDocumentFragment.indexing_status(ai_persona, [upload])[upload.id]
indexing_status = RagDocumentFragment.indexing_status(target, [upload])[upload.id]
RagDocumentFragment.publish_status(upload, indexing_status)
end
end

View File

@ -20,6 +20,7 @@ class AiPersona < ActiveRecord::Base
validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 }
validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 }
validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 }
has_many :rag_document_fragments, dependent: :destroy, as: :target
belongs_to :created_by, class_name: "User"
belongs_to :user
@ -27,12 +28,7 @@ class AiPersona < ActiveRecord::Base
has_many :upload_references, as: :target, dependent: :destroy
has_many :uploads, through: :upload_references
has_many :rag_document_fragment, dependent: :destroy
has_many :rag_document_fragments, through: :ai_persona_rag_document_fragments
before_destroy :ensure_not_system
before_update :regenerate_rag_fragments
def self.persona_cache
@ -230,7 +226,7 @@ class AiPersona < ActiveRecord::Base
def regenerate_rag_fragments
if rag_chunk_tokens_changed? || rag_chunk_overlap_tokens_changed?
RagDocumentFragment.where(ai_persona: self).delete_all
RagDocumentFragment.where(target: self).delete_all
end
end

View File

@ -1,32 +1,40 @@
# frozen_string_literal: true
class RagDocumentFragment < ActiveRecord::Base
# TODO Jan 2025 - remove
self.ignored_columns = %i[ai_persona_id]
belongs_to :upload
belongs_to :ai_persona
belongs_to :target, polymorphic: true
class << self
def link_persona_and_uploads(persona, upload_ids)
return if persona.blank?
def link_target_and_uploads(target, upload_ids)
return if target.blank?
return if upload_ids.blank?
return if !SiteSetting.ai_embeddings_enabled?
UploadReference.ensure_exist!(upload_ids: upload_ids, target: persona)
UploadReference.ensure_exist!(upload_ids: upload_ids, target: target)
upload_ids.each do |upload_id|
Jobs.enqueue(:digest_rag_upload, ai_persona_id: persona.id, upload_id: upload_id)
Jobs.enqueue(
:digest_rag_upload,
target_id: target.id,
target_type: target.class.to_s,
upload_id: upload_id,
)
end
end
def update_persona_uploads(persona, upload_ids)
return if persona.blank?
def update_target_uploads(target, upload_ids)
return if target.blank?
return if !SiteSetting.ai_embeddings_enabled?
if upload_ids.blank?
RagDocumentFragment.where(ai_persona: persona).destroy_all
UploadReference.where(target: persona).destroy_all
RagDocumentFragment.where(target: target).destroy_all
UploadReference.where(target: target).destroy_all
else
RagDocumentFragment.where(ai_persona: persona).where.not(upload_id: upload_ids).destroy_all
link_persona_and_uploads(persona, upload_ids)
RagDocumentFragment.where(target: target).where.not(upload_id: upload_ids).destroy_all
link_target_and_uploads(target, upload_ids)
end
end
@ -37,18 +45,25 @@ class RagDocumentFragment < ActiveRecord::Base
embeddings_table = vector_rep.rag_fragments_table_name
results = DB.query(<<~SQL, persona_id: persona.id, upload_ids: uploads.map(&:id))
results =
DB.query(
<<~SQL,
SELECT
uploads.id,
SUM(CASE WHEN (rdf.upload_id IS NOT NULL) THEN 1 ELSE 0 END) AS total,
SUM(CASE WHEN (eft.rag_document_fragment_id IS NOT NULL) THEN 1 ELSE 0 END) as indexed,
SUM(CASE WHEN (rdf.upload_id IS NOT NULL AND eft.rag_document_fragment_id IS NULL) THEN 1 ELSE 0 END) as left
FROM uploads
LEFT OUTER JOIN rag_document_fragments rdf ON uploads.id = rdf.upload_id AND rdf.ai_persona_id = :persona_id
LEFT OUTER JOIN rag_document_fragments rdf ON uploads.id = rdf.upload_id AND rdf.target_id = :target_id
AND rdf.target_type = :target_type
LEFT OUTER JOIN #{embeddings_table} eft ON rdf.id = eft.rag_document_fragment_id
WHERE uploads.id IN (:upload_ids)
GROUP BY uploads.id
SQL
target_id: persona.id,
target_type: persona.class.to_s,
upload_ids: uploads.map(&:id),
)
results.reduce({}) do |acc, r|
acc[r.id] = { total: r.total, indexed: r.indexed, left: r.left }
@ -78,4 +93,10 @@ end
# created_at :datetime not null
# updated_at :datetime not null
# metadata :text
# target_id :integer
# target_type :string(800)
#
# Indexes
#
# index_rag_document_fragments_on_target_type_and_target_id (target_type,target_id)
#

View File

@ -0,0 +1,9 @@
# frozen_string_literal: true
class AddTargetToRagDocumentFragment < ActiveRecord::Migration[7.1]
def change
add_column :rag_document_fragments, :target_id, :integer, null: true
add_column :rag_document_fragments, :target_type, :string, limit: 800, null: true
add_index :rag_document_fragments, %i[target_type target_id]
end
end

View File

@ -0,0 +1,22 @@
# frozen_string_literal: true
class DropPersonaIdFromRagDocumentFragments < ActiveRecord::Migration[7.1]
def change
execute <<~SQL
UPDATE rag_document_fragments
SET
target_type = 'AiPersona',
target_id = ai_persona_id
WHERE ai_persona_id IS NOT NULL
SQL
# unlikely but lets be safe
execute <<~SQL
DELETE FROM rag_document_fragments
WHERE target_id IS NULL OR target_type IS NULL
SQL
remove_column :rag_document_fragments, :ai_persona_id
change_column_null :rag_document_fragments, :target_id, false
change_column_null :rag_document_fragments, :target_type, false
end
end

View File

@ -288,7 +288,8 @@ module DiscourseAi
candidate_fragment_ids =
vector_rep.asymmetric_rag_fragment_similarity_search(
interactions_vector,
persona_id: id,
target_type: "AiPersona",
target_id: id,
limit:
(
if reranker.reranker_configured?

View File

@ -280,7 +280,8 @@ module DiscourseAi
def asymmetric_rag_fragment_similarity_search(
raw_vector,
persona_id:,
target_id:,
target_type:,
limit:,
offset:,
return_distance: false
@ -299,14 +300,16 @@ module DiscourseAi
WHERE
model_id = #{id} AND
strategy_id = #{@strategy.id} AND
rdf.ai_persona_id = :persona_id
rdf.target_id = :target_id AND
rdf.target_type = :target_type
ORDER BY
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
LIMIT :limit
OFFSET :offset
SQL
query_embedding: raw_vector,
persona_id: persona_id,
target_id: target_id,
target_type: target_type,
limit: limit,
offset: offset,
)

View File

@ -41,7 +41,11 @@ RSpec.describe Jobs::DigestRagUpload do
# be explicit here about chunking strategy
persona.update!(rag_chunk_tokens: 100, rag_chunk_overlap_tokens: 10)
described_class.new.execute(upload_id: upload_with_metadata.id, ai_persona_id: persona.id)
described_class.new.execute(
upload_id: upload_with_metadata.id,
target_id: persona.id,
target_type: persona.class.to_s,
)
parsed = +""
first = true
@ -66,7 +70,11 @@ RSpec.describe Jobs::DigestRagUpload do
before { File.expects(:open).returns(document_file) }
it "splits an upload into chunks" do
subject.execute(upload_id: upload.id, ai_persona_id: persona.id)
subject.execute(
upload_id: upload.id,
target_id: persona.id,
target_type: persona.class.to_s,
)
created_fragment = RagDocumentFragment.last
@ -76,19 +84,23 @@ RSpec.describe Jobs::DigestRagUpload do
end
it "queue jobs to generate embeddings for each fragment" do
expect { subject.execute(upload_id: upload.id, ai_persona_id: persona.id) }.to change(
Jobs::GenerateRagEmbeddings.jobs,
:size,
).by(1)
expect {
subject.execute(
upload_id: upload.id,
target_id: persona.id,
target_type: persona.class.to_s,
)
}.to change(Jobs::GenerateRagEmbeddings.jobs, :size).by(1)
end
end
it "doesn't generate new fragments if we already processed the upload" do
Fabricate(:rag_document_fragment, upload: upload, ai_persona: persona)
previous_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count
Fabricate(:rag_document_fragment, upload: upload, target: persona)
subject.execute(upload_id: upload.id, ai_persona_id: persona.id)
updated_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count
previous_count = RagDocumentFragment.where(upload: upload, target: persona).count
subject.execute(upload_id: upload.id, target_id: persona.id, target_type: persona.class.to_s)
updated_count = RagDocumentFragment.where(upload: upload, target: persona).count
expect(updated_count).to eq(previous_count)
end

View File

@ -11,8 +11,8 @@ RSpec.describe Jobs::GenerateRagEmbeddings do
fab!(:ai_persona)
fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, ai_persona: ai_persona) }
fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, ai_persona: ai_persona) }
fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, target: ai_persona) }
fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, target: ai_persona) }
before do
SiteSetting.ai_embeddings_enabled = true

View File

@ -395,7 +395,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
candidate_ids << Fabricate(
:rag_document_fragment,
fragment: "fragment-n#{i}",
ai_persona_id: ai_persona.id,
target_id: ai_persona.id,
target_type: "AiPersona",
upload: upload,
).id
end

View File

@ -55,7 +55,7 @@ RSpec.describe AiPersona do
id =
RagDocumentFragment.create!(
ai_persona: persona,
target: persona,
fragment: "test",
fragment_number: 1,
upload: Fabricate(:upload),

View File

@ -12,14 +12,14 @@ RSpec.describe RagDocumentFragment do
describe ".link_uploads_and_persona" do
it "does nothing if there is no persona" do
expect { described_class.link_persona_and_uploads(nil, [upload_1.id]) }.not_to change(
expect { described_class.link_target_and_uploads(nil, [upload_1.id]) }.not_to change(
Jobs::DigestRagUpload.jobs,
:size,
)
end
it "does nothing if there are no uploads" do
expect { described_class.link_persona_and_uploads(persona, []) }.not_to change(
expect { described_class.link_target_and_uploads(persona, []) }.not_to change(
Jobs::DigestRagUpload.jobs,
:size,
)
@ -27,12 +27,12 @@ RSpec.describe RagDocumentFragment do
it "queues a job for each upload to generate fragments" do
expect {
described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id])
described_class.link_target_and_uploads(persona, [upload_1.id, upload_2.id])
}.to change(Jobs::DigestRagUpload.jobs, :size).by(2)
end
it "creates references between the persona an each upload" do
described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id])
described_class.link_target_and_uploads(persona, [upload_1.id, upload_2.id])
refs = UploadReference.where(target: persona).pluck(:upload_id)
@ -40,26 +40,25 @@ RSpec.describe RagDocumentFragment do
end
end
describe ".update_persona_uploads" do
describe ".update_target_uploads" do
it "does nothing if there is no persona" do
expect { described_class.update_persona_uploads(nil, [upload_1.id]) }.not_to change(
expect { described_class.update_target_uploads(nil, [upload_1.id]) }.not_to change(
Jobs::DigestRagUpload.jobs,
:size,
)
end
it "deletes the fragment if its not present in the uploads list" do
fragment = Fabricate(:rag_document_fragment, ai_persona: persona)
fragment = Fabricate(:rag_document_fragment, target: persona)
described_class.update_persona_uploads(persona, [])
described_class.update_target_uploads(persona, [])
expect { fragment.reload }.to raise_error(ActiveRecord::RecordNotFound)
end
it "delete references between the upload and the persona" do
described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id])
described_class.update_persona_uploads(persona, [upload_2.id])
described_class.link_target_and_uploads(persona, [upload_1.id, upload_2.id])
described_class.update_target_uploads(persona, [upload_2.id])
refs = UploadReference.where(target: persona).pluck(:upload_id)
@ -67,7 +66,7 @@ RSpec.describe RagDocumentFragment do
end
it "queues jobs to generate new fragments" do
expect { described_class.update_persona_uploads(persona, [upload_1.id]) }.to change(
expect { described_class.update_target_uploads(persona, [upload_1.id]) }.to change(
Jobs::DigestRagUpload.jobs,
:size,
).by(1)
@ -81,11 +80,11 @@ RSpec.describe RagDocumentFragment do
end
fab!(:rag_document_fragment_1) do
Fabricate(:rag_document_fragment, upload: upload_1, ai_persona: persona)
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
end
fab!(:rag_document_fragment_2) do
Fabricate(:rag_document_fragment, upload: upload_1, ai_persona: persona)
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
end
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }

View File

@ -139,7 +139,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
it "includes rag uploads for each persona" do
upload = Fabricate(:upload)
RagDocumentFragment.link_persona_and_uploads(ai_persona, [upload.id])
RagDocumentFragment.link_target_and_uploads(ai_persona, [upload.id])
get "/admin/plugins/discourse-ai/ai-personas/#{ai_persona.id}.json"
expect(response).to be_successful