diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index c656283a..2bb29c83 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -51,6 +51,7 @@ en: ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)" ai_openai_api_key: "API key for OpenAI API" ai_anthropic_api_key: "API key for Anthropic API" + ai_anthropic_native_tool_call_models: "List of models that will use native tool calls vs legacy XML based tools." ai_cohere_api_key: "API key for Cohere API" ai_hugging_face_api_url: "Custom URL used for OpenSource LLM inference. Compatible with https://github.com/huggingface/text-generation-inference" ai_hugging_face_api_key: API key for Hugging Face API diff --git a/config/settings.yml b/config/settings.yml index fe880e2c..2b0e56a5 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -111,6 +111,15 @@ discourse_ai: ai_anthropic_api_key: default: "" secret: true + ai_anthropic_native_tool_call_models: + type: list + list_type: compact + default: "claude-3-sonnet|claude-3-haiku" + allow_any: false + choices: + - claude-3-opus + - claude-3-sonnet + - claude-3-haiku ai_cohere_api_key: default: "" secret: true diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 6a8a9543..2c6fd131 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -22,6 +22,10 @@ module DiscourseAi @messages = messages @tools = tools end + + def has_tools? + tools.present? + end end def tokenizer @@ -33,6 +37,10 @@ module DiscourseAi system_prompt = messages.shift[:content] if messages.first[:role] == "system" + if !system_prompt && !native_tool_support? + system_prompt = tools_dialect.instructions.presence + end + interleving_messages = [] previous_message = nil @@ -48,11 +56,10 @@ module DiscourseAi previous_message = message end - ClaudePrompt.new( - system_prompt.presence, - interleving_messages, - tools_dialect.translated_tools, - ) + tools = nil + tools = tools_dialect.translated_tools if native_tool_support? + + ClaudePrompt.new(system_prompt.presence, interleving_messages, tools) end def max_prompt_tokens @@ -62,18 +69,28 @@ module DiscourseAi 200_000 # Claude-3 has a 200k context window for now end + def native_tool_support? + SiteSetting.ai_anthropic_native_tool_call_models_map.include?(model_name) + end + private def tools_dialect - @tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools) + if native_tool_support? + @tools_dialect ||= DiscourseAi::Completions::Dialects::ClaudeTools.new(prompt.tools) + else + super + end end def tool_call_msg(msg) - tools_dialect.from_raw_tool_call(msg) + translated = tools_dialect.from_raw_tool_call(msg) + { role: "assistant", content: translated } end def tool_msg(msg) - tools_dialect.from_raw_tool(msg) + translated = tools_dialect.from_raw_tool(msg) + { role: "user", content: translated } end def model_msg(msg) diff --git a/lib/completions/dialects/claude_tools.rb b/lib/completions/dialects/claude_tools.rb index 8708497f..b42a1833 100644 --- a/lib/completions/dialects/claude_tools.rb +++ b/lib/completions/dialects/claude_tools.rb @@ -15,13 +15,14 @@ module DiscourseAi required = [] if t[:parameters] - properties = - t[:parameters].each_with_object({}) do |param, h| - h[param[:name]] = { - type: param[:type], - description: param[:description], - }.tap { |hash| hash[:items] = { type: param[:item_type] } if param[:item_type] } - end + properties = {} + + t[:parameters].each do |param| + mapped = { type: param[:type], description: param[:description] } + mapped[:items] = { type: param[:item_type] } if param[:item_type] + mapped[:enum] = param[:enum] if param[:enum] + properties[param[:name]] = mapped + end required = t[:parameters].select { |param| param[:required] }.map { |param| param[:name] } end @@ -39,37 +40,24 @@ module DiscourseAi end def instructions - "" # Noop. Tools are listed separate. + "" end def from_raw_tool_call(raw_message) call_details = JSON.parse(raw_message[:content], symbolize_names: true) tool_call_id = raw_message[:id] - - { - role: "assistant", - content: [ - { - type: "tool_use", - id: tool_call_id, - name: raw_message[:name], - input: call_details[:arguments], - }, - ], - } + [ + { + type: "tool_use", + id: tool_call_id, + name: raw_message[:name], + input: call_details[:arguments], + }, + ] end def from_raw_tool(raw_message) - { - role: "user", - content: [ - { - type: "tool_result", - tool_use_id: raw_message[:id], - content: raw_message[:content], - }, - ], - } + [{ type: "tool_result", tool_use_id: raw_message[:id], content: raw_message[:content] }] end private diff --git a/lib/completions/dialects/xml_tools.rb b/lib/completions/dialects/xml_tools.rb index 47988a71..9eabfadf 100644 --- a/lib/completions/dialects/xml_tools.rb +++ b/lib/completions/dialects/xml_tools.rb @@ -41,13 +41,17 @@ module DiscourseAi def instructions return "" if raw_tools.blank? - has_arrays = raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } } + @instructions ||= + begin + has_arrays = + raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } } - (<<~TEXT).strip - #{tool_preamble(include_array_tip: has_arrays)} - - #{translated_tools} - TEXT + (<<~TEXT).strip + #{tool_preamble(include_array_tip: has_arrays)} + + #{translated_tools} + TEXT + end end def from_raw_tool(raw_message) diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index c1ec288f..2739b02d 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -45,7 +45,12 @@ module DiscourseAi raise "Unsupported model: #{model}" end - { model: mapped_model, max_tokens: 3_000 } + options = { model: mapped_model, max_tokens: 3_000 } + + options[:stop_sequences] = [""] if !dialect.native_tool_support? && + dialect.prompt.has_tools? + + options end def provider_id @@ -54,6 +59,14 @@ module DiscourseAi private + def xml_tags_to_strip(dialect) + if dialect.prompt.has_tools? + %w[thinking search_quality_reflection search_quality_score] + else + [] + end + end + # this is an approximation, we will update it later if request goes through def prompt_size(prompt) tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s) @@ -66,11 +79,13 @@ module DiscourseAi end def prepare_payload(prompt, model_params, dialect) + @native_tool_support = dialect.native_tool_support? + payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) payload[:system] = prompt.system_prompt if prompt.system_prompt.present? payload[:stream] = true if @streaming_mode - payload[:tools] = prompt.tools if prompt.tools.present? + payload[:tools] = prompt.tools if prompt.has_tools? payload end @@ -108,7 +123,7 @@ module DiscourseAi end def native_tool_support? - true + @native_tool_support end def partials_from(decoded_chunk) diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index bbd87749..d0ef6274 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -36,6 +36,9 @@ module DiscourseAi def default_options(dialect) options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" } + + options[:stop_sequences] = [""] if !dialect.native_tool_support? && + dialect.prompt.has_tools? options end @@ -43,6 +46,14 @@ module DiscourseAi AiApiAuditLog::Provider::Anthropic end + def xml_tags_to_strip(dialect) + if dialect.prompt.has_tools? + %w[thinking search_quality_reflection search_quality_score] + else + [] + end + end + private def prompt_size(prompt) @@ -79,9 +90,11 @@ module DiscourseAi end def prepare_payload(prompt, model_params, dialect) + @native_tool_support = dialect.native_tool_support? + payload = default_options(dialect).merge(model_params).merge(messages: prompt.messages) payload[:system] = prompt.system_prompt if prompt.system_prompt.present? - payload[:tools] = prompt.tools if prompt.tools.present? + payload[:tools] = prompt.tools if prompt.has_tools? payload end @@ -169,7 +182,7 @@ module DiscourseAi end def native_tool_support? - true + @native_tool_support end def chunk_to_string(chunk) diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 8d9d5f68..3c0c5984 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -78,11 +78,27 @@ module DiscourseAi end end + def xml_tags_to_strip(dialect) + [] + end + def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk) allow_tools = dialect.prompt.has_tools? model_params = normalize_model_params(model_params) + orig_blk = blk @streaming_mode = block_given? + to_strip = xml_tags_to_strip(dialect) + @xml_stripper = + DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present? + + if @streaming_mode && @xml_stripper + blk = + lambda do |partial, cancel| + partial = @xml_stripper << partial + orig_blk.call(partial, cancel) if partial + end + end prompt = dialect.translate @@ -270,6 +286,11 @@ module DiscourseAi blk.call(function_calls, cancel) end + if @xml_stripper + leftover = @xml_stripper.finish + orig_blk.call(leftover, cancel) if leftover.present? + end + return response_data ensure if log diff --git a/lib/completions/xml_tag_stripper.rb b/lib/completions/xml_tag_stripper.rb new file mode 100644 index 00000000..729c14f7 --- /dev/null +++ b/lib/completions/xml_tag_stripper.rb @@ -0,0 +1,115 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + class XmlTagStripper + def initialize(tags_to_strip) + @tags_to_strip = tags_to_strip + @longest_tag = tags_to_strip.map(&:length).max + @parsed = [] + end + + def <<(text) + if node = @parsed[-1] + if node[:type] == :maybe_tag + @parsed.pop + text = node[:content] + text + end + end + @parsed.concat(parse_tags(text)) + @parsed, result = process_parsed(@parsed) + result + end + + def finish + @parsed.map { |node| node[:content] }.join + end + + def process_parsed(parsed) + output = [] + buffer = [] + stack = [] + + parsed.each do |node| + case node[:type] + when :text + if stack.empty? + output << node[:content] + else + buffer << node + end + when :open_tag + stack << node[:name] + buffer << node + when :close_tag + if stack.empty? + output << node[:content] + else + if stack[0] == node[:name] + buffer = [] + stack = [] + else + buffer << node + end + end + when :maybe_tag + buffer << node + end + end + + result = output.join + result = nil if result.empty? + + [buffer, result] + end + + def parse_tags(text) + parsed = [] + + while true + before, after = text.split("<", 2) + + parsed << { type: :text, content: before } + + break if after.nil? + + tag, after = after.split(">", 2) + + is_end_tag = tag[0] == "/" + tag_name = tag + tag_name = tag[1..-1] || "" if is_end_tag + + if !after + found = false + if tag_name.length <= @longest_tag + @tags_to_strip.each do |tag_to_strip| + if tag_to_strip.start_with?(tag_name) + parsed << { type: :maybe_tag, content: "<" + tag } + found = true + break + end + end + end + parsed << { type: :text, content: "<" + tag } if !found + break + end + + raw_tag = "<" + tag + ">" + + if @tags_to_strip.include?(tag_name) + parsed << { + type: is_end_tag ? :close_tag : :open_tag, + content: raw_tag, + name: tag_name, + } + else + parsed << { type: :text, content: raw_tag } + end + text = after + end + + parsed + end + end + end +end diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb index 8eb00256..a201cf5b 100644 --- a/spec/lib/completions/dialects/claude_spec.rb +++ b/spec/lib/completions/dialects/claude_spec.rb @@ -1,6 +1,10 @@ # frozen_string_literal: true RSpec.describe DiscourseAi::Completions::Dialects::Claude do + let :opus_dialect_klass do + DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus") + end + describe "#translate" do it "can insert OKs to make stuff interleve properly" do messages = [ @@ -13,8 +17,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do prompt = DiscourseAi::Completions::Prompt.new("You are a helpful bot", messages: messages) - dialectKlass = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus") - dialect = dialectKlass.new(prompt, "claude-3-opus") + dialect = opus_dialect_klass.new(prompt, "claude-3-opus") translated = dialect.translate expected_messages = [ @@ -29,8 +32,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do expect(translated.messages).to eq(expected_messages) end - it "can properly translate a prompt" do - dialect = DiscourseAi::Completions::Dialects::Dialect.dialect_for("claude-3-opus") + it "can properly translate a prompt (legacy tools)" do + SiteSetting.ai_anthropic_native_tool_call_models = "" tools = [ { @@ -59,7 +62,59 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do tools: tools, ) - dialect = dialect.new(prompt, "claude-3-opus") + dialect = opus_dialect_klass.new(prompt, "claude-3-opus") + translated = dialect.translate + + expect(translated.system_prompt).to start_with("You are a helpful bot") + + expected = [ + { role: "user", content: "user1: echo something" }, + { + role: "assistant", + content: + "\n\necho\n\nsomething\n\n\n", + }, + { + role: "user", + content: + "\n\ntool_id\n\n\"something\"\n\n\n", + }, + { role: "assistant", content: "I did it" }, + { role: "user", content: "user1: echo something else" }, + ] + expect(translated.messages).to eq(expected) + end + + it "can properly translate a prompt (native tools)" do + SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus" + + tools = [ + { + name: "echo", + description: "echo a string", + parameters: [ + { name: "text", type: "string", description: "string to echo", required: true }, + ], + }, + ] + + tool_call_prompt = { name: "echo", arguments: { text: "something" } } + + messages = [ + { type: :user, id: "user1", content: "echo something" }, + { type: :tool_call, name: "echo", id: "tool_id", content: tool_call_prompt.to_json }, + { type: :tool, id: "tool_id", content: "something".to_json }, + { type: :model, content: "I did it" }, + { type: :user, id: "user1", content: "echo something else" }, + ] + + prompt = + DiscourseAi::Completions::Prompt.new( + "You are a helpful bot", + messages: messages, + tools: tools, + ) + dialect = opus_dialect_klass.new(prompt, "claude-3-opus") translated = dialect.translate expect(translated.system_prompt).to start_with("You are a helpful bot") diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index ce185dcf..0c47f0e8 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -48,6 +48,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do before { SiteSetting.ai_anthropic_api_key = "123" } it "does not eat spaces with tool calls" do + SiteSetting.ai_anthropic_native_tool_call_models = "claude-3-opus" body = <<~STRING event: message_start data: {"type":"message_start","message":{"id":"msg_01Ju4j2MiGQb9KV9EEQ522Y3","type":"message","role":"assistant","model":"claude-3-haiku-20240307","content":[],"stop_reason":null,"stop_sequence":null,"usage":{"input_tokens":1293,"output_tokens":1}} } diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index c6c2399b..b6c96112 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -18,6 +18,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Claude, user) end + def encode_message(message) + wrapped = { bytes: Base64.encode64(message.to_json) }.to_json + io = StringIO.new(wrapped) + aws_message = Aws::EventStream::Message.new(payload: io) + Aws::EventStream::Encoder.new.encode(aws_message) + end + before do SiteSetting.ai_bedrock_access_key_id = "123456" SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd" @@ -25,6 +32,85 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do end describe "function calling" do + it "supports old school xml function calls" do + SiteSetting.ai_anthropic_native_tool_call_models = "" + proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet") + + incomplete_tool_call = <<~XML.strip + I should be ignored + also ignored + 0 + + + google + sydney weather today + + + XML + + messages = + [ + { type: "message_start", message: { usage: { input_tokens: 9 } } }, + { type: "content_block_delta", delta: { text: "hello\n" } }, + { type: "content_block_delta", delta: { text: incomplete_tool_call } }, + { type: "message_delta", delta: { usage: { output_tokens: 25 } } }, + ].map { |message| encode_message(message) } + + request = nil + bedrock_mock.with_chunk_array_support do + stub_request( + :post, + "https://bedrock-runtime.us-east-1.amazonaws.com/model/anthropic.claude-3-sonnet-20240229-v1:0/invoke-with-response-stream", + ) + .with do |inner_request| + request = inner_request + true + end + .to_return(status: 200, body: messages) + + prompt = + DiscourseAi::Completions::Prompt.new( + messages: [{ type: :user, content: "what is the weather in sydney" }], + ) + + tool = { + name: "google", + description: "Will search using Google", + parameters: [ + { name: "query", description: "The search query", type: "string", required: true }, + ], + } + + prompt.tools = [tool] + response = +"" + proxy.generate(prompt, user: user) { |partial| response << partial } + + expect(request.headers["Authorization"]).to be_present + expect(request.headers["X-Amz-Content-Sha256"]).to be_present + + parsed_body = JSON.parse(request.body) + expect(parsed_body["system"]).to include("") + expect(parsed_body["tools"]).to eq(nil) + expect(parsed_body["stop_sequences"]).to eq([""]) + + # note we now have a tool_id cause we were normalized + function_call = <<~XML.strip + hello + + + + + google + sydney weather today + tool_0 + + + XML + + expect(response.strip).to eq(function_call) + end + end + it "supports streaming function calls" do proxy = DiscourseAi::Completions::Llm.proxy("aws_bedrock:claude-3-sonnet") @@ -48,6 +134,13 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do stop_reason: nil, }, }, + { + type: "content_block_start", + index: 0, + delta: { + text: "I should be ignored", + }, + }, { type: "content_block_start", index: 0, @@ -111,12 +204,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do firstByteLatency: 402, }, }, - ].map do |message| - wrapped = { bytes: Base64.encode64(message.to_json) }.to_json - io = StringIO.new(wrapped) - aws_message = Aws::EventStream::Message.new(payload: io) - Aws::EventStream::Encoder.new.encode(aws_message) - end + ].map { |message| encode_message(message) } messages = messages.join("").split @@ -248,12 +336,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do { type: "content_block_delta", delta: { text: "hello " } }, { type: "content_block_delta", delta: { text: "sam" } }, { type: "message_delta", delta: { usage: { output_tokens: 25 } } }, - ].map do |message| - wrapped = { bytes: Base64.encode64(message.to_json) }.to_json - io = StringIO.new(wrapped) - aws_message = Aws::EventStream::Message.new(payload: io) - Aws::EventStream::Encoder.new.encode(aws_message) - end + ].map { |message| encode_message(message) } # stream 1 letter at a time # cause we need to handle this case diff --git a/spec/lib/completions/xml_tag_stripper_spec.rb b/spec/lib/completions/xml_tag_stripper_spec.rb new file mode 100644 index 00000000..02ac36c4 --- /dev/null +++ b/spec/lib/completions/xml_tag_stripper_spec.rb @@ -0,0 +1,51 @@ +# frozen_string_literal: true + +describe DiscourseAi::Completions::PromptMessagesBuilder do + let(:tag_stripper) { DiscourseAi::Completions::XmlTagStripper.new(%w[thinking results]) } + + it "should strip tags correctly in simple cases" do + result = tag_stripper << "xhelloz" + expect(result).to eq("z") + + result = tag_stripper << "king>hello" + expect(result).to eq("king>hello") + + result = tag_stripper << "123" + expect(result).to eq("123") + end + + it "supports odd nesting" do + text = <<~TEXT + + well lets see what happens if I say here... + + hello + TEXT + + result = tag_stripper << text + expect(result).to eq("\nhello\n") + end + + it "works when nesting unrelated tags it strips correctly" do + text = <<~TEXT + + well lets see what happens if I say

here... + + abc hello + TEXT + + result = tag_stripper << text + + expect(result).to eq("\nabc hello\n") + end + + it "handles maybe tags correctly" do + result = tag_stripper << "