From 4182af230ae38999f7419f98a0e1b674f1051889 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Tue, 2 Jan 2024 11:21:13 -0300 Subject: [PATCH] FIX: Correctly translate and read tools for Claude and Chat GPT. (#393) I tested against the live models for the AI bot migration. It ensures Open AI's tool syntax is correct and we can correctly read the replies. : --- lib/completions/dialects/chat_gpt.rb | 49 ++++++++++++----- lib/completions/dialects/claude.rb | 1 + lib/completions/dialects/gemini.rb | 2 +- lib/completions/dialects/llama2_classic.rb | 1 + lib/completions/dialects/mixtral.rb | 1 + lib/completions/dialects/orca_style.rb | 1 + lib/completions/endpoints/aws_bedrock.rb | 2 +- lib/completions/endpoints/base.rb | 52 ++++++++++-------- lib/completions/endpoints/gemini.rb | 23 ++------ lib/completions/endpoints/open_ai.rb | 54 +++++++++---------- .../lib/completions/dialects/chat_gpt_spec.rb | 20 ++++++- .../completions/endpoints/anthropic_spec.rb | 2 + .../completions/endpoints/aws_bedrock_spec.rb | 2 + .../endpoints/endpoint_examples.rb | 2 +- spec/lib/completions/endpoints/gemini_spec.rb | 2 + .../endpoints/hugging_face_spec.rb | 2 + .../lib/completions/endpoints/open_ai_spec.rb | 36 +++++++------ spec/lib/completions/endpoints/vllm_spec.rb | 2 + 18 files changed, 154 insertions(+), 100 deletions(-) diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 777f4184..0c9676a8 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -33,7 +33,7 @@ module DiscourseAi end end - open_ai_prompt.concat!(conversation_context) if prompt[:conversation_context] + open_ai_prompt.concat(conversation_context) if prompt[:conversation_context] open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input] @@ -43,7 +43,25 @@ module DiscourseAi def tools return if prompt[:tools].blank? - prompt[:tools].map { |t| { type: "function", tool: t } } + prompt[:tools].map do |t| + tool = t.dup + + if tool[:parameters] + tool[:parameters] = t[:parameters].reduce( + { type: "object", properties: {}, required: [] }, + ) do |memo, p| + name = p[:name] + memo[:required] << name if p[:required] + + memo[:properties][name] = p.except(:name, :required, :item_type) + + memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type] + memo + end + end + + { type: "function", function: tool } + end end def conversation_context @@ -52,18 +70,25 @@ module DiscourseAi trimmed_context = trim_context(prompt[:conversation_context]) trimmed_context.reverse.map do |context| - translated = context.slice(:content) - translated[:role] = context[:type] + if context[:type] == "tool_call" + { + role: "assistant", + tool_calls: [{ type: "function", function: context[:content], id: context[:name] }], + } + else + translated = context.slice(:content) + translated[:role] = context[:type] - if context[:name] - if translated[:role] == "tool" - translated[:tool_call_id] = context[:name] - else - translated[:name] = context[:name] + if context[:name] + if translated[:role] == "tool" + translated[:tool_call_id] = context[:name] + else + translated[:name] = context[:name] + end end - end - translated + translated + end end end @@ -94,7 +119,7 @@ module DiscourseAi def model_max_tokens case model_name - when "gpt-3.5-turbo", "gpt-3.5-turbo-16k" + when "gpt-3.5-turbo-16k" 16_384 when "gpt-4" 8192 diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 28fff8ee..73f9d231 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -44,6 +44,7 @@ module DiscourseAi trimmed_context .reverse .reduce(+"") do |memo, context| + next(memo) if context[:type] == "tool_call" memo << (context[:type] == "user" ? "Human:" : "Assistant:") if context[:type] == "tool" diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index b6b938bf..72f3a8ff 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -32,7 +32,7 @@ module DiscourseAi end end - gemini_prompt.concat!(conversation_context) if prompt[:conversation_context] + gemini_prompt.concat(conversation_context) if prompt[:conversation_context] gemini_prompt << { role: "user", parts: { text: prompt[:input] } } end diff --git a/lib/completions/dialects/llama2_classic.rb b/lib/completions/dialects/llama2_classic.rb index 0470a61c..26e541a3 100644 --- a/lib/completions/dialects/llama2_classic.rb +++ b/lib/completions/dialects/llama2_classic.rb @@ -44,6 +44,7 @@ module DiscourseAi trimmed_context .reverse .reduce(+"") do |memo, context| + next(memo) if context[:type] == "tool_call" if context[:type] == "tool" memo << <<~TEXT [INST] diff --git a/lib/completions/dialects/mixtral.rb b/lib/completions/dialects/mixtral.rb index 6fb93d04..75e0f954 100644 --- a/lib/completions/dialects/mixtral.rb +++ b/lib/completions/dialects/mixtral.rb @@ -44,6 +44,7 @@ module DiscourseAi trimmed_context .reverse .reduce(+"") do |memo, context| + next(memo) if context[:type] == "tool_call" memo << "[INST] " if context[:type] == "user" if context[:type] == "tool" diff --git a/lib/completions/dialects/orca_style.rb b/lib/completions/dialects/orca_style.rb index fd76f3b5..b89dca01 100644 --- a/lib/completions/dialects/orca_style.rb +++ b/lib/completions/dialects/orca_style.rb @@ -41,6 +41,7 @@ module DiscourseAi trimmed_context .reverse .reduce(+"") do |memo, context| + next(memo) if context[:type] == "tool_call" memo << (context[:type] == "user" ? "### User:" : "### Assistant:") if context[:type] == "tool" diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index 902a375c..98f29634 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -14,7 +14,7 @@ module DiscourseAi end def default_options - { max_tokens_to_sample: 2_000 } + { max_tokens_to_sample: 2_000, stop_sequences: ["\n\nHuman:", ""] } end def provider_id diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index a3cb6d5b..da433c18 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -74,7 +74,7 @@ module DiscourseAi response_data = extract_completion_from(response_raw) partials_raw = response_data.to_s - if has_tool?("", response_data) + if has_tool?(response_data) function_buffer = build_buffer # Nokogiri document function_buffer = add_to_buffer(function_buffer, "", response_data) @@ -125,26 +125,19 @@ module DiscourseAi begin partial = extract_completion_from(raw_partial) + next if response_data.empty? && partial.blank? next if partial.nil? - if has_tool?(response_data, partial) - function_buffer = add_to_buffer(function_buffer, response_data, partial) - - if buffering_finished?(dialect.tools, function_buffer) - invocation = +function_buffer.at("function_calls").to_s - invocation << "\n" - - partials_raw << partial.to_s - response_data << invocation - - yield invocation, cancel - end + # Skip yield for tools. We'll buffer and yield later. + if has_tool?(partials_raw) + function_buffer = add_to_buffer(function_buffer, partials_raw, partial) else - partials_raw << partial response_data << partial yield partial, cancel if partial end + + partials_raw << partial.to_s rescue JSON::ParserError leftover = redo_chunk json_error = true @@ -162,6 +155,17 @@ module DiscourseAi raise if !cancelled end + # Once we have the full response, try to return the tool as a XML doc. + if has_tool?(partials_raw) + if function_buffer.at("tool_name").text.present? + invocation = +function_buffer.at("function_calls").to_s + invocation << "\n" + + response_data << invocation + yield invocation, cancel + end + end + return response_data ensure if log @@ -236,12 +240,22 @@ module DiscourseAi TEXT end - def has_tool?(response, partial) - (response + partial).include?("") + def has_tool?(response) + response.include?("").first + "\n" if raw_data.split( + "", + ).length > 1 + + return function_buffer unless raw_data.include?("") + + read_function = Nokogiri::HTML5.fragment(raw_data) if tool_name = read_function.at("tool_name").text function_buffer.at("tool_name").inner_html = tool_name @@ -264,10 +278,6 @@ module DiscourseAi function_buffer end - - def buffering_finished?(_available_functions, buffer) - buffer.to_s.include?("") - end end end end diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index 8fa8a2d3..231309b2 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -43,8 +43,8 @@ module DiscourseAi response_h = parsed.dig(:candidates, 0, :content, :parts, 0) - has_function_call = response_h.dig(:functionCall).present? - has_function_call ? response_h[:functionCall] : response_h.dig(:text) + @has_function_call ||= response_h.dig(:functionCall).present? + @has_function_call ? response_h[:functionCall] : response_h.dig(:text) end def partials_from(decoded_chunk) @@ -68,8 +68,8 @@ module DiscourseAi prompt.to_s end - def has_tool?(_response_data, partial) - partial.is_a?(Hash) && partial.has_key?(:name) # Has function name + def has_tool?(_response_data) + @has_function_call end def add_to_buffer(function_buffer, _response_data, partial) @@ -91,21 +91,6 @@ module DiscourseAi function_buffer end - - def buffering_finished?(available_functions, buffer) - tool_name = buffer.at("tool_name")&.text - return false if tool_name.blank? - - signature = - available_functions.dig(0, :function_declarations).find { |f| f[:name] == tool_name } - - signature[:parameters].reduce(true) do |memo, param| - param_present = buffer.at(param[:name]).present? - next(memo) if param_present || !signature[:required].include?(param[:name]) - - memo && param_present - end - end end end end diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index bb51090d..2a1d29cb 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -83,8 +83,9 @@ module DiscourseAi response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) - has_function_call = response_h.dig(:tool_calls).present? - has_function_call ? response_h.dig(:tool_calls, 0, :function) : response_h.dig(:content) + @has_function_call ||= response_h.dig(:tool_calls).present? + + @has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content) end def partials_from(decoded_chunk) @@ -101,41 +102,38 @@ module DiscourseAi prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") end - def has_tool?(_response_data, partial) - partial.is_a?(Hash) && partial.has_key?(:name) # Has function name + def has_tool?(_response_data) + @has_function_call end def add_to_buffer(function_buffer, _response_data, partial) - function_buffer.at("tool_name").content = partial[:name] if partial[:name].present? - function_buffer.at("tool_id").content = partial[:id] if partial[:id].present? + @args_buffer ||= +"" - if partial[:arguments] - argument_fragments = - partial[:arguments].reduce(+"") do |memo, (arg_name, value)| - memo << "\n<#{arg_name}>#{value}" - end - argument_fragments << "\n" + f_name = partial.dig(:function, :name) + function_buffer.at("tool_name").content = f_name if f_name + function_buffer.at("tool_id").content = partial[:id] if partial[:id] - function_buffer.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) + if partial.dig(:function, :arguments).present? + @args_buffer << partial.dig(:function, :arguments) + + begin + json_args = JSON.parse(@args_buffer, symbolize_names: true) + + argument_fragments = + json_args.reduce(+"") do |memo, (arg_name, value)| + memo << "\n<#{arg_name}>#{value}" + end + argument_fragments << "\n" + + function_buffer.at("parameters").children = + Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) + rescue JSON::ParserError + return function_buffer + end end function_buffer end - - def buffering_finished?(available_functions, buffer) - tool_name = buffer.at("tool_name")&.text - return false if tool_name.blank? - - signature = available_functions.find { |f| f.dig(:tool, :name) == tool_name }[:tool] - - signature[:parameters].reduce(true) do |memo, param| - param_present = buffer.at(param[:name]).present? - next(memo) if param_present && !param[:required] - - memo && param_present - end - end end end end diff --git a/spec/lib/completions/dialects/chat_gpt_spec.rb b/spec/lib/completions/dialects/chat_gpt_spec.rb index 2792dbf6..84e348bc 100644 --- a/spec/lib/completions/dialects/chat_gpt_spec.rb +++ b/spec/lib/completions/dialects/chat_gpt_spec.rb @@ -115,7 +115,25 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do describe "#tools" do it "returns a list of available tools" do - open_ai_tool_f = { type: "function", tool: tool } + open_ai_tool_f = { + function: { + description: tool[:description], + name: tool[:name], + parameters: { + properties: + tool[:parameters].reduce({}) do |memo, p| + memo[p[:name]] = { description: p[:description], type: p[:type] } + + memo[p[:name]][:enum] = p[:enum] if p[:enum] + + memo + end, + required: %w[location unit], + type: "object", + }, + }, + type: "function", + } expect(subject.tools).to contain_exactly(open_ai_tool_f) end diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index 0a57ad29..4c39c2e7 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -13,6 +13,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do let(:request_body) { model.default_options.merge(prompt: prompt).to_json } let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json } + let(:tool_id) { "get_weather" } + def response(content) { completion: content, diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index 2c866898..65999393 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -16,6 +16,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do let(:request_body) { model.default_options.merge(prompt: prompt).to_json } let(:stream_request_body) { request_body } + let(:tool_id) { "get_weather" } + before do SiteSetting.ai_bedrock_access_key_id = "123456" SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd" diff --git a/spec/lib/completions/endpoints/endpoint_examples.rb b/spec/lib/completions/endpoints/endpoint_examples.rb index 0de7ef73..0448fcbf 100644 --- a/spec/lib/completions/endpoints/endpoint_examples.rb +++ b/spec/lib/completions/endpoints/endpoint_examples.rb @@ -58,7 +58,7 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic get_weather - get_weather + #{tool_id || "get_weather"} Sydney c diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index df855e73..195351e3 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -10,6 +10,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) } let(:prompt) { dialect.translate } + let(:tool_id) { "get_weather" } + let(:tool_payload) do { name: "get_weather", diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb index 087ca1fc..de69f8ed 100644 --- a/spec/lib/completions/endpoints/hugging_face_spec.rb +++ b/spec/lib/completions/endpoints/hugging_face_spec.rb @@ -12,6 +12,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do end let(:prompt) { dialect.translate } + let(:tool_id) { "get_weather" } + let(:request_body) do model .default_options diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index 2b72164f..7caf1a1f 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -10,45 +10,49 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) } let(:prompt) { dialect.translate } + let(:tool_id) { "eujbuebfe" } + let(:tool_deltas) do [ - { id: "get_weather", name: "get_weather", arguments: {} }, - { id: "get_weather", name: "get_weather", arguments: { location: "" } }, - { id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } }, + { id: tool_id, function: {} }, + { id: tool_id, function: { name: "get_weather", arguments: "" } }, + { id: tool_id, function: { name: "get_weather", arguments: "" } }, + { id: tool_id, function: { name: "get_weather", arguments: "{" } }, + { id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } }, + { id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } }, ] end let(:tool_call) do - { id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } } + { + id: tool_id, + function: { + name: "get_weather", + arguments: { location: "Sydney", unit: "c" }.to_json, + }, + } end let(:request_body) do model .default_options .merge(messages: prompt) - .tap do |b| - b[:tools] = generic_prompt[:tools].map do |t| - { type: "function", tool: t } - end if generic_prompt[:tools] - end + .tap { |b| b[:tools] = dialect.tools if generic_prompt[:tools] } .to_json end + let(:stream_request_body) do model .default_options .merge(messages: prompt, stream: true) - .tap do |b| - b[:tools] = generic_prompt[:tools].map do |t| - { type: "function", tool: t } - end if generic_prompt[:tools] - end + .tap { |b| b[:tools] = dialect.tools if generic_prompt[:tools] } .to_json end def response(content, tool_call: false) message_content = if tool_call - { tool_calls: [{ function: content }] } + { tool_calls: [content] } else { content: content } end @@ -79,7 +83,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do def stream_line(delta, finish_reason: nil, tool_call: false) message_content = if tool_call - { tool_calls: [{ function: delta }] } + { tool_calls: [delta] } else { content: delta } end diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb index 99bb0151..54d9955d 100644 --- a/spec/lib/completions/endpoints/vllm_spec.rb +++ b/spec/lib/completions/endpoints/vllm_spec.rb @@ -15,6 +15,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do before { SiteSetting.ai_vllm_endpoint = "https://test.dev" } + let(:tool_id) { "get_weather" } + def response(content) { id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",