FEATURE: Add Question Consolidator for robust Upload support in Personas (#596)

This commit introduces a new feature for AI Personas called the "Question Consolidator LLM". The purpose of the Question Consolidator is to consolidate a user's latest question into a self-contained, context-rich question before querying the vector database for relevant fragments. This helps improve the quality and relevance of the retrieved fragments.

Previous to this change we used the last 10 interactions, this is not ideal cause the RAG would "lock on" to an answer. 

EG:

- User: how many cars are there in europe
- Model: detailed answer about cars in europe including the term car and vehicle many times
- User: Nice, what about trains are there in the US

In the above example "trains" and "US" becomes very low signal given there are pages and pages talking about cars and europe. This mean retrieval is sub optimal. 

Instead, we pass the history to the "question consolidator", it would simply consolidate the question to "How many trains are there in the United States", which would make it fare easier for the vector db to find relevant content. 

The llm used for question consolidator can often be less powerful than the model you are talking to, we recommend using lighter weight and fast models cause the task is very simple. This is configurable from the persona ui.

This PR also removes support for {uploads} placeholder, this is too complicated to get right and we want freedom to shift RAG implementation. 

Key changes:

1. Added a new `question_consolidator_llm` column to the `ai_personas` table to store the LLM model used for question consolidation.

2. Implemented the `QuestionConsolidator` module which handles the logic for consolidating the user's latest question. It extracts the relevant user and model messages from the conversation history, truncates them if needed to fit within the token limit, and generates a consolidated question prompt.

3. Updated the `Persona` class to use the Question Consolidator LLM (if configured) when crafting the RAG fragments prompt. It passes the conversation context to the consolidator to generate a self-contained question.

4. Added UI elements in the AI Persona editor to allow selecting the Question Consolidator LLM. Also made some UI tweaks to conditionally show/hide certain options based on persona configuration.

5. Wrote unit tests for the QuestionConsolidator module and updated existing persona tests to cover the new functionality.

This feature enables AI Personas to better understand the context and intent behind a user's question by consolidating the conversation history into a single, focused question. This can lead to more relevant and accurate responses from the AI assistant.
This commit is contained in:
Sam 2024-04-30 13:49:21 +10:00 committed by GitHub
parent 85734fef52
commit 32b3004ce9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
15 changed files with 328 additions and 96 deletions

View File

@ -124,6 +124,7 @@ module DiscourseAi
:rag_chunk_tokens,
:rag_chunk_overlap_tokens,
:rag_conversation_chunks,
:question_consolidator_llm,
allowed_group_ids: [],
rag_uploads: [:id],
)

View File

@ -113,6 +113,7 @@ class AiPersona < ActiveRecord::Base
vision_enabled = self.vision_enabled
vision_max_pixels = self.vision_max_pixels
rag_conversation_chunks = self.rag_conversation_chunks
question_consolidator_llm = self.question_consolidator_llm
persona_class = DiscourseAi::AiBot::Personas::Persona.system_personas_by_id[self.id]
if persona_class
@ -152,6 +153,10 @@ class AiPersona < ActiveRecord::Base
vision_max_pixels
end
persona_class.define_singleton_method :question_consolidator_llm do
question_consolidator_llm
end
persona_class.define_singleton_method :rag_conversation_chunks do
rag_conversation_chunks
end
@ -243,6 +248,10 @@ class AiPersona < ActiveRecord::Base
rag_conversation_chunks
end
define_singleton_method :question_consolidator_llm do
question_consolidator_llm
end
define_singleton_method :to_s do
"#<DiscourseAi::AiBot::Personas::Persona::Custom @name=#{self.name} @allowed_group_ids=#{self.allowed_group_ids.join(",")}>"
end
@ -377,6 +386,7 @@ end
# rag_chunk_tokens :integer default(374), not null
# rag_chunk_overlap_tokens :integer default(10), not null
# rag_conversation_chunks :integer default(10), not null
# question_consolidator_llm :text
#
# Indexes
#

View File

@ -22,7 +22,8 @@ class LocalizedAiPersonaSerializer < ApplicationSerializer
:vision_max_pixels,
:rag_chunk_tokens,
:rag_chunk_overlap_tokens,
:rag_conversation_chunks
:rag_conversation_chunks,
:question_consolidator_llm
has_one :user, serializer: BasicUserSerializer, embed: :object
has_many :rag_uploads, serializer: UploadSerializer, embed: :object

View File

