mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-07-04 21:42:14 +00:00
This re-implements tool support in DiscourseAi::Completions::Llm #generate Previously tool support was always returned via XML and it would be the responsibility of the caller to parse XML New implementation has the endpoints return ToolCall objects. Additionally this simplifies the Llm endpoint interface and gives it more clarity. Llms must implement decode, decode_chunk (for streaming) It is the implementers responsibility to figure out how to decode chunks, base no longer implements. To make this easy we ship a flexible json decoder which is easy to wire up. Also (new) Better debugging for PMs, we now have a next / previous button to see all the Llm messages associated with a PM Token accounting is fixed for vllm (we were not correctly counting tokens)
312 lines
11 KiB
Ruby
312 lines
11 KiB
Ruby
# frozen_string_literal: true
|
|
require_relative "endpoint_compliance"
|
|
|
|
RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
|
fab!(:cohere_model)
|
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("custom:#{cohere_model.id}") }
|
|
fab!(:user)
|
|
|
|
let(:prompt) do
|
|
DiscourseAi::Completions::Prompt.new(
|
|
"You are hello bot",
|
|
messages: [
|
|
{ type: :user, id: "user1", content: "hello" },
|
|
{ type: :model, content: "hi user" },
|
|
{ type: :user, id: "user1", content: "thanks" },
|
|
],
|
|
)
|
|
end
|
|
|
|
let(:weather_tool) do
|
|
{
|
|
name: "weather",
|
|
description: "lookup weather in a city",
|
|
parameters: [{ name: "city", type: "string", description: "city name", required: true }],
|
|
}
|
|
end
|
|
|
|
let(:prompt_with_tools) do
|
|
prompt =
|
|
DiscourseAi::Completions::Prompt.new(
|
|
"You are weather bot",
|
|
messages: [
|
|
{ type: :user, id: "user1", content: "what is the weather in sydney and melbourne?" },
|
|
],
|
|
)
|
|
|
|
prompt.tools = [weather_tool]
|
|
prompt
|
|
end
|
|
|
|
let(:prompt_with_tool_results) do
|
|
prompt =
|
|
DiscourseAi::Completions::Prompt.new(
|
|
"You are weather bot",
|
|
messages: [
|
|
{ type: :user, id: "user1", content: "what is the weather in sydney and melbourne?" },
|
|
{
|
|
type: :tool_call,
|
|
id: "tool_call_1",
|
|
name: "weather",
|
|
content: { arguments: [%w[city Sydney]] }.to_json,
|
|
},
|
|
{ type: :tool, id: "tool_call_1", name: "weather", content: { weather: "22c" }.to_json },
|
|
],
|
|
)
|
|
|
|
prompt.tools = [weather_tool]
|
|
prompt
|
|
end
|
|
|
|
it "is able to trigger a tool" do
|
|
body = (<<~TEXT).strip
|
|
{"is_finished":false,"event_type":"stream-start","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b"}
|
|
{"is_finished":false,"event_type":"tool-calls-generation","text":"I will search for 'who is sam saffron' and relay the information to the user.","tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]}
|
|
{"is_finished":true,"event_type":"stream-end","response":{"response_id":"71d8c9e1-1138-4d70-80d1-10ddec41c989","text":"I will search for 'who is sam saffron' and relay the information to the user.","generation_id":"1648206e-1fe4-4bb6-90cf-360dd55f575b","chat_history":[{"role":"USER","message":"sam: who is sam saffron?"},{"role":"CHATBOT","message":"I will search for 'who is sam saffron' and relay the information to the user.","tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]}],"finish_reason":"COMPLETE","meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":460,"output_tokens":27},"tokens":{"input_tokens":1227,"output_tokens":27}},"tool_calls":[{"name":"google","parameters":{"query":"who is sam saffron"}}]},"finish_reason":"COMPLETE"}
|
|
TEXT
|
|
|
|
parsed_body = nil
|
|
result = []
|
|
|
|
sig = {
|
|
name: "google",
|
|
description: "Will search using Google",
|
|
parameters: [
|
|
{ name: "query", description: "The search query", type: "string", required: true },
|
|
],
|
|
}
|
|
|
|
prompt.tools = [sig]
|
|
|
|
EndpointMock.with_chunk_array_support do
|
|
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
|
|
body:
|
|
proc do |req_body|
|
|
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
|
true
|
|
end,
|
|
headers: {
|
|
"Content-Type" => "application/json",
|
|
"Authorization" => "Bearer ABC",
|
|
},
|
|
).to_return(status: 200, body: body.split("|"))
|
|
|
|
llm.generate(prompt, user: user) { |partial, cancel| result << partial }
|
|
end
|
|
|
|
text = "I will search for 'who is sam saffron' and relay the information to the user."
|
|
tool_call =
|
|
DiscourseAi::Completions::ToolCall.new(
|
|
id: "tool_0",
|
|
name: "google",
|
|
parameters: {
|
|
query: "who is sam saffron",
|
|
},
|
|
)
|
|
|
|
expect(result).to eq([text, tool_call])
|
|
|
|
expected = {
|
|
model: "command-r-plus",
|
|
preamble: "You are hello bot",
|
|
chat_history: [
|
|
{ role: "USER", message: "user1: hello" },
|
|
{ role: "CHATBOT", message: "hi user" },
|
|
],
|
|
message: "user1: thanks",
|
|
tools: [
|
|
{
|
|
name: "google",
|
|
description: "Will search using Google",
|
|
parameter_definitions: {
|
|
query: {
|
|
description: "The search query",
|
|
type: "str",
|
|
required: true,
|
|
},
|
|
},
|
|
},
|
|
],
|
|
force_single_step: false,
|
|
stream: true,
|
|
}
|
|
|
|
expect(parsed_body).to eq(expected)
|
|
end
|
|
|
|
it "is able to run tools" do
|
|
body = {
|
|
response_id: "0a90275b-273d-4690-abce-8018edcec7d0",
|
|
text: "Sydney is 22c",
|
|
generation_id: "cc2742f7-622c-4e42-8fd4-d95b21012e52",
|
|
chat_history: [],
|
|
finish_reason: "COMPLETE",
|
|
token_count: {
|
|
prompt_tokens: 29,
|
|
response_tokens: 11,
|
|
total_tokens: 40,
|
|
billed_tokens: 25,
|
|
},
|
|
meta: {
|
|
api_version: {
|
|
version: "1",
|
|
},
|
|
billed_units: {
|
|
input_tokens: 17,
|
|
output_tokens: 22,
|
|
},
|
|
},
|
|
}.to_json
|
|
|
|
parsed_body = nil
|
|
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
|
|
body:
|
|
proc do |req_body|
|
|
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
|
true
|
|
end,
|
|
headers: {
|
|
"Content-Type" => "application/json",
|
|
"Authorization" => "Bearer ABC",
|
|
},
|
|
).to_return(status: 200, body: body)
|
|
|
|
result = llm.generate(prompt_with_tool_results, user: user)
|
|
|
|
expect(parsed_body[:preamble]).to include("You are weather bot")
|
|
|
|
expect(result).to eq("Sydney is 22c")
|
|
audit = AiApiAuditLog.order("id desc").first
|
|
|
|
# billing should be picked
|
|
expect(audit.request_tokens).to eq(17)
|
|
expect(audit.response_tokens).to eq(22)
|
|
|
|
expect(audit.language_model).to eq("command-r-plus")
|
|
end
|
|
|
|
it "is able to perform streaming completions" do
|
|
body = <<~TEXT
|
|
{"is_finished":false,"event_type":"stream-start","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf"}
|
|
{"is_finished":false,"event_type":"text-generation","text":"You"}
|
|
{"is_finished":false,"event_type":"text-generation","text":"'re"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" welcome"}
|
|
{"is_finished":false,"event_type":"text-generation","text":"!"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" Is"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" there"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" anything"}|
|
|
{"is_finished":false,"event_type":"text-generation","text":" else"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" I"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" can"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" help"}|
|
|
{"is_finished":false,"event_type":"text-generation","text":" you"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" with"}
|
|
{"is_finished":false,"event_type":"text-generation","text":"?"}|
|
|
{"is_finished":true,"event_type":"stream-end","response":{"response_id":"d235db17-8555-493b-8d91-e601f76de3f9","text":"You're welcome! Is there anything else I can help you with?","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf","chat_history":[{"role":"USER","message":"user1: hello"},{"role":"CHATBOT","message":"hi user"},{"role":"USER","message":"user1: thanks"},{"role":"CHATBOT","message":"You're welcome! Is there anything else I can help you with?"}],"token_count":{"prompt_tokens":29,"response_tokens":14,"total_tokens":43,"billed_tokens":28},"meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":14,"output_tokens":14}}},"finish_reason":"COMPLETE"}
|
|
TEXT
|
|
|
|
parsed_body = nil
|
|
result = +""
|
|
|
|
EndpointMock.with_chunk_array_support do
|
|
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
|
|
body:
|
|
proc do |req_body|
|
|
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
|
true
|
|
end,
|
|
headers: {
|
|
"Content-Type" => "application/json",
|
|
"Authorization" => "Bearer ABC",
|
|
},
|
|
).to_return(status: 200, body: body.split("|"))
|
|
|
|
result = llm.generate(prompt, user: user) { |partial, cancel| result << partial }
|
|
end
|
|
|
|
expect(parsed_body[:preamble]).to eq("You are hello bot")
|
|
expect(parsed_body[:chat_history]).to eq(
|
|
[{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }],
|
|
)
|
|
expect(parsed_body[:message]).to eq("user1: thanks")
|
|
|
|
expect(result).to eq("You're welcome! Is there anything else I can help you with?")
|
|
audit = AiApiAuditLog.order("id desc").first
|
|
|
|
# billing should be picked
|
|
expect(audit.request_tokens).to eq(14)
|
|
expect(audit.response_tokens).to eq(14)
|
|
end
|
|
|
|
it "is able to perform non streaming completions" do
|
|
body = {
|
|
response_id: "0a90275b-273d-4690-abce-8018edcec7d0",
|
|
text: "You're welcome! How can I help you today?",
|
|
generation_id: "cc2742f7-622c-4e42-8fd4-d95b21012e52",
|
|
chat_history: [
|
|
{ role: "USER", message: "user1: hello" },
|
|
{ role: "CHATBOT", message: "hi user" },
|
|
{ role: "USER", message: "user1: thanks" },
|
|
{ role: "CHATBOT", message: "You're welcome! How can I help you today?" },
|
|
],
|
|
finish_reason: "COMPLETE",
|
|
token_count: {
|
|
prompt_tokens: 29,
|
|
response_tokens: 11,
|
|
total_tokens: 40,
|
|
billed_tokens: 25,
|
|
},
|
|
meta: {
|
|
api_version: {
|
|
version: "1",
|
|
},
|
|
billed_units: {
|
|
input_tokens: 14,
|
|
output_tokens: 11,
|
|
},
|
|
},
|
|
}.to_json
|
|
|
|
parsed_body = nil
|
|
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
|
|
body:
|
|
proc do |req_body|
|
|
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
|
true
|
|
end,
|
|
headers: {
|
|
"Content-Type" => "application/json",
|
|
"Authorization" => "Bearer ABC",
|
|
},
|
|
).to_return(status: 200, body: body)
|
|
|
|
result =
|
|
llm.generate(
|
|
prompt,
|
|
user: user,
|
|
temperature: 0.1,
|
|
top_p: 0.5,
|
|
max_tokens: 100,
|
|
stop_sequences: ["stop"],
|
|
)
|
|
|
|
expect(parsed_body[:temperature]).to eq(0.1)
|
|
expect(parsed_body[:p]).to eq(0.5)
|
|
expect(parsed_body[:max_tokens]).to eq(100)
|
|
expect(parsed_body[:stop_sequences]).to eq(["stop"])
|
|
|
|
expect(parsed_body[:preamble]).to eq("You are hello bot")
|
|
expect(parsed_body[:chat_history]).to eq(
|
|
[{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }],
|
|
)
|
|
expect(parsed_body[:message]).to eq("user1: thanks")
|
|
|
|
expect(result).to eq("You're welcome! How can I help you today?")
|
|
audit = AiApiAuditLog.order("id desc").first
|
|
|
|
# billing should be picked
|
|
expect(audit.request_tokens).to eq(14)
|
|
expect(audit.response_tokens).to eq(11)
|
|
end
|
|
end
|