From 7f16d3ad43705474c25b155ecc574d48f4cf0f17 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 11 Apr 2024 07:24:17 +1000 Subject: [PATCH] FEATURE: Cohere Command R support (#558) - Added Cohere Command models (Command, Command Light, Command R, Command R Plus) to the available model list - Added a new site setting `ai_cohere_api_key` for configuring the Cohere API key - Implemented a new `DiscourseAi::Completions::Endpoints::Cohere` class to handle interactions with the Cohere API, including: - Translating request parameters to the Cohere API format - Parsing Cohere API responses - Supporting streaming and non-streaming completions - Supporting "tools" which allow the model to call back to discourse to lookup additional information - Implemented a new `DiscourseAi::Completions::Dialects::Command` class to translate between the generic Discourse AI prompt format and the Cohere Command format - Added specs covering the new Cohere endpoint and dialect classes - Updated `DiscourseAi::AiBot::Bot.guess_model` to map the new Cohere model to the appropriate bot user In summary, this PR adds support for using the Cohere Command family of models with the Discourse AI plugin. It handles configuring API keys, making requests to the Cohere API, and translating between Discourse's generic prompt format and Cohere's specific format. Thorough test coverage was added for the new functionality. --- app/models/ai_api_audit_log.rb | 1 + config/locales/client.en.yml | 1 + config/locales/server.en.yml | 1 + config/settings.yml | 4 + lib/ai_bot/bot.rb | 2 + lib/ai_bot/entry_point.rb | 4 + lib/completions/dialects/command.rb | 107 ++++++++ lib/completions/dialects/dialect.rb | 1 + lib/completions/endpoints/base.rb | 1 + lib/completions/endpoints/cohere.rb | 114 ++++++++ spec/lib/completions/endpoints/cohere_spec.rb | 248 ++++++++++++++++++ 11 files changed, 484 insertions(+) create mode 100644 lib/completions/dialects/command.rb create mode 100644 lib/completions/endpoints/cohere.rb create mode 100644 spec/lib/completions/endpoints/cohere_spec.rb diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index 4de02196..01c3440e 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -7,6 +7,7 @@ class AiApiAuditLog < ActiveRecord::Base HuggingFaceTextGeneration = 3 Gemini = 4 Vllm = 5 + Cohere = 6 end end diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 57a2fd3b..bad17fc3 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -268,6 +268,7 @@ en: claude-3-opus: "Claude 3 Opus" claude-3-sonnet: "Claude 3 Sonnet" claude-3-haiku: "Claude 3 Haiku" + cohere-command-r-plus: "Cohere Command R Plus" gpt-4: "GPT-4" gpt-4-turbo: "GPT-4 Turbo" gpt-3: diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index c47cdd5a..c6f6209b 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -50,6 +50,7 @@ en: ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)" ai_openai_api_key: "API key for OpenAI API" ai_anthropic_api_key: "API key for Anthropic API" + ai_cohere_api_key: "API key for Cohere API" ai_hugging_face_api_url: "Custom URL used for OpenSource LLM inference. Compatible with https://github.com/huggingface/text-generation-inference" ai_hugging_face_api_key: API key for Hugging Face API ai_hugging_face_token_limit: Max tokens Hugging Face API can use per request diff --git a/config/settings.yml b/config/settings.yml index 26013f6a..e656205d 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -110,6 +110,9 @@ discourse_ai: ai_anthropic_api_key: default: "" secret: true + ai_cohere_api_key: + default: "" + secret: true ai_stability_api_key: default: "" secret: true @@ -336,6 +339,7 @@ discourse_ai: - claude-3-opus - claude-3-sonnet - claude-3-haiku + - cohere-command-r-plus ai_bot_add_to_header: default: true client: true diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 0f1a553a..f0bb1856 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -180,6 +180,8 @@ module DiscourseAi when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID # no bedrock support yet 18-03 "anthropic:claude-3-opus" + when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS + "cohere:command-r-plus" when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( "claude-3-sonnet", diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb index 928753d9..59e206c7 100644 --- a/lib/ai_bot/entry_point.rb +++ b/lib/ai_bot/entry_point.rb @@ -17,6 +17,7 @@ module DiscourseAi CLAUDE_3_OPUS_ID = -117 CLAUDE_3_SONNET_ID = -118 CLAUDE_3_HAIKU_ID = -119 + COHERE_COMMAND_R_PLUS = -120 BOTS = [ [GPT4_ID, "gpt4_bot", "gpt-4"], @@ -29,6 +30,7 @@ module DiscourseAi [CLAUDE_3_OPUS_ID, "claude_3_opus_bot", "claude-3-opus"], [CLAUDE_3_SONNET_ID, "claude_3_sonnet_bot", "claude-3-sonnet"], [CLAUDE_3_HAIKU_ID, "claude_3_haiku_bot", "claude-3-haiku"], + [COHERE_COMMAND_R_PLUS, "cohere_command_bot", "cohere-command-r-plus"], ] BOT_USER_IDS = BOTS.map(&:first) @@ -67,6 +69,8 @@ module DiscourseAi CLAUDE_3_SONNET_ID in "claude-3-haiku" CLAUDE_3_HAIKU_ID + in "cohere-command-r-plus" + COHERE_COMMAND_R_PLUS else nil end diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb new file mode 100644 index 00000000..8b4bf67d --- /dev/null +++ b/lib/completions/dialects/command.rb @@ -0,0 +1,107 @@ +# frozen_string_literal: true + +# see: https://docs.cohere.com/reference/chat +# +module DiscourseAi + module Completions + module Dialects + class Command < Dialect + class << self + def can_translate?(model_name) + %w[command-light command command-r command-r-plus].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::OpenAiTokenizer + end + end + + VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ + + def translate + messages = prompt.messages + + # ChatGPT doesn't use an assistant msg to improve long-context responses. + if messages.last[:type] == :model + messages = messages.dup + messages.pop + end + + trimmed_messages = trim_messages(messages) + + chat_history = [] + system_message = nil + + prompt = {} + + trimmed_messages.each do |msg| + case msg[:type] + when :system + if system_message + chat_history << { role: "SYSTEM", message: msg[:content] } + else + system_message = msg[:content] + end + when :model + chat_history << { role: "CHATBOT", message: msg[:content] } + when :tool_call + chat_history << { role: "CHATBOT", message: tool_call_to_xml(msg) } + when :tool + chat_history << { role: "USER", message: tool_result_to_xml(msg) } + when :user + user_message = { role: "USER", message: msg[:content] } + user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id] + chat_history << user_message + end + end + + tools_prompt = build_tools_prompt + prompt[:preamble] = +"#{system_message}" + if tools_prompt.present? + prompt[:preamble] << "\n#{tools_prompt}" + prompt[ + :preamble + ] << "\nNEVER attempt to run tools using JSON, always use XML. Lives depend on it." + end + + prompt[:chat_history] = chat_history if chat_history.present? + + chat_history.reverse_each do |msg| + if msg[:role] == "USER" + prompt[:message] = msg[:message] + chat_history.delete(msg) + break + end + end + + prompt + end + + def max_prompt_tokens + case model_name + when "command-light" + 4096 + when "command" + 8192 + when "command-r" + 131_072 + when "command-r-plus" + 131_072 + else + 8192 + end + end + + private + + def per_message_overhead + 0 + end + + def calculate_message_token(context) + self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) + end + end + end + end +end diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 3baaa68d..d84d9e38 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -17,6 +17,7 @@ module DiscourseAi DiscourseAi::Completions::Dialects::Gemini, DiscourseAi::Completions::Dialects::Mixtral, DiscourseAi::Completions::Dialects::Claude, + DiscourseAi::Completions::Dialects::Command, ] if Rails.env.test? || Rails.env.development? diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 51c8c012..268ea510 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -16,6 +16,7 @@ module DiscourseAi DiscourseAi::Completions::Endpoints::Gemini, DiscourseAi::Completions::Endpoints::Vllm, DiscourseAi::Completions::Endpoints::Anthropic, + DiscourseAi::Completions::Endpoints::Cohere, ] if Rails.env.test? || Rails.env.development? diff --git a/lib/completions/endpoints/cohere.rb b/lib/completions/endpoints/cohere.rb new file mode 100644 index 00000000..4903fe00 --- /dev/null +++ b/lib/completions/endpoints/cohere.rb @@ -0,0 +1,114 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Endpoints + class Cohere < Base + class << self + def can_contact?(endpoint_name, model_name) + return false unless endpoint_name == "cohere" + + %w[command-light command command-r command-r-plus].include?(model_name) + end + + def dependant_setting_names + %w[ai_cohere_api_key] + end + + def correctly_configured?(model_name) + SiteSetting.ai_cohere_api_key.present? + end + + def endpoint_name(model_name) + "Cohere - #{model_name}" + end + end + + def normalize_model_params(model_params) + model_params = model_params.dup + + model_params[:p] = model_params.delete(:top_p) if model_params[:top_p] + + model_params + end + + def default_options(dialect) + options = { model: "command-r-plus" } + + options[:stop_sequences] = [""] if dialect.prompt.has_tools? + options + end + + def provider_id + AiApiAuditLog::Provider::Cohere + end + + private + + def model_uri + URI("https://api.cohere.ai/v1/chat") + end + + def prepare_payload(prompt, model_params, dialect) + payload = default_options(dialect).merge(model_params).merge(prompt) + + payload[:stream] = true if @streaming_mode + + payload + end + + def prepare_request(payload) + headers = { + "Content-Type" => "application/json", + "Authorization" => "Bearer #{SiteSetting.ai_cohere_api_key}", + } + + Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } + end + + def extract_completion_from(response_raw) + parsed = JSON.parse(response_raw, symbolize_names: true) + + if @streaming_mode + if parsed[:event_type] == "text-generation" + parsed[:text] + else + if parsed[:event_type] == "stream-end" + @input_tokens = parsed.dig(:response, :meta, :billed_units, :input_tokens) + @output_tokens = parsed.dig(:response, :meta, :billed_units, :output_tokens) + end + nil + end + else + @input_tokens = parsed.dig(:meta, :billed_units, :input_tokens) + @output_tokens = parsed.dig(:meta, :billed_units, :output_tokens) + parsed[:text].to_s + end + end + + def final_log_update(log) + log.request_tokens = @input_tokens if @input_tokens + log.response_tokens = @output_tokens if @output_tokens + end + + def partials_from(decoded_chunk) + decoded_chunk.split("\n").compact + end + + def extract_prompt_for_tokenizer(prompt) + text = +"" + if prompt[:chat_history] + text << prompt[:chat_history] + .map { |message| message[:content] || message["content"] || "" } + .join("\n") + end + + text << prompt[:message] if prompt[:message] + text << prompt[:preamble] if prompt[:preamble] + + text + end + end + end + end +end diff --git a/spec/lib/completions/endpoints/cohere_spec.rb b/spec/lib/completions/endpoints/cohere_spec.rb new file mode 100644 index 00000000..52a73721 --- /dev/null +++ b/spec/lib/completions/endpoints/cohere_spec.rb @@ -0,0 +1,248 @@ +# frozen_string_literal: true +require_relative "endpoint_compliance" + +RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do + let(:llm) { DiscourseAi::Completions::Llm.proxy("cohere:command-r-plus") } + fab!(:user) + + let(:prompt) do + DiscourseAi::Completions::Prompt.new( + "You are hello bot", + messages: [ + { type: :user, id: "user1", content: "hello" }, + { type: :model, content: "hi user" }, + { type: :user, id: "user1", content: "thanks" }, + ], + ) + end + + let(:weather_tool) do + { + name: "weather", + description: "lookup weather in a city", + parameters: [{ name: "city", type: "string", description: "city name", required: true }], + } + end + + let(:prompt_with_tools) do + prompt = + DiscourseAi::Completions::Prompt.new( + "You are weather bot", + messages: [ + { type: :user, id: "user1", content: "what is the weather in sydney and melbourne?" }, + ], + ) + + prompt.tools = [weather_tool] + prompt + end + + let(:prompt_with_tool_results) do + prompt = + DiscourseAi::Completions::Prompt.new( + "You are weather bot", + messages: [ + { type: :user, id: "user1", content: "what is the weather in sydney and melbourne?" }, + { + type: :tool_call, + id: "tool_call_1", + name: "weather", + content: { arguments: [%w[city Sydney]] }.to_json, + }, + { type: :tool, id: "tool_call_1", name: "weather", content: { weather: "22c" }.to_json }, + ], + ) + + prompt.tools = [weather_tool] + prompt + end + + before { SiteSetting.ai_cohere_api_key = "ABC" } + + it "is able to run tools" do + body = { + response_id: "0a90275b-273d-4690-abce-8018edcec7d0", + text: "Sydney is 22c", + generation_id: "cc2742f7-622c-4e42-8fd4-d95b21012e52", + chat_history: [], + finish_reason: "COMPLETE", + token_count: { + prompt_tokens: 29, + response_tokens: 11, + total_tokens: 40, + billed_tokens: 25, + }, + meta: { + api_version: { + version: "1", + }, + billed_units: { + input_tokens: 17, + output_tokens: 22, + }, + }, + }.to_json + + parsed_body = nil + stub_request(:post, "https://api.cohere.ai/v1/chat").with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + headers: { + "Content-Type" => "application/json", + "Authorization" => "Bearer ABC", + }, + ).to_return(status: 200, body: body) + + result = llm.generate(prompt_with_tool_results, user: user) + + expect(parsed_body[:preamble]).to include("You are weather bot") + expect(parsed_body[:preamble]).to include("") + + expected_message = <<~MESSAGE + + + weather + + {"weather":"22c"} + + + + MESSAGE + + expect(parsed_body[:message].strip).to eq(expected_message.strip) + + expect(result).to eq("Sydney is 22c") + audit = AiApiAuditLog.order("id desc").first + + # billing should be picked + expect(audit.request_tokens).to eq(17) + expect(audit.response_tokens).to eq(22) + end + + it "is able to perform streaming completions" do + body = <<~TEXT + {"is_finished":false,"event_type":"stream-start","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf"} + {"is_finished":false,"event_type":"text-generation","text":"You"} + {"is_finished":false,"event_type":"text-generation","text":"'re"} + {"is_finished":false,"event_type":"text-generation","text":" welcome"} + {"is_finished":false,"event_type":"text-generation","text":"!"} + {"is_finished":false,"event_type":"text-generation","text":" Is"} + {"is_finished":false,"event_type":"text-generation","text":" there"} + {"is_finished":false,"event_type":"text-generation","text":" anything"}| + {"is_finished":false,"event_type":"text-generation","text":" else"} + {"is_finished":false,"event_type":"text-generation","text":" I"} + {"is_finished":false,"event_type":"text-generation","text":" can"} + {"is_finished":false,"event_type":"text-generation","text":" help"}| + {"is_finished":false,"event_type":"text-generation","text":" you"} + {"is_finished":false,"event_type":"text-generation","text":" with"} + {"is_finished":false,"event_type":"text-generation","text":"?"}| + {"is_finished":true,"event_type":"stream-end","response":{"response_id":"d235db17-8555-493b-8d91-e601f76de3f9","text":"You're welcome! Is there anything else I can help you with?","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf","chat_history":[{"role":"USER","message":"user1: hello"},{"role":"CHATBOT","message":"hi user"},{"role":"USER","message":"user1: thanks"},{"role":"CHATBOT","message":"You're welcome! Is there anything else I can help you with?"}],"token_count":{"prompt_tokens":29,"response_tokens":14,"total_tokens":43,"billed_tokens":28},"meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":14,"output_tokens":14}}},"finish_reason":"COMPLETE"} + TEXT + + parsed_body = nil + result = +"" + + EndpointMock.with_chunk_array_support do + stub_request(:post, "https://api.cohere.ai/v1/chat").with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + headers: { + "Content-Type" => "application/json", + "Authorization" => "Bearer ABC", + }, + ).to_return(status: 200, body: body.split("|")) + + result = llm.generate(prompt, user: user) { |partial, cancel| result << partial } + end + + expect(parsed_body[:preamble]).to eq("You are hello bot") + expect(parsed_body[:chat_history]).to eq( + [{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }], + ) + expect(parsed_body[:message]).to eq("user1: thanks") + + expect(result).to eq("You're welcome! Is there anything else I can help you with?") + audit = AiApiAuditLog.order("id desc").first + + # billing should be picked + expect(audit.request_tokens).to eq(14) + expect(audit.response_tokens).to eq(14) + end + + it "is able to perform non streaming completions" do + body = { + response_id: "0a90275b-273d-4690-abce-8018edcec7d0", + text: "You're welcome! How can I help you today?", + generation_id: "cc2742f7-622c-4e42-8fd4-d95b21012e52", + chat_history: [ + { role: "USER", message: "user1: hello" }, + { role: "CHATBOT", message: "hi user" }, + { role: "USER", message: "user1: thanks" }, + { role: "CHATBOT", message: "You're welcome! How can I help you today?" }, + ], + finish_reason: "COMPLETE", + token_count: { + prompt_tokens: 29, + response_tokens: 11, + total_tokens: 40, + billed_tokens: 25, + }, + meta: { + api_version: { + version: "1", + }, + billed_units: { + input_tokens: 14, + output_tokens: 11, + }, + }, + }.to_json + + parsed_body = nil + stub_request(:post, "https://api.cohere.ai/v1/chat").with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + headers: { + "Content-Type" => "application/json", + "Authorization" => "Bearer ABC", + }, + ).to_return(status: 200, body: body) + + result = + llm.generate( + prompt, + user: user, + temperature: 0.1, + top_p: 0.5, + max_tokens: 100, + stop_sequences: ["stop"], + ) + + expect(parsed_body[:temperature]).to eq(0.1) + expect(parsed_body[:p]).to eq(0.5) + expect(parsed_body[:max_tokens]).to eq(100) + expect(parsed_body[:stop_sequences]).to eq(["stop"]) + + expect(parsed_body[:preamble]).to eq("You are hello bot") + expect(parsed_body[:chat_history]).to eq( + [{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }], + ) + expect(parsed_body[:message]).to eq("user1: thanks") + + expect(result).to eq("You're welcome! How can I help you today?") + audit = AiApiAuditLog.order("id desc").first + + # billing should be picked + expect(audit.request_tokens).to eq(14) + expect(audit.response_tokens).to eq(11) + end +end