FIX: Llm selector / forced tools / search tool (#862)

* FIX: Llm selector / forced tools / search tool


This fixes a few issues:

1. When search was not finding any semantic results we would break the tool
2. Gemin / Anthropic models did not implement forced tools previously despite it being an API option
3. Mechanics around displaying llm selector were not right. If you disabled LLM selector server side persona PM did not work correctly.
4. Disabling native tools for anthropic model moved out of a site setting. This deliberately does not migrate cause this feature is really rare to need now, people who had it set probably did not need it.
5. Updates anthropic model names to latest release

* linting

* fix a couple of tests I missed

* clean up conditional
This commit is contained in:
Sam 2024-10-25 06:24:53 +11:00 committed by GitHub
parent 3022d34613
commit 4923837165
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
16 changed files with 121 additions and 40 deletions

View File

@ -19,6 +19,10 @@ class LlmModel < ActiveRecord::Base
aws_bedrock: { aws_bedrock: {
access_key_id: :text, access_key_id: :text,
region: :text, region: :text,
disable_native_tools: :checkbox,
},
anthropic: {
disable_native_tools: :checkbox,
}, },
open_ai: { open_ai: {
organization: :text, organization: :text,

View File

@ -15,7 +15,7 @@ function isBotMessage(composer, currentUser) {
const reciepients = composer.targetRecipients.split(","); const reciepients = composer.targetRecipients.split(",");
return currentUser.ai_enabled_chat_bots return currentUser.ai_enabled_chat_bots
.filter((bot) => !bot.is_persona) .filter((bot) => bot.username)
.any((bot) => reciepients.any((username) => username === bot.username)); .any((bot) => reciepients.any((username) => username === bot.username));
} }
return false; return false;
@ -43,7 +43,7 @@ export default class BotSelector extends Component {
constructor() { constructor() {
super(...arguments); super(...arguments);
if (this.botOptions && this.composer) { if (this.botOptions && this.botOptions.length && this.composer) {
let personaId = this.preferredPersonaStore.getObject("id"); let personaId = this.preferredPersonaStore.getObject("id");
this._value = this.botOptions[0].id; this._value = this.botOptions[0].id;
@ -57,13 +57,18 @@ export default class BotSelector extends Component {
this.composer.metaData = { ai_persona_id: this._value }; this.composer.metaData = { ai_persona_id: this._value };
this.setAllowLLMSelector(); this.setAllowLLMSelector();
if (this.hasLlmSelector) {
let llm = this.preferredLlmStore.getObject("id"); let llm = this.preferredLlmStore.getObject("id");
const llmOption = const llmOption =
this.llmOptions.find((innerLlmOption) => innerLlmOption.id === llm) || this.llmOptions.find((innerLlmOption) => innerLlmOption.id === llm) ||
this.llmOptions[0]; this.llmOptions[0];
if (llmOption) {
llm = llmOption.id; llm = llmOption.id;
} else {
llm = "";
}
if (llm) { if (llm) {
next(() => { next(() => {
@ -71,15 +76,30 @@ export default class BotSelector extends Component {
}); });
} }
} }
next(() => {
this.resetTargetRecipients();
});
}
} }
get composer() { get composer() {
return this.args?.outletArgs?.model; return this.args?.outletArgs?.model;
} }
get hasLlmSelector() {
return this.currentUser.ai_enabled_chat_bots.any((bot) => !bot.is_persona);
}
get botOptions() { get botOptions() {
if (this.currentUser.ai_enabled_personas) { if (this.currentUser.ai_enabled_personas) {
return this.currentUser.ai_enabled_personas.map((persona) => { let enabledPersonas = this.currentUser.ai_enabled_personas;
if (!this.hasLlmSelector) {
enabledPersonas = enabledPersonas.filter((persona) => persona.username);
}
return enabledPersonas.map((persona) => {
return { return {
id: persona.id, id: persona.id,
name: persona.name, name: persona.name,
@ -106,6 +126,11 @@ export default class BotSelector extends Component {
} }
setAllowLLMSelector() { setAllowLLMSelector() {
if (!this.hasLlmSelector) {
this.allowLLMSelector = false;
return;
}
const persona = this.currentUser.ai_enabled_personas.find( const persona = this.currentUser.ai_enabled_personas.find(
(innerPersona) => innerPersona.id === this._value (innerPersona) => innerPersona.id === this._value
); );

View File

@ -329,6 +329,7 @@ en:
organization: "Optional OpenAI Organization ID" organization: "Optional OpenAI Organization ID"
disable_system_prompt: "Disable system message in prompts" disable_system_prompt: "Disable system message in prompts"
enable_native_tool: "Enable native tool support" enable_native_tool: "Enable native tool support"
disable_native_tools: "Disable native tool support (use XML based tools)"
related_topics: related_topics:
title: "Related Topics" title: "Related Topics"

View File

@ -49,7 +49,6 @@ en:
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)" ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
ai_openai_api_key: "API key for OpenAI API. ONLY used for embeddings and Dall-E. For GPT use the LLM config tab" ai_openai_api_key: "API key for OpenAI API. ONLY used for embeddings and Dall-E. For GPT use the LLM config tab"
ai_anthropic_native_tool_call_models: "List of models that will use native tool calls vs legacy XML based tools."
ai_hugging_face_tei_endpoint: URL where the API is running for the Hugging Face text embeddings inference ai_hugging_face_tei_endpoint: URL where the API is running for the Hugging Face text embeddings inference
ai_hugging_face_tei_api_key: API key for Hugging Face text embeddings inference ai_hugging_face_tei_api_key: API key for Hugging Face text embeddings inference

View File

@ -125,16 +125,6 @@ discourse_ai:
ai_anthropic_api_key: ai_anthropic_api_key:
default: "" default: ""
hidden: true hidden: true
ai_anthropic_native_tool_call_models:
type: list
list_type: compact
default: "claude-3-sonnet|claude-3-haiku"
allow_any: false
choices:
- claude-3-opus
- claude-3-sonnet
- claude-3-haiku
- claude-3-5-sonnet
ai_cohere_api_key: ai_cohere_api_key:
default: "" default: ""
hidden: true hidden: true

View File

@ -145,6 +145,8 @@ module DiscourseAi
persona_users = AiPersona.persona_users(user: scope.user) persona_users = AiPersona.persona_users(user: scope.user)
if persona_users.present? if persona_users.present?
persona_users.filter! { |persona_user| persona_user[:username].present? }
bots_map.concat( bots_map.concat(
persona_users.map do |persona_user| persona_users.map do |persona_user|
{ {

View File

@ -61,7 +61,7 @@ module DiscourseAi
end end
def native_tool_support? def native_tool_support?
SiteSetting.ai_anthropic_native_tool_call_models_map.include?(llm_model.name) !llm_model.lookup_custom_param("disable_native_tools")
end end
private private

View File

@ -27,7 +27,7 @@ module DiscourseAi
when "claude-3-opus" when "claude-3-opus"
"claude-3-opus-20240229" "claude-3-opus-20240229"
when "claude-3-5-sonnet" when "claude-3-5-sonnet"
"claude-3-5-sonnet-20240620" "claude-3-5-sonnet-latest"
else else
llm_model.name llm_model.name
end end
@ -70,7 +70,12 @@ module DiscourseAi
payload[:system] = prompt.system_prompt if prompt.system_prompt.present? payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:stream] = true if @streaming_mode payload[:stream] = true if @streaming_mode
payload[:tools] = prompt.tools if prompt.has_tools? if prompt.has_tools?
payload[:tools] = prompt.tools
if dialect.tool_choice.present?
payload[:tool_choice] = { type: "tool", name: dialect.tool_choice }
end
end
payload payload
end end

View File

@ -61,7 +61,7 @@ module DiscourseAi
when "claude-3-opus" when "claude-3-opus"
"anthropic.claude-3-opus-20240229-v1:0" "anthropic.claude-3-opus-20240229-v1:0"
when "claude-3-5-sonnet" when "claude-3-5-sonnet"
"anthropic.claude-3-5-sonnet-20240620-v1:0" "anthropic.claude-3-5-sonnet-20241022-v2:0"
else else
llm_model.name llm_model.name
end end
@ -83,7 +83,13 @@ module DiscourseAi
payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages)
payload[:system] = prompt.system_prompt if prompt.system_prompt.present? payload[:system] = prompt.system_prompt if prompt.system_prompt.present?
payload[:tools] = prompt.tools if prompt.has_tools?
if prompt.has_tools?
payload[:tools] = prompt.tools
if dialect.tool_choice.present?
payload[:tool_choice] = { type: "tool", name: dialect.tool_choice }
end
end
payload payload
end end

View File

@ -67,7 +67,16 @@ module DiscourseAi
} if prompt[:system_instruction].present? } if prompt[:system_instruction].present?
if tools.present? if tools.present?
payload[:tools] = tools payload[:tools] = tools
payload[:tool_config] = { function_calling_config: { mode: "AUTO" } }
function_calling_config = { mode: "AUTO" }
if dialect.tool_choice.present?
function_calling_config = {
mode: "ANY",
allowed_function_names: [dialect.tool_choice],
}
end
payload[:tool_config] = { function_calling_config: function_calling_config }
end end
payload[:generationConfig].merge!(model_params) if model_params.present? payload[:generationConfig].merge!(model_params) if model_params.present?
payload payload
@ -88,8 +97,10 @@ module DiscourseAi
end end
response_h = parsed.dig(:candidates, 0, :content, :parts, 0) response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
if response_h
@has_function_call ||= response_h.dig(:functionCall).present? @has_function_call ||= response_h.dig(:functionCall).present?
@has_function_call ? response_h[:functionCall] : response_h.dig(:text) @has_function_call ? response_h.dig(:functionCall) : response_h.dig(:text)
end
end end
def partials_from(decoded_chunk) def partials_from(decoded_chunk)

View File

@ -79,7 +79,9 @@ module DiscourseAi
search = Search.new(query, { guardian: guardian }) search = Search.new(query, { guardian: guardian })
search_term = search.term search_term = search.term
return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length if search_term.blank? || search_term.length < SiteSetting.min_search_term_length
return Post.none
end
search_embedding = hyde ? hyde_embedding(search_term) : embedding(search_term) search_embedding = hyde ? hyde_embedding(search_term) : embedding(search_term)

View File

@ -35,7 +35,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
end end
it "can properly translate a prompt (legacy tools)" do it "can properly translate a prompt (legacy tools)" do
SiteSetting.ai_anthropic_native_tool_call_models = "" llm_model.provider_params["disable_native_tools"] = true
llm_model.save!
tools = [ tools = [
{ {
@ -88,8 +89,6 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
end end
it "can properly translate a prompt (native tools)" do it "can properly translate a prompt (native tools)" do
SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus"
tools = [ tools = [
{ {
name: "echo", name: "echo",

View File

@ -48,7 +48,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end end
it "does not eat spaces with tool calls" do it "does not eat spaces with tool calls" do
SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus"
body = <<~STRING body = <<~STRING
event: message_start event: message_start
data: {"type":"message_start","message":{"id":"msg_01Ju4j2MiGQb9KV9EEQ522Y3","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":1293,"output_tokens":1}} } data: {"type":"message_start","message":{"id":"msg_01Ju4j2MiGQb9KV9EEQ522Y3","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":1293,"output_tokens":1}} }
@ -195,8 +194,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end end
it "supports non streaming tool calls" do it "supports non streaming tool calls" do
SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus"
tool = { tool = {
name: "calculate", name: "calculate",
description: "calculate something", description: "calculate something",

View File

@ -28,7 +28,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
describe "function calling" do describe "function calling" do
it "supports old school xml function calls" do it "supports old school xml function calls" do
SiteSetting.ai_anthropic_native_tool_call_models = "" model.provider_params["disable_native_tools"] = true
model.save!
proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
incomplete_tool_call = <<~XML.strip incomplete_tool_call = <<~XML.strip

View File

@ -138,6 +138,22 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do
expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" }) expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" })
expect(results[:rows].length).to eq(1) expect(results[:rows].length).to eq(1)
# it also works with no query
search =
described_class.new(
{ order: "likes", user: "sam", status: "public", search_query: "a" },
llm: llm,
bot_user: bot_user,
)
# results will be expanded by semantic search, but it will find nothing
results =
DiscourseAi::Completions::Llm.with_prepared_responses(["<ai>#{query}</ai>"]) do
search.invoke(&progress_blk)
end
expect(results[:rows].length).to eq(0)
end end
end end

View File

@ -22,13 +22,35 @@ RSpec.describe "AI chat channel summarization", type: :system, js: true do
group.add(user) group.add(user)
group.save group.save
allowed_persona = AiPersona.last
allowed_persona.update!(allowed_group_ids: [group.id], enabled: true)
visit "/latest" visit "/latest"
expect(page).to have_selector(".ai-bot-button") expect(page).to have_selector(".ai-bot-button")
find(".ai-bot-button").click find(".ai-bot-button").click
# composer is open find(".gpt-persona").click
expect(page).to have_css(".gpt-persona ul li", count: 1)
find(".llm-selector").click
expect(page).to have_css(".llm-selector ul li", count: 2)
expect(page).to have_selector(".d-editor-container") expect(page).to have_selector(".d-editor-container")
# lets disable bots but still allow 1 persona
allowed_persona.create_user!
allowed_persona.update!(default_llm: "custom:#{gpt_4.id}")
gpt_4.update!(enabled_chat_bot: false)
gpt_3_5_turbo.update!(enabled_chat_bot: false)
visit "/latest"
find(".ai-bot-button").click
find(".gpt-persona").click
expect(page).to have_css(".gpt-persona ul li", count: 1)
expect(page).not_to have_selector(".llm-selector")
SiteSetting.ai_bot_add_to_header = false SiteSetting.ai_bot_add_to_header = false
visit "/latest" visit "/latest"
expect(page).not_to have_selector(".ai-bot-button") expect(page).not_to have_selector(".ai-bot-button")