FEATURE: Add native Cohere tool support (#655)

Add native Cohere tool support

- Introduce CohereTools class for tool translation and result processing
- Update Command dialect to integrate with CohereTools
- Modify Cohere endpoint to support passing tools and processing tool calls
- Add spec for testing tool triggering with Cohere endpoint
This commit is contained in:
Sam 2024-06-04 08:59:15 +10:00 committed by GitHub
parent 97afda278b
commit 564d2de534
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 271 additions and 27 deletions

View File

@ -63,6 +63,8 @@ module DiscourseAi
llm_kwargs[:temperature] = persona.temperature if persona.temperature llm_kwargs[:temperature] = persona.temperature if persona.temperature
llm_kwargs[:top_p] = persona.top_p if persona.top_p llm_kwargs[:top_p] = persona.top_p if persona.top_p
needs_newlines = false
while total_completions <= MAX_COMPLETIONS && ongoing_chain while total_completions <= MAX_COMPLETIONS && ongoing_chain
tool_found = false tool_found = false
@ -72,11 +74,18 @@ module DiscourseAi
if (tools.present?) if (tools.present?)
tool_found = true tool_found = true
# a bit hacky, but extra newlines do no harm
if needs_newlines
update_blk.call("\n\n", cancel, nil)
needs_newlines = false
end
tools[0..MAX_TOOLS].each do |tool| tools[0..MAX_TOOLS].each do |tool|
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
ongoing_chain &&= tool.chain_next_response? ongoing_chain &&= tool.chain_next_response?
end end
else else
needs_newlines = true
update_blk.call(partial, cancel, nil) update_blk.call(partial, cancel, nil)
end end
end end

View File

@ -0,0 +1,93 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class CohereTools
def initialize(tools)
@raw_tools = tools
end
def tool_results(messages)
pairs = []
current_pair = nil
messages.each do |msg|
if current_pair == nil && msg[:type] == :tool_call
current_pair = [msg]
elsif current_pair && msg[:type] == :tool
current_pair << msg
pairs << current_pair
current_pair = nil
else
current_pair = nil
end
end
pairs.map do |call, result|
params = JSON.parse(call[:content])["arguments"]
{
call: {
name: call[:name] == "search" ? "search_local" : call[:name],
parameters: params,
generation_id: call[:id],
},
outputs: [JSON.parse(result[:content])],
}
end
end
def translated_tools
raw_tools.map do |t|
tool = t.dup
tool[:parameter_definitions] = t[:parameters]
.to_a
.reduce({}) do |memo, p|
name = p[:name]
memo[name] = {
description: p[:description],
type: cohere_type(p[:type], p[:item_type]),
required: p[:required],
}
memo[name][:default] = p[:default] if p[:default]
memo
end
{
name: tool[:name] == "search" ? "search_local" : tool[:name],
description: tool[:description],
parameter_definitions: tool[:parameter_definitions],
}
end
end
def instructions
"" # Noop. Tools are listed separate.
end
private
attr_reader :raw_tools
def cohere_type(type, item_type)
case type
when "string"
"str"
when "number"
item_type == "integer" ? "int" : "float"
when "boolean"
"bool"
when "object"
item_type ? "Dict[#{item_type}]" : "Dict"
when "array"
item_type ? "List[#{item_type}]" : "List"
else
type
end
end
end
end
end
end

View File

@ -24,13 +24,43 @@ module DiscourseAi
system_message = messages.shift[:message] if messages.first[:role] == "SYSTEM" system_message = messages.shift[:message] if messages.first[:role] == "SYSTEM"
prompt = { preamble: +"#{system_message}" } prompt = { preamble: +"#{system_message}" }
prompt[:chat_history] = messages if messages.present?
messages.reverse_each do |msg| if messages.present?
if msg[:role] == "USER" with_mapped_tools = []
prompt[:message] = msg[:message]
messages.delete(msg) current_pair = nil
break messages.each do |msg|
if current_pair == nil && msg[:type] == :tool_call
current_pair = [msg]
elsif current_pair && msg[:type] == :tool
current_pair << msg
tool_results = tools_dialect.tool_results(current_pair)
with_mapped_tools << { role: "TOOL", message: "", tool_results: tool_results }
current_pair = nil
else
with_mapped_tools << msg
current_pair = nil
end
end
messages = with_mapped_tools
prompt[:chat_history] = messages
end
tools = tools_dialect.translated_tools
prompt[:tools] = tools if tools.present?
tool_results =
messages.last && messages.last[:role] == "TOOL" && messages.last[:tool_results]
prompt[:tool_results] = tool_results if tool_results.present?
if tool_results.blank?
messages.reverse_each do |msg|
if msg[:role] == "USER"
prompt[:message] = msg[:message]
messages.delete(msg)
break
end
end end
end end
@ -54,8 +84,16 @@ module DiscourseAi
end end
end end
def native_tool_support?
true
end
private private
def tools_dialect
@tools_dialect ||= DiscourseAi::Completions::Dialects::CohereTools.new(prompt.tools)
end
def per_message_overhead def per_message_overhead
0 0
end end
@ -83,11 +121,11 @@ module DiscourseAi
end end
def tool_call_msg(msg) def tool_call_msg(msg)
{ role: "CHATBOT", message: tools_dialect.from_raw_tool_call(msg) } msg
end end
def tool_msg(msg) def tool_msg(msg)
{ role: "USER", message: tools_dialect.from_raw_tool(msg) } msg
end end
def user_msg(msg) def user_msg(msg)

View File

@ -29,10 +29,7 @@ module DiscourseAi
end end
def default_options(dialect) def default_options(dialect)
options = { model: "command-r-plus" } { model: "command-r-plus" }
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
options
end end
def provider_id def provider_id
@ -49,7 +46,11 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect) def prepare_payload(prompt, model_params, dialect)
payload = default_options(dialect).merge(model_params).merge(prompt) payload = default_options(dialect).merge(model_params).merge(prompt)
if prompt[:tools].present?
payload[:tools] = prompt[:tools]
payload[:force_single_step] = false
end
payload[:tool_results] = prompt[:tool_results] if prompt[:tool_results].present?
payload[:stream] = true if @streaming_mode payload[:stream] = true if @streaming_mode
payload payload
@ -70,6 +71,14 @@ module DiscourseAi
if @streaming_mode if @streaming_mode
if parsed[:event_type] == "text-generation" if parsed[:event_type] == "text-generation"
parsed[:text] parsed[:text]
elsif parsed[:event_type] == "tool-calls-generation"
# could just be random thinking...
if parsed.dig(:tool_calls).present?
@has_tool = true
parsed.dig(:tool_calls).to_json
else
""
end
else else
if parsed[:event_type] == "stream-end" if parsed[:event_type] == "stream-end"
@input_tokens = parsed.dig(:response, :meta, :billed_units, :input_tokens) @input_tokens = parsed.dig(:response, :meta, :billed_units, :input_tokens)
@ -84,6 +93,38 @@ module DiscourseAi
end end
end end
def has_tool?(_ignored)
@has_tool
end
def native_tool_support?
true
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
if partial
tools = JSON.parse(partial)
tools.each do |tool|
name = tool["name"]
parameters = tool["parameters"]
xml_params = parameters.map { |k, v| "<#{k}>#{v}</#{k}>\n" }.join
current_function = function_buffer.at("invoke")
if current_function.nil? || current_function.at("tool_name").content.present?
current_function =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
)
end
current_function.at("tool_name").content = name == "search_local" ? "search" : name
current_function.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(xml_params)
end
end
function_buffer
end
def final_log_update(log) def final_log_update(log)
log.request_tokens = @input_tokens if @input_tokens log.request_tokens = @input_tokens if @input_tokens
log.response_tokens = @output_tokens if @output_tokens log.response_tokens = @output_tokens if @output_tokens

