FEATURE: Cohere Command R support (#558)

- Added Cohere Command models (Command, Command Light, Command R, Command R Plus) to the available model list
- Added a new site setting `ai_cohere_api_key` for configuring the Cohere API key
- Implemented a new `DiscourseAi::Completions::Endpoints::Cohere` class to handle interactions with the Cohere API, including:
   - Translating request parameters to the Cohere API format
   - Parsing Cohere API responses 
   - Supporting streaming and non-streaming completions
   - Supporting "tools" which allow the model to call back to discourse to lookup additional information
- Implemented a new `DiscourseAi::Completions::Dialects::Command` class to translate between the generic Discourse AI prompt format and the Cohere Command format
- Added specs covering the new Cohere endpoint and dialect classes
- Updated `DiscourseAi::AiBot::Bot.guess_model` to map the new Cohere model to the appropriate bot user

In summary, this PR adds support for using the Cohere Command family of models with the Discourse AI plugin. It handles configuring API keys, making requests to the Cohere API, and translating between Discourse's generic prompt format and Cohere's specific format. Thorough test coverage was added for the new functionality.
This commit is contained in:
Sam 2024-04-11 07:24:17 +10:00 committed by GitHub
parent eb93b21769
commit 7f16d3ad43
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 484 additions and 0 deletions

View File

@ -7,6 +7,7 @@ class AiApiAuditLog < ActiveRecord::Base
HuggingFaceTextGeneration = 3
Gemini = 4
Vllm = 5
Cohere = 6
end
end

View File

@ -268,6 +268,7 @@ en:
claude-3-opus: "Claude 3 Opus"
claude-3-sonnet: "Claude 3 Sonnet"
claude-3-haiku: "Claude 3 Haiku"
cohere-command-r-plus: "Cohere Command R Plus"
gpt-4: "GPT-4"
gpt-4-turbo: "GPT-4 Turbo"
gpt-3:

View File

@ -50,6 +50,7 @@ en:
ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)"
ai_openai_api_key: "API key for OpenAI API"
ai_anthropic_api_key: "API key for Anthropic API"
ai_cohere_api_key: "API key for Cohere API"
ai_hugging_face_api_url: "Custom URL used for OpenSource LLM inference. Compatible with https://github.com/huggingface/text-generation-inference"
ai_hugging_face_api_key: API key for Hugging Face API
ai_hugging_face_token_limit: Max tokens Hugging Face API can use per request

View File

@ -110,6 +110,9 @@ discourse_ai:
ai_anthropic_api_key:
default: ""
secret: true
ai_cohere_api_key:
default: ""
secret: true
ai_stability_api_key:
default: ""
secret: true
@ -336,6 +339,7 @@ discourse_ai:
- claude-3-opus
- claude-3-sonnet
- claude-3-haiku
- cohere-command-r-plus
ai_bot_add_to_header:
default: true
client: true

View File

