From 492383716543191ab6063353b231059b4ed0ebd2 Mon Sep 17 00:00:00 2001 From: Sam Date: Fri, 25 Oct 2024 06:24:53 +1100 Subject: [PATCH] FIX: Llm selector / forced tools / search tool (#862) * FIX: Llm selector / forced tools / search tool This fixes a few issues: 1. When search was not finding any semantic results we would break the tool 2. Gemin / Anthropic models did not implement forced tools previously despite it being an API option 3. Mechanics around displaying llm selector were not right. If you disabled LLM selector server side persona PM did not work correctly. 4. Disabling native tools for anthropic model moved out of a site setting. This deliberately does not migrate cause this feature is really rare to need now, people who had it set probably did not need it. 5. Updates anthropic model names to latest release * linting * fix a couple of tests I missed * clean up conditional --- app/models/llm_model.rb | 4 ++ .../composer-fields/persona-llm-selector.gjs | 49 ++++++++++++++----- config/locales/client.en.yml | 1 + config/locales/server.en.yml | 1 - config/settings.yml | 10 ---- lib/ai_bot/entry_point.rb | 2 + lib/completions/dialects/claude.rb | 2 +- lib/completions/endpoints/anthropic.rb | 9 +++- lib/completions/endpoints/aws_bedrock.rb | 10 +++- lib/completions/endpoints/gemini.rb | 17 +++++-- lib/embeddings/semantic_search.rb | 4 +- spec/lib/completions/dialects/claude_spec.rb | 5 +- .../completions/endpoints/anthropic_spec.rb | 3 -- .../completions/endpoints/aws_bedrock_spec.rb | 4 +- spec/lib/modules/ai_bot/tools/search_spec.rb | 16 ++++++ spec/system/ai_bot/ai_bot_helper_spec.rb | 24 ++++++++- 16 files changed, 121 insertions(+), 40 deletions(-) diff --git a/app/models/llm_model.rb b/app/models/llm_model.rb index 79e065a6..d4169011 100644 --- a/app/models/llm_model.rb +++ b/app/models/llm_model.rb @@ -19,6 +19,10 @@ class LlmModel < ActiveRecord::Base aws_bedrock: { access_key_id: :text, region: :text, + disable_native_tools: :checkbox, + }, + anthropic: { + disable_native_tools: :checkbox, }, open_ai: { organization: :text, diff --git a/assets/javascripts/discourse/connectors/composer-fields/persona-llm-selector.gjs b/assets/javascripts/discourse/connectors/composer-fields/persona-llm-selector.gjs index 0a19f8a8..a3827b76 100644 --- a/assets/javascripts/discourse/connectors/composer-fields/persona-llm-selector.gjs +++ b/assets/javascripts/discourse/connectors/composer-fields/persona-llm-selector.gjs @@ -15,7 +15,7 @@ function isBotMessage(composer, currentUser) { const reciepients = composer.targetRecipients.split(","); return currentUser.ai_enabled_chat_bots - .filter((bot) => !bot.is_persona) + .filter((bot) => bot.username) .any((bot) => reciepients.any((username) => username === bot.username)); } return false; @@ -43,7 +43,7 @@ export default class BotSelector extends Component { constructor() { super(...arguments); - if (this.botOptions && this.composer) { + if (this.botOptions && this.botOptions.length && this.composer) { let personaId = this.preferredPersonaStore.getObject("id"); this._value = this.botOptions[0].id; @@ -57,19 +57,29 @@ export default class BotSelector extends Component { this.composer.metaData = { ai_persona_id: this._value }; this.setAllowLLMSelector(); - let llm = this.preferredLlmStore.getObject("id"); + if (this.hasLlmSelector) { + let llm = this.preferredLlmStore.getObject("id"); - const llmOption = - this.llmOptions.find((innerLlmOption) => innerLlmOption.id === llm) || - this.llmOptions[0]; + const llmOption = + this.llmOptions.find((innerLlmOption) => innerLlmOption.id === llm) || + this.llmOptions[0]; - llm = llmOption.id; + if (llmOption) { + llm = llmOption.id; + } else { + llm = ""; + } - if (llm) { - next(() => { - this.currentLlm = llm; - }); + if (llm) { + next(() => { + this.currentLlm = llm; + }); + } } + + next(() => { + this.resetTargetRecipients(); + }); } } @@ -77,9 +87,19 @@ export default class BotSelector extends Component { return this.args?.outletArgs?.model; } + get hasLlmSelector() { + return this.currentUser.ai_enabled_chat_bots.any((bot) => !bot.is_persona); + } + get botOptions() { if (this.currentUser.ai_enabled_personas) { - return this.currentUser.ai_enabled_personas.map((persona) => { + let enabledPersonas = this.currentUser.ai_enabled_personas; + + if (!this.hasLlmSelector) { + enabledPersonas = enabledPersonas.filter((persona) => persona.username); + } + + return enabledPersonas.map((persona) => { return { id: persona.id, name: persona.name, @@ -106,6 +126,11 @@ export default class BotSelector extends Component { } setAllowLLMSelector() { + if (!this.hasLlmSelector) { + this.allowLLMSelector = false; + return; + } + const persona = this.currentUser.ai_enabled_personas.find( (innerPersona) => innerPersona.id === this._value ); diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index a614bc09..7cea7b07 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -329,6 +329,7 @@ en: organization: "Optional OpenAI Organization ID" disable_system_prompt: "Disable system message in prompts" enable_native_tool: "Enable native tool support" + disable_native_tools: "Disable native tool support (use XML based tools)" related_topics: title: "Related Topics" diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 2c2033a0..1ede2ad0 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -49,7 +49,6 @@ en: ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)" ai_openai_api_key: "API key for OpenAI API. ONLY used for embeddings and Dall-E. For GPT use the LLM config tab" - ai_anthropic_native_tool_call_models: "List of models that will use native tool calls vs legacy XML based tools." ai_hugging_face_tei_endpoint: URL where the API is running for the Hugging Face text embeddings inference ai_hugging_face_tei_api_key: API key for Hugging Face text embeddings inference diff --git a/config/settings.yml b/config/settings.yml index 07abc071..55f8392f 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -125,16 +125,6 @@ discourse_ai: ai_anthropic_api_key: default: "" hidden: true - ai_anthropic_native_tool_call_models: - type: list - list_type: compact - default: "claude-3-sonnet|claude-3-haiku" - allow_any: false - choices: - - claude-3-opus - - claude-3-sonnet - - claude-3-haiku - - claude-3-5-sonnet ai_cohere_api_key: default: "" hidden: true diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb index 2906b521..64fd9133 100644 --- a/lib/ai_bot/entry_point.rb +++ b/lib/ai_bot/entry_point.rb @@ -145,6 +145,8 @@ module DiscourseAi persona_users = AiPersona.persona_users(user: scope.user) if persona_users.present? + persona_users.filter! { |persona_user| persona_user[:username].present? } + bots_map.concat( persona_users.map do |persona_user| { diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 296b8ffd..916bd90c 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -61,7 +61,7 @@ module DiscourseAi end def native_tool_support? - SiteSetting.ai_anthropic_native_tool_call_models_map.include?(llm_model.name) + !llm_model.lookup_custom_param("disable_native_tools") end private diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index c76bd04a..44762b88 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -27,7 +27,7 @@ module DiscourseAi when "claude-3-opus" "claude-3-opus-20240229" when "claude-3-5-sonnet" - "claude-3-5-sonnet-20240620" + "claude-3-5-sonnet-latest" else llm_model.name end @@ -70,7 +70,12 @@ module DiscourseAi payload[:system] = prompt.system_prompt if prompt.system_prompt.present? payload[:stream] = true if @streaming_mode - payload[:tools] = prompt.tools if prompt.has_tools? + if prompt.has_tools? + payload[:tools] = prompt.tools + if dialect.tool_choice.present? + payload[:tool_choice] = { type: "tool", name: dialect.tool_choice } + end + end payload end diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 409c856b..f3146c2d 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -61,7 +61,7 @@ module DiscourseAi when "claude-3-opus" "anthropic.claude-3-opus-20240229-v1:0" when "claude-3-5-sonnet" - "anthropic.claude-3-5-sonnet-20240620-v1:0" + "anthropic.claude-3-5-sonnet-20241022-v2:0" else llm_model.name end @@ -83,7 +83,13 @@ module DiscourseAi payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) payload[:system] = prompt.system_prompt if prompt.system_prompt.present? - payload[:tools] = prompt.tools if prompt.has_tools? + + if prompt.has_tools? + payload[:tools] = prompt.tools + if dialect.tool_choice.present? + payload[:tool_choice] = { type: "tool", name: dialect.tool_choice } + end + end payload end diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index 92ac29ac..f2f8fa9b 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -67,7 +67,16 @@ module DiscourseAi } if prompt[:system_instruction].present? if tools.present? payload[:tools] = tools - payload[:tool_config] = { function_calling_config: { mode: "AUTO" } } + + function_calling_config = { mode: "AUTO" } + if dialect.tool_choice.present? + function_calling_config = { + mode: "ANY", + allowed_function_names: [dialect.tool_choice], + } + end + + payload[:tool_config] = { function_calling_config: function_calling_config } end payload[:generationConfig].merge!(model_params) if model_params.present? payload @@ -88,8 +97,10 @@ module DiscourseAi end response_h = parsed.dig(:candidates, 0, :content, :parts, 0) - @has_function_call ||= response_h.dig(:functionCall).present? - @has_function_call ? response_h[:functionCall] : response_h.dig(:text) + if response_h + @has_function_call ||= response_h.dig(:functionCall).present? + @has_function_call ? response_h.dig(:functionCall) : response_h.dig(:text) + end end def partials_from(decoded_chunk) diff --git a/lib/embeddings/semantic_search.rb b/lib/embeddings/semantic_search.rb index 02d8dd11..cae93958 100644 --- a/lib/embeddings/semantic_search.rb +++ b/lib/embeddings/semantic_search.rb @@ -79,7 +79,9 @@ module DiscourseAi search = Search.new(query, { guardian: guardian }) search_term = search.term - return [] if search_term.nil? || search_term.length < SiteSetting.min_search_term_length + if search_term.blank? || search_term.length < SiteSetting.min_search_term_length + return Post.none + end search_embedding = hyde ? hyde_embedding(search_term) : embedding(search_term) diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb index d1657cc0..624c431f 100644 --- a/spec/lib/completions/dialects/claude_spec.rb +++ b/spec/lib/completions/dialects/claude_spec.rb @@ -35,7 +35,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do end it "can properly translate a prompt (legacy tools)" do - SiteSetting.ai_anthropic_native_tool_call_models = "" + llm_model.provider_params["disable_native_tools"] = true + llm_model.save! tools = [ { @@ -88,8 +89,6 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do end it "can properly translate a prompt (native tools)" do - SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus" - tools = [ { name: "echo", diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index 4caf077f..3ab38644 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -48,7 +48,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do end it "does not eat spaces with tool calls" do - SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus" body = <<~STRING event: message_start data: {"type":"message_start","message":{"id":"msg_01Ju4j2MiGQb9KV9EEQ522Y3","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":1293,"output_tokens":1}} } @@ -195,8 +194,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do end it "supports non streaming tool calls" do - SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus" - tool = { name: "calculate", description: "calculate something", diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index eccc4195..d9519344 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -28,7 +28,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do describe "function calling" do it "supports old school xml function calls" do - SiteSetting.ai_anthropic_native_tool_call_models = "" + model.provider_params["disable_native_tools"] = true + model.save! + proxy = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") incomplete_tool_call = <<~XML.strip diff --git a/spec/lib/modules/ai_bot/tools/search_spec.rb b/spec/lib/modules/ai_bot/tools/search_spec.rb index bbd92878..4f664f1b 100644 --- a/spec/lib/modules/ai_bot/tools/search_spec.rb +++ b/spec/lib/modules/ai_bot/tools/search_spec.rb @@ -138,6 +138,22 @@ RSpec.describe DiscourseAi::AiBot::Tools::Search do expect(results[:args]).to eq({ search_query: "hello world, sam", status: "public" }) expect(results[:rows].length).to eq(1) + + # it also works with no query + search = + described_class.new( + { order: "likes", user: "sam", status: "public", search_query: "a" }, + llm: llm, + bot_user: bot_user, + ) + + # results will be expanded by semantic search, but it will find nothing + results = + DiscourseAi::Completions::Llm.with_prepared_responses(["#{query}"]) do + search.invoke(&progress_blk) + end + + expect(results[:rows].length).to eq(0) end end diff --git a/spec/system/ai_bot/ai_bot_helper_spec.rb b/spec/system/ai_bot/ai_bot_helper_spec.rb index b4b151b2..97186a05 100644 --- a/spec/system/ai_bot/ai_bot_helper_spec.rb +++ b/spec/system/ai_bot/ai_bot_helper_spec.rb @@ -22,13 +22,35 @@ RSpec.describe "AI chat channel summarization", type: :system, js: true do group.add(user) group.save + allowed_persona = AiPersona.last + allowed_persona.update!(allowed_group_ids: [group.id], enabled: true) + visit "/latest" expect(page).to have_selector(".ai-bot-button") find(".ai-bot-button").click - # composer is open + find(".gpt-persona").click + expect(page).to have_css(".gpt-persona ul li", count: 1) + + find(".llm-selector").click + expect(page).to have_css(".llm-selector ul li", count: 2) + expect(page).to have_selector(".d-editor-container") + # lets disable bots but still allow 1 persona + allowed_persona.create_user! + allowed_persona.update!(default_llm: "custom:#{gpt_4.id}") + + gpt_4.update!(enabled_chat_bot: false) + gpt_3_5_turbo.update!(enabled_chat_bot: false) + + visit "/latest" + find(".ai-bot-button").click + + find(".gpt-persona").click + expect(page).to have_css(".gpt-persona ul li", count: 1) + expect(page).not_to have_selector(".llm-selector") + SiteSetting.ai_bot_add_to_header = false visit "/latest" expect(page).not_to have_selector(".ai-bot-button")