From 564d2de534d4b1960b0fa04621bd61b1c4d5cd0b Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 4 Jun 2024 08:59:15 +1000 Subject: [PATCH] FEATURE: Add native Cohere tool support (#655) Add native Cohere tool support - Introduce CohereTools class for tool translation and result processing - Update Command dialect to integrate with CohereTools - Modify Cohere endpoint to support passing tools and processing tool calls - Add spec for testing tool triggering with Cohere endpoint --- lib/ai_bot/bot.rb | 9 ++ lib/completions/dialects/cohere_tools.rb | 93 +++++++++++++++++++ lib/completions/dialects/command.rb | 54 +++++++++-- lib/completions/endpoints/cohere.rb | 51 +++++++++- spec/lib/completions/endpoints/cohere_spec.rb | 91 +++++++++++++++--- 5 files changed, 271 insertions(+), 27 deletions(-) create mode 100644 lib/completions/dialects/cohere_tools.rb diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 90526b57..defc072f 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -63,6 +63,8 @@ module DiscourseAi llm_kwargs[:temperature] = persona.temperature if persona.temperature llm_kwargs[:top_p] = persona.top_p if persona.top_p + needs_newlines = false + while total_completions <= MAX_COMPLETIONS && ongoing_chain tool_found = false @@ -72,11 +74,18 @@ module DiscourseAi if (tools.present?) tool_found = true + # a bit hacky, but extra newlines do no harm + if needs_newlines + update_blk.call("\n\n", cancel, nil) + needs_newlines = false + end + tools[0..MAX_TOOLS].each do |tool| process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) ongoing_chain &&= tool.chain_next_response? end else + needs_newlines = true update_blk.call(partial, cancel, nil) end end diff --git a/lib/completions/dialects/cohere_tools.rb b/lib/completions/dialects/cohere_tools.rb new file mode 100644 index 00000000..7a49c19c --- /dev/null +++ b/lib/completions/dialects/cohere_tools.rb @@ -0,0 +1,93 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Dialects + class CohereTools + def initialize(tools) + @raw_tools = tools + end + + def tool_results(messages) + pairs = [] + + current_pair = nil + messages.each do |msg| + if current_pair == nil && msg[:type] == :tool_call + current_pair = [msg] + elsif current_pair && msg[:type] == :tool + current_pair << msg + pairs << current_pair + current_pair = nil + else + current_pair = nil + end + end + + pairs.map do |call, result| + params = JSON.parse(call[:content])["arguments"] + { + call: { + name: call[:name] == "search" ? "search_local" : call[:name], + parameters: params, + generation_id: call[:id], + }, + outputs: [JSON.parse(result[:content])], + } + end + end + + def translated_tools + raw_tools.map do |t| + tool = t.dup + + tool[:parameter_definitions] = t[:parameters] + .to_a + .reduce({}) do |memo, p| + name = p[:name] + memo[name] = { + description: p[:description], + type: cohere_type(p[:type], p[:item_type]), + required: p[:required], + } + + memo[name][:default] = p[:default] if p[:default] + memo + end + + { + name: tool[:name] == "search" ? "search_local" : tool[:name], + description: tool[:description], + parameter_definitions: tool[:parameter_definitions], + } + end + end + + def instructions + "" # Noop. Tools are listed separate. + end + + private + + attr_reader :raw_tools + + def cohere_type(type, item_type) + case type + when "string" + "str" + when "number" + item_type == "integer" ? "int" : "float" + when "boolean" + "bool" + when "object" + item_type ? "Dict[#{item_type}]" : "Dict" + when "array" + item_type ? "List[#{item_type}]" : "List" + else + type + end + end + end + end + end +end diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb index 528c8285..b2455140 100644 --- a/lib/completions/dialects/command.rb +++ b/lib/completions/dialects/command.rb @@ -24,13 +24,43 @@ module DiscourseAi system_message = messages.shift[:message] if messages.first[:role] == "SYSTEM" prompt = { preamble: +"#{system_message}" } - prompt[:chat_history] = messages if messages.present? - messages.reverse_each do |msg| - if msg[:role] == "USER" - prompt[:message] = msg[:message] - messages.delete(msg) - break + if messages.present? + with_mapped_tools = [] + + current_pair = nil + messages.each do |msg| + if current_pair == nil && msg[:type] == :tool_call + current_pair = [msg] + elsif current_pair && msg[:type] == :tool + current_pair << msg + tool_results = tools_dialect.tool_results(current_pair) + with_mapped_tools << { role: "TOOL", message: "", tool_results: tool_results } + current_pair = nil + else + with_mapped_tools << msg + current_pair = nil + end + end + + messages = with_mapped_tools + prompt[:chat_history] = messages + end + + tools = tools_dialect.translated_tools + prompt[:tools] = tools if tools.present? + + tool_results = + messages.last && messages.last[:role] == "TOOL" && messages.last[:tool_results] + prompt[:tool_results] = tool_results if tool_results.present? + + if tool_results.blank? + messages.reverse_each do |msg| + if msg[:role] == "USER" + prompt[:message] = msg[:message] + messages.delete(msg) + break + end end end @@ -54,8 +84,16 @@ module DiscourseAi end end + def native_tool_support? + true + end + private + def tools_dialect + @tools_dialect ||= DiscourseAi::Completions::Dialects::CohereTools.new(prompt.tools) + end + def per_message_overhead 0 end @@ -83,11 +121,11 @@ module DiscourseAi end def tool_call_msg(msg) - { role: "CHATBOT", message: tools_dialect.from_raw_tool_call(msg) } + msg end def tool_msg(msg) - { role: "USER", message: tools_dialect.from_raw_tool(msg) } + msg end def user_msg(msg) diff --git a/lib/completions/endpoints/cohere.rb b/lib/completions/endpoints/cohere.rb index 7e4468bb..e12a9d12 100644 --- a/lib/completions/endpoints/cohere.rb +++ b/lib/completions/endpoints/cohere.rb @@ -29,10 +29,7 @@ module DiscourseAi end def default_options(dialect) - options = { model: "command-r-plus" } - - options[:stop_sequences] = [""] if dialect.prompt.has_tools? - options + { model: "command-r-plus" } end def provider_id @@ -49,7 +46,11 @@ module DiscourseAi def prepare_payload(prompt, model_params, dialect) payload = default_options(dialect).merge(model_params).merge(prompt) - + if prompt[:tools].present? + payload[:tools] = prompt[:tools] + payload[:force_single_step] = false + end + payload[:tool_results] = prompt[:tool_results] if prompt[:tool_results].present? payload[:stream] = true if @streaming_mode payload @@ -70,6 +71,14 @@ module DiscourseAi if @streaming_mode if parsed[:event_type] == "text-generation" parsed[:text] + elsif parsed[:event_type] == "tool-calls-generation" + # could just be random thinking... + if parsed.dig(:tool_calls).present? + @has_tool = true + parsed.dig(:tool_calls).to_json + else + "" + end else if parsed[:event_type] == "stream-end" @input_tokens = parsed.dig(:response, :meta, :billed_units, :input_tokens) @@ -84,6 +93,38 @@ module DiscourseAi end end + def has_tool?(_ignored) + @has_tool + end + + def native_tool_support? + true + end + + def add_to_function_buffer(function_buffer, partial: nil, payload: nil) + if partial + tools = JSON.parse(partial) + tools.each do |tool| + name = tool["name"] + parameters = tool["parameters"] + xml_params = parameters.map { |k, v| "<#{k}>#{v}\n" }.join + + current_function = function_buffer.at("invoke") + if current_function.nil? || current_function.at("tool_name").content.present? + current_function = + function_buffer.at("function_calls").add_child( + Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"), + ) + end + + current_function.at("tool_name").content = name == "search_local" ? "search" : name + current_function.at("parameters").children = + Nokogiri::HTML5::DocumentFragment.parse(xml_params) + end + end + function_buffer + end + def final_log_update(log) log.request_tokens = @input_tokens if @input_tokens log.response_tokens = @output_tokens if @output_tokens diff --git a/spec/lib/completions/endpoints/cohere_spec.rb b/spec/lib/completions/endpoints/cohere_spec.rb index a6fa8d23..57b6a8c6 100644 --- a/spec/lib/completions/endpoints/cohere_spec.rb +++ b/spec/lib/completions/endpoints/cohere_spec.rb @@ -59,6 +59,83 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do before { SiteSetting.ai_cohere_api_key = "ABC" } + it "is able to trigger a tool" do + body = (<<~TEXT).strip + {"is_finished":false,"event_type":"stream-start","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b"} +{"is_finished":false,"event_type":"tool-calls-generation","text":"I will search for 'who is sam saffron' and relay the information to the user.","tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]} +{"is_finished":true,"event_type":"stream-end","response":{"response_id":"71d8c9e1-1138-4d70-80d1-10ddec41c989","text":"I will search for 'who is sam saffron' and relay the information to the user.","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b","chat_history":[{"role":"USER","message":"sam: who is sam saffron?"},{"role":"CHATBOT","message":"I will search for 'who is sam saffron' and relay the information to the user.","tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]}],"finish_reason":"COMPLETE","meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":460,"output_tokens":27},"tokens":{"input_tokens":1227,"output_tokens":27}},"tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]},"finish_reason":"COMPLETE"} + TEXT + + parsed_body = nil + result = +"" + + sig = { + name: "google", + description: "Will search using Google", + parameters: [ + { name: "query", description: "The search query", type: "string", required: true }, + ], + } + + prompt.tools = [sig] + + EndpointMock.with_chunk_array_support do + stub_request(:post, "https://api.cohere.ai/v1/chat").with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + headers: { + "Content-Type" => "application/json", + "Authorization" => "Bearer ABC", + }, + ).to_return(status: 200, body: body.split("|")) + + result = llm.generate(prompt, user: user) { |partial, cancel| result << partial } + end + + expected = <<~TEXT + + + google + who is sam saffron + + tool_0 + + + TEXT + + expect(result.strip).to eq(expected.strip) + + expected = { + model: "command-r-plus", + preamble: "You are hello bot", + chat_history: [ + { role: "USER", message: "user1: hello" }, + { role: "CHATBOT", message: "hi user" }, + ], + message: "user1: thanks", + tools: [ + { + name: "google", + description: "Will search using Google", + parameter_definitions: { + query: { + description: "The search query", + type: "str", + required: true, + }, + }, + }, + ], + force_single_step: false, + stream: true, + } + + expect(parsed_body).to eq(expected) + end + it "is able to run tools" do body = { response_id: "0a90275b-273d-4690-abce-8018edcec7d0", @@ -99,20 +176,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do result = llm.generate(prompt_with_tool_results, user: user) expect(parsed_body[:preamble]).to include("You are weather bot") - expect(parsed_body[:preamble]).to include("") - - expected_message = <<~MESSAGE - - - weather - - {"weather":"22c"} - - - - MESSAGE - - expect(parsed_body[:message].strip).to eq(expected_message.strip) expect(result).to eq("Sydney is 22c") audit = AiApiAuditLog.order("id desc").first