@ -180,6 +180,8 @@ module DiscourseAi
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID
# no bedrock support yet 18-03
"anthropic:claude-3-opus"
when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS
"cohere:command-r-plus"
when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID
if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?(
"claude-3-sonnet",

View File

@ -17,6 +17,7 @@ module DiscourseAi
CLAUDE_3_OPUS_ID = -117
CLAUDE_3_SONNET_ID = -118
CLAUDE_3_HAIKU_ID = -119
COHERE_COMMAND_R_PLUS = -120
BOTS = [
[GPT4_ID, "gpt4_bot", "gpt-4"],
@ -29,6 +30,7 @@ module DiscourseAi
[CLAUDE_3_OPUS_ID, "claude_3_opus_bot", "claude-3-opus"],
[CLAUDE_3_SONNET_ID, "claude_3_sonnet_bot", "claude-3-sonnet"],
[CLAUDE_3_HAIKU_ID, "claude_3_haiku_bot", "claude-3-haiku"],
[COHERE_COMMAND_R_PLUS, "cohere_command_bot", "cohere-command-r-plus"],
]
BOT_USER_IDS = BOTS.map(&:first)
@ -67,6 +69,8 @@ module DiscourseAi
CLAUDE_3_SONNET_ID
in "claude-3-haiku"
CLAUDE_3_HAIKU_ID
in "cohere-command-r-plus"
COHERE_COMMAND_R_PLUS
else
nil
end

View File

@ -0,0 +1,107 @@
# frozen_string_literal: true
# see: https://docs.cohere.com/reference/chat
#
module DiscourseAi
module Completions
module Dialects
class Command < Dialect
class << self
def can_translate?(model_name)
%w[command-light command command-r command-r-plus].include?(model_name)
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def translate
messages = prompt.messages
# ChatGPT doesn't use an assistant msg to improve long-context responses.
if messages.last[:type] == :model
messages = messages.dup
messages.pop
end
trimmed_messages = trim_messages(messages)
chat_history = []
system_message = nil
prompt = {}
trimmed_messages.each do |msg|
case msg[:type]
when :system
if system_message
chat_history << { role: "SYSTEM", message: msg[:content] }
else
system_message = msg[:content]
end
when :model
chat_history << { role: "CHATBOT", message: msg[:content] }
when :tool_call
chat_history << { role: "CHATBOT", message: tool_call_to_xml(msg) }
when :tool
chat_history << { role: "USER", message: tool_result_to_xml(msg) }
when :user
user_message = { role: "USER", message: msg[:content] }
user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id]
chat_history << user_message
end
end
tools_prompt = build_tools_prompt
prompt[:preamble] = +"#{system_message}"
if tools_prompt.present?
prompt[:preamble] << "\n#{tools_prompt}"
prompt[
:preamble
] << "\nNEVER attempt to run tools using JSON, always use XML. Lives depend on it."
end
prompt[:chat_history] = chat_history if chat_history.present?
chat_history.reverse_each do |msg|
if msg[:role] == "USER"
prompt[:message] = msg[:message]
chat_history.delete(msg)
break
end
end
prompt
end
def max_prompt_tokens
case model_name
when "command-light"
4096
when "command"
8192
when "command-r"
131_072
when "command-r-plus"
131_072
else
8192
end
end
private
def per_message_overhead
0
end
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
end
end
end
end

View File

@ -17,6 +17,7 @@ module DiscourseAi
DiscourseAi::Completions::Dialects::Gemini,
DiscourseAi::Completions::Dialects::Mixtral,
DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Command,
]
if Rails.env.test? || Rails.env.development?

View File

@ -16,6 +16,7 @@ module DiscourseAi
DiscourseAi::Completions::Endpoints::Gemini,
DiscourseAi::Completions::Endpoints::Vllm,
DiscourseAi::Completions::Endpoints::Anthropic,
DiscourseAi::Completions::Endpoints::Cohere,
]
if Rails.env.test? || Rails.env.development?

View File

@ -0,0 +1,114 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class Cohere < Base
class << self
def can_contact?(endpoint_name, model_name)
return false unless endpoint_name == "cohere"
%w[command-light command command-r command-r-plus].include?(model_name)
end
def dependant_setting_names
%w[ai_cohere_api_key]
end
def correctly_configured?(model_name)
SiteSetting.ai_cohere_api_key.present?
end
def endpoint_name(model_name)
"Cohere - #{model_name}"
end
end
def normalize_model_params(model_params)
model_params = model_params.dup
model_params[:p] = model_params.delete(:top_p) if model_params[:top_p]
model_params
end
def default_options(dialect)
options = { model: "command-r-plus" }
options[:stop_sequences] = ["</function_calls>"] if dialect.prompt.has_tools?
options
end
def provider_id
AiApiAuditLog::Provider::Cohere
end
private
def model_uri
URI("https://api.cohere.ai/v1/chat")
end
def prepare_payload(prompt, model_params, dialect)
payload = default_options(dialect).merge(model_params).merge(prompt)
payload[:stream] = true if @streaming_mode
payload
end
def prepare_request(payload)
headers = {
"Content-Type" => "application/json",
"Authorization" => "Bearer #{SiteSetting.ai_cohere_api_key}",
}
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true)
if @streaming_mode
if parsed[:event_type] == "text-generation"
parsed[:text]
else
if parsed[:event_type] == "stream-end"
@input_tokens = parsed.dig(:response, :meta, :billed_units, :input_tokens)
@output_tokens = parsed.dig(:response, :meta, :billed_units, :output_tokens)
end
nil
end
else
@input_tokens = parsed.dig(:meta, :billed_units, :input_tokens)
@output_tokens = parsed.dig(:meta, :billed_units, :output_tokens)
parsed[:text].to_s
end
end
def final_log_update(log)
log.request_tokens = @input_tokens if @input_tokens
log.response_tokens = @output_tokens if @output_tokens
end
def partials_from(decoded_chunk)
decoded_chunk.split("\n").compact
end
def extract_prompt_for_tokenizer(prompt)
text = +""
if prompt[:chat_history]
text << prompt[:chat_history]
.map { |message| message[:content] || message["content"] || "" }
.join("\n")
end
text << prompt[:message] if prompt[:message]
text << prompt[:preamble] if prompt[:preamble]
text
end
end
end
end
end

