DEV: improve internal design of ai persona and bug fix (#495)

* DEV: improve internal design of ai persona and bug fix

- Fixes bug where OpenAI could not describe images
- Fixes bug where mentionable personas could not be mentioned unless overarching bot was enabled
- Improves internal design of playground and bot to allow better for non "bot" users
- Allow PMs directly to persona users (previously bot user would also have to be in PM)
- Simplify internal code


Co-authored-by: Martin Brennan <martin@discourse.org>
This commit is contained in:
Sam 2024-02-28 16:46:32 +11:00 committed by GitHub
parent a1b607db80
commit 484fd1435b
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
8 changed files with 116 additions and 93 deletions

View File

@ -6,7 +6,7 @@ module ::Jobs
def execute(args) def execute(args)
return unless bot_user = User.find_by(id: args[:bot_user_id]) return unless bot_user = User.find_by(id: args[:bot_user_id])
return unless bot = DiscourseAi::AiBot::Bot.as(bot_user) return unless bot = DiscourseAi::AiBot::Bot.as(bot_user, model: args[:model])
return unless post = Post.includes(:topic).find_by(id: args[:post_id]) return unless post = Post.includes(:topic).find_by(id: args[:post_id])
return unless post.topic.custom_fields[DiscourseAi::AiBot::EntryPoint::REQUIRE_TITLE_UPDATE] return unless post.topic.custom_fields[DiscourseAi::AiBot::EntryPoint::REQUIRE_TITLE_UPDATE]

View File

@ -60,21 +60,28 @@ class AiPersona < ActiveRecord::Base
.map(&:class_instance) .map(&:class_instance)
end end
def self.mentionables def self.mentionables(user: nil)
persona_cache[:mentionable_usernames] ||= AiPersona all_mentionables =
.where(mentionable: true) persona_cache[:mentionable_usernames] ||= AiPersona
.where(enabled: true) .where(mentionable: true)
.joins(:user) .where(enabled: true)
.pluck("ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm") .joins(:user)
.map do |id, user_id, username, allowed_group_ids, default_llm| .pluck("ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm")
{ .map do |id, user_id, username, allowed_group_ids, default_llm|
id: id, {
user_id: user_id, id: id,
username: username, user_id: user_id,
allowed_group_ids: allowed_group_ids, username: username,
default_llm: default_llm, allowed_group_ids: allowed_group_ids,
} default_llm: default_llm,
end }
end
if user
all_mentionables.select { |mentionable| user.in_any_groups?(mentionable[:allowed_group_ids]) }
else
all_mentionables
end
end end
after_commit :bump_cache after_commit :bump_cache

View File

@ -3,16 +3,19 @@
module DiscourseAi module DiscourseAi
module AiBot module AiBot
class Bot class Bot
attr_reader :model
BOT_NOT_FOUND = Class.new(StandardError) BOT_NOT_FOUND = Class.new(StandardError)
MAX_COMPLETIONS = 5 MAX_COMPLETIONS = 5
def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new) def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new, model: nil)
new(bot_user, persona) new(bot_user, persona, model)
end end
def initialize(bot_user, persona) def initialize(bot_user, persona, model = nil)
@bot_user = bot_user @bot_user = bot_user
@persona = persona @persona = persona
@model = model || self.class.guess_model(bot_user) || @persona.class.default_llm
end end
attr_reader :bot_user attr_reader :bot_user
@ -46,7 +49,6 @@ module DiscourseAi
total_completions = 0 total_completions = 0
ongoing_chain = true ongoing_chain = true
low_cost = false
raw_context = [] raw_context = []
user = context[:user] user = context[:user]
@ -56,7 +58,7 @@ module DiscourseAi
llm_kwargs[:top_p] = persona.top_p if persona.top_p llm_kwargs[:top_p] = persona.top_p if persona.top_p
while total_completions <= MAX_COMPLETIONS && ongoing_chain while total_completions <= MAX_COMPLETIONS && ongoing_chain
current_model = model(prefer_low_cost: low_cost) current_model = model
llm = DiscourseAi::Completions::Llm.proxy(current_model) llm = DiscourseAi::Completions::Llm.proxy(current_model)
tool_found = false tool_found = false
@ -65,7 +67,6 @@ module DiscourseAi
if (tool = persona.find_tool(partial)) if (tool = persona.find_tool(partial))
tool_found = true tool_found = true
ongoing_chain = tool.chain_next_response? ongoing_chain = tool.chain_next_response?
low_cost = tool.low_cost?
tool_call_id = tool.tool_call_id tool_call_id = tool.tool_call_id
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
@ -129,43 +130,36 @@ module DiscourseAi
result result
end end
def model(prefer_low_cost: false) def self.guess_model(bot_user)
# HACK(roman): We'll do this until we define how we represent different providers in the bot settings # HACK(roman): We'll do this until we define how we represent different providers in the bot settings
default_model = case bot_user.id
case bot_user.id when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") "aws_bedrock:claude-2"
"aws_bedrock:claude-2"
else
"anthropic:claude-2"
end
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
"open_ai:gpt-4"
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
"open_ai:gpt-4-turbo"
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
"open_ai:gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
)
"vllm:mistralai/Mixtral-8x7B-Instruct-v0.1"
else
"hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1"
end
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"google:gemini-pro"
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
"fake:fake"
else else
nil "anthropic:claude-2"
end end
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
if %w[open_ai:gpt-4 open_ai:gpt-4-turbo].include?(default_model) && prefer_low_cost "open_ai:gpt-4"
return "open_ai:gpt-3.5-turbo-16k" when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
"open_ai:gpt-4-turbo"
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
"open_ai:gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
)
"vllm:mistralai/Mixtral-8x7B-Instruct-v0.1"
else
"hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1"
end
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"google:gemini-pro"
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
"fake:fake"
else
nil
end end
default_model
end end
def tool_invocation?(partial) def tool_invocation?(partial)

