DEV: improve internal design of ai persona and bug fix (#495)

* DEV: improve internal design of ai persona and bug fix

- Fixes bug where OpenAI could not describe images
- Fixes bug where mentionable personas could not be mentioned unless overarching bot was enabled
- Improves internal design of playground and bot to allow better for non "bot" users
- Allow PMs directly to persona users (previously bot user would also have to be in PM)
- Simplify internal code


Co-authored-by: Martin Brennan <martin@discourse.org>
This commit is contained in:
Sam 2024-02-28 16:46:32 +11:00 committed by GitHub
parent a1b607db80
commit 484fd1435b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 116 additions and 93 deletions

View File

@ -6,7 +6,7 @@ module ::Jobs
def execute(args)
return unless bot_user = User.find_by(id: args[:bot_user_id])
return unless bot = DiscourseAi::AiBot::Bot.as(bot_user)
return unless bot = DiscourseAi::AiBot::Bot.as(bot_user, model: args[:model])
return unless post = Post.includes(:topic).find_by(id: args[:post_id])
return unless post.topic.custom_fields[DiscourseAi::AiBot::EntryPoint::REQUIRE_TITLE_UPDATE]

View File

@ -60,21 +60,28 @@ class AiPersona < ActiveRecord::Base
.map(&:class_instance)
end
def self.mentionables
persona_cache[:mentionable_usernames] ||= AiPersona
.where(mentionable: true)
.where(enabled: true)
.joins(:user)
.pluck("ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm")
.map do |id, user_id, username, allowed_group_ids, default_llm|
{
id: id,
user_id: user_id,
username: username,
allowed_group_ids: allowed_group_ids,
default_llm: default_llm,
}
end
def self.mentionables(user: nil)
all_mentionables =
persona_cache[:mentionable_usernames] ||= AiPersona
.where(mentionable: true)
.where(enabled: true)
.joins(:user)
.pluck("ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm")
.map do |id, user_id, username, allowed_group_ids, default_llm|
{
id: id,
user_id: user_id,
username: username,
allowed_group_ids: allowed_group_ids,
default_llm: default_llm,
}
end
if user
all_mentionables.select { |mentionable| user.in_any_groups?(mentionable[:allowed_group_ids]) }
else
all_mentionables
end
end
after_commit :bump_cache

View File

@ -3,16 +3,19 @@
module DiscourseAi
module AiBot
class Bot
attr_reader :model
BOT_NOT_FOUND = Class.new(StandardError)
MAX_COMPLETIONS = 5
def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new)
new(bot_user, persona)
def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new, model: nil)
new(bot_user, persona, model)
end
def initialize(bot_user, persona)
def initialize(bot_user, persona, model = nil)
@bot_user = bot_user
@persona = persona
@model = model || self.class.guess_model(bot_user) || @persona.class.default_llm
end
attr_reader :bot_user
@ -46,7 +49,6 @@ module DiscourseAi
total_completions = 0
ongoing_chain = true
low_cost = false
raw_context = []
user = context[:user]
@ -56,7 +58,7 @@ module DiscourseAi
llm_kwargs[:top_p] = persona.top_p if persona.top_p
while total_completions <= MAX_COMPLETIONS && ongoing_chain
current_model = model(prefer_low_cost: low_cost)
current_model = model
llm = DiscourseAi::Completions::Llm.proxy(current_model)
tool_found = false
@ -65,7 +67,6 @@ module DiscourseAi
if (tool = persona.find_tool(partial))
tool_found = true
ongoing_chain = tool.chain_next_response?
low_cost = tool.low_cost?
tool_call_id = tool.tool_call_id
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
@ -129,43 +130,36 @@ module DiscourseAi
result
end
def model(prefer_low_cost: false)
def self.guess_model(bot_user)
# HACK(roman): We'll do this until we define how we represent different providers in the bot settings
default_model =
case bot_user.id
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
"aws_bedrock:claude-2"
else
"anthropic:claude-2"
end
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
"open_ai:gpt-4"
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
"open_ai:gpt-4-turbo"
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
"open_ai:gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
)
"vllm:mistralai/Mixtral-8x7B-Instruct-v0.1"
else
"hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1"
end
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"google:gemini-pro"
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
"fake:fake"
case bot_user.id
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
"aws_bedrock:claude-2"
else
nil
"anthropic:claude-2"
end
if %w[open_ai:gpt-4 open_ai:gpt-4-turbo].include?(default_model) && prefer_low_cost
return "open_ai:gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
"open_ai:gpt-4"
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
"open_ai:gpt-4-turbo"
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
"open_ai:gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
)
"vllm:mistralai/Mixtral-8x7B-Instruct-v0.1"
else
"hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1"
end
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"google:gemini-pro"
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
"fake:fake"
else
nil
end
default_model
end
def tool_invocation?(partial)

View File