View File

@ -59,6 +59,83 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
before { SiteSetting.ai_cohere_api_key = "ABC" } before { SiteSetting.ai_cohere_api_key = "ABC" }
it "is able to trigger a tool" do
body = (<<~TEXT).strip
{"is_finished":false,"event_type":"stream-start","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b"}
{"is_finished":false,"event_type":"tool-calls-generation","text":"I will search for 'who is sam saffron' and relay the information to the user.","tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]}
{"is_finished":true,"event_type":"stream-end","response":{"response_id":"71d8c9e1-1138-4d70-80d1-10ddec41c989","text":"I will search for 'who is sam saffron' and relay the information to the user.","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b","chat_history":[{"role":"USER","message":"sam: who is sam saffron?"},{"role":"CHATBOT","message":"I will search for 'who is sam saffron' and relay the information to the user.","tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]}],"finish_reason":"COMPLETE","meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":460,"output_tokens":27},"tokens":{"input_tokens":1227,"output_tokens":27}},"tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]},"finish_reason":"COMPLETE"}
TEXT
parsed_body = nil
result = +""
sig = {
name: "google",
description: "Will search using Google",
parameters: [
{ name: "query", description: "The search query", type: "string", required: true },
],
}
prompt.tools = [sig]
EndpointMock.with_chunk_array_support do
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
body:
proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true)
true
end,
headers: {
"Content-Type" => "application/json",
"Authorization" => "Bearer ABC",
},
).to_return(status: 200, body: body.split("|"))
result = llm.generate(prompt, user: user) { |partial, cancel| result << partial }
end
expected = <<~TEXT
<function_calls>
<invoke>
<tool_name>google</tool_name>
<parameters><query>who is sam saffron</query>
</parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
TEXT
expect(result.strip).to eq(expected.strip)
expected = {
model: "command-r-plus",
preamble: "You are hello bot",
chat_history: [
{ role: "USER", message: "user1: hello" },
{ role: "CHATBOT", message: "hi user" },
],
message: "user1: thanks",
tools: [
{
name: "google",
description: "Will search using Google",
parameter_definitions: {
query: {
description: "The search query",
type: "str",
required: true,
},
},
},
],
force_single_step: false,
stream: true,
}
expect(parsed_body).to eq(expected)
end
it "is able to run tools" do it "is able to run tools" do
body = { body = {
response_id: "0a90275b-273d-4690-abce-8018edcec7d0", response_id: "0a90275b-273d-4690-abce-8018edcec7d0",
@ -99,20 +176,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
result = llm.generate(prompt_with_tool_results, user: user) result = llm.generate(prompt_with_tool_results, user: user)
expect(parsed_body[:preamble]).to include("You are weather bot") expect(parsed_body[:preamble]).to include("You are weather bot")
expect(parsed_body[:preamble]).to include("<tools>")
expected_message = <<~MESSAGE
<function_results>
<result>
<tool_name>weather</tool_name>
<json>
{"weather":"22c"}
</json>
</result>
</function_results>
MESSAGE
expect(parsed_body[:message].strip).to eq(expected_message.strip)
expect(result).to eq("Sydney is 22c") expect(result).to eq("Sydney is 22c")
audit = AiApiAuditLog.order("id desc").first audit = AiApiAuditLog.order("id desc").first