View File

@ -11,46 +11,46 @@ module DiscourseAi
REQUIRE_TITLE_UPDATE = "discourse-ai-title-update" REQUIRE_TITLE_UPDATE = "discourse-ai-title-update"
def self.schedule_reply(post) def self.is_bot_user_id?(user_id)
bot_ids = DiscourseAi::AiBot::EntryPoint::BOT_USER_IDS bot_ids = DiscourseAi::AiBot::EntryPoint::BOT_USER_IDS
bot_ids.include?(user_id) ||
begin
mentionable_ids = AiPersona.mentionables.map { |mentionable| mentionable[:user_id] }
mentionable_ids.include?(user_id)
end
end
return if bot_ids.include?(post.user_id) def self.schedule_reply(post)
if AiPersona.mentionables.any? { |mentionable| mentionable[:user_id] == post.user_id } return if is_bot_user_id?(post.user_id)
return
end bot_ids = DiscourseAi::AiBot::EntryPoint::BOT_USER_IDS
mentionables = AiPersona.mentionables(user: post.user)
bot_user = nil bot_user = nil
mentioned = nil mentioned = nil
if post.topic.private_message? if post.topic.private_message?
bot_user = post.topic.topic_allowed_users.where(user_id: bot_ids).first&.user bot_user = post.topic.topic_allowed_users.where(user_id: bot_ids).first&.user
bot_user ||=
post
.topic
.topic_allowed_users
.where(user_id: mentionables.map { |m| m[:user_id] })
.first
&.user
end end
if AiPersona.mentionables.length > 0 if mentionables.present?
mentions = post.mentions.map(&:downcase) mentions = post.mentions.map(&:downcase)
mentioned = mentioned = mentionables.find { |mentionable| mentions.include?(mentionable[:username]) }
AiPersona.mentionables.find do |mentionable|
mentions.include?(mentionable[:username]) &&
(post.user.group_ids & mentionable[:allowed_group_ids]).present?
end
# PM always takes precedence # direct PM to mentionable
if mentioned && !bot_user if !mentioned && bot_user
model_without_provider = mentioned[:default_llm].split(":").last mentioned = mentionables.find { |mentionable| bot_user.id == mentionable[:user_id] }
user_id =
DiscourseAi::AiBot::EntryPoint.map_bot_model_to_user_id(model_without_provider)
if !user_id
Rails.logger.warn(
"Model #{mentioned[:default_llm]} not found for persona #{mentioned[:username]}",
)
if Rails.env.development? || Rails.env.test?
raise "Model #{mentioned[:default_llm]} not found for persona #{mentioned[:username]}"
end
else
bot_user = User.find_by(id: user_id)
end
end end
# public topic so we need to use the persona user
bot_user ||= User.find_by(id: mentioned[:user_id]) if mentioned
end end
if bot_user if bot_user
@ -309,6 +309,7 @@ module DiscourseAi
:update_ai_bot_pm_title, :update_ai_bot_pm_title,
post_id: post.id, post_id: post.id,
bot_user_id: bot.bot_user.id, bot_user_id: bot.bot_user.id,
model: bot.model,
) )
end end
end end