@ -11,46 +11,46 @@ module DiscourseAi
REQUIRE_TITLE_UPDATE = "discourse-ai-title-update"
def self.schedule_reply(post)
def self.is_bot_user_id?(user_id)
bot_ids = DiscourseAi::AiBot::EntryPoint::BOT_USER_IDS
bot_ids.include?(user_id) ||
begin
mentionable_ids = AiPersona.mentionables.map { |mentionable| mentionable[:user_id] }
mentionable_ids.include?(user_id)
end
end
return if bot_ids.include?(post.user_id)
if AiPersona.mentionables.any? { |mentionable| mentionable[:user_id] == post.user_id }
return
end
def self.schedule_reply(post)
return if is_bot_user_id?(post.user_id)
bot_ids = DiscourseAi::AiBot::EntryPoint::BOT_USER_IDS
mentionables = AiPersona.mentionables(user: post.user)
bot_user = nil
mentioned = nil
if post.topic.private_message?
bot_user = post.topic.topic_allowed_users.where(user_id: bot_ids).first&.user
bot_user ||=
post
.topic
.topic_allowed_users
.where(user_id: mentionables.map { |m| m[:user_id] })
.first
&.user
end
if AiPersona.mentionables.length > 0
if mentionables.present?
mentions = post.mentions.map(&:downcase)
mentioned =
AiPersona.mentionables.find do |mentionable|
mentions.include?(mentionable[:username]) &&
(post.user.group_ids & mentionable[:allowed_group_ids]).present?
end
mentioned = mentionables.find { |mentionable| mentions.include?(mentionable[:username]) }
# PM always takes precedence
if mentioned && !bot_user
model_without_provider = mentioned[:default_llm].split(":").last
user_id =
DiscourseAi::AiBot::EntryPoint.map_bot_model_to_user_id(model_without_provider)
if !user_id
Rails.logger.warn(
"Model #{mentioned[:default_llm]} not found for persona #{mentioned[:username]}",
)
if Rails.env.development? || Rails.env.test?
raise "Model #{mentioned[:default_llm]} not found for persona #{mentioned[:username]}"
end
else
bot_user = User.find_by(id: user_id)
end
# direct PM to mentionable
if !mentioned && bot_user
mentioned = mentionables.find { |mentionable| bot_user.id == mentionable[:user_id] }
end
# public topic so we need to use the persona user
bot_user ||= User.find_by(id: mentioned[:user_id]) if mentioned
end
if bot_user
@ -309,6 +309,7 @@ module DiscourseAi
:update_ai_bot_pm_title,
post_id: post.id,
bot_user_id: bot.bot_user.id,
model: bot.model,
)
end
end

View File

@ -44,10 +44,6 @@ module DiscourseAi
true
end
def low_cost?
true
end
def custom_raw
@last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found")
end

View File

@ -75,10 +75,6 @@ module DiscourseAi
false
end
def low_cost?
false
end
protected
def accepted_options

View File

@ -147,8 +147,7 @@ module DiscourseAi
{
type: "text",
text:
"Describe this image in a single sentence" +
custom_locale_instructions(user),
"Describe this image in a single sentence#{custom_locale_instructions(user)}",
},
{ type: "image_url", image_url: image_url },
],

View File

@ -69,7 +69,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
)
persona.create_user!
persona.update!(default_llm: "claude-2", mentionable: true)
persona.update!(default_llm: "anthropic:claude-2", mentionable: true)
persona
end
@ -92,6 +92,9 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end
it "allows mentioning a persona" do
# we still should be able to mention with no bots
SiteSetting.ai_bot_enabled_chat_bots = ""
post = nil
DiscourseAi::Completions::Llm.with_prepared_responses(["Yes I can"]) do
post =
@ -107,6 +110,33 @@ RSpec.describe DiscourseAi::AiBot::Playground do
expect(last_post.user_id).to eq(persona.user_id)
end
it "allows PMing a persona even when no particular bots are enabled" do
SiteSetting.ai_bot_enabled = true
SiteSetting.ai_bot_enabled_chat_bots = ""
post = nil
DiscourseAi::Completions::Llm.with_prepared_responses(
["Magic title", "Yes I can"],
llm: "anthropic:claude-2",
) do
post =
create_post(
title: "I just made a PM",
raw: "Hey there #{persona.user.username}, can you help me?",
target_usernames: "#{user.username},#{persona.user.username}",
archetype: Archetype.private_message,
user: admin,
)
end
last_post = post.topic.posts.order(:post_number).last
expect(last_post.raw).to eq("Yes I can")
expect(last_post.user_id).to eq(persona.user_id)
last_post.topic.reload
expect(last_post.topic.allowed_users.pluck(:user_id)).to include(persona.user_id)
end
it "picks the correct llm for persona in PMs" do
# If you start a PM with GPT 3.5 bot, replies should come from it, not from Claude
SiteSetting.ai_bot_enabled = true