View File

@ -0,0 +1,248 @@
# frozen_string_literal: true
require_relative "endpoint_compliance"
RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
let(:llm) { DiscourseAi::Completions::Llm.proxy("cohere:command-r-plus") }
fab!(:user)
let(:prompt) do
DiscourseAi::Completions::Prompt.new(
"You are hello bot",
messages: [
{ type: :user, id: "user1", content: "hello" },
{ type: :model, content: "hi user" },
{ type: :user, id: "user1", content: "thanks" },
],
)
end
let(:weather_tool) do
{
name: "weather",
description: "lookup weather in a city",
parameters: [{ name: "city", type: "string", description: "city name", required: true }],
}
end
let(:prompt_with_tools) do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are weather bot",
messages: [
{ type: :user, id: "user1", content: "what is the weather in sydney and melbourne?" },
],
)
prompt.tools = [weather_tool]
prompt
end
let(:prompt_with_tool_results) do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are weather bot",
messages: [
{ type: :user, id: "user1", content: "what is the weather in sydney and melbourne?" },
{
type: :tool_call,
id: "tool_call_1",
name: "weather",
content: { arguments: [%w[city Sydney]] }.to_json,
},
{ type: :tool, id: "tool_call_1", name: "weather", content: { weather: "22c" }.to_json },
],
)
prompt.tools = [weather_tool]
prompt
end
before { SiteSetting.ai_cohere_api_key = "ABC" }
it "is able to run tools" do
body = {
response_id: "0a90275b-273d-4690-abce-8018edcec7d0",
text: "Sydney is 22c",
generation_id: "cc2742f7-622c-4e42-8fd4-d95b21012e52",
chat_history: [],
finish_reason: "COMPLETE",
token_count: {
prompt_tokens: 29,
response_tokens: 11,
total_tokens: 40,
billed_tokens: 25,
},
meta: {
api_version: {
version: "1",
},
billed_units: {
input_tokens: 17,
output_tokens: 22,
},
},
}.to_json
parsed_body = nil
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
body:
proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true)
true
end,
headers: {
"Content-Type" => "application/json",
"Authorization" => "Bearer ABC",
},
).to_return(status: 200, body: body)
result = llm.generate(prompt_with_tool_results, user: user)
expect(parsed_body[:preamble]).to include("You are weather bot")
expect(parsed_body[:preamble]).to include("<tools>")
expected_message = <<~MESSAGE
<function_results>
<result>
<tool_name>weather</tool_name>
<json>
{"weather":"22c"}
</json>
</result>
</function_results>
MESSAGE
expect(parsed_body[:message].strip).to eq(expected_message.strip)
expect(result).to eq("Sydney is 22c")
audit = AiApiAuditLog.order("id desc").first
# billing should be picked
expect(audit.request_tokens).to eq(17)
expect(audit.response_tokens).to eq(22)
end
it "is able to perform streaming completions" do
body = <<~TEXT
{"is_finished":false,"event_type":"stream-start","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf"}
{"is_finished":false,"event_type":"text-generation","text":"You"}
{"is_finished":false,"event_type":"text-generation","text":"'re"}
{"is_finished":false,"event_type":"text-generation","text":" welcome"}
{"is_finished":false,"event_type":"text-generation","text":"!"}
{"is_finished":false,"event_type":"text-generation","text":" Is"}
{"is_finished":false,"event_type":"text-generation","text":" there"}
{"is_finished":false,"event_type":"text-generation","text":" anything"}|
{"is_finished":false,"event_type":"text-generation","text":" else"}
{"is_finished":false,"event_type":"text-generation","text":" I"}
{"is_finished":false,"event_type":"text-generation","text":" can"}
{"is_finished":false,"event_type":"text-generation","text":" help"}|
{"is_finished":false,"event_type":"text-generation","text":" you"}
{"is_finished":false,"event_type":"text-generation","text":" with"}
{"is_finished":false,"event_type":"text-generation","text":"?"}|
{"is_finished":true,"event_type":"stream-end","response":{"response_id":"d235db17-8555-493b-8d91-e601f76de3f9","text":"You're welcome! Is there anything else I can help you with?","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf","chat_history":[{"role":"USER","message":"user1: hello"},{"role":"CHATBOT","message":"hi user"},{"role":"USER","message":"user1: thanks"},{"role":"CHATBOT","message":"You're welcome! Is there anything else I can help you with?"}],"token_count":{"prompt_tokens":29,"response_tokens":14,"total_tokens":43,"billed_tokens":28},"meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":14,"output_tokens":14}}},"finish_reason":"COMPLETE"}
TEXT
parsed_body = nil
result = +""
EndpointMock.with_chunk_array_support do
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
body:
proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true)
true
end,
headers: {
"Content-Type" => "application/json",
"Authorization" => "Bearer ABC",
},
).to_return(status: 200, body: body.split("|"))
result = llm.generate(prompt, user: user) { |partial, cancel| result << partial }
end
expect(parsed_body[:preamble]).to eq("You are hello bot")
expect(parsed_body[:chat_history]).to eq(
[{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }],
)
expect(parsed_body[:message]).to eq("user1: thanks")
expect(result).to eq("You're welcome! Is there anything else I can help you with?")
audit = AiApiAuditLog.order("id desc").first
# billing should be picked
expect(audit.request_tokens).to eq(14)
expect(audit.response_tokens).to eq(14)
end
it "is able to perform non streaming completions" do
body = {
response_id: "0a90275b-273d-4690-abce-8018edcec7d0",
text: "You're welcome! How can I help you today?",
generation_id: "cc2742f7-622c-4e42-8fd4-d95b21012e52",
chat_history: [
{ role: "USER", message: "user1: hello" },
{ role: "CHATBOT", message: "hi user" },
{ role: "USER", message: "user1: thanks" },
{ role: "CHATBOT", message: "You're welcome! How can I help you today?" },
],
finish_reason: "COMPLETE",
token_count: {
prompt_tokens: 29,
response_tokens: 11,
total_tokens: 40,
billed_tokens: 25,
},
meta: {
api_version: {
version: "1",
},
billed_units: {
input_tokens: 14,
output_tokens: 11,
},
},
}.to_json
parsed_body = nil
stub_request(:post, "https://api.cohere.ai/v1/chat").with(
body:
proc do |req_body|
parsed_body = JSON.parse(req_body, symbolize_names: true)
true
end,
headers: {
"Content-Type" => "application/json",
"Authorization" => "Bearer ABC",
},
).to_return(status: 200, body: body)
result =
llm.generate(
prompt,
user: user,
temperature: 0.1,
top_p: 0.5,
max_tokens: 100,
stop_sequences: ["stop"],
)
expect(parsed_body[:temperature]).to eq(0.1)
expect(parsed_body[:p]).to eq(0.5)
expect(parsed_body[:max_tokens]).to eq(100)
expect(parsed_body[:stop_sequences]).to eq(["stop"])
expect(parsed_body[:preamble]).to eq("You are hello bot")
expect(parsed_body[:chat_history]).to eq(
[{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }],
)
expect(parsed_body[:message]).to eq("user1: thanks")
expect(result).to eq("You're welcome! How can I help you today?")
audit = AiApiAuditLog.order("id desc").first
# billing should be picked
expect(audit.request_tokens).to eq(14)
expect(audit.response_tokens).to eq(11)
end
end