FEATURE: Add native Cohere tool support (#655)
Add native Cohere tool support - Introduce CohereTools class for tool translation and result processing - Update Command dialect to integrate with CohereTools - Modify Cohere endpoint to support passing tools and processing tool calls - Add spec for testing tool triggering with Cohere endpoint
This commit is contained in:
parent
97afda278b
commit
564d2de534
|
@ -63,6 +63,8 @@ module DiscourseAi
|
||||||
llm_kwargs[:temperature] = persona.temperature if persona.temperature
|
llm_kwargs[:temperature] = persona.temperature if persona.temperature
|
||||||
llm_kwargs[:top_p] = persona.top_p if persona.top_p
|
llm_kwargs[:top_p] = persona.top_p if persona.top_p
|
||||||
|
|
||||||
|
needs_newlines = false
|
||||||
|
|
||||||
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
||||||
tool_found = false
|
tool_found = false
|
||||||
|
|
||||||
|
@ -72,11 +74,18 @@ module DiscourseAi
|
||||||
|
|
||||||
if (tools.present?)
|
if (tools.present?)
|
||||||
tool_found = true
|
tool_found = true
|
||||||
|
# a bit hacky, but extra newlines do no harm
|
||||||
|
if needs_newlines
|
||||||
|
update_blk.call("\n\n", cancel, nil)
|
||||||
|
needs_newlines = false
|
||||||
|
end
|
||||||
|
|
||||||
tools[0..MAX_TOOLS].each do |tool|
|
tools[0..MAX_TOOLS].each do |tool|
|
||||||
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
|
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
|
||||||
ongoing_chain &&= tool.chain_next_response?
|
ongoing_chain &&= tool.chain_next_response?
|
||||||
end
|
end
|
||||||
else
|
else
|
||||||
|
needs_newlines = true
|
||||||
update_blk.call(partial, cancel, nil)
|
update_blk.call(partial, cancel, nil)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,93 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Dialects
|
||||||
|
class CohereTools
|
||||||
|
def initialize(tools)
|
||||||
|
@raw_tools = tools
|
||||||
|
end
|
||||||
|
|
||||||
|
def tool_results(messages)
|
||||||
|
pairs = []
|
||||||
|
|
||||||
|
current_pair = nil
|
||||||
|
messages.each do |msg|
|
||||||
|
if current_pair == nil && msg[:type] == :tool_call
|
||||||
|
current_pair = [msg]
|
||||||
|
elsif current_pair && msg[:type] == :tool
|
||||||
|
current_pair << msg
|
||||||
|
pairs << current_pair
|
||||||
|
current_pair = nil
|
||||||
|
else
|
||||||
|
current_pair = nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
pairs.map do |call, result|
|
||||||
|
params = JSON.parse(call[:content])["arguments"]
|
||||||
|
{
|
||||||
|
call: {
|
||||||
|
name: call[:name] == "search" ? "search_local" : call[:name],
|
||||||
|
parameters: params,
|
||||||
|
generation_id: call[:id],
|
||||||
|
},
|
||||||
|
outputs: [JSON.parse(result[:content])],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def translated_tools
|
||||||
|
raw_tools.map do |t|
|
||||||
|
tool = t.dup
|
||||||
|
|
||||||
|
tool[:parameter_definitions] = t[:parameters]
|
||||||
|
.to_a
|
||||||
|
.reduce({}) do |memo, p|
|
||||||
|
name = p[:name]
|
||||||
|
memo[name] = {
|
||||||
|
description: p[:description],
|
||||||
|
type: cohere_type(p[:type], p[:item_type]),
|
||||||
|
required: p[:required],
|
||||||
|
}
|
||||||
|
|
||||||
|
memo[name][:default] = p[:default] if p[:default]
|
||||||
|
memo
|
||||||
|
end
|
||||||
|
|
||||||
|
{
|
||||||
|
name: tool[:name] == "search" ? "search_local" : tool[:name],
|
||||||
|
description: tool[:description],
|
||||||
|
parameter_definitions: tool[:parameter_definitions],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def instructions
|
||||||
|
"" # Noop. Tools are listed separate.
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
attr_reader :raw_tools
|
||||||
|
|
||||||
|
def cohere_type(type, item_type)
|
||||||
|
case type
|
||||||
|
when "string"
|
||||||
|
"str"
|
||||||
|
when "number"
|
||||||
|
item_type == "integer" ? "int" : "float"
|
||||||
|
when "boolean"
|
||||||
|
"bool"
|
||||||
|
when "object"
|
||||||
|
item_type ? "Dict[#{item_type}]" : "Dict"
|
||||||
|
when "array"
|
||||||
|
item_type ? "List[#{item_type}]" : "List"
|
||||||
|
else
|
||||||
|
type
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -24,13 +24,43 @@ module DiscourseAi
|
||||||
system_message = messages.shift[:message] if messages.first[:role] == "SYSTEM"
|
system_message = messages.shift[:message] if messages.first[:role] == "SYSTEM"
|
||||||
|
|
||||||
prompt = { preamble: +"#{system_message}" }
|
prompt = { preamble: +"#{system_message}" }
|
||||||
prompt[:chat_history] = messages if messages.present?
|
|
||||||
|
|
||||||
messages.reverse_each do |msg|
|
if messages.present?
|
||||||
if msg[:role] == "USER"
|
with_mapped_tools = []
|
||||||
prompt[:message] = msg[:message]
|
|
||||||
messages.delete(msg)
|
current_pair = nil
|
||||||
break
|
messages.each do |msg|
|
||||||
|
if current_pair == nil && msg[:type] == :tool_call
|
||||||
|
current_pair = [msg]
|
||||||
|
elsif current_pair && msg[:type] == :tool
|
||||||
|
current_pair << msg
|
||||||
|
tool_results = tools_dialect.tool_results(current_pair)
|
||||||
|
with_mapped_tools << { role: "TOOL", message: "", tool_results: tool_results }
|
||||||
|
current_pair = nil
|
||||||
|
else
|
||||||
|
with_mapped_tools << msg
|
||||||
|
current_pair = nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
messages = with_mapped_tools
|
||||||
|
prompt[:chat_history] = messages
|
||||||
|
end
|
||||||
|
|
||||||
|
tools = tools_dialect.translated_tools
|
||||||
|
prompt[:tools] = tools if tools.present?
|
||||||
|
|
||||||
|
tool_results =
|
||||||
|
messages.last && messages.last[:role] == "TOOL" && messages.last[:tool_results]
|
||||||
|
prompt[:tool_results] = tool_results if tool_results.present?
|
||||||
|
|
||||||
|
if tool_results.blank?
|
||||||
|
messages.reverse_each do |msg|
|
||||||
|
if msg[:role] == "USER"
|
||||||
|
prompt[:message] = msg[:message]
|
||||||
|
messages.delete(msg)
|
||||||
|
break
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -54,8 +84,16 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def native_tool_support?
|
||||||
|
true
|
||||||
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
|
def tools_dialect
|
||||||
|
@tools_dialect ||= DiscourseAi::Completions::Dialects::CohereTools.new(prompt.tools)
|
||||||
|
end
|
||||||
|
|
||||||
def per_message_overhead
|
def per_message_overhead
|
||||||
0
|
0
|
||||||
end
|
end
|
||||||
|
@ -83,11 +121,11 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def tool_call_msg(msg)
|
def tool_call_msg(msg)
|
||||||
{ role: "CHATBOT", message: tools_dialect.from_raw_tool_call(msg) }
|
msg
|
||||||
end
|
end
|
||||||
|
|
||||||
def tool_msg(msg)
|
def tool_msg(msg)
|
||||||
{ role: "USER", message: tools_dialect.from_raw_tool(msg) }
|
msg
|
||||||
end
|
end
|
||||||
|
|
||||||
def user_msg(msg)
|
def user_msg(msg)
|
||||||
|
|
|
@ -29,10 +29,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options(dialect)
|
def default_options(dialect)
|
||||||
options = { model: "command-r-plus" }
|
{ model: "command-r-plus" }
|
||||||
|
|
||||||
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
|
|
||||||
options
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def provider_id
|
def provider_id
|
||||||
|
@ -49,7 +46,11 @@ module DiscourseAi
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, dialect)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
payload = default_options(dialect).merge(model_params).merge(prompt)
|
payload = default_options(dialect).merge(model_params).merge(prompt)
|
||||||
|
if prompt[:tools].present?
|
||||||
|
payload[:tools] = prompt[:tools]
|
||||||
|
payload[:force_single_step] = false
|
||||||
|
end
|
||||||
|
payload[:tool_results] = prompt[:tool_results] if prompt[:tool_results].present?
|
||||||
payload[:stream] = true if @streaming_mode
|
payload[:stream] = true if @streaming_mode
|
||||||
|
|
||||||
payload
|
payload
|
||||||
|
@ -70,6 +71,14 @@ module DiscourseAi
|
||||||
if @streaming_mode
|
if @streaming_mode
|
||||||
if parsed[:event_type] == "text-generation"
|
if parsed[:event_type] == "text-generation"
|
||||||
parsed[:text]
|
parsed[:text]
|
||||||
|
elsif parsed[:event_type] == "tool-calls-generation"
|
||||||
|
# could just be random thinking...
|
||||||
|
if parsed.dig(:tool_calls).present?
|
||||||
|
@has_tool = true
|
||||||
|
parsed.dig(:tool_calls).to_json
|
||||||
|
else
|
||||||
|
""
|
||||||
|
end
|
||||||
else
|
else
|
||||||
if parsed[:event_type] == "stream-end"
|
if parsed[:event_type] == "stream-end"
|
||||||
@input_tokens = parsed.dig(:response, :meta, :billed_units, :input_tokens)
|
@input_tokens = parsed.dig(:response, :meta, :billed_units, :input_tokens)
|
||||||
|
@ -84,6 +93,38 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def has_tool?(_ignored)
|
||||||
|
@has_tool
|
||||||
|
end
|
||||||
|
|
||||||
|
def native_tool_support?
|
||||||
|
true
|
||||||
|
end
|
||||||
|
|
||||||
|
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||||
|
if partial
|
||||||
|
tools = JSON.parse(partial)
|
||||||
|
tools.each do |tool|
|
||||||
|
name = tool["name"]
|
||||||
|
parameters = tool["parameters"]
|
||||||
|
xml_params = parameters.map { |k, v| "<#{k}>#{v}</#{k}>\n" }.join
|
||||||
|
|
||||||
|
current_function = function_buffer.at("invoke")
|
||||||
|
if current_function.nil? || current_function.at("tool_name").content.present?
|
||||||
|
current_function =
|
||||||
|
function_buffer.at("function_calls").add_child(
|
||||||
|
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
current_function.at("tool_name").content = name == "search_local" ? "search" : name
|
||||||
|
current_function.at("parameters").children =
|
||||||
|
Nokogiri::HTML5::DocumentFragment.parse(xml_params)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
function_buffer
|
||||||
|
end
|
||||||
|
|
||||||
def final_log_update(log)
|
def final_log_update(log)
|
||||||
log.request_tokens = @input_tokens if @input_tokens
|
log.request_tokens = @input_tokens if @input_tokens
|
||||||
log.response_tokens = @output_tokens if @output_tokens
|
log.response_tokens = @output_tokens if @output_tokens
|
||||||
|
|
|
@ -59,6 +59,83 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
||||||
|
|
||||||
before { SiteSetting.ai_cohere_api_key = "ABC" }
|
before { SiteSetting.ai_cohere_api_key = "ABC" }
|
||||||
|
|
||||||
|
it "is able to trigger a tool" do
|
||||||
|
body = (<<~TEXT).strip
|
||||||
|
{"is_finished":false,"event_type":"stream-start","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b"}
|
||||||
|
{"is_finished":false,"event_type":"tool-calls-generation","text":"I will search for 'who is sam saffron' and relay the information to the user.","tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]}
|
||||||
|
{"is_finished":true,"event_type":"stream-end","response":{"response_id":"71d8c9e1-1138-4d70-80d1-10ddec41c989","text":"I will search for 'who is sam saffron' and relay the information to the user.","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b","chat_history":[{"role":"USER","message":"sam: who is sam saffron?"},{"role":"CHATBOT","message":"I will search for 'who is sam saffron' and relay the information to the user.","tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]}],"finish_reason":"COMPLETE","meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":460,"output_tokens":27},"tokens":{"input_tokens":1227,"output_tokens":27}},"tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]},"finish_reason":"COMPLETE"}
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
parsed_body = nil
|
||||||
|
result = +""
|
||||||
|
|
||||||
|
sig = {
|
||||||
|
name: "google",
|
||||||
|
description: "Will search using Google",
|
||||||
|
parameters: [
|
||||||
|
{ name: "query", description: "The search query", type: "string", required: true },
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
prompt.tools = [sig]
|
||||||
|
|
||||||
|
EndpointMock.with_chunk_array_support do
|
||||||
|
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
|
||||||
|
body:
|
||||||
|
proc do |req_body|
|
||||||
|
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
||||||
|
true
|
||||||
|
end,
|
||||||
|
headers: {
|
||||||
|
"Content-Type" => "application/json",
|
||||||
|
"Authorization" => "Bearer ABC",
|
||||||
|
},
|
||||||
|
).to_return(status: 200, body: body.split("|"))
|
||||||
|
|
||||||
|
result = llm.generate(prompt, user: user) { |partial, cancel| result << partial }
|
||||||
|
end
|
||||||
|
|
||||||
|
expected = <<~TEXT
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>google</tool_name>
|
||||||
|
<parameters><query>who is sam saffron</query>
|
||||||
|
</parameters>
|
||||||
|
<tool_id>tool_0</tool_id>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
expect(result.strip).to eq(expected.strip)
|
||||||
|
|
||||||
|
expected = {
|
||||||
|
model: "command-r-plus",
|
||||||
|
preamble: "You are hello bot",
|
||||||
|
chat_history: [
|
||||||
|
{ role: "USER", message: "user1: hello" },
|
||||||
|
{ role: "CHATBOT", message: "hi user" },
|
||||||
|
],
|
||||||
|
message: "user1: thanks",
|
||||||
|
tools: [
|
||||||
|
{
|
||||||
|
name: "google",
|
||||||
|
description: "Will search using Google",
|
||||||
|
parameter_definitions: {
|
||||||
|
query: {
|
||||||
|
description: "The search query",
|
||||||
|
type: "str",
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
force_single_step: false,
|
||||||
|
stream: true,
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(parsed_body).to eq(expected)
|
||||||
|
end
|
||||||
|
|
||||||
it "is able to run tools" do
|
it "is able to run tools" do
|
||||||
body = {
|
body = {
|
||||||
response_id: "0a90275b-273d-4690-abce-8018edcec7d0",
|
response_id: "0a90275b-273d-4690-abce-8018edcec7d0",
|
||||||
|
@ -99,20 +176,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
||||||
result = llm.generate(prompt_with_tool_results, user: user)
|
result = llm.generate(prompt_with_tool_results, user: user)
|
||||||
|
|
||||||
expect(parsed_body[:preamble]).to include("You are weather bot")
|
expect(parsed_body[:preamble]).to include("You are weather bot")
|
||||||
expect(parsed_body[:preamble]).to include("<tools>")
|
|
||||||
|
|
||||||
expected_message = <<~MESSAGE
|
|
||||||
<function_results>
|
|
||||||
<result>
|
|
||||||
<tool_name>weather</tool_name>
|
|
||||||
<json>
|
|
||||||
{"weather":"22c"}
|
|
||||||
</json>
|
|
||||||
</result>
|
|
||||||
</function_results>
|
|
||||||
MESSAGE
|
|
||||||
|
|
||||||
expect(parsed_body[:message].strip).to eq(expected_message.strip)
|
|
||||||
|
|
||||||
expect(result).to eq("Sydney is 22c")
|
expect(result).to eq("Sydney is 22c")
|
||||||
audit = AiApiAuditLog.order("id desc").first
|
audit = AiApiAuditLog.order("id desc").first
|
||||||
|
|
Loading…
Reference in New Issue