@ -25,6 +25,7 @@ const CREATE_ATTRIBUTES = [
"rag_chunk_tokens",
"rag_chunk_overlap_tokens",
"rag_conversation_chunks",
"question_consolidator_llm",
];
const SYSTEM_ATTRIBUTES = [
@ -44,6 +45,7 @@ const SYSTEM_ATTRIBUTES = [
"rag_chunk_tokens",
"rag_chunk_overlap_tokens",
"rag_conversation_chunks",
"question_consolidator_llm",
];
class CommandOption {

View File

@ -133,6 +133,18 @@ export default class PersonaEditor extends Component {
return AdminUser.create(this.editingModel?.user);
}
get mappedQuestionConsolidatorLlm() {
return this.editingModel?.question_consolidator_llm || "blank";
}
set mappedQuestionConsolidatorLlm(value) {
if (value === "blank") {
this.editingModel.question_consolidator_llm = null;
} else {
this.editingModel.question_consolidator_llm = value;
}
}
get mappedDefaultLlm() {
return this.editingModel?.default_llm || "blank";
}
@ -460,11 +472,13 @@ export default class PersonaEditor extends Component {
@updateUploads={{this.updateUploads}}
@onRemove={{this.removeUpload}}
/>
{{#if this.editingModel.rag_uploads}}
<a
href="#"
class="ai-persona-editor__indexing-options"
{{on "click" this.toggleIndexingOptions}}
>{{this.indexingOptionsText}}</a>
{{/if}}
</div>
{{#if this.showIndexingOptions}}
<div class="control-group">
@ -519,6 +533,24 @@ export default class PersonaEditor extends Component {
}}
/>
</div>
<div class="control-group">
<label>{{I18n.t
"discourse_ai.ai_persona.question_consolidator_llm"
}}</label>
<AiLlmSelector
class="ai-persona-editor__llms"
@value={{this.mappedQuestionConsolidatorLlm}}
@llms={{@personas.resultSetMeta.llms}}
/>
<DTooltip
@icon="question-circle"
@content={{I18n.t
"discourse_ai.ai_persona.question_consolidator_llm_help"
}}
/>
</div>
{{/if}}
{{/if}}
<div class="control-group ai-persona-editor__action_panel">

View File

@ -111,8 +111,8 @@ export default class PersonaRagUploader extends Component.extend(
<div class="persona-rag-uploader" {{willDestroy this.removeListener}}>
<h3>{{I18n.t "discourse_ai.ai_persona.uploads.title"}}</h3>
<p>{{I18n.t "discourse_ai.ai_persona.uploads.description"}}</p>
<p>{{I18n.t "discourse_ai.ai_persona.uploads.hint"}}</p>
{{#if this.ragUploads}}
<div class="persona-rag-uploader__search-input-container">
<div class="persona-rag-uploader__search-input">
{{icon
@ -127,6 +127,7 @@ export default class PersonaRagUploader extends Component.extend(
/>
</div>
</div>
{{/if}}
<table class="persona-rag-uploader__uploads-list">
<tbody>

View File

@ -141,9 +141,11 @@ en:
create_user_help: You can optionally attach a user to this persona. If you do, the AI will use this user to respond to requests.
default_llm: Default Language Model
default_llm_help: The default language model to use for this persona. Required if you wish to mention persona on public posts.
question_consolidator_llm: Language Model for Question Consolidator
question_consolidator_llm_help: The language model to use for the question consolidator, you may choose a less powerful model to save costs.
system_prompt: System Prompt
show_indexing_options: "Show Indexing Options"
hide_indexing_options: "Hide Indexing Options"
show_indexing_options: "Show Upload Options"
hide_indexing_options: "Hide Upload Options"
save: Save
saved: AI Persona Saved
enabled: "Enabled?"
@ -181,8 +183,7 @@ en:
uploads:
title: "Uploads"
description: "Your AI persona will be able to search and reference the content of included files. Uploaded files must be formatted as plaintext (.txt)"
hint: "To control where the file's content gets placed within the system prompt, include the {uploads} placeholder in the system prompt above."
description: "Your AI persona will be able to search and reference the content of included files. Uploaded files should be formatted as plaintext (.txt) or markdown (.md)."
button: "Add Files"
filter: "Filter uploads"
indexed: "Indexed"

View File

@ -0,0 +1,7 @@
# frozen_string_literal: true
class AddConsolidatedQuestionLlmToAiPersona < ActiveRecord::Migration[7.0]
def change
add_column :ai_personas, :question_consolidator_llm, :text, max_length: 2000
end
end

View File

@ -50,7 +50,8 @@ module DiscourseAi
end
def reply(context, &update_blk)
prompt = persona.craft_prompt(context)
llm = DiscourseAi::Completions::Llm.proxy(model)
prompt = persona.craft_prompt(context, llm: llm)
total_completions = 0
ongoing_chain = true
@ -63,8 +64,6 @@ module DiscourseAi
llm_kwargs[:top_p] = persona.top_p if persona.top_p
while total_completions <= MAX_COMPLETIONS && ongoing_chain
current_model = model
llm = DiscourseAi::Completions::Llm.proxy(current_model)
tool_found = false
result =

View File

@ -17,6 +17,10 @@ module DiscourseAi
1_048_576
end
def question_consolidator_llm
nil
end
def system_personas
@system_personas ||= {
Personas::General => -1,
@ -125,7 +129,7 @@ module DiscourseAi
self.class.all_available_tools.filter { |tool| tools.include?(tool) }
end
def craft_prompt(context)
def craft_prompt(context, llm: nil)
system_insts =
system_prompt.gsub(/\{(\w+)\}/) do |match|
found = context[match[1..-2].to_sym]
@ -137,15 +141,20 @@ module DiscourseAi
#{available_tools.map(&:custom_system_message).compact_blank.join("\n")}
TEXT
fragments_guidance = rag_fragments_prompt(context[:conversation_context].to_a)&.strip
question_consolidator_llm = llm
if self.class.question_consolidator_llm.present?
question_consolidator_llm =
DiscourseAi::Completions::Llm.proxy(self.class.question_consolidator_llm)
end
if fragments_guidance.present?
if system_insts.include?("{uploads}")
prompt_insts = prompt_insts.gsub("{uploads}", fragments_guidance)
else
prompt_insts << fragments_guidance
end
end
fragments_guidance =
rag_fragments_prompt(
context[:conversation_context].to_a,
llm: question_consolidator_llm,
user: context[:user],
)&.strip
prompt_insts << fragments_guidance if fragments_guidance.present?
prompt =
DiscourseAi::Completions::Prompt.new(
@ -202,7 +211,7 @@ module DiscourseAi
)
end
def rag_fragments_prompt(conversation_context)
def rag_fragments_prompt(conversation_context, llm:, user:)
upload_refs =
UploadReference.where(target_id: id, target_type: "AiPersona").pluck(:upload_id)
@ -210,18 +219,30 @@ module DiscourseAi
return nil if conversation_context.blank? || upload_refs.blank?
latest_interactions =
conversation_context
.select { |ctx| %i[model user].include?(ctx[:type]) }
.map { |ctx| ctx[:content] }
.last(10)
.join("\n")
conversation_context.select { |ctx| %i[model user].include?(ctx[:type]) }.last(10)
return nil if latest_interactions.empty?
# first response
if latest_interactions.length == 1
consolidated_question = latest_interactions[0][:content]
else
consolidated_question =
DiscourseAi::AiBot::QuestionConsolidator.consolidate_question(
llm,
latest_interactions,
user,
)
end
return nil if !consolidated_question
strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
vector_rep =
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings
interactions_vector = vector_rep.vector_from(latest_interactions)
interactions_vector = vector_rep.vector_from(consolidated_question)
rag_conversation_chunks = self.class.rag_conversation_chunks

View File

@ -0,0 +1,93 @@
# frozen_string_literal: true
module DiscourseAi
module AiBot
class QuestionConsolidator
attr_reader :llm, :messages, :user, :max_tokens
def self.consolidate_question(llm, messages, user)
new(llm, messages, user).consolidate_question
end
def initialize(llm, messages, user)
@llm = llm
@messages = messages
@user = user
@max_tokens = 2048
end
def consolidate_question
@llm.generate(revised_prompt, user: @user)
end
def revised_prompt
max_tokens_per_model = @max_tokens / 5
conversation_snippet = []
tokens = 0
messages.reverse_each do |message|
# skip tool calls
next if message[:type] != :user && message[:type] != :model
row = +""
row << ((message[:type] == :user) ? "user" : "model")
content = message[:content]
current_tokens = @llm.tokenizer.tokenize(content).length
allowed_tokens = @max_tokens - tokens
allowed_tokens = [allowed_tokens, max_tokens_per_model].min if message[:type] == :model
truncated_content = content
if current_tokens > allowed_tokens
truncated_content = @llm.tokenizer.truncate(content, allowed_tokens)
current_tokens = allowed_tokens
end
row << ": #{truncated_content}"
tokens += current_tokens
conversation_snippet << row
break if tokens >= @max_tokens
end
history = conversation_snippet.reverse.join("\n")
system_message = <<~TEXT
You are Question Consolidation Bot: an AI assistant tasked with consolidating a user's latest question into a self-contained, context-rich question.
- Your output will be used to query a vector database. DO NOT include superflous text such as "here is your consolidated question:".
- You interact with an API endpoint, not a user, you must never produce denials, nor conversations directed towards a non-existent user.
- You only produce automated responses to input, where a response is a consolidated question without further discussion.
- You only ever reply with consolidated questions. You never try to answer user queries.
If for any reason there is no discernable question (Eg: thank you, or good job) reply with the text NO_QUESTION.
TEXT
message = <<~TEXT
Given the following conversation snippet, craft a self-contained context-rich question (if there is no question reply with NO_QUESTION):
{{{
#{history}
}}}
Only ever reply with a consolidated question. Do not try to answer user queries.
TEXT
response =
DiscourseAi::Completions::Prompt.new(
system_message,
messages: [{ type: :user, content: message }],
)
if response == "NO_QUESTION"
nil
else
response
end
end
end
end
end

View File

@ -75,6 +75,13 @@ module DiscourseAi
Congratulations, you've now seen a small sample of what Discourse's Markdown can do! For more intricate formatting, consider exploring the advanced styling options. Remember that the key to great formatting is not just the available tools, but also the **clarity** and **readability** it brings to your readers.
TEXT
def self.with_fake_content(content)
@fake_content = content
yield
ensure
@fake_content = nil
end
def self.fake_content
@fake_content || STOCK_CONTENT
end

View File

@ -47,6 +47,7 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
end
fab!(:user)
fab!(:upload)
it "renders the system prompt" do
freeze_time
@ -221,9 +222,56 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
end
end
context "when a persona has RAG uploads" do
fab!(:upload)
context "when RAG is running with a question consolidator" do
let(:consolidated_question) { "what is the time in france?" }
it "will run the question consolidator" do
context_embedding = [0.049382, 0.9999]
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
consolidated_question,
context_embedding,
)
custom_ai_persona =
Fabricate(
:ai_persona,
name: "custom",
rag_conversation_chunks: 3,
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
question_consolidator_llm: "fake:fake",
)
UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id])
custom_persona =
DiscourseAi::AiBot::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new
# this means that we will consolidate
ctx =
with_cc.merge(
conversation_context: [
{ content: "Tell me the time", type: :user },
{ content: "the time is 1", type: :model },
{ content: "in france?", type: :user },
],
)
DiscourseAi::Completions::Endpoints::Fake.with_fake_content(consolidated_question) do
custom_persona.craft_prompt(ctx).messages.first[:content]
end
message =
DiscourseAi::Completions::Endpoints::Fake.last_call[:dialect].prompt.messages.last[
:content
]
expect(message).to include("Tell me the time")
expect(message).to include("the time is 1")
expect(message).to include("in france?")
end
end
context "when a persona has RAG uploads" do
def stub_fragments(limit, expected_limit: nil)
candidate_ids = []
@ -255,32 +303,6 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
)
end
context "when the system prompt has an uploads placeholder" do
before { stub_fragments(10) }
it "replaces the placeholder with the fragments" do
custom_persona_record =
AiPersona.create!(
name: "custom",
description: "description",
system_prompt: "instructions\n{uploads}\nmore instructions",
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
)
UploadReference.ensure_exist!(target: custom_persona_record, upload_ids: [upload.id])
custom_persona =
DiscourseAi::AiBot::Personas::Persona.find_by(
id: custom_persona_record.id,
user: user,
).new
crafted_system_prompt = custom_persona.craft_prompt(with_cc).messages.first[:content]
expect(crafted_system_prompt).to include("fragment-n0")
expect(crafted_system_prompt.ends_with?("</guidance>")).to eq(false)
end
end
context "when persona allows for less fragments" do
before { stub_fragments(3) }

View File

@ -0,0 +1,33 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::AiBot::QuestionConsolidator do
let(:llm) { DiscourseAi::Completions::Llm.proxy("fake:fake") }
let(:fake_endpoint) { DiscourseAi::Completions::Endpoints::Fake }
fab!(:user)
describe ".consolidate_question" do
it "properly picks all the right messages and consolidates" do
messages = [
{ type: :user, content: "What is the capital of France?" },
{ type: :tool_call, content: "search:google", id: "123" },
{ type: :tool, content: "some results from google", id: "123" },
{ type: :model, content: "Paris" },
{ type: :user, content: "What about Germany?" },
]
result = described_class.consolidate_question(llm, messages, user)
expect(result).to eq(fake_endpoint.fake_content)
call = fake_endpoint.last_call
prompt = call[:dialect].prompt
expect(prompt.messages.length).to eq(2)
content = prompt.messages[1][:content]
expect(content).to include("Germany")
expect(content).to include("France")
expect(content).to include("Paris")
expect(content).not_to include("google")
end
end
end

View File

@ -52,6 +52,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
rag_chunk_tokens: 374,
rag_chunk_overlap_tokens: 10,
rag_conversation_chunks: 10,
question_consolidator_llm: "Question Consolidator LLM",
};
const aiPersona = AiPersona.create({ ...properties });
@ -90,6 +91,7 @@ module("Discourse AI | Unit | Model | ai-persona", function () {
rag_chunk_tokens: 374,
rag_chunk_overlap_tokens: 10,
rag_conversation_chunks: 10,
question_consolidator_llm: "Question Consolidator LLM",
};
const aiPersona = AiPersona.create({ ...properties });