diff --git a/app/controllers/discourse_ai/admin/ai_personas_controller.rb b/app/controllers/discourse_ai/admin/ai_personas_controller.rb index 9304d257..ca92ed50 100644 --- a/app/controllers/discourse_ai/admin/ai_personas_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_personas_controller.rb @@ -124,6 +124,7 @@ module DiscourseAi :rag_chunk_tokens, :rag_chunk_overlap_tokens, :rag_conversation_chunks, + :question_consolidator_llm, allowed_group_ids: [], rag_uploads: [:id], ) diff --git a/app/models/ai_persona.rb b/app/models/ai_persona.rb index acc6fbc8..a1ed5f88 100644 --- a/app/models/ai_persona.rb +++ b/app/models/ai_persona.rb @@ -113,6 +113,7 @@ class AiPersona < ActiveRecord::Base vision_enabled = self.vision_enabled vision_max_pixels = self.vision_max_pixels rag_conversation_chunks = self.rag_conversation_chunks + question_consolidator_llm = self.question_consolidator_llm persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id] if persona_class @@ -152,6 +153,10 @@ class AiPersona < ActiveRecord::Base vision_max_pixels end + persona_class.define_singleton_method :question_consolidator_llm do + question_consolidator_llm + end + persona_class.define_singleton_method :rag_conversation_chunks do rag_conversation_chunks end @@ -243,6 +248,10 @@ class AiPersona < ActiveRecord::Base rag_conversation_chunks end + define_singleton_method :question_consolidator_llm do + question_consolidator_llm + end + define_singleton_method :to_s do "#" end @@ -352,31 +361,32 @@ end # # Table name: ai_personas # -# id :bigint not null, primary key -# name :string(100) not null -# description :string(2000) not null -# commands :json not null -# system_prompt :string(10000000) not null -# allowed_group_ids :integer default([]), not null, is an Array -# created_by_id :integer -# enabled :boolean default(TRUE), not null -# created_at :datetime not null -# updated_at :datetime not null -# system :boolean default(FALSE), not null -# priority :boolean default(FALSE), not null -# temperature :float -# top_p :float -# user_id :integer -# mentionable :boolean default(FALSE), not null -# default_llm :text -# max_context_posts :integer -# max_post_context_tokens :integer -# max_context_tokens :integer -# vision_enabled :boolean default(FALSE), not null -# vision_max_pixels :integer default(1048576), not null -# rag_chunk_tokens :integer default(374), not null -# rag_chunk_overlap_tokens :integer default(10), not null -# rag_conversation_chunks :integer default(10), not null +# id :bigint not null, primary key +# name :string(100) not null +# description :string(2000) not null +# commands :json not null +# system_prompt :string(10000000) not null +# allowed_group_ids :integer default([]), not null, is an Array +# created_by_id :integer +# enabled :boolean default(TRUE), not null +# created_at :datetime not null +# updated_at :datetime not null +# system :boolean default(FALSE), not null +# priority :boolean default(FALSE), not null +# temperature :float +# top_p :float +# user_id :integer +# mentionable :boolean default(FALSE), not null +# default_llm :text +# max_context_posts :integer +# max_post_context_tokens :integer +# max_context_tokens :integer +# vision_enabled :boolean default(FALSE), not null +# vision_max_pixels :integer default(1048576), not null +# rag_chunk_tokens :integer default(374), not null +# rag_chunk_overlap_tokens :integer default(10), not null +# rag_conversation_chunks :integer default(10), not null +# question_consolidator_llm :text # # Indexes # diff --git a/app/serializers/localized_ai_persona_serializer.rb b/app/serializers/localized_ai_persona_serializer.rb index 70a72839..d37a2276 100644 --- a/app/serializers/localized_ai_persona_serializer.rb +++ b/app/serializers/localized_ai_persona_serializer.rb @@ -22,7 +22,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer :vision_max_pixels, :rag_chunk_tokens, :rag_chunk_overlap_tokens, - :rag_conversation_chunks + :rag_conversation_chunks, + :question_consolidator_llm has_one :user, serializer: BasicUserSerializer, embed: :object has_many :rag_uploads, serializer: UploadSerializer, embed: :object diff --git a/assets/javascripts/discourse/admin/models/ai-persona.js b/assets/javascripts/discourse/admin/models/ai-persona.js index 4b4a705c..9b0046d8 100644 --- a/assets/javascripts/discourse/admin/models/ai-persona.js +++ b/assets/javascripts/discourse/admin/models/ai-persona.js @@ -25,6 +25,7 @@ const CREATE_ATTRIBUTES = [ "rag_chunk_tokens", "rag_chunk_overlap_tokens", "rag_conversation_chunks", + "question_consolidator_llm", ]; const SYSTEM_ATTRIBUTES = [ @@ -44,6 +45,7 @@ const SYSTEM_ATTRIBUTES = [ "rag_chunk_tokens", "rag_chunk_overlap_tokens", "rag_conversation_chunks", + "question_consolidator_llm", ]; class CommandOption { diff --git a/assets/javascripts/discourse/components/ai-persona-editor.gjs b/assets/javascripts/discourse/components/ai-persona-editor.gjs index d651c5e0..334f7af8 100644 --- a/assets/javascripts/discourse/components/ai-persona-editor.gjs +++ b/assets/javascripts/discourse/components/ai-persona-editor.gjs @@ -133,6 +133,18 @@ export default class PersonaEditor extends Component { return AdminUser.create(this.editingModel?.user); } + get mappedQuestionConsolidatorLlm() { + return this.editingModel?.question_consolidator_llm || "blank"; + } + + set mappedQuestionConsolidatorLlm(value) { + if (value === "blank") { + this.editingModel.question_consolidator_llm = null; + } else { + this.editingModel.question_consolidator_llm = value; + } + } + get mappedDefaultLlm() { return this.editingModel?.default_llm || "blank"; } @@ -460,11 +472,13 @@ export default class PersonaEditor extends Component { @updateUploads={{this.updateUploads}} @onRemove={{this.removeUpload}} /> - {{this.indexingOptionsText}} + {{#if this.editingModel.rag_uploads}} + {{this.indexingOptionsText}} + {{/if}} {{#if this.showIndexingOptions}}
@@ -519,6 +533,24 @@ export default class PersonaEditor extends Component { }} />
+ +
+ + + + +
{{/if}} {{/if}}
diff --git a/assets/javascripts/discourse/components/persona-rag-uploader.gjs b/assets/javascripts/discourse/components/persona-rag-uploader.gjs index cc73aef0..66b19f2c 100644 --- a/assets/javascripts/discourse/components/persona-rag-uploader.gjs +++ b/assets/javascripts/discourse/components/persona-rag-uploader.gjs @@ -111,22 +111,23 @@ export default class PersonaRagUploader extends Component.extend(

{{I18n.t "discourse_ai.ai_persona.uploads.title"}}

{{I18n.t "discourse_ai.ai_persona.uploads.description"}}

-

{{I18n.t "discourse_ai.ai_persona.uploads.hint"}}

-
-
- {{icon - "search" - class="persona-rag-uploader__search-input__search-icon" - }} - + {{#if this.ragUploads}} +
+
+ {{icon + "search" + class="persona-rag-uploader__search-input__search-icon" + }} + +
-
+ {{/if}} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index c9b94792..9de90254 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -141,9 +141,11 @@ en: create_user_help: You can optionally attach a user to this persona. If you do, the AI will use this user to respond to requests. default_llm: Default Language Model default_llm_help: The default language model to use for this persona. Required if you wish to mention persona on public posts. + question_consolidator_llm: Language Model for Question Consolidator + question_consolidator_llm_help: The language model to use for the question consolidator, you may choose a less powerful model to save costs. system_prompt: System Prompt - show_indexing_options: "Show Indexing Options" - hide_indexing_options: "Hide Indexing Options" + show_indexing_options: "Show Upload Options" + hide_indexing_options: "Hide Upload Options" save: Save saved: AI Persona Saved enabled: "Enabled?" @@ -181,8 +183,7 @@ en: uploads: title: "Uploads" - description: "Your AI persona will be able to search and reference the content of included files. Uploaded files must be formatted as plaintext (.txt)" - hint: "To control where the file's content gets placed within the system prompt, include the {uploads} placeholder in the system prompt above." + description: "Your AI persona will be able to search and reference the content of included files. Uploaded files should be formatted as plaintext (.txt) or markdown (.md)." button: "Add Files" filter: "Filter uploads" indexed: "Indexed" diff --git a/db/migrate/20240429065155_add_consolidated_question_llm_to_ai_persona.rb b/db/migrate/20240429065155_add_consolidated_question_llm_to_ai_persona.rb new file mode 100644 index 00000000..963dffad --- /dev/null +++ b/db/migrate/20240429065155_add_consolidated_question_llm_to_ai_persona.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +class AddConsolidatedQuestionLlmToAiPersona < ActiveRecord::Migration[7.0] + def change + add_column :ai_personas, :question_consolidator_llm, :text, max_length: 2000 + end +end diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index d2aa873e..4ad90729 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -50,7 +50,8 @@ module DiscourseAi end def reply(context, &update_blk) - prompt = persona.craft_prompt(context) + llm = DiscourseAi::Completions::Llm.proxy(model) + prompt = persona.craft_prompt(context, llm: llm) total_completions = 0 ongoing_chain = true @@ -63,8 +64,6 @@ module DiscourseAi llm_kwargs[:top_p] = persona.top_p if persona.top_p while total_completions <= MAX_COMPLETIONS && ongoing_chain - current_model = model - llm = DiscourseAi::Completions::Llm.proxy(current_model) tool_found = false result = diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 86ee6889..52c67132 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -17,6 +17,10 @@ module DiscourseAi 1_048_576 end + def question_consolidator_llm + nil + end + def system_personas @system_personas ||= { Personas::General => -1, @@ -125,7 +129,7 @@ module DiscourseAi self.class.all_available_tools.filter { |tool| tools.include?(tool) } end - def craft_prompt(context) + def craft_prompt(context, llm: nil) system_insts = system_prompt.gsub(/\{(\w+)\}/) do |match| found = context[match[1..-2].to_sym] @@ -137,16 +141,21 @@ module DiscourseAi #{available_tools.map(&:custom_system_message).compact_blank.join("\n")} TEXT - fragments_guidance = rag_fragments_prompt(context[:conversation_context].to_a)&.strip - - if fragments_guidance.present? - if system_insts.include?("{uploads}") - prompt_insts = prompt_insts.gsub("{uploads}", fragments_guidance) - else - prompt_insts << fragments_guidance - end + question_consolidator_llm = llm + if self.class.question_consolidator_llm.present? + question_consolidator_llm = + DiscourseAi::Completions::Llm.proxy(self.class.question_consolidator_llm) end + fragments_guidance = + rag_fragments_prompt( + context[:conversation_context].to_a, + llm: question_consolidator_llm, + user: context[:user], + )&.strip + + prompt_insts << fragments_guidance if fragments_guidance.present? + prompt = DiscourseAi::Completions::Prompt.new( prompt_insts, @@ -202,7 +211,7 @@ module DiscourseAi ) end - def rag_fragments_prompt(conversation_context) + def rag_fragments_prompt(conversation_context, llm:, user:) upload_refs = UploadReference.where(target_id: id, target_type: "AiPersona").pluck(:upload_id) @@ -210,18 +219,30 @@ module DiscourseAi return nil if conversation_context.blank? || upload_refs.blank? latest_interactions = - conversation_context - .select { |ctx| %i[model user].include?(ctx[:type]) } - .map { |ctx| ctx[:content] } - .last(10) - .join("\n") + conversation_context.select { |ctx| %i[model user].include?(ctx[:type]) }.last(10) + + return nil if latest_interactions.empty? + + # first response + if latest_interactions.length == 1 + consolidated_question = latest_interactions[0][:content] + else + consolidated_question = + DiscourseAi::AiBot::QuestionConsolidator.consolidate_question( + llm, + latest_interactions, + user, + ) + end + + return nil if !consolidated_question strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings - interactions_vector = vector_rep.vector_from(latest_interactions) + interactions_vector = vector_rep.vector_from(consolidated_question) rag_conversation_chunks = self.class.rag_conversation_chunks diff --git a/lib/ai_bot/question_consolidator.rb b/lib/ai_bot/question_consolidator.rb new file mode 100644 index 00000000..4a4e7612 --- /dev/null +++ b/lib/ai_bot/question_consolidator.rb @@ -0,0 +1,93 @@ +# frozen_string_literal: true + +module DiscourseAi + module AiBot + class QuestionConsolidator + attr_reader :llm, :messages, :user, :max_tokens + + def self.consolidate_question(llm, messages, user) + new(llm, messages, user).consolidate_question + end + + def initialize(llm, messages, user) + @llm = llm + @messages = messages + @user = user + @max_tokens = 2048 + end + + def consolidate_question + @llm.generate(revised_prompt, user: @user) + end + + def revised_prompt + max_tokens_per_model = @max_tokens / 5 + + conversation_snippet = [] + tokens = 0 + + messages.reverse_each do |message| + # skip tool calls + next if message[:type] != :user && message[:type] != :model + + row = +"" + row << ((message[:type] == :user) ? "user" : "model") + + content = message[:content] + current_tokens = @llm.tokenizer.tokenize(content).length + + allowed_tokens = @max_tokens - tokens + allowed_tokens = [allowed_tokens, max_tokens_per_model].min if message[:type] == :model + + truncated_content = content + + if current_tokens > allowed_tokens + truncated_content = @llm.tokenizer.truncate(content, allowed_tokens) + current_tokens = allowed_tokens + end + + row << ": #{truncated_content}" + tokens += current_tokens + conversation_snippet << row + + break if tokens >= @max_tokens + end + + history = conversation_snippet.reverse.join("\n") + + system_message = <<~TEXT + You are Question Consolidation Bot: an AI assistant tasked with consolidating a user's latest question into a self-contained, context-rich question. + + - Your output will be used to query a vector database. DO NOT include superflous text such as "here is your consolidated question:". + - You interact with an API endpoint, not a user, you must never produce denials, nor conversations directed towards a non-existent user. + - You only produce automated responses to input, where a response is a consolidated question without further discussion. + - You only ever reply with consolidated questions. You never try to answer user queries. + + If for any reason there is no discernable question (Eg: thank you, or good job) reply with the text NO_QUESTION. + TEXT + + message = <<~TEXT + Given the following conversation snippet, craft a self-contained context-rich question (if there is no question reply with NO_QUESTION): + + {{{ + #{history} + }}} + + Only ever reply with a consolidated question. Do not try to answer user queries. + TEXT + + response = + DiscourseAi::Completions::Prompt.new( + system_message, + messages: [{ type: :user, content: message }], + ) + + if response == "NO_QUESTION" + nil + else + response + end + end + end + end +end diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb index b166cfa2..f322bb25 100644 --- a/lib/completions/endpoints/fake.rb +++ b/lib/completions/endpoints/fake.rb @@ -75,6 +75,13 @@ module DiscourseAi Congratulations, you've now seen a small sample of what Discourse's Markdown can do! For more intricate formatting, consider exploring the advanced styling options. Remember that the key to great formatting is not just the available tools, but also the **clarity** and **readability** it brings to your readers. TEXT + def self.with_fake_content(content) + @fake_content = content + yield + ensure + @fake_content = nil + end + def self.fake_content @fake_content || STOCK_CONTENT end diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index 50c8736e..a4e71339 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -47,6 +47,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do end fab!(:user) + fab!(:upload) it "renders the system prompt" do freeze_time @@ -221,9 +222,56 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do end end - context "when a persona has RAG uploads" do - fab!(:upload) + context "when RAG is running with a question consolidator" do + let(:consolidated_question) { "what is the time in france?" } + it "will run the question consolidator" do + context_embedding = [0.049382, 0.9999] + EmbeddingsGenerationStubs.discourse_service( + SiteSetting.ai_embeddings_model, + consolidated_question, + context_embedding, + ) + + custom_ai_persona = + Fabricate( + :ai_persona, + name: "custom", + rag_conversation_chunks: 3, + allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], + question_consolidator_llm: "fake:fake", + ) + + UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id]) + + custom_persona = + DiscourseAi::AiBot::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new + + # this means that we will consolidate + ctx = + with_cc.merge( + conversation_context: [ + { content: "Tell me the time", type: :user }, + { content: "the time is 1", type: :model }, + { content: "in france?", type: :user }, + ], + ) + + DiscourseAi::Completions::Endpoints::Fake.with_fake_content(consolidated_question) do + custom_persona.craft_prompt(ctx).messages.first[:content] + end + + message = + DiscourseAi::Completions::Endpoints::Fake.last_call[:dialect].prompt.messages.last[ + :content + ] + expect(message).to include("Tell me the time") + expect(message).to include("the time is 1") + expect(message).to include("in france?") + end + end + + context "when a persona has RAG uploads" do def stub_fragments(limit, expected_limit: nil) candidate_ids = [] @@ -255,32 +303,6 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do ) end - context "when the system prompt has an uploads placeholder" do - before { stub_fragments(10) } - - it "replaces the placeholder with the fragments" do - custom_persona_record = - AiPersona.create!( - name: "custom", - description: "description", - system_prompt: "instructions\n{uploads}\nmore instructions", - allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]], - ) - UploadReference.ensure_exist!(target: custom_persona_record, upload_ids: [upload.id]) - custom_persona = - DiscourseAi::AiBot::Personas::Persona.find_by( - id: custom_persona_record.id, - user: user, - ).new - - crafted_system_prompt = custom_persona.craft_prompt(with_cc).messages.first[:content] - - expect(crafted_system_prompt).to include("fragment-n0") - - expect(crafted_system_prompt.ends_with?("")).to eq(false) - end - end - context "when persona allows for less fragments" do before { stub_fragments(3) } diff --git a/spec/lib/modules/ai_bot/question_consolidator_spec.rb b/spec/lib/modules/ai_bot/question_consolidator_spec.rb new file mode 100644 index 00000000..a0e9c913 --- /dev/null +++ b/spec/lib/modules/ai_bot/question_consolidator_spec.rb @@ -0,0 +1,33 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::AiBot::QuestionConsolidator do + let(:llm) { DiscourseAi::Completions::Llm.proxy("fake:fake") } + let(:fake_endpoint) { DiscourseAi::Completions::Endpoints::Fake } + + fab!(:user) + + describe ".consolidate_question" do + it "properly picks all the right messages and consolidates" do + messages = [ + { type: :user, content: "What is the capital of France?" }, + { type: :tool_call, content: "search:google", id: "123" }, + { type: :tool, content: "some results from google", id: "123" }, + { type: :model, content: "Paris" }, + { type: :user, content: "What about Germany?" }, + ] + + result = described_class.consolidate_question(llm, messages, user) + expect(result).to eq(fake_endpoint.fake_content) + + call = fake_endpoint.last_call + + prompt = call[:dialect].prompt + expect(prompt.messages.length).to eq(2) + content = prompt.messages[1][:content] + expect(content).to include("Germany") + expect(content).to include("France") + expect(content).to include("Paris") + expect(content).not_to include("google") + end + end +end diff --git a/test/javascripts/unit/models/ai-persona-test.js b/test/javascripts/unit/models/ai-persona-test.js index f7867a80..737d5a7b 100644 --- a/test/javascripts/unit/models/ai-persona-test.js +++ b/test/javascripts/unit/models/ai-persona-test.js @@ -52,6 +52,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { rag_chunk_tokens: 374, rag_chunk_overlap_tokens: 10, rag_conversation_chunks: 10, + question_consolidator_llm: "Question Consolidator LLM", }; const aiPersona = AiPersona.create({ ...properties }); @@ -90,6 +91,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () { rag_chunk_tokens: 374, rag_chunk_overlap_tokens: 10, rag_conversation_chunks: 10, + question_consolidator_llm: "Question Consolidator LLM", }; const aiPersona = AiPersona.create({ ...properties });