FEATURE: Make tool support polymorphic (#798)
Polymorphic RAG means that we will be able to access RAG fragments both from AiPersona and AiCustomTool In turn this gives us support for richer RAG implementations.
This commit is contained in:
parent
b16390ae2a
commit
03eccbe392
|
@ -41,7 +41,7 @@ module DiscourseAi
|
|||
def create
|
||||
ai_persona = AiPersona.new(ai_persona_params.except(:rag_uploads))
|
||||
if ai_persona.save
|
||||
RagDocumentFragment.link_persona_and_uploads(ai_persona, attached_upload_ids)
|
||||
RagDocumentFragment.link_target_and_uploads(ai_persona, attached_upload_ids)
|
||||
|
||||
render json: {
|
||||
ai_persona: LocalizedAiPersonaSerializer.new(ai_persona, root: false),
|
||||
|
@ -59,7 +59,7 @@ module DiscourseAi
|
|||
|
||||
def update
|
||||
if @ai_persona.update(ai_persona_params.except(:rag_uploads))
|
||||
RagDocumentFragment.update_persona_uploads(@ai_persona, attached_upload_ids)
|
||||
RagDocumentFragment.update_target_uploads(@ai_persona, attached_upload_ids)
|
||||
|
||||
render json: LocalizedAiPersonaSerializer.new(@ai_persona, root: false)
|
||||
else
|
||||
|
|
|
@ -9,17 +9,24 @@ module ::Jobs
|
|||
# TODO(roman): Add a way to automatically recover from errors, resulting in unindexed uploads.
|
||||
def execute(args)
|
||||
return if (upload = Upload.find_by(id: args[:upload_id])).nil?
|
||||
return if (ai_persona = AiPersona.find_by(id: args[:ai_persona_id])).nil?
|
||||
|
||||
target_type = args[:target_type]
|
||||
target_id = args[:target_id]
|
||||
|
||||
return if !target_type || !target_id
|
||||
|
||||
target = target_type.constantize.find_by(id: target_id)
|
||||
return if !target
|
||||
|
||||
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
|
||||
vector_rep =
|
||||
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
|
||||
|
||||
tokenizer = vector_rep.tokenizer
|
||||
chunk_tokens = ai_persona.rag_chunk_tokens
|
||||
overlap_tokens = ai_persona.rag_chunk_overlap_tokens
|
||||
chunk_tokens = target.rag_chunk_tokens
|
||||
overlap_tokens = target.rag_chunk_overlap_tokens
|
||||
|
||||
fragment_ids = RagDocumentFragment.where(ai_persona: ai_persona, upload: upload).pluck(:id)
|
||||
fragment_ids = RagDocumentFragment.where(target: target, upload: upload).pluck(:id)
|
||||
|
||||
# Check if this is the first time we process this upload.
|
||||
if fragment_ids.empty?
|
||||
|
@ -39,7 +46,7 @@ module ::Jobs
|
|||
overlap_tokens: overlap_tokens,
|
||||
) do |chunk, metadata|
|
||||
fragment_ids << RagDocumentFragment.create!(
|
||||
ai_persona: ai_persona,
|
||||
target: target,
|
||||
fragment: chunk,
|
||||
fragment_number: idx + 1,
|
||||
upload: upload,
|
||||
|
|
|
@ -17,10 +17,10 @@ module ::Jobs
|
|||
fragments.map { |fragment| vector_rep.generate_representation_from(fragment) }
|
||||
|
||||
last_fragment = fragments.last
|
||||
ai_persona = last_fragment.ai_persona
|
||||
target = last_fragment.target
|
||||
upload = last_fragment.upload
|
||||
|
||||
indexing_status = RagDocumentFragment.indexing_status(ai_persona, [upload])[upload.id]
|
||||
indexing_status = RagDocumentFragment.indexing_status(target, [upload])[upload.id]
|
||||
RagDocumentFragment.publish_status(upload, indexing_status)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -20,6 +20,7 @@ class AiPersona < ActiveRecord::Base
|
|||
validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 }
|
||||
validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 }
|
||||
validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 }
|
||||
has_many :rag_document_fragments, dependent: :destroy, as: :target
|
||||
|
||||
belongs_to :created_by, class_name: "User"
|
||||
belongs_to :user
|
||||
|
@ -27,12 +28,7 @@ class AiPersona < ActiveRecord::Base
|
|||
has_many :upload_references, as: :target, dependent: :destroy
|
||||
has_many :uploads, through: :upload_references
|
||||
|
||||
has_many :rag_document_fragment, dependent: :destroy
|
||||
|
||||
has_many :rag_document_fragments, through: :ai_persona_rag_document_fragments
|
||||
|
||||
before_destroy :ensure_not_system
|
||||
|
||||
before_update :regenerate_rag_fragments
|
||||
|
||||
def self.persona_cache
|
||||
|
@ -230,7 +226,7 @@ class AiPersona < ActiveRecord::Base
|
|||
|
||||
def regenerate_rag_fragments
|
||||
if rag_chunk_tokens_changed? || rag_chunk_overlap_tokens_changed?
|
||||
RagDocumentFragment.where(ai_persona: self).delete_all
|
||||
RagDocumentFragment.where(target: self).delete_all
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -1,32 +1,40 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class RagDocumentFragment < ActiveRecord::Base
|
||||
# TODO Jan 2025 - remove
|
||||
self.ignored_columns = %i[ai_persona_id]
|
||||
|
||||
belongs_to :upload
|
||||
belongs_to :ai_persona
|
||||
belongs_to :target, polymorphic: true
|
||||
|
||||
class << self
|
||||
def link_persona_and_uploads(persona, upload_ids)
|
||||
return if persona.blank?
|
||||
def link_target_and_uploads(target, upload_ids)
|
||||
return if target.blank?
|
||||
return if upload_ids.blank?
|
||||
return if !SiteSetting.ai_embeddings_enabled?
|
||||
|
||||
UploadReference.ensure_exist!(upload_ids: upload_ids, target: persona)
|
||||
UploadReference.ensure_exist!(upload_ids: upload_ids, target: target)
|
||||
|
||||
upload_ids.each do |upload_id|
|
||||
Jobs.enqueue(:digest_rag_upload, ai_persona_id: persona.id, upload_id: upload_id)
|
||||
Jobs.enqueue(
|
||||
:digest_rag_upload,
|
||||
target_id: target.id,
|
||||
target_type: target.class.to_s,
|
||||
upload_id: upload_id,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
def update_persona_uploads(persona, upload_ids)
|
||||
return if persona.blank?
|
||||
def update_target_uploads(target, upload_ids)
|
||||
return if target.blank?
|
||||
return if !SiteSetting.ai_embeddings_enabled?
|
||||
|
||||
if upload_ids.blank?
|
||||
RagDocumentFragment.where(ai_persona: persona).destroy_all
|
||||
UploadReference.where(target: persona).destroy_all
|
||||
RagDocumentFragment.where(target: target).destroy_all
|
||||
UploadReference.where(target: target).destroy_all
|
||||
else
|
||||
RagDocumentFragment.where(ai_persona: persona).where.not(upload_id: upload_ids).destroy_all
|
||||
link_persona_and_uploads(persona, upload_ids)
|
||||
RagDocumentFragment.where(target: target).where.not(upload_id: upload_ids).destroy_all
|
||||
link_target_and_uploads(target, upload_ids)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -37,18 +45,25 @@ class RagDocumentFragment < ActiveRecord::Base
|
|||
|
||||
embeddings_table = vector_rep.rag_fragments_table_name
|
||||
|
||||
results = DB.query(<<~SQL, persona_id: persona.id, upload_ids: uploads.map(&:id))
|
||||
results =
|
||||
DB.query(
|
||||
<<~SQL,
|
||||
SELECT
|
||||
uploads.id,
|
||||
SUM(CASE WHEN (rdf.upload_id IS NOT NULL) THEN 1 ELSE 0 END) AS total,
|
||||
SUM(CASE WHEN (eft.rag_document_fragment_id IS NOT NULL) THEN 1 ELSE 0 END) as indexed,
|
||||
SUM(CASE WHEN (rdf.upload_id IS NOT NULL AND eft.rag_document_fragment_id IS NULL) THEN 1 ELSE 0 END) as left
|
||||
FROM uploads
|
||||
LEFT OUTER JOIN rag_document_fragments rdf ON uploads.id = rdf.upload_id AND rdf.ai_persona_id = :persona_id
|
||||
LEFT OUTER JOIN rag_document_fragments rdf ON uploads.id = rdf.upload_id AND rdf.target_id = :target_id
|
||||
AND rdf.target_type = :target_type
|
||||
LEFT OUTER JOIN #{embeddings_table} eft ON rdf.id = eft.rag_document_fragment_id
|
||||
WHERE uploads.id IN (:upload_ids)
|
||||
GROUP BY uploads.id
|
||||
SQL
|
||||
target_id: persona.id,
|
||||
target_type: persona.class.to_s,
|
||||
upload_ids: uploads.map(&:id),
|
||||
)
|
||||
|
||||
results.reduce({}) do |acc, r|
|
||||
acc[r.id] = { total: r.total, indexed: r.indexed, left: r.left }
|
||||
|
@ -78,4 +93,10 @@ end
|
|||
# created_at :datetime not null
|
||||
# updated_at :datetime not null
|
||||
# metadata :text
|
||||
# target_id :integer
|
||||
# target_type :string(800)
|
||||
#
|
||||
# Indexes
|
||||
#
|
||||
# index_rag_document_fragments_on_target_type_and_target_id (target_type,target_id)
|
||||
#
|
||||
|
|
|
@ -0,0 +1,9 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class AddTargetToRagDocumentFragment < ActiveRecord::Migration[7.1]
|
||||
def change
|
||||
add_column :rag_document_fragments, :target_id, :integer, null: true
|
||||
add_column :rag_document_fragments, :target_type, :string, limit: 800, null: true
|
||||
add_index :rag_document_fragments, %i[target_type target_id]
|
||||
end
|
||||
end
|
|
@ -0,0 +1,22 @@
|
|||
# frozen_string_literal: true
|
||||
class DropPersonaIdFromRagDocumentFragments < ActiveRecord::Migration[7.1]
|
||||
def change
|
||||
execute <<~SQL
|
||||
UPDATE rag_document_fragments
|
||||
SET
|
||||
target_type = 'AiPersona',
|
||||
target_id = ai_persona_id
|
||||
WHERE ai_persona_id IS NOT NULL
|
||||
SQL
|
||||
|
||||
# unlikely but lets be safe
|
||||
execute <<~SQL
|
||||
DELETE FROM rag_document_fragments
|
||||
WHERE target_id IS NULL OR target_type IS NULL
|
||||
SQL
|
||||
|
||||
remove_column :rag_document_fragments, :ai_persona_id
|
||||
change_column_null :rag_document_fragments, :target_id, false
|
||||
change_column_null :rag_document_fragments, :target_type, false
|
||||
end
|
||||
end
|
|
@ -288,7 +288,8 @@ module DiscourseAi
|
|||
candidate_fragment_ids =
|
||||
vector_rep.asymmetric_rag_fragment_similarity_search(
|
||||
interactions_vector,
|
||||
persona_id: id,
|
||||
target_type: "AiPersona",
|
||||
target_id: id,
|
||||
limit:
|
||||
(
|
||||
if reranker.reranker_configured?
|
||||
|
|
|
@ -280,7 +280,8 @@ module DiscourseAi
|
|||
|
||||
def asymmetric_rag_fragment_similarity_search(
|
||||
raw_vector,
|
||||
persona_id:,
|
||||
target_id:,
|
||||
target_type:,
|
||||
limit:,
|
||||
offset:,
|
||||
return_distance: false
|
||||
|
@ -299,14 +300,16 @@ module DiscourseAi
|
|||
WHERE
|
||||
model_id = #{id} AND
|
||||
strategy_id = #{@strategy.id} AND
|
||||
rdf.ai_persona_id = :persona_id
|
||||
rdf.target_id = :target_id AND
|
||||
rdf.target_type = :target_type
|
||||
ORDER BY
|
||||
embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions})
|
||||
LIMIT :limit
|
||||
OFFSET :offset
|
||||
SQL
|
||||
query_embedding: raw_vector,
|
||||
persona_id: persona_id,
|
||||
target_id: target_id,
|
||||
target_type: target_type,
|
||||
limit: limit,
|
||||
offset: offset,
|
||||
)
|
||||
|
|
|
@ -41,7 +41,11 @@ RSpec.describe Jobs::DigestRagUpload do
|
|||
# be explicit here about chunking strategy
|
||||
persona.update!(rag_chunk_tokens: 100, rag_chunk_overlap_tokens: 10)
|
||||
|
||||
described_class.new.execute(upload_id: upload_with_metadata.id, ai_persona_id: persona.id)
|
||||
described_class.new.execute(
|
||||
upload_id: upload_with_metadata.id,
|
||||
target_id: persona.id,
|
||||
target_type: persona.class.to_s,
|
||||
)
|
||||
|
||||
parsed = +""
|
||||
first = true
|
||||
|
@ -66,7 +70,11 @@ RSpec.describe Jobs::DigestRagUpload do
|
|||
before { File.expects(:open).returns(document_file) }
|
||||
|
||||
it "splits an upload into chunks" do
|
||||
subject.execute(upload_id: upload.id, ai_persona_id: persona.id)
|
||||
subject.execute(
|
||||
upload_id: upload.id,
|
||||
target_id: persona.id,
|
||||
target_type: persona.class.to_s,
|
||||
)
|
||||
|
||||
created_fragment = RagDocumentFragment.last
|
||||
|
||||
|
@ -76,19 +84,23 @@ RSpec.describe Jobs::DigestRagUpload do
|
|||
end
|
||||
|
||||
it "queue jobs to generate embeddings for each fragment" do
|
||||
expect { subject.execute(upload_id: upload.id, ai_persona_id: persona.id) }.to change(
|
||||
Jobs::GenerateRagEmbeddings.jobs,
|
||||
:size,
|
||||
).by(1)
|
||||
expect {
|
||||
subject.execute(
|
||||
upload_id: upload.id,
|
||||
target_id: persona.id,
|
||||
target_type: persona.class.to_s,
|
||||
)
|
||||
}.to change(Jobs::GenerateRagEmbeddings.jobs, :size).by(1)
|
||||
end
|
||||
end
|
||||
|
||||
it "doesn't generate new fragments if we already processed the upload" do
|
||||
Fabricate(:rag_document_fragment, upload: upload, ai_persona: persona)
|
||||
previous_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count
|
||||
Fabricate(:rag_document_fragment, upload: upload, target: persona)
|
||||
|
||||
subject.execute(upload_id: upload.id, ai_persona_id: persona.id)
|
||||
updated_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count
|
||||
previous_count = RagDocumentFragment.where(upload: upload, target: persona).count
|
||||
|
||||
subject.execute(upload_id: upload.id, target_id: persona.id, target_type: persona.class.to_s)
|
||||
updated_count = RagDocumentFragment.where(upload: upload, target: persona).count
|
||||
|
||||
expect(updated_count).to eq(previous_count)
|
||||
end
|
||||
|
|
|
@ -11,8 +11,8 @@ RSpec.describe Jobs::GenerateRagEmbeddings do
|
|||
|
||||
fab!(:ai_persona)
|
||||
|
||||
fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, ai_persona: ai_persona) }
|
||||
fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, ai_persona: ai_persona) }
|
||||
fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, target: ai_persona) }
|
||||
fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, target: ai_persona) }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_embeddings_enabled = true
|
||||
|
|
|
@ -395,7 +395,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
candidate_ids << Fabricate(
|
||||
:rag_document_fragment,
|
||||
fragment: "fragment-n#{i}",
|
||||
ai_persona_id: ai_persona.id,
|
||||
target_id: ai_persona.id,
|
||||
target_type: "AiPersona",
|
||||
upload: upload,
|
||||
).id
|
||||
end
|
||||
|
|
|
@ -55,7 +55,7 @@ RSpec.describe AiPersona do
|
|||
|
||||
id =
|
||||
RagDocumentFragment.create!(
|
||||
ai_persona: persona,
|
||||
target: persona,
|
||||
fragment: "test",
|
||||
fragment_number: 1,
|
||||
upload: Fabricate(:upload),
|
||||
|
|
|
@ -12,14 +12,14 @@ RSpec.describe RagDocumentFragment do
|
|||
|
||||
describe ".link_uploads_and_persona" do
|
||||
it "does nothing if there is no persona" do
|
||||
expect { described_class.link_persona_and_uploads(nil, [upload_1.id]) }.not_to change(
|
||||
expect { described_class.link_target_and_uploads(nil, [upload_1.id]) }.not_to change(
|
||||
Jobs::DigestRagUpload.jobs,
|
||||
:size,
|
||||
)
|
||||
end
|
||||
|
||||
it "does nothing if there are no uploads" do
|
||||
expect { described_class.link_persona_and_uploads(persona, []) }.not_to change(
|
||||
expect { described_class.link_target_and_uploads(persona, []) }.not_to change(
|
||||
Jobs::DigestRagUpload.jobs,
|
||||
:size,
|
||||
)
|
||||
|
@ -27,12 +27,12 @@ RSpec.describe RagDocumentFragment do
|
|||
|
||||
it "queues a job for each upload to generate fragments" do
|
||||
expect {
|
||||
described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id])
|
||||
described_class.link_target_and_uploads(persona, [upload_1.id, upload_2.id])
|
||||
}.to change(Jobs::DigestRagUpload.jobs, :size).by(2)
|
||||
end
|
||||
|
||||
it "creates references between the persona an each upload" do
|
||||
described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id])
|
||||
described_class.link_target_and_uploads(persona, [upload_1.id, upload_2.id])
|
||||
|
||||
refs = UploadReference.where(target: persona).pluck(:upload_id)
|
||||
|
||||
|
@ -40,26 +40,25 @@ RSpec.describe RagDocumentFragment do
|
|||
end
|
||||
end
|
||||
|
||||
describe ".update_persona_uploads" do
|
||||
describe ".update_target_uploads" do
|
||||
it "does nothing if there is no persona" do
|
||||
expect { described_class.update_persona_uploads(nil, [upload_1.id]) }.not_to change(
|
||||
expect { described_class.update_target_uploads(nil, [upload_1.id]) }.not_to change(
|
||||
Jobs::DigestRagUpload.jobs,
|
||||
:size,
|
||||
)
|
||||
end
|
||||
|
||||
it "deletes the fragment if its not present in the uploads list" do
|
||||
fragment = Fabricate(:rag_document_fragment, ai_persona: persona)
|
||||
fragment = Fabricate(:rag_document_fragment, target: persona)
|
||||
|
||||
described_class.update_persona_uploads(persona, [])
|
||||
described_class.update_target_uploads(persona, [])
|
||||
|
||||
expect { fragment.reload }.to raise_error(ActiveRecord::RecordNotFound)
|
||||
end
|
||||
|
||||
it "delete references between the upload and the persona" do
|
||||
described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id])
|
||||
|
||||
described_class.update_persona_uploads(persona, [upload_2.id])
|
||||
described_class.link_target_and_uploads(persona, [upload_1.id, upload_2.id])
|
||||
described_class.update_target_uploads(persona, [upload_2.id])
|
||||
|
||||
refs = UploadReference.where(target: persona).pluck(:upload_id)
|
||||
|
||||
|
@ -67,7 +66,7 @@ RSpec.describe RagDocumentFragment do
|
|||
end
|
||||
|
||||
it "queues jobs to generate new fragments" do
|
||||
expect { described_class.update_persona_uploads(persona, [upload_1.id]) }.to change(
|
||||
expect { described_class.update_target_uploads(persona, [upload_1.id]) }.to change(
|
||||
Jobs::DigestRagUpload.jobs,
|
||||
:size,
|
||||
).by(1)
|
||||
|
@ -81,11 +80,11 @@ RSpec.describe RagDocumentFragment do
|
|||
end
|
||||
|
||||
fab!(:rag_document_fragment_1) do
|
||||
Fabricate(:rag_document_fragment, upload: upload_1, ai_persona: persona)
|
||||
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
||||
end
|
||||
|
||||
fab!(:rag_document_fragment_2) do
|
||||
Fabricate(:rag_document_fragment, upload: upload_1, ai_persona: persona)
|
||||
Fabricate(:rag_document_fragment, upload: upload_1, target: persona)
|
||||
end
|
||||
|
||||
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
|
||||
|
|
|
@ -139,7 +139,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
|
||||
it "includes rag uploads for each persona" do
|
||||
upload = Fabricate(:upload)
|
||||
RagDocumentFragment.link_persona_and_uploads(ai_persona, [upload.id])
|
||||
RagDocumentFragment.link_target_and_uploads(ai_persona, [upload.id])
|
||||
|
||||
get "/admin/plugins/discourse-ai/ai-personas/#{ai_persona.id}.json"
|
||||
expect(response).to be_successful
|
||||
|
|
Loading…
Reference in New Issue