diff --git a/assets/javascripts/discourse/components/persona-rag-uploader.gjs b/assets/javascripts/discourse/components/persona-rag-uploader.gjs
new file mode 100644
index 00000000..1068e9d1
--- /dev/null
+++ b/assets/javascripts/discourse/components/persona-rag-uploader.gjs
@@ -0,0 +1,153 @@
+import { tracked } from "@glimmer/tracking";
+import Component, { Input } from "@ember/component";
+import { fn } from "@ember/helper";
+import { on } from "@ember/modifier";
+import { action } from "@ember/object";
+import { inject as service } from "@ember/service";
+import DButton from "discourse/components/d-button";
+import UppyUploadMixin from "discourse/mixins/uppy-upload";
+import icon from "discourse-common/helpers/d-icon";
+import discourseDebounce from "discourse-common/lib/debounce";
+import I18n from "discourse-i18n";
+
+export default class PersonaRagUploader extends Component.extend(
+ UppyUploadMixin
+) {
+ @service appEvents;
+
+ @tracked term = null;
+ @tracked filteredUploads = null;
+ id = "discourse-ai-persona-rag-uploader";
+ maxFiles = 20;
+ uploadUrl = "/admin/plugins/discourse-ai/ai-personas/files/upload";
+ preventDirectS3Uploads = true;
+
+ didReceiveAttrs() {
+ super.didReceiveAttrs(...arguments);
+
+ if (this.inProgressUploads?.length > 0) {
+ this._uppyInstance?.cancelAll();
+ }
+
+ this.filteredUploads = this.ragUploads || [];
+ }
+
+ uploadDone(uploadedFile) {
+ this.onAdd(uploadedFile.upload);
+ this.debouncedSearch();
+ }
+
+ @action
+ submitFiles() {
+ this.fileInputEl.click();
+ }
+
+ @action
+ cancelUploading(upload) {
+ this.appEvents.trigger(`upload-mixin:${this.id}:cancel-upload`, {
+ fileId: upload.id,
+ });
+ }
+
+ @action
+ search() {
+ if (this.term) {
+ this.filteredUploads = this.ragUploads.filter((u) => {
+ return (
+ u.original_filename.toUpperCase().indexOf(this.term.toUpperCase()) >
+ -1
+ );
+ });
+ } else {
+ this.filteredUploads = this.ragUploads;
+ }
+ }
+
+ @action
+ debouncedSearch() {
+ discourseDebounce(this, this.search, 100);
+ }
+
+
+
+
{{I18n.t "discourse_ai.ai_persona.uploads.title"}}
+
{{I18n.t "discourse_ai.ai_persona.uploads.description"}}
+
{{I18n.t "discourse_ai.ai_persona.uploads.hint"}}
+
+
+
+
+
+ {{#each this.filteredUploads as |upload|}}
+
+
+ {{icon
+ "file"
+ }}
+ {{upload.original_filename}} |
+ {{icon "check"}}
+ {{I18n.t "discourse_ai.ai_persona.uploads.complete"}} |
+
+
+ |
+
+ {{/each}}
+ {{#each this.inProgressUploads as |upload|}}
+
+ {{icon
+ "file"
+ }}
+ {{upload.original_filename}} |
+
+
+ {{I18n.t "discourse_ai.ai_persona.uploads.uploading"}}
+ {{upload.uploadProgress}}%
+ |
+
+
+ |
+
+ {{/each}}
+
+
+
+
+
+
+
+}
diff --git a/assets/stylesheets/modules/ai-bot/common/ai-persona.scss b/assets/stylesheets/modules/ai-bot/common/ai-persona.scss
index af2b2ce3..c6ff1c01 100644
--- a/assets/stylesheets/modules/ai-bot/common/ai-persona.scss
+++ b/assets/stylesheets/modules/ai-bot/common/ai-persona.scss
@@ -76,4 +76,69 @@
display: flex;
align-items: center;
}
+
+ .persona-rag-uploader {
+ width: 500px;
+
+ &__search-input {
+ display: flex;
+ align-items: center;
+ border: 1px solid var(--primary-400);
+ width: 100%;
+ box-sizing: border-box;
+ height: 35px;
+ padding: 0 0.5rem;
+
+ &:focus,
+ &:focus-within {
+ @include default-focus();
+ }
+
+ &-container {
+ display: flex;
+ flex-grow: 1;
+ }
+
+ &__search-icon {
+ background: none !important;
+ color: var(--primary-medium);
+ }
+
+ &__input {
+ width: 100% !important;
+ }
+
+ &__input,
+ &__input:focus {
+ margin: 0 !important;
+ border: 0 !important;
+ appearance: none !important;
+ outline: none !important;
+ background: none !important;
+ }
+ }
+
+ &__uploads-list {
+ margin-bottom: 20px;
+
+ tbody {
+ border-top: none;
+ }
+ }
+
+ &__upload-status {
+ text-align: right;
+ padding-right: 0;
+ color: var(--success);
+ }
+
+ &__remove-file {
+ text-align: right;
+ padding-left: 0;
+ }
+
+ &__rag-file-icon {
+ margin-right: 5px;
+ }
+ }
}
diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml
index 33221310..fd3e6754 100644
--- a/config/locales/client.en.yml
+++ b/config/locales/client.en.yml
@@ -164,6 +164,14 @@ en:
#### Group-Specific Access to AI Personas
Moreover, you can set it up so that certain user groups have access to specific personas. This means you can have different AI behaviors for different sections of your forum, further enhancing the diversity and richness of your community's interactions.
+
+ uploads:
+ title: "Uploads"
+ description: "Your AI persona will be able to search and reference the content of included files. Uploaded files must be formatted as plaintext (.txt)"
+ hint: "To control where the file's content gets placed within the system prompt, include the {uploads} placeholder in the system prompt above."
+ button: "Add Files"
+ filter: "Filter uploads"
+ complete: "Complete"
related_topics:
title: "Related Topics"
diff --git a/config/routes.rb b/config/routes.rb
index c9c5ac8b..8d349265 100644
--- a/config/routes.rb
+++ b/config/routes.rb
@@ -41,5 +41,7 @@ Discourse::Application.routes.draw do
controller: "discourse_ai/admin/ai_personas"
post "/ai-personas/:id/create-user", to: "discourse_ai/admin/ai_personas#create_user"
+ post "/ai-personas/files/upload", to: "discourse_ai/admin/ai_personas#upload_file"
+ put "/ai-personas/:id/files/remove", to: "discourse_ai/admin/ai_personas#remove_file"
end
end
diff --git a/db/migrate/20240309034752_create_rag_document_fragment_table.rb b/db/migrate/20240309034752_create_rag_document_fragment_table.rb
new file mode 100644
index 00000000..a6a3e233
--- /dev/null
+++ b/db/migrate/20240309034752_create_rag_document_fragment_table.rb
@@ -0,0 +1,13 @@
+# frozen_string_literal: true
+
+class CreateRagDocumentFragmentTable < ActiveRecord::Migration[7.0]
+ def change
+ create_table :rag_document_fragments do |t|
+ t.text :fragment, null: false
+ t.integer :upload_id, null: false
+ t.integer :ai_persona_id, null: false
+ t.integer :fragment_number, null: false
+ t.timestamps
+ end
+ end
+end
diff --git a/db/migrate/20240313165121_embedding_tables_for_rag_uploads.rb b/db/migrate/20240313165121_embedding_tables_for_rag_uploads.rb
new file mode 100644
index 00000000..0aa7d502
--- /dev/null
+++ b/db/migrate/20240313165121_embedding_tables_for_rag_uploads.rb
@@ -0,0 +1,96 @@
+# frozen_string_literal: true
+
+class EmbeddingTablesForRagUploads < ActiveRecord::Migration[7.0]
+ def change
+ create_table :ai_document_fragment_embeddings_1_1, id: false do |t|
+ t.integer :rag_document_fragment_id, null: false
+ t.integer :model_version, null: false
+ t.integer :strategy_version, null: false
+ t.text :digest, null: false
+ t.column :embeddings, "vector(768)", null: false
+ t.timestamps
+
+ t.index :rag_document_fragment_id,
+ unique: true,
+ name: "rag_document_fragment_id_embeddings_1_1"
+ end
+
+ create_table :ai_document_fragment_embeddings_2_1, id: false do |t|
+ t.integer :rag_document_fragment_id, null: false
+ t.integer :model_version, null: false
+ t.integer :strategy_version, null: false
+ t.text :digest, null: false
+ t.column :embeddings, "vector(1536)", null: false
+ t.timestamps
+
+ t.index :rag_document_fragment_id,
+ unique: true,
+ name: "rag_document_fragment_id_embeddings_2_1"
+ end
+
+ create_table :ai_document_fragment_embeddings_3_1, id: false do |t|
+ t.integer :rag_document_fragment_id, null: false
+ t.integer :model_version, null: false
+ t.integer :strategy_version, null: false
+ t.text :digest, null: false
+ t.column :embeddings, "vector(1024)", null: false
+ t.timestamps
+
+ t.index :rag_document_fragment_id,
+ unique: true,
+ name: "rag_document_fragment_id_embeddings_3_1"
+ end
+
+ create_table :ai_document_fragment_embeddings_4_1, id: false do |t|
+ t.integer :rag_document_fragment_id, null: false
+ t.integer :model_version, null: false
+ t.integer :strategy_version, null: false
+ t.text :digest, null: false
+ t.column :embeddings, "vector(1024)", null: false
+ t.timestamps
+
+ t.index :rag_document_fragment_id,
+ unique: true,
+ name: "rag_document_fragment_id_embeddings_4_1"
+ end
+
+ create_table :ai_document_fragment_embeddings_5_1, id: false do |t|
+ t.integer :rag_document_fragment_id, null: false
+ t.integer :model_version, null: false
+ t.integer :strategy_version, null: false
+ t.text :digest, null: false
+ t.column :embeddings, "vector(768)", null: false
+ t.timestamps
+
+ t.index :rag_document_fragment_id,
+ unique: true,
+ name: "rag_document_fragment_id_embeddings_5_1"
+ end
+
+ create_table :ai_document_fragment_embeddings_6_1, id: false do |t|
+ t.integer :rag_document_fragment_id, null: false
+ t.integer :model_version, null: false
+ t.integer :strategy_version, null: false
+ t.text :digest, null: false
+ t.column :embeddings, "vector(1536)", null: false
+ t.timestamps
+
+ t.index :rag_document_fragment_id,
+ unique: true,
+ name: "rag_document_fragment_id_embeddings_6_1"
+ end
+
+ create_table :ai_document_fragment_embeddings_7_1, id: false do |t|
+ t.integer :rag_document_fragment_id, null: false
+ t.integer :model_version, null: false
+ t.integer :strategy_version, null: false
+ t.text :digest, null: false
+ t.column :embeddings, "vector(2000)", null: false
+ t.timestamps
+
+ t.index :rag_document_fragment_id,
+ unique: true,
+ name: "rag_document_fragment_id_embeddings_7_1"
+ end
+ end
+end
diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb
index 0c0562a1..88c7ef02 100644
--- a/lib/ai_bot/entry_point.rb
+++ b/lib/ai_bot/entry_point.rb
@@ -200,6 +200,17 @@ module DiscourseAi
if plugin.respond_to?(:register_editable_topic_custom_field)
plugin.register_editable_topic_custom_field(:ai_persona_id)
end
+
+ plugin.on(:site_setting_changed) do |name, old_value, new_value|
+ if name == "ai_embeddings_model" && SiteSetting.ai_embeddings_enabled? &&
+ new_value != old_value
+ RagDocumentFragment.find_in_batches do |batch|
+ batch.each_slice(100) do |fragments|
+ Jobs.enqueue(:generate_rag_embeddings, fragment_ids: fragments.map(&:id))
+ end
+ end
+ end
+ end
end
end
end
diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb
index b4c6060f..bb8e9f08 100644
--- a/lib/ai_bot/personas/persona.rb
+++ b/lib/ai_bot/personas/persona.rb
@@ -93,6 +93,10 @@ module DiscourseAi
end
end
+ def id
+ @ai_persona&.id || self.class.system_personas[self.class]
+ end
+
def tools
[]
end
@@ -124,12 +128,24 @@ module DiscourseAi
found.nil? ? match : found.to_s
end
+ prompt_insts = <<~TEXT.strip
+ #{system_insts}
+ #{available_tools.map(&:custom_system_message).compact_blank.join("\n")}
+ TEXT
+
+ fragments_guidance = rag_fragments_prompt(context[:conversation_context].to_a)&.strip
+
+ if fragments_guidance.present?
+ if system_insts.include?("{uploads}")
+ prompt_insts = prompt_insts.gsub("{uploads}", fragments_guidance)
+ else
+ prompt_insts << fragments_guidance
+ end
+ end
+
prompt =
DiscourseAi::Completions::Prompt.new(
- <<~TEXT.strip,
- #{system_insts}
- #{available_tools.map(&:custom_system_message).compact_blank.join("\n")}
- TEXT
+ prompt_insts,
messages: context[:conversation_context].to_a,
topic_id: context[:topic_id],
post_id: context[:post_id],
@@ -181,6 +197,68 @@ module DiscourseAi
persona_options: options[tool_klass].to_h,
)
end
+
+ def rag_fragments_prompt(conversation_context)
+ upload_refs =
+ UploadReference.where(target_id: id, target_type: "AiPersona").pluck(:upload_id)
+
+ return nil if !SiteSetting.ai_embeddings_enabled?
+ return nil if conversation_context.blank? || upload_refs.blank?
+
+ latest_interactions =
+ conversation_context
+ .select { |ctx| %i[model user].include?(ctx[:type]) }
+ .map { |ctx| ctx[:content] }
+ .last(10)
+ .join("\n")
+
+ strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
+ vector_rep =
+ DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
+ reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
+
+ interactions_vector = vector_rep.vector_from(latest_interactions)
+
+ candidate_fragment_ids =
+ vector_rep.asymmetric_rag_fragment_similarity_search(
+ interactions_vector,
+ persona_id: id,
+ limit: reranker.reranker_configured? ? 50 : 10,
+ offset: 0,
+ )
+
+ guidance =
+ RagDocumentFragment.where(upload_id: upload_refs, id: candidate_fragment_ids).pluck(
+ :fragment,
+ )
+
+ if reranker.reranker_configured?
+ ranks =
+ DiscourseAi::Inference::HuggingFaceTextEmbeddings
+ .rerank(conversation_context.last[:content], guidance)
+ .to_a
+ .take(10)
+ .map { _1[:index] }
+
+ if ranks.empty?
+ guidance = guidance.take(10)
+ else
+ guidance = ranks.map { |idx| guidance[idx] }
+ end
+ end
+
+ <<~TEXT
+
+ The following texts will give you additional guidance to elaborate a response.
+ We included them because we believe they are relevant to this conversation topic.
+ Take them into account to elaborate a response.
+
+ Texts:
+
+ #{guidance.join("\n")}
+
+ TEXT
+ end
end
end
end
diff --git a/lib/embeddings/strategies/truncation.rb b/lib/embeddings/strategies/truncation.rb
index b2b29041..ced670ab 100644
--- a/lib/embeddings/strategies/truncation.rb
+++ b/lib/embeddings/strategies/truncation.rb
@@ -18,6 +18,8 @@ module DiscourseAi
topic_truncation(target, tokenizer, max_length)
when Post
post_truncation(target, tokenizer, max_length)
+ when RagDocumentFragment
+ tokenizer.truncate(target.fragment, max_length)
else
raise ArgumentError, "Invalid target type"
end
diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb
index 205639b2..ea880ece 100644
--- a/lib/embeddings/vector_representations/base.rb
+++ b/lib/embeddings/vector_representations/base.rb
@@ -155,6 +155,18 @@ module DiscourseAi
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
return if text.blank?
+ target_column =
+ case target
+ when Topic
+ "topic_id"
+ when Post
+ "post_id"
+ when RagDocumentFragment
+ "rag_document_fragment_id"
+ else
+ raise ArgumentError, "Invalid target type"
+ end
+
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
current_digest = DB.query_single(<<~SQL, target_id: target.id).first
SELECT
@@ -162,7 +174,7 @@ module DiscourseAi
FROM
#{table_name(target)}
WHERE
- #{target.is_a?(Topic) ? "topic_id" : "post_id"} = :target_id
+ #{target_column} = :target_id
LIMIT 1
SQL
return if current_digest == new_digest
@@ -248,6 +260,47 @@ module DiscourseAi
raise MissingEmbeddingError
end
+ def asymmetric_rag_fragment_similarity_search(
+ raw_vector,
+ persona_id:,
+ limit:,
+ offset:,
+ return_distance: false
+ )
+ results =
+ DB.query(
+ <<~SQL,
+ #{probes_sql(post_table_name)}
+ SELECT
+ rag_document_fragment_id,
+ embeddings #{pg_function} '[:query_embedding]' AS distance
+ FROM
+ #{rag_fragments_table_name}
+ INNER JOIN
+ rag_document_fragments AS rdf ON rdf.id = rag_document_fragment_id
+ WHERE
+ rdf.ai_persona_id = :persona_id
+ ORDER BY
+ embeddings #{pg_function} '[:query_embedding]'
+ LIMIT :limit
+ OFFSET :offset
+ SQL
+ query_embedding: raw_vector,
+ persona_id: persona_id,
+ limit: limit,
+ offset: offset,
+ )
+
+ if return_distance
+ results.map { |r| [r.rag_document_fragment_id, r.distance] }
+ else
+ results.map(&:rag_document_fragment_id)
+ end
+ rescue PG::Error => e
+ Rails.logger.error("Error #{e} querying embeddings for model #{name}")
+ raise MissingEmbeddingError
+ end
+
def symmetric_topics_similarity_search(topic)
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
#{probes_sql(topic_table_name)}
@@ -282,12 +335,18 @@ module DiscourseAi
"ai_post_embeddings_#{id}_#{@strategy.id}"
end
+ def rag_fragments_table_name
+ "ai_document_fragment_embeddings_#{id}_#{@strategy.id}"
+ end
+
def table_name(target)
case target
when Topic
topic_table_name
when Post
post_table_name
+ when RagDocumentFragment
+ rag_fragments_table_name
else
raise ArgumentError, "Invalid target type"
end
@@ -375,6 +434,25 @@ module DiscourseAi
digest: digest,
embeddings: vector,
)
+ elsif target.is_a?(RagDocumentFragment)
+ DB.exec(
+ <<~SQL,
+ INSERT INTO #{rag_fragments_table_name} (rag_document_fragment_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
+ VALUES (:fragment_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
+ ON CONFLICT (rag_document_fragment_id)
+ DO UPDATE SET
+ model_version = :model_version,
+ strategy_version = :strategy_version,
+ digest = :digest,
+ embeddings = '[:embeddings]',
+ updated_at = CURRENT_TIMESTAMP
+ SQL
+ fragment_id: target.id,
+ model_version: version,
+ strategy_version: @strategy.version,
+ digest: digest,
+ embeddings: vector,
+ )
else
raise ArgumentError, "Invalid target type"
end
diff --git a/lib/inference/hugging_face_text_embeddings.rb b/lib/inference/hugging_face_text_embeddings.rb
index d0a5cddb..1daab50b 100644
--- a/lib/inference/hugging_face_text_embeddings.rb
+++ b/lib/inference/hugging_face_text_embeddings.rb
@@ -56,6 +56,11 @@ module ::DiscourseAi
JSON.parse(response.body, symbolize_names: true)
end
+ def reranker_configured?
+ SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
+ SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
+ end
+
def configured?
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?
diff --git a/plugin.rb b/plugin.rb
index aac7776d..d4964f9e 100644
--- a/plugin.rb
+++ b/plugin.rb
@@ -10,6 +10,7 @@
gem "tokenizers", "0.4.3"
gem "tiktoken_ruby", "0.0.7"
+gem "baran", "0.1.10"
enabled_site_setting :discourse_ai_enabled
diff --git a/spec/fabricators/rag_document_fragment_fabricator.rb b/spec/fabricators/rag_document_fragment_fabricator.rb
new file mode 100644
index 00000000..d70d2f9d
--- /dev/null
+++ b/spec/fabricators/rag_document_fragment_fabricator.rb
@@ -0,0 +1,7 @@
+# frozen_string_literal: true
+
+Fabricator(:rag_document_fragment) do
+ fragment { sequence(:fragment) { |n| "Document fragment #{n}" } }
+ upload
+ fragment_number { sequence(:fragment_number) { |n| n + 1 } }
+end
diff --git a/spec/jobs/regular/digest_rag_upload_spec.rb b/spec/jobs/regular/digest_rag_upload_spec.rb
new file mode 100644
index 00000000..e82b1384
--- /dev/null
+++ b/spec/jobs/regular/digest_rag_upload_spec.rb
@@ -0,0 +1,58 @@
+# frozen_string_literal: true
+
+RSpec.describe Jobs::DigestRagUpload do
+ fab!(:persona) { Fabricate(:ai_persona) }
+ fab!(:upload)
+
+ let(:document_file) { StringIO.new("some text" * 200) }
+
+ let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
+ let(:vector_rep) do
+ DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
+ end
+
+ let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
+
+ before do
+ SiteSetting.ai_embeddings_enabled = true
+ SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
+
+ WebMock.stub_request(
+ :post,
+ "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
+ ).to_return(status: 200, body: JSON.dump(expected_embedding))
+ end
+
+ describe "#execute" do
+ context "when processing an upload for the first time" do
+ before { File.expects(:open).returns(document_file) }
+
+ it "splits an upload into chunks" do
+ subject.execute(upload_id: upload.id, ai_persona_id: persona.id)
+
+ created_fragment = RagDocumentFragment.last
+
+ expect(created_fragment).to be_present
+ expect(created_fragment.fragment).to be_present
+ expect(created_fragment.fragment_number).to eq(2)
+ end
+
+ it "queue jobs to generate embeddings for each fragment" do
+ expect { subject.execute(upload_id: upload.id, ai_persona_id: persona.id) }.to change(
+ Jobs::GenerateRagEmbeddings.jobs,
+ :size,
+ ).by(1)
+ end
+ end
+
+ it "doesn't generate new fragments if we already processed the upload" do
+ Fabricate(:rag_document_fragment, upload: upload, ai_persona: persona)
+ previous_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count
+
+ subject.execute(upload_id: upload.id, ai_persona_id: persona.id)
+ updated_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count
+
+ expect(updated_count).to eq(previous_count)
+ end
+ end
+end
diff --git a/spec/jobs/regular/generate_rag_embeddings_spec.rb b/spec/jobs/regular/generate_rag_embeddings_spec.rb
new file mode 100644
index 00000000..b8ead8e2
--- /dev/null
+++ b/spec/jobs/regular/generate_rag_embeddings_spec.rb
@@ -0,0 +1,38 @@
+# frozen_string_literal: true
+
+RSpec.describe Jobs::GenerateRagEmbeddings do
+ describe "#execute" do
+ let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
+ let(:vector_rep) do
+ DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
+ end
+
+ let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
+
+ fab!(:ai_persona)
+
+ fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, ai_persona: ai_persona) }
+ fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, ai_persona: ai_persona) }
+
+ before do
+ SiteSetting.ai_embeddings_enabled = true
+ SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
+
+ WebMock.stub_request(
+ :post,
+ "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
+ ).to_return(status: 200, body: JSON.dump(expected_embedding))
+ end
+
+ it "generates a new vector for each fragment" do
+ expected_embeddings = 2
+
+ subject.execute(fragment_ids: [rag_document_fragment_1.id, rag_document_fragment_2.id])
+
+ embeddings_count =
+ DB.query_single("SELECT COUNT(*) from #{vector_rep.rag_fragments_table_name}").first
+
+ expect(embeddings_count).to eq(expected_embeddings)
+ end
+ end
+end
diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb
index dcbb32c0..673b0c28 100644
--- a/spec/lib/modules/ai_bot/personas/persona_spec.rb
+++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb
@@ -196,4 +196,128 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
)
end
end
+
+ describe "#craft_prompt" do
+ before do
+ Group.refresh_automatic_groups!
+ SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
+ SiteSetting.ai_embeddings_enabled = true
+ end
+
+ let(:ai_persona) { DiscourseAi::AiBot::Personas::Persona.all(user: user).first.new }
+
+ let(:with_cc) do
+ context.merge(conversation_context: [{ content: "Tell me the time", type: :user }])
+ end
+
+ context "when a persona has no uploads" do
+ it "doesn't include RAG guidance" do
+ guidance_fragment =
+ "The following texts will give you additional guidance to elaborate a response."
+
+ expect(ai_persona.craft_prompt(with_cc).messages.first[:content]).not_to include(
+ guidance_fragment,
+ )
+ end
+ end
+
+ context "when a persona has RAG uploads" do
+ fab!(:upload)
+
+ def stub_fragments(limit)
+ candidate_ids = []
+
+ limit.times do |i|
+ candidate_ids << Fabricate(
+ :rag_document_fragment,
+ fragment: "fragment-n#{i}",
+ ai_persona_id: ai_persona.id,
+ upload: upload,
+ ).id
+ end
+
+ DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
+ .any_instance
+ .expects(:asymmetric_rag_fragment_similarity_search)
+ .returns(candidate_ids)
+ end
+
+ before do
+ stored_ai_persona = AiPersona.find(ai_persona.id)
+ UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
+
+ context_embedding = [0.049382, 0.9999]
+ EmbeddingsGenerationStubs.discourse_service(
+ SiteSetting.ai_embeddings_model,
+ with_cc.dig(:conversation_context, 0, :content),
+ context_embedding,
+ )
+ end
+
+ context "when the system prompt has an uploads placeholder" do
+ before { stub_fragments(10) }
+
+ it "replaces the placeholder with the fragments" do
+ custom_persona_record =
+ AiPersona.create!(
+ name: "custom",
+ description: "description",
+ system_prompt: "instructions\n{uploads}\nmore instructions",
+ allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
+ )
+ UploadReference.ensure_exist!(target: custom_persona_record, upload_ids: [upload.id])
+ custom_persona =
+ DiscourseAi::AiBot::Personas::Persona.find_by(
+ id: custom_persona_record.id,
+ user: user,
+ ).new
+
+ crafted_system_prompt = custom_persona.craft_prompt(with_cc).messages.first[:content]
+
+ expect(crafted_system_prompt).to include("fragment-n0")
+
+ expect(crafted_system_prompt.ends_with?("")).to eq(false)
+ end
+ end
+
+ context "when the reranker is available" do
+ before do
+ SiteSetting.ai_hugging_face_tei_reranker_endpoint = "https://test.reranker.com"
+
+ stub_fragments(15) # Mimic limit being more than 10 results
+ end
+
+ it "uses the re-ranker to reorder the fragments and pick the top 10 candidates" do
+ expected_reranked = (0..14).to_a.reverse.map { |idx| { index: idx } }
+
+ WebMock.stub_request(:post, "https://test.reranker.com/rerank").to_return(
+ status: 200,
+ body: JSON.dump(expected_reranked),
+ )
+
+ crafted_system_prompt = ai_persona.craft_prompt(with_cc).messages.first[:content]
+
+ expect(crafted_system_prompt).to include("fragment-n14")
+ expect(crafted_system_prompt).to include("fragment-n13")
+ expect(crafted_system_prompt).to include("fragment-n12")
+
+ expect(crafted_system_prompt).not_to include("fragment-n4") # Fragment #11 not included
+ end
+ end
+
+ context "when the reranker is not available" do
+ before { stub_fragments(10) }
+
+ it "picks the first 10 candidates from the similarity search" do
+ crafted_system_prompt = ai_persona.craft_prompt(with_cc).messages.first[:content]
+
+ expect(crafted_system_prompt).to include("fragment-n0")
+ expect(crafted_system_prompt).to include("fragment-n1")
+ expect(crafted_system_prompt).to include("fragment-n2")
+
+ expect(crafted_system_prompt).not_to include("fragment-n10") # Fragment #10 not included
+ end
+ end
+ end
+ end
end
diff --git a/spec/models/rag_document_fragment_spec.rb b/spec/models/rag_document_fragment_spec.rb
new file mode 100644
index 00000000..3d16218c
--- /dev/null
+++ b/spec/models/rag_document_fragment_spec.rb
@@ -0,0 +1,76 @@
+# frozen_string_literal: true
+
+RSpec.describe RagDocumentFragment do
+ fab!(:persona) { Fabricate(:ai_persona) }
+ fab!(:upload_1) { Fabricate(:upload) }
+ fab!(:upload_2) { Fabricate(:upload) }
+
+ before do
+ SiteSetting.ai_embeddings_enabled = true
+ SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
+ end
+
+ describe ".link_uploads_and_persona" do
+ it "does nothing if there is no persona" do
+ expect { described_class.link_persona_and_uploads(nil, [upload_1.id]) }.not_to change(
+ Jobs::DigestRagUpload.jobs,
+ :size,
+ )
+ end
+
+ it "does nothing if there are no uploads" do
+ expect { described_class.link_persona_and_uploads(persona, []) }.not_to change(
+ Jobs::DigestRagUpload.jobs,
+ :size,
+ )
+ end
+
+ it "queues a job for each upload to generate fragments" do
+ expect {
+ described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id])
+ }.to change(Jobs::DigestRagUpload.jobs, :size).by(2)
+ end
+
+ it "creates references between the persona an each upload" do
+ described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id])
+
+ refs = UploadReference.where(target: persona).pluck(:upload_id)
+
+ expect(refs).to contain_exactly(upload_1.id, upload_2.id)
+ end
+ end
+
+ describe ".update_persona_uploads" do
+ it "does nothing if there is no persona" do
+ expect { described_class.update_persona_uploads(nil, [upload_1.id]) }.not_to change(
+ Jobs::DigestRagUpload.jobs,
+ :size,
+ )
+ end
+
+ it "deletes the fragment if its not present in the uploads list" do
+ fragment = Fabricate(:rag_document_fragment, ai_persona: persona)
+
+ described_class.update_persona_uploads(persona, [])
+
+ expect { fragment.reload }.to raise_error(ActiveRecord::RecordNotFound)
+ end
+
+ it "delete references between the upload and the persona" do
+ described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id])
+
+ described_class.update_persona_uploads(persona, [upload_2.id])
+
+ refs = UploadReference.where(target: persona).pluck(:upload_id)
+
+ expect(refs).to contain_exactly(upload_2.id)
+ end
+
+ it "queues jobs to generate new fragments" do
+ expect { described_class.update_persona_uploads(persona, [upload_1.id]) }.to change(
+ Jobs::DigestRagUpload.jobs,
+ :size,
+ ).by(1)
+ end
+ end
+end
diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb
index bd7cbefe..cf0aef7f 100644
--- a/spec/requests/admin/ai_personas_controller_spec.rb
+++ b/spec/requests/admin/ai_personas_controller_spec.rb
@@ -4,7 +4,12 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
fab!(:admin)
fab!(:ai_persona)
- before { sign_in(admin) }
+ before do
+ sign_in(admin)
+
+ SiteSetting.ai_embeddings_enabled = true
+ SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
+ end
describe "GET #index" do
it "returns a success response" do
@@ -125,6 +130,21 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(response).to be_successful
expect(response.parsed_body["ai_persona"]["name"]).to eq(ai_persona.name)
end
+
+ it "includes rag uploads for each persona" do
+ upload = Fabricate(:upload)
+ RagDocumentFragment.link_persona_and_uploads(ai_persona, [upload.id])
+
+ get "/admin/plugins/discourse-ai/ai-personas/#{ai_persona.id}.json"
+ expect(response).to be_successful
+
+ serialized_persona = response.parsed_body["ai_persona"]
+
+ expect(serialized_persona.dig("rag_uploads", 0, "id")).to eq(upload.id)
+ expect(serialized_persona.dig("rag_uploads", 0, "original_filename")).to eq(
+ upload.original_filename,
+ )
+ end
end
describe "POST #create" do
@@ -323,6 +343,17 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
end
end
+ describe "POST #upload_file" do
+ it "works" do
+ post "/admin/plugins/discourse-ai/ai-personas/files/upload.json",
+ params: {
+ file: Rack::Test::UploadedFile.new(file_from_fixtures("spec.txt", "md")),
+ }
+
+ expect(response.status).to eq(200)
+ end
+ end
+
describe "DELETE #destroy" do
it "destroys the requested ai_persona" do
expect {
diff --git a/test/javascripts/unit/models/ai-persona-test.js b/test/javascripts/unit/models/ai-persona-test.js
index 7f30f4ee..cdae4395 100644
--- a/test/javascripts/unit/models/ai-persona-test.js
+++ b/test/javascripts/unit/models/ai-persona-test.js
@@ -48,6 +48,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
max_context_posts: 5,
vision_enabled: true,
vision_max_pixels: 100,
+ rag_uploads: [],
};
const aiPersona = AiPersona.create({ ...properties });
@@ -81,6 +82,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
max_context_posts: 5,
vision_enabled: true,
vision_max_pixels: 100,
+ rag_uploads: [],
};
const aiPersona = AiPersona.create({ ...properties });