discourse-ai/spec/lib/modules/ai_bot/personas/persona_spec.rb

482 lines
15 KiB
Ruby

#frozen_string_literal: true
class TestPersona < DiscourseAi::AiBot::Personas::Persona
def tools
[
DiscourseAi::AiBot::Tools::ListTags,
DiscourseAi::AiBot::Tools::Search,
DiscourseAi::AiBot::Tools::Image,
]
end
def system_prompt
<<~PROMPT
{site_url}
{site_title}
{site_description}
{participants}
{time}
PROMPT
end
end
RSpec.describe DiscourseAi::AiBot::Personas::Persona do
let :persona do
TestPersona.new
end
let :topic_with_users do
topic = Topic.new
topic.allowed_users = [User.new(username: "joe"), User.new(username: "jane")]
topic
end
after do
# we are rolling back transactions so we can create poison cache
AiPersona.persona_cache.flush!
end
let(:context) do
{
site_url: Discourse.base_url,
site_title: "test site title",
site_description: "test site description",
time: Time.zone.now,
participants: topic_with_users.allowed_users.map(&:username).join(", "),
}
end
fab!(:admin)
fab!(:user)
fab!(:upload)
it "renders the system prompt" do
freeze_time
rendered = persona.craft_prompt(context)
system_message = rendered.messages.first[:content]
expect(system_message).to include(Discourse.base_url)
expect(system_message).to include("test site title")
expect(system_message).to include("test site description")
expect(system_message).to include("joe, jane")
expect(system_message).to include(Time.zone.now.to_s)
tools = rendered.tools
expect(tools.find { |t| t[:name] == "search" }).to be_present
expect(tools.find { |t| t[:name] == "tags" }).to be_present
# needs to be configured so it is not available
expect(tools.find { |t| t[:name] == "image" }).to be_nil
end
it "can parse string that are wrapped in quotes" do
SiteSetting.ai_stability_api_key = "123"
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "image",
id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
prompts: ["cat oil painting", "big car"],
aspect_ratio: "16:9",
},
)
tool_instance =
DiscourseAi::AiBot::Personas::Artist.new.find_tool(
tool_call,
bot_user: nil,
llm: nil,
context: nil,
)
expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"])
expect(tool_instance.parameters[:aspect_ratio]).to eq("16:9")
end
it "enforces enums" do
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
max_posts: "3.2",
status: "cow",
foo: "bar",
},
)
tool_instance =
DiscourseAi::AiBot::Personas::General.new.find_tool(
tool_call,
bot_user: nil,
llm: nil,
context: nil,
)
expect(tool_instance.parameters.key?(:status)).to eq(false)
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
max_posts: "3.2",
status: "open",
foo: "bar",
},
)
tool_instance =
DiscourseAi::AiBot::Personas::General.new.find_tool(
tool_call,
bot_user: nil,
llm: nil,
context: nil,
)
expect(tool_instance.parameters[:status]).to eq("open")
end
it "can coerce integers" do
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
max_posts: "3.2",
search_query: "hello world",
foo: "bar",
},
)
search =
DiscourseAi::AiBot::Personas::General.new.find_tool(
tool_call,
bot_user: nil,
llm: nil,
context: nil,
)
expect(search.parameters[:max_posts]).to eq(3)
expect(search.parameters[:search_query]).to eq("hello world")
expect(search.parameters.key?(:foo)).to eq(false)
end
it "can correctly parse arrays in tools" do
SiteSetting.ai_openai_api_key = "123"
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "dall_e",
id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
prompts: ["cat oil painting", "big car"],
},
)
tool_instance =
DiscourseAi::AiBot::Personas::DallE3.new.find_tool(
tool_call,
bot_user: nil,
llm: nil,
context: nil,
)
expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"])
end
describe "custom personas" do
it "is able to find custom personas" do
Group.refresh_automatic_groups!
# define an ai persona everyone can see
persona =
AiPersona.create!(
name: "zzzpun_bot",
description: "you write puns",
system_prompt: "you are pun bot",
tools: ["Image"],
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
)
custom_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last
expect(custom_persona.name).to eq("zzzpun_bot")
expect(custom_persona.description).to eq("you write puns")
instance = custom_persona.new
expect(instance.tools).to eq([DiscourseAi::AiBot::Tools::Image])
expect(instance.craft_prompt(context).messages.first[:content]).to eq("you are pun bot")
# should update
persona.update!(name: "zzzpun_bot2")
custom_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last
expect(custom_persona.name).to eq("zzzpun_bot2")
# can be disabled
persona.update!(enabled: false)
last_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last
expect(last_persona.name).not_to eq("zzzpun_bot2")
persona.update!(enabled: true)
# no groups have access
persona.update!(allowed_group_ids: [])
last_persona = DiscourseAi::AiBot::Personas::Persona.all(user: user).last
expect(last_persona.name).not_to eq("zzzpun_bot2")
end
end
describe "available personas" do
it "includes all personas by default" do
Group.refresh_automatic_groups!
# must be enabled to see it
SiteSetting.ai_stability_api_key = "abc"
SiteSetting.ai_google_custom_search_api_key = "abc"
SiteSetting.ai_google_custom_search_cx = "abc123"
# should be ordered by priority and then alpha
expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to eq(
[
DiscourseAi::AiBot::Personas::General,
DiscourseAi::AiBot::Personas::Artist,
DiscourseAi::AiBot::Personas::Creative,
DiscourseAi::AiBot::Personas::DiscourseHelper,
DiscourseAi::AiBot::Personas::GithubHelper,
DiscourseAi::AiBot::Personas::Researcher,
DiscourseAi::AiBot::Personas::SettingsExplorer,
DiscourseAi::AiBot::Personas::SqlHelper,
],
)
# it should allow staff access to WebArtifactCreator
expect(DiscourseAi::AiBot::Personas::Persona.all(user: admin)).to eq(
[
DiscourseAi::AiBot::Personas::General,
DiscourseAi::AiBot::Personas::Artist,
DiscourseAi::AiBot::Personas::Creative,
DiscourseAi::AiBot::Personas::DiscourseHelper,
DiscourseAi::AiBot::Personas::GithubHelper,
DiscourseAi::AiBot::Personas::Researcher,
DiscourseAi::AiBot::Personas::SettingsExplorer,
DiscourseAi::AiBot::Personas::SqlHelper,
DiscourseAi::AiBot::Personas::WebArtifactCreator,
],
)
# omits personas if key is missing
SiteSetting.ai_stability_api_key = ""
SiteSetting.ai_google_custom_search_api_key = ""
SiteSetting.ai_artifact_security = "disabled"
expect(DiscourseAi::AiBot::Personas::Persona.all(user: admin)).to contain_exactly(
DiscourseAi::AiBot::Personas::General,
DiscourseAi::AiBot::Personas::SqlHelper,
DiscourseAi::AiBot::Personas::SettingsExplorer,
DiscourseAi::AiBot::Personas::Creative,
DiscourseAi::AiBot::Personas::DiscourseHelper,
DiscourseAi::AiBot::Personas::GithubHelper,
)
AiPersona.find(
DiscourseAi::AiBot::Personas::Persona.system_personas[
DiscourseAi::AiBot::Personas::General
],
).update!(enabled: false)
expect(DiscourseAi::AiBot::Personas::Persona.all(user: user)).to contain_exactly(
DiscourseAi::AiBot::Personas::SqlHelper,
DiscourseAi::AiBot::Personas::SettingsExplorer,
DiscourseAi::AiBot::Personas::Creative,
DiscourseAi::AiBot::Personas::DiscourseHelper,
DiscourseAi::AiBot::Personas::GithubHelper,
)
end
end
describe "#craft_prompt" do
before do
Group.refresh_automatic_groups!
SiteSetting.ai_embeddings_discourse_service_api_endpoint = "http://test.com"
SiteSetting.ai_embeddings_enabled = true
end
let(:ai_persona) { DiscourseAi::AiBot::Personas::Persona.all(user: user).first.new }
let(:with_cc) do
context.merge(conversation_context: [{ content: "Tell me the time", type: :user }])
end
context "when a persona has no uploads" do
it "doesn't include RAG guidance" do
guidance_fragment =
"The following texts will give you additional guidance to elaborate a response."
expect(ai_persona.craft_prompt(with_cc).messages.first[:content]).not_to include(
guidance_fragment,
)
end
end
context "when RAG is running with a question consolidator" do
let(:consolidated_question) { "what is the time in france?" }
fab!(:llm_model) { Fabricate(:fake_model) }
it "will run the question consolidator" do
vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
context_embedding = vector_rep.dimensions.times.map { rand(-1.0...1.0) }
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
consolidated_question,
context_embedding,
)
custom_ai_persona =
Fabricate(
:ai_persona,
name: "custom",
rag_conversation_chunks: 3,
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
question_consolidator_llm: "custom:#{llm_model.id}",
)
UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id])
custom_persona =
DiscourseAi::AiBot::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new
# this means that we will consolidate
ctx =
with_cc.merge(
conversation_context: [
{ content: "Tell me the time", type: :user },
{ content: "the time is 1", type: :model },
{ content: "in france?", type: :user },
],
)
DiscourseAi::Completions::Endpoints::Fake.with_fake_content(consolidated_question) do
custom_persona.craft_prompt(ctx).messages.first[:content]
end
message =
DiscourseAi::Completions::Endpoints::Fake.last_call[:dialect].prompt.messages.last[
:content
]
expect(message).to include("Tell me the time")
expect(message).to include("the time is 1")
expect(message).to include("in france?")
end
end
context "when a persona has RAG uploads" do
let(:vector_rep) do
DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation
end
let(:embedding_value) { 0.04381 }
let(:prompt_cc_embeddings) { [embedding_value] * vector_rep.dimensions }
def stub_fragments(fragment_count, persona: ai_persona)
schema = DiscourseAi::Embeddings::Schema.for(RagDocumentFragment, vector: vector_rep)
fragment_count.times do |i|
fragment =
Fabricate(
:rag_document_fragment,
fragment: "fragment-n#{i}",
target_id: persona.id,
target_type: "AiPersona",
upload: upload,
)
# Similarity is determined left-to-right.
embeddings = [embedding_value + "0.000#{i}".to_f] * vector_rep.dimensions
schema.store(fragment, embeddings, "test")
end
end
before do
stored_ai_persona = AiPersona.find(ai_persona.id)
UploadReference.ensure_exist!(target: stored_ai_persona, upload_ids: [upload.id])
EmbeddingsGenerationStubs.discourse_service(
SiteSetting.ai_embeddings_model,
with_cc.dig(:conversation_context, 0, :content),
prompt_cc_embeddings,
)
end
context "when persona allows for less fragments" do
it "will only pick 3 fragments" do
custom_ai_persona =
Fabricate(
:ai_persona,
name: "custom",
rag_conversation_chunks: 3,
allowed_group_ids: [Group::AUTO_GROUPS[:trust_level_0]],
)
stub_fragments(3, persona: custom_ai_persona)
UploadReference.ensure_exist!(target: custom_ai_persona, upload_ids: [upload.id])
custom_persona =
DiscourseAi::AiBot::Personas::Persona.find_by(id: custom_ai_persona.id, user: user).new
expect(custom_persona.class.rag_conversation_chunks).to eq(3)
crafted_system_prompt = custom_persona.craft_prompt(with_cc).messages.first[:content]
expect(crafted_system_prompt).to include("fragment-n0")
expect(crafted_system_prompt).to include("fragment-n1")
expect(crafted_system_prompt).to include("fragment-n2")
expect(crafted_system_prompt).not_to include("fragment-n3")
end
end
context "when the reranker is available" do
before do
SiteSetting.ai_hugging_face_tei_reranker_endpoint = "https://test.reranker.com"
stub_fragments(15)
end
it "uses the re-ranker to reorder the fragments and pick the top 10 candidates" do
# The re-ranker reverses the similarity search, but return less results
# to act as a limit for test-purposes.
expected_reranked = (4..14).to_a.reverse.map { |idx| { index: idx } }
WebMock.stub_request(:post, "https://test.reranker.com/rerank").to_return(
status: 200,
body: JSON.dump(expected_reranked),
)
crafted_system_prompt = ai_persona.craft_prompt(with_cc).messages.first[:content]
expect(crafted_system_prompt).to include("fragment-n14")
expect(crafted_system_prompt).to include("fragment-n13")
expect(crafted_system_prompt).to include("fragment-n12")
expect(crafted_system_prompt).not_to include("fragment-n4") # Fragment #11 not included
end
end
context "when the reranker is not available" do
before { stub_fragments(10) }
it "picks the first 10 candidates from the similarity search" do
crafted_system_prompt = ai_persona.craft_prompt(with_cc).messages.first[:content]
expect(crafted_system_prompt).to include("fragment-n0")
expect(crafted_system_prompt).to include("fragment-n1")
expect(crafted_system_prompt).to include("fragment-n2")
expect(crafted_system_prompt).not_to include("fragment-n10") # Fragment #10 not included
end
end
end
end
end