diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index d3bcc7ef..4082692e 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -41,7 +41,7 @@ module DiscourseAi def create ai_persona = AiPersona.new(ai_persona_params.except(:rag_uploads)) if ai_persona.save - RagDocumentFragment.link_persona_and_uploads(ai_persona, attached_upload_ids) + RagDocumentFragment.link_target_and_uploads(ai_persona, attached_upload_ids) render json: { ai_persona: LocalizedAiPersonaSerializer.new(ai_persona, root: false), @@ -59,7 +59,7 @@ module DiscourseAi def update if @ai_persona.update(ai_persona_params.except(:rag_uploads)) - RagDocumentFragment.update_persona_uploads(@ai_persona, attached_upload_ids) + RagDocumentFragment.update_target_uploads(@ai_persona, attached_upload_ids) render json: LocalizedAiPersonaSerializer.new(@ai_persona, root: false) else diff --git a/app/jobs/regular/digest_rag_upload.rb b/app/jobs/regular/digest_rag_upload.rb index 5f99d31d..b24bba1e 100644 --- a/app/jobs/regular/digest_rag_upload.rb +++ b/app/jobs/regular/digest_rag_upload.rb @@ -9,17 +9,24 @@ module ::Jobs # TODO(roman): Add a way to automatically recover from errors, resulting in unindexed uploads. def execute(args) return if (upload = Upload.find_by(id: args[:upload_id])).nil? - return if (ai_persona = AiPersona.find_by(id: args[:ai_persona_id])).nil? + + target_type = args[:target_type] + target_id = args[:target_id] + + return if !target_type || !target_id + + target = target_type.constantize.find_by(id: target_id) + return if !target truncation = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation) tokenizer = vector_rep.tokenizer - chunk_tokens = ai_persona.rag_chunk_tokens - overlap_tokens = ai_persona.rag_chunk_overlap_tokens + chunk_tokens = target.rag_chunk_tokens + overlap_tokens = target.rag_chunk_overlap_tokens - fragment_ids = RagDocumentFragment.where(ai_persona: ai_persona, upload: upload).pluck(:id) + fragment_ids = RagDocumentFragment.where(target: target, upload: upload).pluck(:id) # Check if this is the first time we process this upload. if fragment_ids.empty? @@ -39,7 +46,7 @@ module ::Jobs overlap_tokens: overlap_tokens, ) do |chunk, metadata| fragment_ids << RagDocumentFragment.create!( - ai_persona: ai_persona, + target: target, fragment: chunk, fragment_number: idx + 1, upload: upload, diff --git a/app/jobs/regular/generate_rag_embeddings.rb b/app/jobs/regular/generate_rag_embeddings.rb index d1d05a19..a125a21b 100644 --- a/app/jobs/regular/generate_rag_embeddings.rb +++ b/app/jobs/regular/generate_rag_embeddings.rb @@ -17,10 +17,10 @@ module ::Jobs fragments.map { |fragment| vector_rep.generate_representation_from(fragment) } last_fragment = fragments.last - ai_persona = last_fragment.ai_persona + target = last_fragment.target upload = last_fragment.upload - indexing_status = RagDocumentFragment.indexing_status(ai_persona, [upload])[upload.id] + indexing_status = RagDocumentFragment.indexing_status(target, [upload])[upload.id] RagDocumentFragment.publish_status(upload, indexing_status) end end diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index 6a6e4c3f..29d31e54 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -20,6 +20,7 @@ class AiPersona < ActiveRecord::Base validates :rag_chunk_tokens, numericality: { greater_than: 0, maximum: 50_000 } validates :rag_chunk_overlap_tokens, numericality: { greater_than: -1, maximum: 200 } validates :rag_conversation_chunks, numericality: { greater_than: 0, maximum: 1000 } + has_many :rag_document_fragments, dependent: :destroy, as: :target belongs_to :created_by, class_name: "User" belongs_to :user @@ -27,12 +28,7 @@ class AiPersona < ActiveRecord::Base has_many :upload_references, as: :target, dependent: :destroy has_many :uploads, through: :upload_references - has_many :rag_document_fragment, dependent: :destroy - - has_many :rag_document_fragments, through: :ai_persona_rag_document_fragments - before_destroy :ensure_not_system - before_update :regenerate_rag_fragments def self.persona_cache @@ -230,7 +226,7 @@ class AiPersona < ActiveRecord::Base def regenerate_rag_fragments if rag_chunk_tokens_changed? || rag_chunk_overlap_tokens_changed? - RagDocumentFragment.where(ai_persona: self).delete_all + RagDocumentFragment.where(target: self).delete_all end end diff --git a/app/models/rag_document_fragment.rb b/app/models/rag_document_fragment.rb index bffd65e0..344506be 100644 --- a/app/models/rag_document_fragment.rb +++ b/app/models/rag_document_fragment.rb @@ -1,32 +1,40 @@ # frozen_string_literal: true class RagDocumentFragment < ActiveRecord::Base + # TODO Jan 2025 - remove + self.ignored_columns = %i[ai_persona_id] + belongs_to :upload - belongs_to :ai_persona + belongs_to :target, polymorphic: true class << self - def link_persona_and_uploads(persona, upload_ids) - return if persona.blank? + def link_target_and_uploads(target, upload_ids) + return if target.blank? return if upload_ids.blank? return if !SiteSetting.ai_embeddings_enabled? - UploadReference.ensure_exist!(upload_ids: upload_ids, target: persona) + UploadReference.ensure_exist!(upload_ids: upload_ids, target: target) upload_ids.each do |upload_id| - Jobs.enqueue(:digest_rag_upload, ai_persona_id: persona.id, upload_id: upload_id) + Jobs.enqueue( + :digest_rag_upload, + target_id: target.id, + target_type: target.class.to_s, + upload_id: upload_id, + ) end end - def update_persona_uploads(persona, upload_ids) - return if persona.blank? + def update_target_uploads(target, upload_ids) + return if target.blank? return if !SiteSetting.ai_embeddings_enabled? if upload_ids.blank? - RagDocumentFragment.where(ai_persona: persona).destroy_all - UploadReference.where(target: persona).destroy_all + RagDocumentFragment.where(target: target).destroy_all + UploadReference.where(target: target).destroy_all else - RagDocumentFragment.where(ai_persona: persona).where.not(upload_id: upload_ids).destroy_all - link_persona_and_uploads(persona, upload_ids) + RagDocumentFragment.where(target: target).where.not(upload_id: upload_ids).destroy_all + link_target_and_uploads(target, upload_ids) end end @@ -37,18 +45,25 @@ class RagDocumentFragment < ActiveRecord::Base embeddings_table = vector_rep.rag_fragments_table_name - results = DB.query(<<~SQL, persona_id: persona.id, upload_ids: uploads.map(&:id)) + results = + DB.query( + <<~SQL, SELECT uploads.id, SUM(CASE WHEN (rdf.upload_id IS NOT NULL) THEN 1 ELSE 0 END) AS total, SUM(CASE WHEN (eft.rag_document_fragment_id IS NOT NULL) THEN 1 ELSE 0 END) as indexed, SUM(CASE WHEN (rdf.upload_id IS NOT NULL AND eft.rag_document_fragment_id IS NULL) THEN 1 ELSE 0 END) as left FROM uploads - LEFT OUTER JOIN rag_document_fragments rdf ON uploads.id = rdf.upload_id AND rdf.ai_persona_id = :persona_id + LEFT OUTER JOIN rag_document_fragments rdf ON uploads.id = rdf.upload_id AND rdf.target_id = :target_id + AND rdf.target_type = :target_type LEFT OUTER JOIN #{embeddings_table} eft ON rdf.id = eft.rag_document_fragment_id WHERE uploads.id IN (:upload_ids) GROUP BY uploads.id SQL + target_id: persona.id, + target_type: persona.class.to_s, + upload_ids: uploads.map(&:id), + ) results.reduce({}) do |acc, r| acc[r.id] = { total: r.total, indexed: r.indexed, left: r.left } @@ -78,4 +93,10 @@ end # created_at :datetime not null # updated_at :datetime not null # metadata :text +# target_id :integer +# target_type :string(800) +# +# Indexes +# +# index_rag_document_fragments_on_target_type_and_target_id (target_type,target_id) # diff --git a/db/migrate/20240912052713_add_target_to_rag_document_fragment.rb b/db/migrate/20240912052713_add_target_to_rag_document_fragment.rb new file mode 100644 index 00000000..43b48ad4 --- /dev/null +++ b/db/migrate/20240912052713_add_target_to_rag_document_fragment.rb @@ -0,0 +1,9 @@ +# frozen_string_literal: true + +class AddTargetToRagDocumentFragment < ActiveRecord::Migration[7.1] + def change + add_column :rag_document_fragments, :target_id, :integer, null: true + add_column :rag_document_fragments, :target_type, :string, limit: 800, null: true + add_index :rag_document_fragments, %i[target_type target_id] + end +end diff --git a/db/post_migrate/20240912055831_drop_persona_id_from_rag_document_fragments.rb b/db/post_migrate/20240912055831_drop_persona_id_from_rag_document_fragments.rb new file mode 100644 index 00000000..238c2968 --- /dev/null +++ b/db/post_migrate/20240912055831_drop_persona_id_from_rag_document_fragments.rb @@ -0,0 +1,22 @@ +# frozen_string_literal: true +class DropPersonaIdFromRagDocumentFragments < ActiveRecord::Migration[7.1] + def change + execute <<~SQL + UPDATE rag_document_fragments + SET + target_type = 'AiPersona', + target_id = ai_persona_id + WHERE ai_persona_id IS NOT NULL + SQL + + # unlikely but lets be safe + execute <<~SQL + DELETE FROM rag_document_fragments + WHERE target_id IS NULL OR target_type IS NULL + SQL + + remove_column :rag_document_fragments, :ai_persona_id + change_column_null :rag_document_fragments, :target_id, false + change_column_null :rag_document_fragments, :target_type, false + end +end diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index cfb0d2bd..b46abebe 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -288,7 +288,8 @@ module DiscourseAi candidate_fragment_ids = vector_rep.asymmetric_rag_fragment_similarity_search( interactions_vector, - persona_id: id, + target_type: "AiPersona", + target_id: id, limit: ( if reranker.reranker_configured? diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index 5febc5f2..db4d4a82 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -280,7 +280,8 @@ module DiscourseAi def asymmetric_rag_fragment_similarity_search( raw_vector, - persona_id:, + target_id:, + target_type:, limit:, offset:, return_distance: false @@ -299,14 +300,16 @@ module DiscourseAi WHERE model_id = #{id} AND strategy_id = #{@strategy.id} AND - rdf.ai_persona_id = :persona_id + rdf.target_id = :target_id AND + rdf.target_type = :target_type ORDER BY embeddings::halfvec(#{dimensions}) #{pg_function} '[:query_embedding]'::halfvec(#{dimensions}) LIMIT :limit OFFSET :offset SQL query_embedding: raw_vector, - persona_id: persona_id, + target_id: target_id, + target_type: target_type, limit: limit, offset: offset, ) diff --git a/spec/jobs/regular/digest_rag_upload_spec.rb b/spec/jobs/regular/digest_rag_upload_spec.rb index 3c0bc103..063bbf73 100644 --- a/spec/jobs/regular/digest_rag_upload_spec.rb +++ b/spec/jobs/regular/digest_rag_upload_spec.rb @@ -41,7 +41,11 @@ RSpec.describe Jobs::DigestRagUpload do # be explicit here about chunking strategy persona.update!(rag_chunk_tokens: 100, rag_chunk_overlap_tokens: 10) - described_class.new.execute(upload_id: upload_with_metadata.id, ai_persona_id: persona.id) + described_class.new.execute( + upload_id: upload_with_metadata.id, + target_id: persona.id, + target_type: persona.class.to_s, + ) parsed = +"" first = true @@ -66,7 +70,11 @@ RSpec.describe Jobs::DigestRagUpload do before { File.expects(:open).returns(document_file) } it "splits an upload into chunks" do - subject.execute(upload_id: upload.id, ai_persona_id: persona.id) + subject.execute( + upload_id: upload.id, + target_id: persona.id, + target_type: persona.class.to_s, + ) created_fragment = RagDocumentFragment.last @@ -76,19 +84,23 @@ RSpec.describe Jobs::DigestRagUpload do end it "queue jobs to generate embeddings for each fragment" do - expect { subject.execute(upload_id: upload.id, ai_persona_id: persona.id) }.to change( - Jobs::GenerateRagEmbeddings.jobs, - :size, - ).by(1) + expect { + subject.execute( + upload_id: upload.id, + target_id: persona.id, + target_type: persona.class.to_s, + ) + }.to change(Jobs::GenerateRagEmbeddings.jobs, :size).by(1) end end it "doesn't generate new fragments if we already processed the upload" do - Fabricate(:rag_document_fragment, upload: upload, ai_persona: persona) - previous_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count + Fabricate(:rag_document_fragment, upload: upload, target: persona) - subject.execute(upload_id: upload.id, ai_persona_id: persona.id) - updated_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count + previous_count = RagDocumentFragment.where(upload: upload, target: persona).count + + subject.execute(upload_id: upload.id, target_id: persona.id, target_type: persona.class.to_s) + updated_count = RagDocumentFragment.where(upload: upload, target: persona).count expect(updated_count).to eq(previous_count) end diff --git a/spec/jobs/regular/generate_rag_embeddings_spec.rb b/spec/jobs/regular/generate_rag_embeddings_spec.rb index b200bb01..26b843a3 100644 --- a/spec/jobs/regular/generate_rag_embeddings_spec.rb +++ b/spec/jobs/regular/generate_rag_embeddings_spec.rb @@ -11,8 +11,8 @@ RSpec.describe Jobs::GenerateRagEmbeddings do fab!(:ai_persona) - fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, ai_persona: ai_persona) } - fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, ai_persona: ai_persona) } + fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, target: ai_persona) } + fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, target: ai_persona) } before do SiteSetting.ai_embeddings_enabled = true diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index 824ed4da..2fb95d19 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -395,7 +395,8 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do candidate_ids << Fabricate( :rag_document_fragment, fragment: "fragment-n#{i}", - ai_persona_id: ai_persona.id, + target_id: ai_persona.id, + target_type: "AiPersona", upload: upload, ).id end diff --git a/spec/models/ai_persona_spec.rb b/spec/models/ai_persona_spec.rb index 7f661e4d..0d48b4cf 100644 --- a/spec/models/ai_persona_spec.rb +++ b/spec/models/ai_persona_spec.rb @@ -55,7 +55,7 @@ RSpec.describe AiPersona do id = RagDocumentFragment.create!( - ai_persona: persona, + target: persona, fragment: "test", fragment_number: 1, upload: Fabricate(:upload), diff --git a/spec/models/rag_document_fragment_spec.rb b/spec/models/rag_document_fragment_spec.rb index dbcd4c9b..9ddbc830 100644 --- a/spec/models/rag_document_fragment_spec.rb +++ b/spec/models/rag_document_fragment_spec.rb @@ -12,14 +12,14 @@ RSpec.describe RagDocumentFragment do describe ".link_uploads_and_persona" do it "does nothing if there is no persona" do - expect { described_class.link_persona_and_uploads(nil, [upload_1.id]) }.not_to change( + expect { described_class.link_target_and_uploads(nil, [upload_1.id]) }.not_to change( Jobs::DigestRagUpload.jobs, :size, ) end it "does nothing if there are no uploads" do - expect { described_class.link_persona_and_uploads(persona, []) }.not_to change( + expect { described_class.link_target_and_uploads(persona, []) }.not_to change( Jobs::DigestRagUpload.jobs, :size, ) @@ -27,12 +27,12 @@ RSpec.describe RagDocumentFragment do it "queues a job for each upload to generate fragments" do expect { - described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id]) + described_class.link_target_and_uploads(persona, [upload_1.id, upload_2.id]) }.to change(Jobs::DigestRagUpload.jobs, :size).by(2) end it "creates references between the persona an each upload" do - described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id]) + described_class.link_target_and_uploads(persona, [upload_1.id, upload_2.id]) refs = UploadReference.where(target: persona).pluck(:upload_id) @@ -40,26 +40,25 @@ RSpec.describe RagDocumentFragment do end end - describe ".update_persona_uploads" do + describe ".update_target_uploads" do it "does nothing if there is no persona" do - expect { described_class.update_persona_uploads(nil, [upload_1.id]) }.not_to change( + expect { described_class.update_target_uploads(nil, [upload_1.id]) }.not_to change( Jobs::DigestRagUpload.jobs, :size, ) end it "deletes the fragment if its not present in the uploads list" do - fragment = Fabricate(:rag_document_fragment, ai_persona: persona) + fragment = Fabricate(:rag_document_fragment, target: persona) - described_class.update_persona_uploads(persona, []) + described_class.update_target_uploads(persona, []) expect { fragment.reload }.to raise_error(ActiveRecord::RecordNotFound) end it "delete references between the upload and the persona" do - described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id]) - - described_class.update_persona_uploads(persona, [upload_2.id]) + described_class.link_target_and_uploads(persona, [upload_1.id, upload_2.id]) + described_class.update_target_uploads(persona, [upload_2.id]) refs = UploadReference.where(target: persona).pluck(:upload_id) @@ -67,7 +66,7 @@ RSpec.describe RagDocumentFragment do end it "queues jobs to generate new fragments" do - expect { described_class.update_persona_uploads(persona, [upload_1.id]) }.to change( + expect { described_class.update_target_uploads(persona, [upload_1.id]) }.to change( Jobs::DigestRagUpload.jobs, :size, ).by(1) @@ -81,11 +80,11 @@ RSpec.describe RagDocumentFragment do end fab!(:rag_document_fragment_1) do - Fabricate(:rag_document_fragment, upload: upload_1, ai_persona: persona) + Fabricate(:rag_document_fragment, upload: upload_1, target: persona) end fab!(:rag_document_fragment_2) do - Fabricate(:rag_document_fragment, upload: upload_1, ai_persona: persona) + Fabricate(:rag_document_fragment, upload: upload_1, target: persona) end let(:expected_embedding) { [0.0038493] * vector_rep.dimensions } diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index bf1cbfc7..091ddee2 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -139,7 +139,7 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do it "includes rag uploads for each persona" do upload = Fabricate(:upload) - RagDocumentFragment.link_persona_and_uploads(ai_persona, [upload.id]) + RagDocumentFragment.link_target_and_uploads(ai_persona, [upload.id]) get "/admin/plugins/discourse-ai/ai-personas/#{ai_persona.id}.json" expect(response).to be_successful