mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-02 23:39:30 +00:00
- Added Cohere Command models (Command, Command Light, Command R, Command R Plus) to the available model list - Added a new site setting `ai_cohere_api_key` for configuring the Cohere API key - Implemented a new `DiscourseAi::Completions::Endpoints::Cohere` class to handle interactions with the Cohere API, including: - Translating request parameters to the Cohere API format - Parsing Cohere API responses - Supporting streaming and non-streaming completions - Supporting "tools" which allow the model to call back to discourse to lookup additional information - Implemented a new `DiscourseAi::Completions::Dialects::Command` class to translate between the generic Discourse AI prompt format and the Cohere Command format - Added specs covering the new Cohere endpoint and dialect classes - Updated `DiscourseAi::AiBot::Bot.guess_model` to map the new Cohere model to the appropriate bot user In summary, this PR adds support for using the Cohere Command family of models with the Discourse AI plugin. It handles configuring API keys, making requests to the Cohere API, and translating between Discourse's generic prompt format and Cohere's specific format. Thorough test coverage was added for the new functionality.
249 lines
8.3 KiB
Ruby
249 lines
8.3 KiB
Ruby
# frozen_string_literal: true
|
|
require_relative "endpoint_compliance"
|
|
|
|
RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
|
let(:llm) { DiscourseAi::Completions::Llm.proxy("cohere:command-r-plus") }
|
|
fab!(:user)
|
|
|
|
let(:prompt) do
|
|
DiscourseAi::Completions::Prompt.new(
|
|
"You are hello bot",
|
|
messages: [
|
|
{ type: :user, id: "user1", content: "hello" },
|
|
{ type: :model, content: "hi user" },
|
|
{ type: :user, id: "user1", content: "thanks" },
|
|
],
|
|
)
|
|
end
|
|
|
|
let(:weather_tool) do
|
|
{
|
|
name: "weather",
|
|
description: "lookup weather in a city",
|
|
parameters: [{ name: "city", type: "string", description: "city name", required: true }],
|
|
}
|
|
end
|
|
|
|
let(:prompt_with_tools) do
|
|
prompt =
|
|
DiscourseAi::Completions::Prompt.new(
|
|
"You are weather bot",
|
|
messages: [
|
|
{ type: :user, id: "user1", content: "what is the weather in sydney and melbourne?" },
|
|
],
|
|
)
|
|
|
|
prompt.tools = [weather_tool]
|
|
prompt
|
|
end
|
|
|
|
let(:prompt_with_tool_results) do
|
|
prompt =
|
|
DiscourseAi::Completions::Prompt.new(
|
|
"You are weather bot",
|
|
messages: [
|
|
{ type: :user, id: "user1", content: "what is the weather in sydney and melbourne?" },
|
|
{
|
|
type: :tool_call,
|
|
id: "tool_call_1",
|
|
name: "weather",
|
|
content: { arguments: [%w[city Sydney]] }.to_json,
|
|
},
|
|
{ type: :tool, id: "tool_call_1", name: "weather", content: { weather: "22c" }.to_json },
|
|
],
|
|
)
|
|
|
|
prompt.tools = [weather_tool]
|
|
prompt
|
|
end
|
|
|
|
before { SiteSetting.ai_cohere_api_key = "ABC" }
|
|
|
|
it "is able to run tools" do
|
|
body = {
|
|
response_id: "0a90275b-273d-4690-abce-8018edcec7d0",
|
|
text: "Sydney is 22c",
|
|
generation_id: "cc2742f7-622c-4e42-8fd4-d95b21012e52",
|
|
chat_history: [],
|
|
finish_reason: "COMPLETE",
|
|
token_count: {
|
|
prompt_tokens: 29,
|
|
response_tokens: 11,
|
|
total_tokens: 40,
|
|
billed_tokens: 25,
|
|
},
|
|
meta: {
|
|
api_version: {
|
|
version: "1",
|
|
},
|
|
billed_units: {
|
|
input_tokens: 17,
|
|
output_tokens: 22,
|
|
},
|
|
},
|
|
}.to_json
|
|
|
|
parsed_body = nil
|
|
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
|
|
body:
|
|
proc do |req_body|
|
|
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
|
true
|
|
end,
|
|
headers: {
|
|
"Content-Type" => "application/json",
|
|
"Authorization" => "Bearer ABC",
|
|
},
|
|
).to_return(status: 200, body: body)
|
|
|
|
result = llm.generate(prompt_with_tool_results, user: user)
|
|
|
|
expect(parsed_body[:preamble]).to include("You are weather bot")
|
|
expect(parsed_body[:preamble]).to include("<tools>")
|
|
|
|
expected_message = <<~MESSAGE
|
|
<function_results>
|
|
<result>
|
|
<tool_name>weather</tool_name>
|
|
<json>
|
|
{"weather":"22c"}
|
|
</json>
|
|
</result>
|
|
</function_results>
|
|
MESSAGE
|
|
|
|
expect(parsed_body[:message].strip).to eq(expected_message.strip)
|
|
|
|
expect(result).to eq("Sydney is 22c")
|
|
audit = AiApiAuditLog.order("id desc").first
|
|
|
|
# billing should be picked
|
|
expect(audit.request_tokens).to eq(17)
|
|
expect(audit.response_tokens).to eq(22)
|
|
end
|
|
|
|
it "is able to perform streaming completions" do
|
|
body = <<~TEXT
|
|
{"is_finished":false,"event_type":"stream-start","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf"}
|
|
{"is_finished":false,"event_type":"text-generation","text":"You"}
|
|
{"is_finished":false,"event_type":"text-generation","text":"'re"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" welcome"}
|
|
{"is_finished":false,"event_type":"text-generation","text":"!"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" Is"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" there"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" anything"}|
|
|
{"is_finished":false,"event_type":"text-generation","text":" else"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" I"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" can"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" help"}|
|
|
{"is_finished":false,"event_type":"text-generation","text":" you"}
|
|
{"is_finished":false,"event_type":"text-generation","text":" with"}
|
|
{"is_finished":false,"event_type":"text-generation","text":"?"}|
|
|
{"is_finished":true,"event_type":"stream-end","response":{"response_id":"d235db17-8555-493b-8d91-e601f76de3f9","text":"You're welcome! Is there anything else I can help you with?","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf","chat_history":[{"role":"USER","message":"user1: hello"},{"role":"CHATBOT","message":"hi user"},{"role":"USER","message":"user1: thanks"},{"role":"CHATBOT","message":"You're welcome! Is there anything else I can help you with?"}],"token_count":{"prompt_tokens":29,"response_tokens":14,"total_tokens":43,"billed_tokens":28},"meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":14,"output_tokens":14}}},"finish_reason":"COMPLETE"}
|
|
TEXT
|
|
|
|
parsed_body = nil
|
|
result = +""
|
|
|
|
EndpointMock.with_chunk_array_support do
|
|
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
|
|
body:
|
|
proc do |req_body|
|
|
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
|
true
|
|
end,
|
|
headers: {
|
|
"Content-Type" => "application/json",
|
|
"Authorization" => "Bearer ABC",
|
|
},
|
|
).to_return(status: 200, body: body.split("|"))
|
|
|
|
result = llm.generate(prompt, user: user) { |partial, cancel| result << partial }
|
|
end
|
|
|
|
expect(parsed_body[:preamble]).to eq("You are hello bot")
|
|
expect(parsed_body[:chat_history]).to eq(
|
|
[{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }],
|
|
)
|
|
expect(parsed_body[:message]).to eq("user1: thanks")
|
|
|
|
expect(result).to eq("You're welcome! Is there anything else I can help you with?")
|
|
audit = AiApiAuditLog.order("id desc").first
|
|
|
|
# billing should be picked
|
|
expect(audit.request_tokens).to eq(14)
|
|
expect(audit.response_tokens).to eq(14)
|
|
end
|
|
|
|
it "is able to perform non streaming completions" do
|
|
body = {
|
|
response_id: "0a90275b-273d-4690-abce-8018edcec7d0",
|
|
text: "You're welcome! How can I help you today?",
|
|
generation_id: "cc2742f7-622c-4e42-8fd4-d95b21012e52",
|
|
chat_history: [
|
|
{ role: "USER", message: "user1: hello" },
|
|
{ role: "CHATBOT", message: "hi user" },
|
|
{ role: "USER", message: "user1: thanks" },
|
|
{ role: "CHATBOT", message: "You're welcome! How can I help you today?" },
|
|
],
|
|
finish_reason: "COMPLETE",
|
|
token_count: {
|
|
prompt_tokens: 29,
|
|
response_tokens: 11,
|
|
total_tokens: 40,
|
|
billed_tokens: 25,
|
|
},
|
|
meta: {
|
|
api_version: {
|
|
version: "1",
|
|
},
|
|
billed_units: {
|
|
input_tokens: 14,
|
|
output_tokens: 11,
|
|
},
|
|
},
|
|
}.to_json
|
|
|
|
parsed_body = nil
|
|
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
|
|
body:
|
|
proc do |req_body|
|
|
parsed_body = JSON.parse(req_body, symbolize_names: true)
|
|
true
|
|
end,
|
|
headers: {
|
|
"Content-Type" => "application/json",
|
|
"Authorization" => "Bearer ABC",
|
|
},
|
|
).to_return(status: 200, body: body)
|
|
|
|
result =
|
|
llm.generate(
|
|
prompt,
|
|
user: user,
|
|
temperature: 0.1,
|
|
top_p: 0.5,
|
|
max_tokens: 100,
|
|
stop_sequences: ["stop"],
|
|
)
|
|
|
|
expect(parsed_body[:temperature]).to eq(0.1)
|
|
expect(parsed_body[:p]).to eq(0.5)
|
|
expect(parsed_body[:max_tokens]).to eq(100)
|
|
expect(parsed_body[:stop_sequences]).to eq(["stop"])
|
|
|
|
expect(parsed_body[:preamble]).to eq("You are hello bot")
|
|
expect(parsed_body[:chat_history]).to eq(
|
|
[{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }],
|
|
)
|
|
expect(parsed_body[:message]).to eq("user1: thanks")
|
|
|
|
expect(result).to eq("You're welcome! How can I help you today?")
|
|
audit = AiApiAuditLog.order("id desc").first
|
|
|
|
# billing should be picked
|
|
expect(audit.request_tokens).to eq(14)
|
|
expect(audit.response_tokens).to eq(11)
|
|
end
|
|
end
|