DEV: improve internal design of ai persona and bug fix (#495)
* DEV: improve internal design of ai persona and bug fix - Fixes bug where OpenAI could not describe images - Fixes bug where mentionable personas could not be mentioned unless overarching bot was enabled - Improves internal design of playground and bot to allow better for non "bot" users - Allow PMs directly to persona users (previously bot user would also have to be in PM) - Simplify internal code Co-authored-by: Martin Brennan <martin@discourse.org>
This commit is contained in:
parent
a1b607db80
commit
484fd1435b
|
@ -6,7 +6,7 @@ module ::Jobs
|
|||
|
||||
def execute(args)
|
||||
return unless bot_user = User.find_by(id: args[:bot_user_id])
|
||||
return unless bot = DiscourseAi::AiBot::Bot.as(bot_user)
|
||||
return unless bot = DiscourseAi::AiBot::Bot.as(bot_user, model: args[:model])
|
||||
return unless post = Post.includes(:topic).find_by(id: args[:post_id])
|
||||
|
||||
return unless post.topic.custom_fields[DiscourseAi::AiBot::EntryPoint::REQUIRE_TITLE_UPDATE]
|
||||
|
|
|
@ -60,21 +60,28 @@ class AiPersona < ActiveRecord::Base
|
|||
.map(&:class_instance)
|
||||
end
|
||||
|
||||
def self.mentionables
|
||||
persona_cache[:mentionable_usernames] ||= AiPersona
|
||||
.where(mentionable: true)
|
||||
.where(enabled: true)
|
||||
.joins(:user)
|
||||
.pluck("ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm")
|
||||
.map do |id, user_id, username, allowed_group_ids, default_llm|
|
||||
{
|
||||
id: id,
|
||||
user_id: user_id,
|
||||
username: username,
|
||||
allowed_group_ids: allowed_group_ids,
|
||||
default_llm: default_llm,
|
||||
}
|
||||
end
|
||||
def self.mentionables(user: nil)
|
||||
all_mentionables =
|
||||
persona_cache[:mentionable_usernames] ||= AiPersona
|
||||
.where(mentionable: true)
|
||||
.where(enabled: true)
|
||||
.joins(:user)
|
||||
.pluck("ai_personas.id, users.id, users.username_lower, allowed_group_ids, default_llm")
|
||||
.map do |id, user_id, username, allowed_group_ids, default_llm|
|
||||
{
|
||||
id: id,
|
||||
user_id: user_id,
|
||||
username: username,
|
||||
allowed_group_ids: allowed_group_ids,
|
||||
default_llm: default_llm,
|
||||
}
|
||||
end
|
||||
|
||||
if user
|
||||
all_mentionables.select { |mentionable| user.in_any_groups?(mentionable[:allowed_group_ids]) }
|
||||
else
|
||||
all_mentionables
|
||||
end
|
||||
end
|
||||
|
||||
after_commit :bump_cache
|
||||
|
|
|
@ -3,16 +3,19 @@
|
|||
module DiscourseAi
|
||||
module AiBot
|
||||
class Bot
|
||||
attr_reader :model
|
||||
|
||||
BOT_NOT_FOUND = Class.new(StandardError)
|
||||
MAX_COMPLETIONS = 5
|
||||
|
||||
def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new)
|
||||
new(bot_user, persona)
|
||||
def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new, model: nil)
|
||||
new(bot_user, persona, model)
|
||||
end
|
||||
|
||||
def initialize(bot_user, persona)
|
||||
def initialize(bot_user, persona, model = nil)
|
||||
@bot_user = bot_user
|
||||
@persona = persona
|
||||
@model = model || self.class.guess_model(bot_user) || @persona.class.default_llm
|
||||
end
|
||||
|
||||
attr_reader :bot_user
|
||||
|
@ -46,7 +49,6 @@ module DiscourseAi
|
|||
|
||||
total_completions = 0
|
||||
ongoing_chain = true
|
||||
low_cost = false
|
||||
raw_context = []
|
||||
|
||||
user = context[:user]
|
||||
|
@ -56,7 +58,7 @@ module DiscourseAi
|
|||
llm_kwargs[:top_p] = persona.top_p if persona.top_p
|
||||
|
||||
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
||||
current_model = model(prefer_low_cost: low_cost)
|
||||
current_model = model
|
||||
llm = DiscourseAi::Completions::Llm.proxy(current_model)
|
||||
tool_found = false
|
||||
|
||||
|
@ -65,7 +67,6 @@ module DiscourseAi
|
|||
if (tool = persona.find_tool(partial))
|
||||
tool_found = true
|
||||
ongoing_chain = tool.chain_next_response?
|
||||
low_cost = tool.low_cost?
|
||||
tool_call_id = tool.tool_call_id
|
||||
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
|
||||
|
||||
|
@ -129,43 +130,36 @@ module DiscourseAi
|
|||
result
|
||||
end
|
||||
|
||||
def model(prefer_low_cost: false)
|
||||
def self.guess_model(bot_user)
|
||||
# HACK(roman): We'll do this until we define how we represent different providers in the bot settings
|
||||
default_model =
|
||||
case bot_user.id
|
||||
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
|
||||
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
|
||||
"aws_bedrock:claude-2"
|
||||
else
|
||||
"anthropic:claude-2"
|
||||
end
|
||||
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
|
||||
"open_ai:gpt-4"
|
||||
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
|
||||
"open_ai:gpt-4-turbo"
|
||||
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
|
||||
"open_ai:gpt-3.5-turbo-16k"
|
||||
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
|
||||
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
)
|
||||
"vllm:mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
else
|
||||
"hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
end
|
||||
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
|
||||
"google:gemini-pro"
|
||||
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
|
||||
"fake:fake"
|
||||
case bot_user.id
|
||||
when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID
|
||||
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2")
|
||||
"aws_bedrock:claude-2"
|
||||
else
|
||||
nil
|
||||
"anthropic:claude-2"
|
||||
end
|
||||
|
||||
if %w[open_ai:gpt-4 open_ai:gpt-4-turbo].include?(default_model) && prefer_low_cost
|
||||
return "open_ai:gpt-3.5-turbo-16k"
|
||||
when DiscourseAi::AiBot::EntryPoint::GPT4_ID
|
||||
"open_ai:gpt-4"
|
||||
when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID
|
||||
"open_ai:gpt-4-turbo"
|
||||
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
|
||||
"open_ai:gpt-3.5-turbo-16k"
|
||||
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
|
||||
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(
|
||||
"mistralai/Mixtral-8x7B-Instruct-v0.1",
|
||||
)
|
||||
"vllm:mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
else
|
||||
"hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1"
|
||||
end
|
||||
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
|
||||
"google:gemini-pro"
|
||||
when DiscourseAi::AiBot::EntryPoint::FAKE_ID
|
||||
"fake:fake"
|
||||
else
|
||||
nil
|
||||
end
|
||||
|
||||
default_model
|
||||
end
|
||||
|
||||
def tool_invocation?(partial)
|
||||
|
|
|
@ -11,46 +11,46 @@ module DiscourseAi
|
|||
|
||||
REQUIRE_TITLE_UPDATE = "discourse-ai-title-update"
|
||||
|
||||
def self.schedule_reply(post)
|
||||
def self.is_bot_user_id?(user_id)
|
||||
bot_ids = DiscourseAi::AiBot::EntryPoint::BOT_USER_IDS
|
||||
bot_ids.include?(user_id) ||
|
||||
begin
|
||||
mentionable_ids = AiPersona.mentionables.map { |mentionable| mentionable[:user_id] }
|
||||
mentionable_ids.include?(user_id)
|
||||
end
|
||||
end
|
||||
|
||||
return if bot_ids.include?(post.user_id)
|
||||
if AiPersona.mentionables.any? { |mentionable| mentionable[:user_id] == post.user_id }
|
||||
return
|
||||
end
|
||||
def self.schedule_reply(post)
|
||||
return if is_bot_user_id?(post.user_id)
|
||||
|
||||
bot_ids = DiscourseAi::AiBot::EntryPoint::BOT_USER_IDS
|
||||
mentionables = AiPersona.mentionables(user: post.user)
|
||||
|
||||
bot_user = nil
|
||||
mentioned = nil
|
||||
|
||||
if post.topic.private_message?
|
||||
bot_user = post.topic.topic_allowed_users.where(user_id: bot_ids).first&.user
|
||||
bot_user ||=
|
||||
post
|
||||
.topic
|
||||
.topic_allowed_users
|
||||
.where(user_id: mentionables.map { |m| m[:user_id] })
|
||||
.first
|
||||
&.user
|
||||
end
|
||||
|
||||
if AiPersona.mentionables.length > 0
|
||||
if mentionables.present?
|
||||
mentions = post.mentions.map(&:downcase)
|
||||
mentioned =
|
||||
AiPersona.mentionables.find do |mentionable|
|
||||
mentions.include?(mentionable[:username]) &&
|
||||
(post.user.group_ids & mentionable[:allowed_group_ids]).present?
|
||||
end
|
||||
mentioned = mentionables.find { |mentionable| mentions.include?(mentionable[:username]) }
|
||||
|
||||
# PM always takes precedence
|
||||
if mentioned && !bot_user
|
||||
model_without_provider = mentioned[:default_llm].split(":").last
|
||||
user_id =
|
||||
DiscourseAi::AiBot::EntryPoint.map_bot_model_to_user_id(model_without_provider)
|
||||
|
||||
if !user_id
|
||||
Rails.logger.warn(
|
||||
"Model #{mentioned[:default_llm]} not found for persona #{mentioned[:username]}",
|
||||
)
|
||||
if Rails.env.development? || Rails.env.test?
|
||||
raise "Model #{mentioned[:default_llm]} not found for persona #{mentioned[:username]}"
|
||||
end
|
||||
else
|
||||
bot_user = User.find_by(id: user_id)
|
||||
end
|
||||
# direct PM to mentionable
|
||||
if !mentioned && bot_user
|
||||
mentioned = mentionables.find { |mentionable| bot_user.id == mentionable[:user_id] }
|
||||
end
|
||||
|
||||
# public topic so we need to use the persona user
|
||||
bot_user ||= User.find_by(id: mentioned[:user_id]) if mentioned
|
||||
end
|
||||
|
||||
if bot_user
|
||||
|
@ -309,6 +309,7 @@ module DiscourseAi
|
|||
:update_ai_bot_pm_title,
|
||||
post_id: post.id,
|
||||
bot_user_id: bot.bot_user.id,
|
||||
model: bot.model,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -44,10 +44,6 @@ module DiscourseAi
|
|||
true
|
||||
end
|
||||
|
||||
def low_cost?
|
||||
true
|
||||
end
|
||||
|
||||
def custom_raw
|
||||
@last_summary || I18n.t("discourse_ai.ai_bot.topic_not_found")
|
||||
end
|
||||
|
|
|
@ -75,10 +75,6 @@ module DiscourseAi
|
|||
false
|
||||
end
|
||||
|
||||
def low_cost?
|
||||
false
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def accepted_options
|
||||
|
|
|
@ -147,8 +147,7 @@ module DiscourseAi
|
|||
{
|
||||
type: "text",
|
||||
text:
|
||||
"Describe this image in a single sentence" +
|
||||
custom_locale_instructions(user),
|
||||
"Describe this image in a single sentence#{custom_locale_instructions(user)}",
|
||||
},
|
||||
{ type: "image_url", image_url: image_url },
|
||||
],
|
||||
|
|
|
@ -69,7 +69,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
)
|
||||
|
||||
persona.create_user!
|
||||
persona.update!(default_llm: "claude-2", mentionable: true)
|
||||
persona.update!(default_llm: "anthropic:claude-2", mentionable: true)
|
||||
persona
|
||||
end
|
||||
|
||||
|
@ -92,6 +92,9 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
end
|
||||
|
||||
it "allows mentioning a persona" do
|
||||
# we still should be able to mention with no bots
|
||||
SiteSetting.ai_bot_enabled_chat_bots = ""
|
||||
|
||||
post = nil
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(["Yes I can"]) do
|
||||
post =
|
||||
|
@ -107,6 +110,33 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
expect(last_post.user_id).to eq(persona.user_id)
|
||||
end
|
||||
|
||||
it "allows PMing a persona even when no particular bots are enabled" do
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
SiteSetting.ai_bot_enabled_chat_bots = ""
|
||||
post = nil
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(
|
||||
["Magic title", "Yes I can"],
|
||||
llm: "anthropic:claude-2",
|
||||
) do
|
||||
post =
|
||||
create_post(
|
||||
title: "I just made a PM",
|
||||
raw: "Hey there #{persona.user.username}, can you help me?",
|
||||
target_usernames: "#{user.username},#{persona.user.username}",
|
||||
archetype: Archetype.private_message,
|
||||
user: admin,
|
||||
)
|
||||
end
|
||||
|
||||
last_post = post.topic.posts.order(:post_number).last
|
||||
expect(last_post.raw).to eq("Yes I can")
|
||||
expect(last_post.user_id).to eq(persona.user_id)
|
||||
|
||||
last_post.topic.reload
|
||||
expect(last_post.topic.allowed_users.pluck(:user_id)).to include(persona.user_id)
|
||||
end
|
||||
|
||||
it "picks the correct llm for persona in PMs" do
|
||||
# If you start a PM with GPT 3.5 bot, replies should come from it, not from Claude
|
||||
SiteSetting.ai_bot_enabled = true
|
||||
|
|
Loading…
Reference in New Issue