FEATURE: AI Bot RAG support. (#537)

This PR lets you associate uploads to an AI persona, which we'll split and generate embeddings from. When building the system prompt to get a bot reply, we'll do a similarity search followed by a re-ranking (if available). This will let us find the most relevant fragments from the body of knowledge you associated with the persona, resulting in better, more informed responses.

For now, we'll only allow plain-text files, but this will change in the future.

Commits:

* FEATURE: RAG embeddings for the AI Bot

This first commit introduces a UI where admins can upload text files, which we'll store, split into fragments,
and generate embeddings of. In a next commit, we'll use those to give the bot additional information during
conversations.

* Basic asymmetric similarity search to provide guidance in system prompt

* Fix tests and lint

* Apply reranker to fragments

* Uploads filter, css adjustments and file validations

* Add placeholder for rag fragments

* Update annotations
This commit is contained in:
Roman Rizzi 2024-04-01 13:43:34 -03:00 committed by GitHub
parent a2018d4a04
commit 1f1c94e5c6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
29 changed files with 1132 additions and 35 deletions

View File

@ -5,6 +5,7 @@ export default DiscourseRoute.extend({
async model() {
const record = this.store.createRecord("ai-persona");
record.set("allowed_group_ids", [AUTO_GROUPS.trust_level_0.id]);
record.set("rag_uploads", []);
return record;
},

View File

@ -30,8 +30,10 @@ module DiscourseAi
end
def create
ai_persona = AiPersona.new(ai_persona_params)
ai_persona = AiPersona.new(ai_persona_params.except(:rag_uploads))
if ai_persona.save
RagDocumentFragment.link_persona_and_uploads(ai_persona, attached_upload_ids)
render json: { ai_persona: ai_persona }, status: :created
else
render_json_error ai_persona
@ -44,7 +46,9 @@ module DiscourseAi
end
def update
if @ai_persona.update(ai_persona_params)
if @ai_persona.update(ai_persona_params.except(:rag_uploads))
RagDocumentFragment.update_persona_uploads(@ai_persona, attached_upload_ids)
render json: @ai_persona
else
render_json_error @ai_persona
@ -59,12 +63,43 @@ module DiscourseAi
end
end
def upload_file
file = params[:file] || params[:files].first
if !SiteSetting.ai_embeddings_enabled?
raise Discourse::InvalidAccess.new("Embeddings not enabled")
end
validate_extension!(file.original_filename)
validate_file_size!(file.tempfile.size)
hijack do
upload =
UploadCreator.new(
file.tempfile,
file.original_filename,
type: "discourse_ai_rag_upload",
skip_validations: true,
).create_for(current_user.id)
if upload.persisted?
render json: UploadSerializer.new(upload)
else
render json: failed_json.merge(errors: upload.errors.full_messages), status: 422
end
end
end
private
def find_ai_persona
@ai_persona = AiPersona.find(params[:id])
end
def attached_upload_ids
ai_persona_params[:rag_uploads].to_a.map { |h| h[:id] }
end
def ai_persona_params
permitted =
params.require(:ai_persona).permit(
@ -82,6 +117,7 @@ module DiscourseAi
:vision_enabled,
:vision_max_pixels,
allowed_group_ids: [],
rag_uploads: [:id],
)
if commands = params.dig(:ai_persona, :commands)
@ -105,6 +141,28 @@ module DiscourseAi
end
end
end
def validate_extension!(filename)
extension = File.extname(filename)[1..-1] || ""
authorized_extension = "txt"
if extension != authorized_extension
raise Discourse::InvalidParameters.new(
I18n.t("upload.unauthorized", authorized_extensions: authorized_extension),
)
end
end
def validate_file_size!(filesize)
max_size_bytes = 20.megabytes
if filesize > max_size_bytes
raise Discourse::InvalidParameters.new(
I18n.t(
"upload.attachments.too_large_humanized",
max_size: ActiveSupport::NumberHelper.number_to_human_size(max_size_bytes),
),
)
end
end
end
end
end

View File

@ -0,0 +1,68 @@
# frozen_string_literal: true
module ::Jobs
class DigestRagUpload < ::Jobs::Base
# TODO(roman): Add a way to automatically recover from errors, resulting in unindexed uploads.
def execute(args)
return if (upload = Upload.find_by(id: args[:upload_id])).nil?
return if (ai_persona = AiPersona.find_by(id: args[:ai_persona_id])).nil?
fragment_ids = RagDocumentFragment.where(ai_persona: ai_persona, upload: upload).pluck(:id)
# Check if this is the first time we process this upload.
if fragment_ids.empty?
document = get_uploaded_file(upload)
return if document.nil?
chunk_size = 1024
chunk_overlap = 64
chunks = []
overlap = ""
splitter =
Baran::RecursiveCharacterTextSplitter.new(
chunk_size: chunk_size,
chunk_overlap: chunk_overlap,
separators: ["\n\n", "\n", " ", ""],
)
while raw_text = document.read(2048)
splitter.chunks(overlap + raw_text).each { |chunk| chunks << chunk[:text] }
overlap = chunks.last[-chunk_overlap..-1] || chunks.last
end
ActiveRecord::Base.transaction do
fragment_ids =
chunks.each_with_index.map do |fragment_text, idx|
RagDocumentFragment.create!(
ai_persona: ai_persona,
fragment: Encodings.to_utf8(fragment_text),
fragment_number: idx + 1,
upload: upload,
).id
end
end
end
fragment_ids.each_slice(50) do |slice|
Jobs.enqueue(:generate_rag_embeddings, fragment_ids: slice)
end
end
private
def get_uploaded_file(upload)
store = Discourse.store
@file ||=
if store.external?
# Upload#filesize could be approximate.
# add two extra Mbs to make sure that we'll be able to download the upload.
max_filesize = upload.filesize + 2.megabytes
store.download(upload, max_file_size_kb: max_filesize)
else
File.open(store.path_for(upload))
end
end
end
end

View File

@ -0,0 +1,17 @@
# frozen_string_literal: true
module ::Jobs
class GenerateRagEmbeddings < ::Jobs::Base
def execute(args)
return if (fragments = RagDocumentFragment.where(id: args[:fragment_ids].to_a)).empty?
truncation = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
# generate_representation_from checks compares the digest value to make sure
# the embedding is only generated once per fragment unless something changes.
fragments.map { |fragment| vector_rep.generate_representation_from(fragment) }
end
end
end

View File

@ -16,6 +16,13 @@ class AiPersona < ActiveRecord::Base
belongs_to :created_by, class_name: "User"
belongs_to :user
has_many :upload_references, as: :target, dependent: :destroy
has_many :uploads, through: :upload_references
has_many :rag_document_fragment, dependent: :destroy
has_many :rag_document_fragments, through: :ai_persona_rag_document_fragments
before_destroy :ensure_not_system
class MultisiteHash
@ -238,6 +245,10 @@ class AiPersona < ActiveRecord::Base
super(*args, **kwargs)
end
define_method :persona_id do
@ai_persona&.id
end
define_method :tools do
tools
end
@ -257,6 +268,10 @@ class AiPersona < ActiveRecord::Base
define_method :system_prompt do
@ai_persona&.system_prompt || "You are a helpful bot."
end
define_method :uploads do
@ai_persona&.uploads
end
end
end
@ -320,26 +335,26 @@ end
#
# Table name: ai_personas
#
# id :bigint not null, primary key
# name :string(100) not null
# description :string(2000) not null
# commands :json not null
# system_prompt :string(10000000) not null
# allowed_group_ids :integer default([]), not null, is an Array
# created_by_id :integer
# enabled :boolean default(TRUE), not null
# created_at :datetime not null
# updated_at :datetime not null
# system :boolean default(FALSE), not null
# priority :boolean default(FALSE), not null
# temperature :float
# top_p :float
# user_id :integer
# mentionable :boolean default(FALSE), not null
# default_llm :text
# max_context_posts :integer
# max_post_context_tokens :integer
# max_context_tokens :integer
# id :bigint not null, primary key
# name :string(100) not null
# description :string(2000) not null
# commands :json not null
# system_prompt :string(10000000) not null
# allowed_group_ids :integer default([]), not null, is an Array
# created_by_id :integer
# enabled :boolean default(TRUE), not null
# created_at :datetime not null
# updated_at :datetime not null
# system :boolean default(FALSE), not null
# priority :boolean default(FALSE), not null
# temperature :float
# top_p :float
# user_id :integer
# mentionable :boolean default(FALSE), not null
# default_llm :text
# max_context_posts :integer
# vision_enabled :boolean default(FALSE), not null
# vision_max_pixels :integer default(1048576), not null
#
# Indexes
#

View File

@ -7,3 +7,18 @@ end
class ::Post
has_one :post_custom_prompt, dependent: :destroy
end
# == Schema Information
#
# Table name: post_custom_prompts
#
# id :bigint not null, primary key
# post_id :integer not null
# custom_prompt :json not null
# created_at :datetime not null
# updated_at :datetime not null
#
# Indexes
#
# index_post_custom_prompts_on_post_id (post_id) UNIQUE
#

View File

@ -0,0 +1,46 @@
# frozen_string_literal: true
class RagDocumentFragment < ActiveRecord::Base
belongs_to :upload
belongs_to :ai_persona
class << self
def link_persona_and_uploads(persona, upload_ids)
return if persona.blank?
return if upload_ids.blank?
return if !SiteSetting.ai_embeddings_enabled?
UploadReference.ensure_exist!(upload_ids: upload_ids, target: persona)
upload_ids.each do |upload_id|
Jobs.enqueue(:digest_rag_upload, ai_persona_id: persona.id, upload_id: upload_id)
end
end
def update_persona_uploads(persona, upload_ids)
return if persona.blank?
return if !SiteSetting.ai_embeddings_enabled?
if upload_ids.blank?
RagDocumentFragment.where(ai_persona: persona).destroy_all
UploadReference.where(target: persona).destroy_all
else
RagDocumentFragment.where(ai_persona: persona).where.not(upload_id: upload_ids).destroy_all
link_persona_and_uploads(persona, upload_ids)
end
end
end
end
# == Schema Information
#
# Table name: rag_document_fragments
#
# id :bigint not null, primary key
# fragment :text not null
# ai_persona_id :integer not null
# upload_id :integer not null
# fragment_number :integer not null
# created_at :datetime not null
# updated_at :datetime not null
#

View File

@ -22,6 +22,11 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
:vision_max_pixels
has_one :user, serializer: BasicUserSerializer, embed: :object
has_many :rag_uploads, serializer: UploadSerializer, embed: :object
def rag_uploads
object.uploads
end
def name
object.class_instance.name

View File

@ -20,6 +20,7 @@ const ATTRIBUTES = [
"max_context_posts",
"vision_enabled",
"vision_max_pixels",
"rag_uploads",
];
const SYSTEM_ATTRIBUTES = [
@ -34,6 +35,7 @@ const SYSTEM_ATTRIBUTES = [
"max_context_posts",
"vision_enabled",
"vision_max_pixels",
"rag_uploads",
];
class CommandOption {

View File

@ -23,18 +23,19 @@ import DTooltip from "float-kit/components/d-tooltip";
import AiCommandSelector from "./ai-command-selector";
import AiLlmSelector from "./ai-llm-selector";
import AiPersonaCommandOptions from "./ai-persona-command-options";
import PersonaRagUploader from "./persona-rag-uploader";
export default class PersonaEditor extends Component {
@service router;
@service store;
@service dialog;
@service toasts;
@service siteSettings;
@tracked allGroups = [];
@tracked isSaving = false;
@tracked editingModel = null;
@tracked showDelete = false;
@tracked maxPixelsValue = null;
@action
@ -190,6 +191,20 @@ export default class PersonaEditor extends Component {
}
}
@action
addUpload(upload) {
const newUpload = upload;
newUpload.status = "uploaded";
newUpload.statusText = I18n.t("discourse_ai.ai_persona.uploads.uploaded");
this.editingModel.rag_uploads.addObject(newUpload);
}
@action
removeUpload(upload) {
this.editingModel.rag_uploads.removeObject(upload);
this.save();
}
async toggleField(field, sortPersonas) {
this.args.model.set(field, !this.args.model[field]);
this.editingModel.set(field, this.args.model[field]);
@ -392,8 +407,8 @@ export default class PersonaEditor extends Component {
/>
</div>
{{/if}}
{{#if this.showTemperature}}
<div class="control-group">
<div class="control-group">
{{#if this.showTemperature}}
<label>{{I18n.t "discourse_ai.ai_persona.temperature"}}</label>
<Input
@type="number"
@ -407,10 +422,8 @@ export default class PersonaEditor extends Component {
@icon="question-circle"
@content={{I18n.t "discourse_ai.ai_persona.temperature_help"}}
/>
</div>
{{/if}}
{{#if this.showTopP}}
<div class="control-group">
{{/if}}
{{#if this.showTopP}}
<label>{{I18n.t "discourse_ai.ai_persona.top_p"}}</label>
<Input
@type="number"
@ -424,6 +437,15 @@ export default class PersonaEditor extends Component {
@icon="question-circle"
@content={{I18n.t "discourse_ai.ai_persona.top_p_help"}}
/>
{{/if}}
</div>
{{#if this.siteSettings.ai_embeddings_enabled}}
<div class="control-group">
<PersonaRagUploader
@ragUploads={{this.editingModel.rag_uploads}}
@onAdd={{this.addUpload}}
@onRemove={{this.removeUpload}}
/>
</div>
{{/if}}
<div class="control-group ai-persona-editor__action_panel">

View File

@ -0,0 +1,153 @@
import { tracked } from "@glimmer/tracking";
import Component, { Input } from "@ember/component";
import { fn } from "@ember/helper";
import { on } from "@ember/modifier";
import { action } from "@ember/object";
import { inject as service } from "@ember/service";
import DButton from "discourse/components/d-button";
import UppyUploadMixin from "discourse/mixins/uppy-upload";
import icon from "discourse-common/helpers/d-icon";
import discourseDebounce from "discourse-common/lib/debounce";
import I18n from "discourse-i18n";
export default class PersonaRagUploader extends Component.extend(
UppyUploadMixin
) {
@service appEvents;
@tracked term = null;
@tracked filteredUploads = null;
id = "discourse-ai-persona-rag-uploader";
maxFiles = 20;
uploadUrl = "/admin/plugins/discourse-ai/ai-personas/files/upload";
preventDirectS3Uploads = true;
didReceiveAttrs() {
super.didReceiveAttrs(...arguments);
if (this.inProgressUploads?.length > 0) {
this._uppyInstance?.cancelAll();
}
this.filteredUploads = this.ragUploads || [];
}
uploadDone(uploadedFile) {
this.onAdd(uploadedFile.upload);
this.debouncedSearch();
}
@action
submitFiles() {
this.fileInputEl.click();
}
@action
cancelUploading(upload) {
this.appEvents.trigger(`upload-mixin:${this.id}:cancel-upload`, {
fileId: upload.id,
});
}
@action
search() {
if (this.term) {
this.filteredUploads = this.ragUploads.filter((u) => {
return (
u.original_filename.toUpperCase().indexOf(this.term.toUpperCase()) >
-1
);
});
} else {
this.filteredUploads = this.ragUploads;
}
}
@action
debouncedSearch() {
discourseDebounce(this, this.search, 100);
}
<template>
<div class="persona-rag-uploader">
<h3>{{I18n.t "discourse_ai.ai_persona.uploads.title"}}</h3>
<p>{{I18n.t "discourse_ai.ai_persona.uploads.description"}}</p>
<p>{{I18n.t "discourse_ai.ai_persona.uploads.hint"}}</p>
<div class="persona-rag-uploader__search-input-container">
<div class="persona-rag-uploader__search-input">
{{icon
"search"
class="persona-rag-uploader__search-input__search-icon"
}}
<Input
class="persona-rag-uploader__search-input__input"
placeholder={{I18n.t "discourse_ai.ai_persona.uploads.filter"}}
@value={{this.term}}
{{on "keyup" this.debouncedSearch}}
/>
</div>
</div>
<table class="persona-rag-uploader__uploads-list">
<tbody>
{{#each this.filteredUploads as |upload|}}
<tr>
<td>
<span class="persona-rag-uploader__rag-file-icon">{{icon
"file"
}}</span>
{{upload.original_filename}}</td>
<td class="persona-rag-uploader__upload-status">{{icon "check"}}
{{I18n.t "discourse_ai.ai_persona.uploads.complete"}}</td>
<td class="persona-rag-uploader__remove-file">
<DButton
@icon="times"
@title="discourse_ai.ai_persona.uploads.remove"
@action={{fn @onRemove upload}}
@class="btn-flat"
/>
</td>
</tr>
{{/each}}
{{#each this.inProgressUploads as |upload|}}
<tr>
<td><span class="persona-rag-uploader__rag-file-icon">{{icon
"file"
}}</span>
{{upload.original_filename}}</td>
<td class="persona-rag-uploader__upload-status">
<div class="spinner small"></div>
<span>{{I18n.t "discourse_ai.ai_persona.uploads.uploading"}}
{{upload.uploadProgress}}%</span>
</td>
<td class="persona-rag-uploader__remove-file">
<DButton
@icon="times"
@title="discourse_ai.ai_persona.uploads.remove"
@action={{fn this.cancelUploading upload}}
@class="btn-flat"
/>
</td>
</tr>
{{/each}}
</tbody>
</table>
<input
class="hidden-upload-field"
disabled={{this.uploading}}
type="file"
multiple="multiple"
accept=".txt"
/>
<DButton
@label="discourse_ai.ai_persona.uploads.button"
@icon="plus"
@title="discourse_ai.ai_persona.uploads.button"
@action={{this.submitFiles}}
class="btn-default"
/>
</div>
</template>
}

View File

@ -76,4 +76,69 @@
display: flex;
align-items: center;
}
.persona-rag-uploader {
width: 500px;
&__search-input {
display: flex;
align-items: center;
border: 1px solid var(--primary-400);
width: 100%;
box-sizing: border-box;
height: 35px;
padding: 0 0.5rem;
&:focus,
&:focus-within {
@include default-focus();
}
&-container {
display: flex;
flex-grow: 1;
}
&__search-icon {
background: none !important;
color: var(--primary-medium);
}
&__input {
width: 100% !important;
}
&__input,
&__input:focus {
margin: 0 !important;
border: 0 !important;
appearance: none !important;
outline: none !important;
background: none !important;
}
}
&__uploads-list {
margin-bottom: 20px;
tbody {
border-top: none;
}
}
&__upload-status {
text-align: right;
padding-right: 0;
color: var(--success);
}
&__remove-file {
text-align: right;
padding-left: 0;
}
&__rag-file-icon {
margin-right: 5px;
}
}
}

View File

@ -164,6 +164,14 @@ en:
#### Group-Specific Access to AI Personas
Moreover, you can set it up so that certain user groups have access to specific personas. This means you can have different AI behaviors for different sections of your forum, further enhancing the diversity and richness of your community's interactions.
uploads:
title: "Uploads"
description: "Your AI persona will be able to search and reference the content of included files. Uploaded files must be formatted as plaintext (.txt)"
hint: "To control where the file's content gets placed within the system prompt, include the {uploads} placeholder in the system prompt above."
button: "Add Files"
filter: "Filter uploads"
complete: "Complete"
related_topics:
title: "Related Topics"

View File

@ -41,5 +41,7 @@ Discourse::Application.routes.draw do
controller: "discourse_ai/admin/ai_personas"
post "/ai-personas/:id/create-user", to: "discourse_ai/admin/ai_personas#create_user"
post "/ai-personas/files/upload", to: "discourse_ai/admin/ai_personas#upload_file"
put "/ai-personas/:id/files/remove", to: "discourse_ai/admin/ai_personas#remove_file"
end
end

View File

@ -0,0 +1,13 @@
# frozen_string_literal: true
class CreateRagDocumentFragmentTable < ActiveRecord::Migration[7.0]
def change
create_table :rag_document_fragments do |t|
t.text :fragment, null: false
t.integer :upload_id, null: false
t.integer :ai_persona_id, null: false
t.integer :fragment_number, null: false
t.timestamps
end
end
end

View File

@ -0,0 +1,96 @@
# frozen_string_literal: true
class EmbeddingTablesForRagUploads < ActiveRecord::Migration[7.0]
def change
create_table :ai_document_fragment_embeddings_1_1, id: false do |t|
t.integer :rag_document_fragment_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(768)", null: false
t.timestamps
t.index :rag_document_fragment_id,
unique: true,
name: "rag_document_fragment_id_embeddings_1_1"
end
create_table :ai_document_fragment_embeddings_2_1, id: false do |t|
t.integer :rag_document_fragment_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(1536)", null: false
t.timestamps
t.index :rag_document_fragment_id,
unique: true,
name: "rag_document_fragment_id_embeddings_2_1"
end
create_table :ai_document_fragment_embeddings_3_1, id: false do |t|
t.integer :rag_document_fragment_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(1024)", null: false
t.timestamps
t.index :rag_document_fragment_id,
unique: true,
name: "rag_document_fragment_id_embeddings_3_1"
end
create_table :ai_document_fragment_embeddings_4_1, id: false do |t|
t.integer :rag_document_fragment_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(1024)", null: false
t.timestamps
t.index :rag_document_fragment_id,
unique: true,
name: "rag_document_fragment_id_embeddings_4_1"
end
create_table :ai_document_fragment_embeddings_5_1, id: false do |t|
t.integer :rag_document_fragment_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(768)", null: false
t.timestamps
t.index :rag_document_fragment_id,
unique: true,
name: "rag_document_fragment_id_embeddings_5_1"
end
create_table :ai_document_fragment_embeddings_6_1, id: false do |t|
t.integer :rag_document_fragment_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(1536)", null: false
t.timestamps
t.index :rag_document_fragment_id,
unique: true,
name: "rag_document_fragment_id_embeddings_6_1"
end
create_table :ai_document_fragment_embeddings_7_1, id: false do |t|
t.integer :rag_document_fragment_id, null: false
t.integer :model_version, null: false
t.integer :strategy_version, null: false
t.text :digest, null: false
t.column :embeddings, "vector(2000)", null: false
t.timestamps
t.index :rag_document_fragment_id,
unique: true,
name: "rag_document_fragment_id_embeddings_7_1"
end
end
end

View File

@ -200,6 +200,17 @@ module DiscourseAi
if plugin.respond_to?(:register_editable_topic_custom_field)
plugin.register_editable_topic_custom_field(:ai_persona_id)
end
plugin.on(:site_setting_changed) do |name, old_value, new_value|
if name == "ai_embeddings_model" && SiteSetting.ai_embeddings_enabled? &&
new_value != old_value
RagDocumentFragment.find_in_batches do |batch|
batch.each_slice(100) do |fragments|
Jobs.enqueue(:generate_rag_embeddings, fragment_ids: fragments.map(&:id))
end
end
end
end
end
end
end

View File

@ -93,6 +93,10 @@ module DiscourseAi
end
end
def id
@ai_persona&.id || self.class.system_personas[self.class]
end
def tools
[]
end
@ -124,12 +128,24 @@ module DiscourseAi
found.nil? ? match : found.to_s
end
prompt_insts = <<~TEXT.strip
#{system_insts}
#{available_tools.map(&:custom_system_message).compact_blank.join("\n")}
TEXT
fragments_guidance = rag_fragments_prompt(context[:conversation_context].to_a)&.strip
if fragments_guidance.present?
if system_insts.include?("{uploads}")
prompt_insts = prompt_insts.gsub("{uploads}", fragments_guidance)
else
prompt_insts << fragments_guidance
end
end
prompt =
DiscourseAi::Completions::Prompt.new(
<<~TEXT.strip,
#{system_insts}
#{available_tools.map(&:custom_system_message).compact_blank.join("\n")}
TEXT
prompt_insts,
messages: context[:conversation_context].to_a,
topic_id: context[:topic_id],
post_id: context[:post_id],
@ -181,6 +197,68 @@ module DiscourseAi
persona_options: options[tool_klass].to_h,
)
end
def rag_fragments_prompt(conversation_context)
upload_refs =
UploadReference.where(target_id: id, target_type: "AiPersona").pluck(:upload_id)
return nil if !SiteSetting.ai_embeddings_enabled?
return nil if conversation_context.blank? || upload_refs.blank?
latest_interactions =
conversation_context
.select { |ctx| %i[model user].include?(ctx[:type]) }
.map { |ctx| ctx[:content] }
.last(10)
.join("\n")
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
interactions_vector = vector_rep.vector_from(latest_interactions)
candidate_fragment_ids =
vector_rep.asymmetric_rag_fragment_similarity_search(
interactions_vector,
persona_id: id,
limit: reranker.reranker_configured? ? 50 : 10,
offset: 0,
)
guidance =
RagDocumentFragment.where(upload_id: upload_refs, id: candidate_fragment_ids).pluck(
:fragment,
)
if reranker.reranker_configured?
ranks =
DiscourseAi::Inference::HuggingFaceTextEmbeddings
.rerank(conversation_context.last[:content], guidance)
.to_a
.take(10)
.map { _1[:index] }
if ranks.empty?
guidance = guidance.take(10)
else
guidance = ranks.map { |idx| guidance[idx] }
end
end
<<~TEXT
<guidance>
The following texts will give you additional guidance to elaborate a response.
We included them because we believe they are relevant to this conversation topic.
Take them into account to elaborate a response.
Texts:
#{guidance.join("\n")}
</guidance>
TEXT
end
end
end
end

View File

@ -18,6 +18,8 @@ module DiscourseAi
topic_truncation(target, tokenizer, max_length)
when Post
post_truncation(target, tokenizer, max_length)
when RagDocumentFragment
tokenizer.truncate(target.fragment, max_length)
else
raise ArgumentError, "Invalid target type"
end

View File

@ -155,6 +155,18 @@ module DiscourseAi
text = @strategy.prepare_text_from(target, tokenizer, max_sequence_length - 2)
return if text.blank?
target_column =
case target
when Topic
"topic_id"
when Post
"post_id"
when RagDocumentFragment
"rag_document_fragment_id"
else
raise ArgumentError, "Invalid target type"
end
new_digest = OpenSSL::Digest::SHA1.hexdigest(text)
current_digest = DB.query_single(<<~SQL, target_id: target.id).first
SELECT
@ -162,7 +174,7 @@ module DiscourseAi
FROM
#{table_name(target)}
WHERE
#{target.is_a?(Topic) ? "topic_id" : "post_id"} = :target_id
#{target_column} = :target_id
LIMIT 1
SQL
return if current_digest == new_digest
@ -248,6 +260,47 @@ module DiscourseAi
raise MissingEmbeddingError
end
def asymmetric_rag_fragment_similarity_search(
raw_vector,
persona_id:,
limit:,
offset:,
return_distance: false
)
results =
DB.query(
<<~SQL,
#{probes_sql(post_table_name)}
SELECT
rag_document_fragment_id,
embeddings #{pg_function} '[:query_embedding]' AS distance
FROM
#{rag_fragments_table_name}
INNER JOIN
rag_document_fragments AS rdf ON rdf.id = rag_document_fragment_id
WHERE
rdf.ai_persona_id = :persona_id
ORDER BY
embeddings #{pg_function} '[:query_embedding]'
LIMIT :limit
OFFSET :offset
SQL
query_embedding: raw_vector,
persona_id: persona_id,
limit: limit,
offset: offset,
)
if return_distance
results.map { |r| [r.rag_document_fragment_id, r.distance] }
else
results.map(&:rag_document_fragment_id)
end
rescue PG::Error => e
Rails.logger.error("Error #{e} querying embeddings for model #{name}")
raise MissingEmbeddingError
end
def symmetric_topics_similarity_search(topic)
DB.query(<<~SQL, topic_id: topic.id).map(&:topic_id)
#{probes_sql(topic_table_name)}
@ -282,12 +335,18 @@ module DiscourseAi
"ai_post_embeddings_#{id}_#{@strategy.id}"
end
def rag_fragments_table_name
"ai_document_fragment_embeddings_#{id}_#{@strategy.id}"
end
def table_name(target)
case target
when Topic
topic_table_name
when Post
post_table_name
when RagDocumentFragment
rag_fragments_table_name
else
raise ArgumentError, "Invalid target type"
end
@ -375,6 +434,25 @@ module DiscourseAi
digest: digest,
embeddings: vector,
)
elsif target.is_a?(RagDocumentFragment)
DB.exec(
<<~SQL,
INSERT INTO #{rag_fragments_table_name} (rag_document_fragment_id, model_version, strategy_version, digest, embeddings, created_at, updated_at)
VALUES (:fragment_id, :model_version, :strategy_version, :digest, '[:embeddings]', CURRENT_TIMESTAMP, CURRENT_TIMESTAMP)
ON CONFLICT (rag_document_fragment_id)
DO UPDATE SET
model_version = :model_version,
strategy_version = :strategy_version,
digest = :digest,
embeddings = '[:embeddings]',
updated_at = CURRENT_TIMESTAMP
SQL
fragment_id: target.id,
model_version: version,
strategy_version: @strategy.version,
digest: digest,
embeddings: vector,
)
else
raise ArgumentError, "Invalid target type"
end

View File

@ -56,6 +56,11 @@ module ::DiscourseAi
JSON.parse(response.body, symbolize_names: true)
end
def reranker_configured?
SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present?
end
def configured?
SiteSetting.ai_hugging_face_tei_endpoint.present? ||
SiteSetting.ai_hugging_face_tei_endpoint_srv.present?

View File

@ -10,6 +10,7 @@
gem "tokenizers", "0.4.3"
gem "tiktoken_ruby", "0.0.7"
gem "baran", "0.1.10"
enabled_site_setting :discourse_ai_enabled

View File

@ -0,0 +1,7 @@
# frozen_string_literal: true
Fabricator(:rag_document_fragment) do
fragment { sequence(:fragment) { |n| "Document fragment #{n}" } }
upload
fragment_number { sequence(:fragment_number) { |n| n + 1 } }
end

View File

@ -0,0 +1,58 @@
# frozen_string_literal: true
RSpec.describe Jobs::DigestRagUpload do
fab!(:persona) { Fabricate(:ai_persona) }
fab!(:upload)
let(:document_file) { StringIO.new("some text" * 200) }
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
before do
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
WebMock.stub_request(
:post,
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(expected_embedding))
end
describe "#execute" do
context "when processing an upload for the first time" do
before { File.expects(:open).returns(document_file) }
it "splits an upload into chunks" do
subject.execute(upload_id: upload.id, ai_persona_id: persona.id)
created_fragment = RagDocumentFragment.last
expect(created_fragment).to be_present
expect(created_fragment.fragment).to be_present
expect(created_fragment.fragment_number).to eq(2)
end
it "queue jobs to generate embeddings for each fragment" do
expect { subject.execute(upload_id: upload.id, ai_persona_id: persona.id) }.to change(
Jobs::GenerateRagEmbeddings.jobs,
:size,
).by(1)
end
end
it "doesn't generate new fragments if we already processed the upload" do
Fabricate(:rag_document_fragment, upload: upload, ai_persona: persona)
previous_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count
subject.execute(upload_id: upload.id, ai_persona_id: persona.id)
updated_count = RagDocumentFragment.where(upload: upload, ai_persona: persona).count
expect(updated_count).to eq(previous_count)
end
end
end

View File

@ -0,0 +1,38 @@
# frozen_string_literal: true
RSpec.describe Jobs::GenerateRagEmbeddings do
describe "#execute" do
let(:truncation) { DiscourseAi::Embeddings::Strategies::Truncation.new }
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(truncation)
end
let(:expected_embedding) { [0.0038493] * vector_rep.dimensions }
fab!(:ai_persona)
fab!(:rag_document_fragment_1) { Fabricate(:rag_document_fragment, ai_persona: ai_persona) }
fab!(:rag_document_fragment_2) { Fabricate(:rag_document_fragment, ai_persona: ai_persona) }
before do
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
WebMock.stub_request(
:post,
"#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
).to_return(status: 200, body: JSON.dump(expected_embedding))
end
it "generates a new vector for each fragment" do
expected_embeddings = 2
subject.execute(fragment_ids: [rag_document_fragment_1.id, rag_document_fragment_2.id])
embeddings_count =
DB.query_single("SELECT COUNT(*) from #{vector_rep.rag_fragments_table_name}").first
expect(embeddings_count).to eq(expected_embeddings)
end
end
end

View File

@ -196,4 +196,128 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
)
end
end
describe "#craft_prompt" do
before do
Group.refresh_automatic_groups!
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_enabled = true
end
let(:ai_persona) { DiscourseAi::AiBot::Personas::Persona.all(user: user).first.new }
let(:with_cc) do
context.merge(conversation_context: [{ content: "Tell me the time", type: :user }])
end
context "when a persona has no uploads" do
it "doesn't include RAG guidance" do
guidance_fragment =
"The following texts will give you additional guidance to elaborate a response."
expect(ai_persona.craft_prompt(with_cc).messages.first[:content]).not_to include(
guidance_fragment,
)
end
end
context "when a persona has RAG uploads" do
fab!(:upload)
def stub_fragments(limit)
candidate_ids = []
limit.times do |i|
candidate_ids << Fabricate(
:rag_document_fragment,
fragment: "fragment-n#{i}",
ai_persona_id: ai_persona.id,
upload: upload,
).id
end
DiscourseAi::Embeddings::VectorRepresentations::BgeLargeEn
.any_instance
.expects(:asymmetric_rag_fragment_similarity_search)
.returns(candidate_ids)
end
before do
stored_ai_persona = AiPersona.find(ai_persona.id)
UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
context_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
with_cc.dig(:conversation_context, 0, :content),
context_embedding,
)
end
context "when the system prompt has an uploads placeholder" do
before { stub_fragments(10) }
it "replaces the placeholder with the fragments" do
custom_persona_record =
AiPersona.create!(
name: "custom",
description: "description",
system_prompt: "instructions\n{uploads}\nmore instructions",
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
)
UploadReference.ensure_exist!(target: custom_persona_record, upload_ids: [upload.id])
custom_persona =
DiscourseAi::AiBot::Personas::Persona.find_by(
id: custom_persona_record.id,
user: user,
).new
crafted_system_prompt = custom_persona.craft_prompt(with_cc).messages.first[:content]
expect(crafted_system_prompt).to include("fragment-n0")
expect(crafted_system_prompt.ends_with?("</guidance>")).to eq(false)
end
end
context "when the reranker is available" do
before do
SiteSetting.ai_hugging_face_tei_reranker_endpoint = "https://test.reranker.com"
stub_fragments(15) # Mimic limit being more than 10 results
end
it "uses the re-ranker to reorder the fragments and pick the top 10 candidates" do
expected_reranked = (0..14).to_a.reverse.map { |idx| { index: idx } }
WebMock.stub_request(:post, "https://test.reranker.com/rerank").to_return(
status: 200,
body: JSON.dump(expected_reranked),
)
crafted_system_prompt = ai_persona.craft_prompt(with_cc).messages.first[:content]
expect(crafted_system_prompt).to include("fragment-n14")
expect(crafted_system_prompt).to include("fragment-n13")
expect(crafted_system_prompt).to include("fragment-n12")
expect(crafted_system_prompt).not_to include("fragment-n4") # Fragment #11 not included
end
end
context "when the reranker is not available" do
before { stub_fragments(10) }
it "picks the first 10 candidates from the similarity search" do
crafted_system_prompt = ai_persona.craft_prompt(with_cc).messages.first[:content]
expect(crafted_system_prompt).to include("fragment-n0")
expect(crafted_system_prompt).to include("fragment-n1")
expect(crafted_system_prompt).to include("fragment-n2")
expect(crafted_system_prompt).not_to include("fragment-n10") # Fragment #10 not included
end
end
end
end
end

View File

@ -0,0 +1,76 @@
# frozen_string_literal: true
RSpec.describe RagDocumentFragment do
fab!(:persona) { Fabricate(:ai_persona) }
fab!(:upload_1) { Fabricate(:upload) }
fab!(:upload_2) { Fabricate(:upload) }
before do
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
end
describe ".link_uploads_and_persona" do
it "does nothing if there is no persona" do
expect { described_class.link_persona_and_uploads(nil, [upload_1.id]) }.not_to change(
Jobs::DigestRagUpload.jobs,
:size,
)
end
it "does nothing if there are no uploads" do
expect { described_class.link_persona_and_uploads(persona, []) }.not_to change(
Jobs::DigestRagUpload.jobs,
:size,
)
end
it "queues a job for each upload to generate fragments" do
expect {
described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id])
}.to change(Jobs::DigestRagUpload.jobs, :size).by(2)
end
it "creates references between the persona an each upload" do
described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id])
refs = UploadReference.where(target: persona).pluck(:upload_id)
expect(refs).to contain_exactly(upload_1.id, upload_2.id)
end
end
describe ".update_persona_uploads" do
it "does nothing if there is no persona" do
expect { described_class.update_persona_uploads(nil, [upload_1.id]) }.not_to change(
Jobs::DigestRagUpload.jobs,
:size,
)
end
it "deletes the fragment if its not present in the uploads list" do
fragment = Fabricate(:rag_document_fragment, ai_persona: persona)
described_class.update_persona_uploads(persona, [])
expect { fragment.reload }.to raise_error(ActiveRecord::RecordNotFound)
end
it "delete references between the upload and the persona" do
described_class.link_persona_and_uploads(persona, [upload_1.id, upload_2.id])
described_class.update_persona_uploads(persona, [upload_2.id])
refs = UploadReference.where(target: persona).pluck(:upload_id)
expect(refs).to contain_exactly(upload_2.id)
end
it "queues jobs to generate new fragments" do
expect { described_class.update_persona_uploads(persona, [upload_1.id]) }.to change(
Jobs::DigestRagUpload.jobs,
:size,
).by(1)
end
end
end

View File

@ -4,7 +4,12 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
fab!(:admin)
fab!(:ai_persona)
before { sign_in(admin) }
before do
sign_in(admin)
SiteSetting.ai_embeddings_enabled = true
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
end
describe "GET #index" do
it "returns a success response" do
@ -125,6 +130,21 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(response).to be_successful
expect(response.parsed_body["ai_persona"]["name"]).to eq(ai_persona.name)
end
it "includes rag uploads for each persona" do
upload = Fabricate(:upload)
RagDocumentFragment.link_persona_and_uploads(ai_persona, [upload.id])
get "/admin/plugins/discourse-ai/ai-personas/#{ai_persona.id}.json"
expect(response).to be_successful
serialized_persona = response.parsed_body["ai_persona"]
expect(serialized_persona.dig("rag_uploads", 0, "id")).to eq(upload.id)
expect(serialized_persona.dig("rag_uploads", 0, "original_filename")).to eq(
upload.original_filename,
)
end
end
describe "POST #create" do
@ -323,6 +343,17 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
end
end
describe "POST #upload_file" do
it "works" do
post "/admin/plugins/discourse-ai/ai-personas/files/upload.json",
params: {
file: Rack::Test::UploadedFile.new(file_from_fixtures("spec.txt", "md")),
}
expect(response.status).to eq(200)
end
end
describe "DELETE #destroy" do
it "destroys the requested ai_persona" do
expect {

View File

@ -48,6 +48,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
max_context_posts: 5,
vision_enabled: true,
vision_max_pixels: 100,
rag_uploads: [],
};
const aiPersona = AiPersona.create({ ...properties });
@ -81,6 +82,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
max_context_posts: 5,
vision_enabled: true,
vision_max_pixels: 100,
rag_uploads: [],
};
const aiPersona = AiPersona.create({ ...properties });