View File

@ -44,10 +44,6 @@ module DiscourseAi
true true
end end
def low_cost?
true
end
def custom_raw def custom_raw
@last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found") @last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found")
end end

View File

@ -75,10 +75,6 @@ module DiscourseAi
false false
end end
def low_cost?
false
end
protected protected
def accepted_options def accepted_options

View File

@ -147,8 +147,7 @@ module DiscourseAi
{ {
type: "text", type: "text",
text: text:
"Describe this image in a single sentence" + "Describe this image in a single sentence#{custom_locale_instructions(user)}",
custom_locale_instructions(user),
}, },
{ type: "image_url", image_url: image_url }, { type: "image_url", image_url: image_url },
], ],

View File

@ -69,7 +69,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
) )
persona.create_user! persona.create_user!
persona.update!(default_llm: "claude-2", mentionable: true) persona.update!(default_llm: "anthropic:claude-2", mentionable: true)
persona persona
end end
@ -92,6 +92,9 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end end
it "allows mentioning a persona" do it "allows mentioning a persona" do
# we still should be able to mention with no bots
SiteSetting.ai_bot_enabled_chat_bots = ""
post = nil post = nil
DiscourseAi::Completions::Llm.with_prepared_responses(["Yes I can"]) do DiscourseAi::Completions::Llm.with_prepared_responses(["Yes I can"]) do
post = post =
@ -107,6 +110,33 @@ RSpec.describe DiscourseAi::AiBot::Playground do
expect(last_post.user_id).to eq(persona.user_id) expect(last_post.user_id).to eq(persona.user_id)
end end
it "allows PMing a persona even when no particular bots are enabled" do
SiteSetting.ai_bot_enabled = true
SiteSetting.ai_bot_enabled_chat_bots = ""
post = nil
DiscourseAi::Completions::Llm.with_prepared_responses(
["Magic title", "Yes I can"],
llm: "anthropic:claude-2",
) do
post =
create_post(
title: "I just made a PM",
raw: "Hey there #{persona.user.username}, can you help me?",
target_usernames: "#{user.username},#{persona.user.username}",
archetype: Archetype.private_message,
user: admin,
)
end
last_post = post.topic.posts.order(:post_number).last
expect(last_post.raw).to eq("Yes I can")
expect(last_post.user_id).to eq(persona.user_id)
last_post.topic.reload
expect(last_post.topic.allowed_users.pluck(:user_id)).to include(persona.user_id)
end
it "picks the correct llm for persona in PMs" do it "picks the correct llm for persona in PMs" do
# If you start a PM with GPT 3.5 bot, replies should come from it, not from Claude # If you start a PM with GPT 3.5 bot, replies should come from it, not from Claude
SiteSetting.ai_bot_enabled = true SiteSetting.ai_bot_enabled = true