FEATURE: Add native Cohere tool support (#655)
Add native Cohere tool support - Introduce CohereTools class for tool translation and result processing - Update Command dialect to integrate with CohereTools - Modify Cohere endpoint to support passing tools and processing tool calls - Add spec for testing tool triggering with Cohere endpoint
This commit is contained in:
parent
97afda278b
commit
564d2de534
|
@ -63,6 +63,8 @@ module DiscourseAi
|
|||
llm_kwargs[:temperature] = persona.temperature if persona.temperature
|
||||
llm_kwargs[:top_p] = persona.top_p if persona.top_p
|
||||
|
||||
needs_newlines = false
|
||||
|
||||
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
||||
tool_found = false
|
||||
|
||||
|
@ -72,11 +74,18 @@ module DiscourseAi
|
|||
|
||||
if (tools.present?)
|
||||
tool_found = true
|
||||
# a bit hacky, but extra newlines do no harm
|
||||
if needs_newlines
|
||||
update_blk.call("\n\n", cancel, nil)
|
||||
needs_newlines = false
|
||||
end
|
||||
|
||||
tools[0..MAX_TOOLS].each do |tool|
|
||||
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
|
||||
ongoing_chain &&= tool.chain_next_response?
|
||||
end
|
||||
else
|
||||
needs_newlines = true
|
||||
update_blk.call(partial, cancel, nil)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,93 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class CohereTools
|
||||
def initialize(tools)
|
||||
@raw_tools = tools
|
||||
end
|
||||
|
||||
def tool_results(messages)
|
||||
pairs = []
|
||||
|
||||
current_pair = nil
|
||||
messages.each do |msg|
|
||||
if current_pair == nil && msg[:type] == :tool_call
|
||||
current_pair = [msg]
|
||||
elsif current_pair && msg[:type] == :tool
|
||||
current_pair << msg
|
||||
pairs << current_pair
|
||||
current_pair = nil
|
||||
else
|
||||
current_pair = nil
|
||||
end
|
||||
end
|
||||
|
||||
pairs.map do |call, result|
|
||||
params = JSON.parse(call[:content])["arguments"]
|
||||
{
|
||||
call: {
|
||||
name: call[:name] == "search" ? "search_local" : call[:name],
|
||||
parameters: params,
|
||||
generation_id: call[:id],
|
||||
},
|
||||
outputs: [JSON.parse(result[:content])],
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
def translated_tools
|
||||
raw_tools.map do |t|
|
||||
tool = t.dup
|
||||
|
||||
tool[:parameter_definitions] = t[:parameters]
|
||||
.to_a
|
||||
.reduce({}) do |memo, p|
|
||||
name = p[:name]
|
||||
memo[name] = {
|
||||
description: p[:description],
|
||||
type: cohere_type(p[:type], p[:item_type]),
|
||||
required: p[:required],
|
||||
}
|
||||
|
||||
memo[name][:default] = p[:default] if p[:default]
|
||||
memo
|
||||
end
|
||||
|
||||
{
|
||||
name: tool[:name] == "search" ? "search_local" : tool[:name],
|
||||
description: tool[:description],
|
||||
parameter_definitions: tool[:parameter_definitions],
|
||||
}
|
||||
end
|
||||
end
|
||||
|
||||
def instructions
|
||||
"" # Noop. Tools are listed separate.
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
attr_reader :raw_tools
|
||||
|
||||
def cohere_type(type, item_type)
|
||||
case type
|
||||
when "string"
|
||||
"str"
|
||||
when "number"
|
||||
item_type == "integer" ? "int" : "float"
|
||||
when "boolean"
|
||||
"bool"
|
||||
when "object"
|
||||
item_type ? "Dict[#{item_type}]" : "Dict"
|
||||
when "array"
|
||||
item_type ? "List[#{item_type}]" : "List"
|
||||
else
|
||||
type
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -24,13 +24,43 @@ module DiscourseAi
|
|||
system_message = messages.shift[:message] if messages.first[:role] == "SYSTEM"
|
||||
|
||||
prompt = { preamble: +"#{system_message}" }
|
||||
prompt[:chat_history] = messages if messages.present?
|
||||
|
||||
messages.reverse_each do |msg|
|
||||
if msg[:role] == "USER"
|
||||
prompt[:message] = msg[:message]
|
||||
messages.delete(msg)
|
||||
break
|
||||
if messages.present?
|
||||
with_mapped_tools = []
|
||||
|
||||
current_pair = nil
|
||||
messages.each do |msg|
|
||||
if current_pair == nil && msg[:type] == :tool_call
|
||||
current_pair = [msg]
|
||||
elsif current_pair && msg[:type] == :tool
|
||||
current_pair << msg
|
||||
tool_results = tools_dialect.tool_results(current_pair)
|
||||
with_mapped_tools << { role: "TOOL", message: "", tool_results: tool_results }
|
||||
current_pair = nil
|
||||
else
|
||||
with_mapped_tools << msg
|
||||
current_pair = nil
|
||||
end
|
||||
end
|
||||
|
||||
messages = with_mapped_tools
|
||||
prompt[:chat_history] = messages
|
||||
end
|
||||
|
||||
tools = tools_dialect.translated_tools
|
||||
prompt[:tools] = tools if tools.present?
|
||||
|
||||
tool_results =
|
||||
messages.last && messages.last[:role] == "TOOL" && messages.last[:tool_results]
|
||||
prompt[:tool_results] = tool_results if tool_results.present?
|
||||
|
||||
if tool_results.blank?
|
||||
messages.reverse_each do |msg|
|
||||
if msg[:role] == "USER"
|
||||
prompt[:message] = msg[:message]
|
||||
messages.delete(msg)
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -54,8 +84,16 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
true
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def tools_dialect
|
||||
@tools_dialect ||= DiscourseAi::Completions::Dialects::CohereTools.new(prompt.tools)
|
||||
end
|
||||
|
||||
def per_message_overhead
|
||||
0
|
||||
end
|
||||
|
@ -83,11 +121,11 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def tool_call_msg(msg)
|
||||
{ role: "CHATBOT", message: tools_dialect.from_raw_tool_call(msg) }
|
||||
msg
|
||||
end
|
||||
|
||||
def tool_msg(msg)
|
||||
{ role: "USER", message: tools_dialect.from_raw_tool(msg) }
|
||||
msg
|
||||
end
|
||||
|
||||
def user_msg(msg)
|
||||
|
|
|
@ -29,10 +29,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def default_options(dialect)
|
||||
options = { model: "command-r-plus" }
|
||||
|
||||
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
|
||||
options
|
||||
{ model: "command-r-plus" }
|
||||
end
|
||||
|
||||
def provider_id
|
||||
|
@ -49,7 +46,11 @@ module DiscourseAi
|
|||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
payload = default_options(dialect).merge(model_params).merge(prompt)
|
||||
|
||||
if prompt[:tools].present?
|
||||
payload[:tools] = prompt[:tools]
|
||||
payload[:force_single_step] = false
|
||||
end
|
||||
payload[:tool_results] = prompt[:tool_results] if prompt[:tool_results].present?
|
||||
payload[:stream] = true if @streaming_mode
|
||||
|
||||
payload
|
||||
|
@ -70,6 +71,14 @@ module DiscourseAi
|
|||
if @streaming_mode
|
||||
if parsed[:event_type] == "text-generation"
|
||||
parsed[:text]
|
||||
elsif parsed[:event_type] == "tool-calls-generation"
|
||||
# could just be random thinking...
|
||||
if parsed.dig(:tool_calls).present?
|
||||
@has_tool = true
|
||||
parsed.dig(:tool_calls).to_json
|
||||
else
|
||||
""
|
||||
end
|
||||
else
|
||||
if parsed[:event_type] == "stream-end"
|
||||
@input_tokens = parsed.dig(:response, :meta, :billed_units, :input_tokens)
|
||||
|
@ -84,6 +93,38 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def has_tool?(_ignored)
|
||||
@has_tool
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
true
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
if partial
|
||||
tools = JSON.parse(partial)
|
||||
tools.each do |tool|
|
||||
name = tool["name"]
|
||||
parameters = tool["parameters"]
|
||||
xml_params = parameters.map { |k, v| "<#{k}>#{v}</#{k}>\n" }.join
|
||||
|
||||
current_function = function_buffer.at("invoke")
|
||||
if current_function.nil? || current_function.at("tool_name").content.present?
|
||||
current_function =
|
||||
function_buffer.at("function_calls").add_child(
|
||||
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
||||
)
|
||||
end
|
||||
|
||||
current_function.at("tool_name").content = name == "search_local" ? "search" : name
|
||||
current_function.at("parameters").children =
|
||||
Nokogiri::HTML5::DocumentFragment.parse(xml_params)
|
||||
end
|
||||
end
|
||||
function_buffer
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = @input_tokens if @input_tokens
|
||||
log.response_tokens = @output_tokens if @output_tokens
|
||||
|
|
|
@ -59,6 +59,83 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
|||
|
||||
before { SiteSetting.ai_cohere_api_key = "ABC" }
|
||||
|
||||
it "is able to trigger a tool" do
|
||||
body = (<<~TEXT).strip
|
||||
{"is_finished":false,"event_type":"stream-start","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b"}
|
||||
{"is_finished":false,"event_type":"tool-calls-generation","text":"I will search for 'who is sam saffron' and relay the information to the user.","tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]}
|
||||
{"is_finished":true,"event_type":"stream-end","response":{"response_id":"71d8c9e1-1138-4d70-80d1-10ddec41c989","text":"I will search for 'who is sam saffron' and relay the information to the user.","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b","chat_history":[{"role":"USER","message":"sam: who is sam saffron?"},{"role":"CHATBOT","message":"I will search for 'who is sam saffron' and relay the information to the user.","tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]}],"finish_reason":"COMPLETE","meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":460,"output_tokens":27},"tokens":{"input_tokens":1227,"output_tokens":27}},"tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]},"finish_reason":"COMPLETE"}
|
||||
TEXT
|
||||
|
||||
parsed_body = nil
|
||||
result = +""
|
||||
|
||||
sig = {
|
||||
name: "google",
|
||||
description: "Will search using Google",
|
||||
parameters: [
|
||||
{ name: "query", description: "The search query", type: "string", required: true },
|
||||
],
|
||||
}
|
||||
|
||||
prompt.tools = [sig]
|
||||
|
||||
EndpointMock.with_chunk_array_support do
|
||||
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
|
||||
body:
|
||||
proc do |req_body|
|
||||
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
||||
true
|
||||
end,
|
||||
headers: {
|
||||
"Content-Type" => "application/json",
|
||||
"Authorization" => "Bearer ABC",
|
||||
},
|
||||
).to_return(status: 200, body: body.split("|"))
|
||||
|
||||
result = llm.generate(prompt, user: user) { |partial, cancel| result << partial }
|
||||
end
|
||||
|
||||
expected = <<~TEXT
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>google</tool_name>
|
||||
<parameters><query>who is sam saffron</query>
|
||||
</parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
expect(result.strip).to eq(expected.strip)
|
||||
|
||||
expected = {
|
||||
model: "command-r-plus",
|
||||
preamble: "You are hello bot",
|
||||
chat_history: [
|
||||
{ role: "USER", message: "user1: hello" },
|
||||
{ role: "CHATBOT", message: "hi user" },
|
||||
],
|
||||
message: "user1: thanks",
|
||||
tools: [
|
||||
{
|
||||
name: "google",
|
||||
description: "Will search using Google",
|
||||
parameter_definitions: {
|
||||
query: {
|
||||
description: "The search query",
|
||||
type: "str",
|
||||
required: true,
|
||||
},
|
||||
},
|
||||
},
|
||||
],
|
||||
force_single_step: false,
|
||||
stream: true,
|
||||
}
|
||||
|
||||
expect(parsed_body).to eq(expected)
|
||||
end
|
||||
|
||||
it "is able to run tools" do
|
||||
body = {
|
||||
response_id: "0a90275b-273d-4690-abce-8018edcec7d0",
|
||||
|
@ -99,20 +176,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
|||
result = llm.generate(prompt_with_tool_results, user: user)
|
||||
|
||||
expect(parsed_body[:preamble]).to include("You are weather bot")
|
||||
expect(parsed_body[:preamble]).to include("<tools>")
|
||||
|
||||
expected_message = <<~MESSAGE
|
||||
<function_results>
|
||||
<result>
|
||||
<tool_name>weather</tool_name>
|
||||
<json>
|
||||
{"weather":"22c"}
|
||||
</json>
|
||||
</result>
|
||||
</function_results>
|
||||
MESSAGE
|
||||
|
||||
expect(parsed_body[:message].strip).to eq(expected_message.strip)
|
||||
|
||||
expect(result).to eq("Sydney is 22c")
|
||||
audit = AiApiAuditLog.order("id desc").first
|
||||
|
|
Loading…
Reference in New Issue