diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index 4de02196..01c3440e 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -7,6 +7,7 @@ class AiApiAuditLog < ActiveRecord::Base HuggingFaceTextGeneration = 3 Gemini = 4 Vllm = 5 + Cohere = 6 end end diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 57a2fd3b..bad17fc3 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -268,6 +268,7 @@ en: claude-3-opus: "Claude 3 Opus" claude-3-sonnet: "Claude 3 Sonnet" claude-3-haiku: "Claude 3 Haiku" + cohere-command-r-plus: "Cohere Command R Plus" gpt-4: "GPT-4" gpt-4-turbo: "GPT-4 Turbo" gpt-3: diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index c47cdd5a..c6f6209b 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -50,6 +50,7 @@ en: ai_openai_embeddings_url: "Custom URL used for the OpenAI embeddings API. (in the case of Azure it can be: https://COMPANY.openai.azure.com/openai/deployments/DEPLOYMENT/embeddings?api-version=2023-05-15)" ai_openai_api_key: "API key for OpenAI API" ai_anthropic_api_key: "API key for Anthropic API" + ai_cohere_api_key: "API key for Cohere API" ai_hugging_face_api_url: "Custom URL used for OpenSource LLM inference. Compatible with https://github.com/huggingface/text-generation-inference" ai_hugging_face_api_key: API key for Hugging Face API ai_hugging_face_token_limit: Max tokens Hugging Face API can use per request diff --git a/config/settings.yml b/config/settings.yml index 26013f6a..e656205d 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -110,6 +110,9 @@ discourse_ai: ai_anthropic_api_key: default: "" secret: true + ai_cohere_api_key: + default: "" + secret: true ai_stability_api_key: default: "" secret: true @@ -336,6 +339,7 @@ discourse_ai: - claude-3-opus - claude-3-sonnet - claude-3-haiku + - cohere-command-r-plus ai_bot_add_to_header: default: true client: true diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 0f1a553a..f0bb1856 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -180,6 +180,8 @@ module DiscourseAi when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID # no bedrock support yet 18-03 "anthropic:claude-3-opus" + when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS + "cohere:command-r-plus" when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( "claude-3-sonnet", diff --git a/lib/ai_bot/entry_point.rb b/lib/ai_bot/entry_point.rb index 928753d9..59e206c7 100644 --- a/lib/ai_bot/entry_point.rb +++ b/lib/ai_bot/entry_point.rb @@ -17,6 +17,7 @@ module DiscourseAi CLAUDE_3_OPUS_ID = -117 CLAUDE_3_SONNET_ID = -118 CLAUDE_3_HAIKU_ID = -119 + COHERE_COMMAND_R_PLUS = -120 BOTS = [ [GPT4_ID, "gpt4_bot", "gpt-4"], @@ -29,6 +30,7 @@ module DiscourseAi [CLAUDE_3_OPUS_ID, "claude_3_opus_bot", "claude-3-opus"], [CLAUDE_3_SONNET_ID, "claude_3_sonnet_bot", "claude-3-sonnet"], [CLAUDE_3_HAIKU_ID, "claude_3_haiku_bot", "claude-3-haiku"], + [COHERE_COMMAND_R_PLUS, "cohere_command_bot", "cohere-command-r-plus"], ] BOT_USER_IDS = BOTS.map(&:first) @@ -67,6 +69,8 @@ module DiscourseAi CLAUDE_3_SONNET_ID in "claude-3-haiku" CLAUDE_3_HAIKU_ID + in "cohere-command-r-plus" + COHERE_COMMAND_R_PLUS else nil end diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb new file mode 100644 index 00000000..8b4bf67d --- /dev/null +++ b/lib/completions/dialects/command.rb @@ -0,0 +1,107 @@ +# frozen_string_literal: true + +# see: https://docs.cohere.com/reference/chat +# +module DiscourseAi + module Completions + module Dialects + class Command < Dialect + class << self + def can_translate?(model_name) + %w[command-light command command-r command-r-plus].include?(model_name) + end + + def tokenizer + DiscourseAi::Tokenizer::OpenAiTokenizer + end + end + + VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ + + def translate + messages = prompt.messages + + # ChatGPT doesn't use an assistant msg to improve long-context responses. + if messages.last[:type] == :model + messages = messages.dup + messages.pop + end + + trimmed_messages = trim_messages(messages) + + chat_history = [] + system_message = nil + + prompt = {} + + trimmed_messages.each do |msg| + case msg[:type] + when :system + if system_message + chat_history << { role: "SYSTEM", message: msg[:content] } + else + system_message = msg[:content] + end + when :model + chat_history << { role: "CHATBOT", message: msg[:content] } + when :tool_call + chat_history << { role: "CHATBOT", message: tool_call_to_xml(msg) } + when :tool + chat_history << { role: "USER", message: tool_result_to_xml(msg) } + when :user + user_message = { role: "USER", message: msg[:content] } + user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id] + chat_history << user_message + end + end + + tools_prompt = build_tools_prompt + prompt[:preamble] = +"#{system_message}" + if tools_prompt.present? + prompt[:preamble] << "\n#{tools_prompt}" + prompt[ + :preamble + ] << "\nNEVER attempt to run tools using JSON, always use XML. Lives depend on it." + end + + prompt[:chat_history] = chat_history if chat_history.present? + + chat_history.reverse_each do |msg| + if msg[:role] == "USER" + prompt[:message] = msg[:message] + chat_history.delete(msg) + break + end + end + + prompt + end + + def max_prompt_tokens + case model_name + when "command-light" + 4096 + when "command" + 8192 + when "command-r" + 131_072 + when "command-r-plus" + 131_072 + else + 8192 + end + end + + private + + def per_message_overhead + 0 + end + + def calculate_message_token(context) + self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) + end + end + end + end +end diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 3baaa68d..d84d9e38 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -17,6 +17,7 @@ module DiscourseAi DiscourseAi::Completions::Dialects::Gemini, DiscourseAi::Completions::Dialects::Mixtral, DiscourseAi::Completions::Dialects::Claude, + DiscourseAi::Completions::Dialects::Command, ] if Rails.env.test? || Rails.env.development? diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 51c8c012..268ea510 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -16,6 +16,7 @@ module DiscourseAi DiscourseAi::Completions::Endpoints::Gemini, DiscourseAi::Completions::Endpoints::Vllm, DiscourseAi::Completions::Endpoints::Anthropic, + DiscourseAi::Completions::Endpoints::Cohere, ] if Rails.env.test? || Rails.env.development? diff --git a/lib/completions/endpoints/cohere.rb b/lib/completions/endpoints/cohere.rb new file mode 100644 index 00000000..4903fe00 --- /dev/null +++ b/lib/completions/endpoints/cohere.rb @@ -0,0 +1,114 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Endpoints + class Cohere < Base + class << self + def can_contact?(endpoint_name, model_name) + return false unless endpoint_name == "cohere" + + %w[command-light command command-r command-r-plus].include?(model_name) + end + + def dependant_setting_names + %w[ai_cohere_api_key] + end + + def correctly_configured?(model_name) + SiteSetting.ai_cohere_api_key.present? + end + + def endpoint_name(model_name) + "Cohere - #{model_name}" + end + end + + def normalize_model_params(model_params) + model_params = model_params.dup + + model_params[:p] = model_params.delete(:top_p) if model_params[:top_p] + + model_params + end + + def default_options(dialect) + options = { model: "command-r-plus" } + + options[:stop_sequences] = [""] if dialect.prompt.has_tools? + options + end + + def provider_id + AiApiAuditLog::Provider::Cohere + end + + private + + def model_uri + URI("https://api.cohere.ai/v1/chat") + end + + def prepare_payload(prompt, model_params, dialect) + payload = default_options(dialect).merge(model_params).merge(prompt) + + payload[:stream] = true if @streaming_mode + + payload + end + + def prepare_request(payload) + headers = { + "Content-Type" => "application/json", + "Authorization" => "Bearer #{SiteSetting.ai_cohere_api_key}", + } + + Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } + end + + def extract_completion_from(response_raw) + parsed = JSON.parse(response_raw, symbolize_names: true) + + if @streaming_mode + if parsed[:event_type] == "text-generation" + parsed[:text] + else + if parsed[:event_type] == "stream-end" + @input_tokens = parsed.dig(:response, :meta, :billed_units, :input_tokens) + @output_tokens = parsed.dig(:response, :meta, :billed_units, :output_tokens) + end + nil + end + else + @input_tokens = parsed.dig(:meta, :billed_units, :input_tokens) + @output_tokens = parsed.dig(:meta, :billed_units, :output_tokens) + parsed[:text].to_s + end + end + + def final_log_update(log) + log.request_tokens = @input_tokens if @input_tokens + log.response_tokens = @output_tokens if @output_tokens + end + + def partials_from(decoded_chunk) + decoded_chunk.split("\n").compact + end + + def extract_prompt_for_tokenizer(prompt) + text = +"" + if prompt[:chat_history] + text << prompt[:chat_history] + .map { |message| message[:content] || message["content"] || "" } + .join("\n") + end + + text << prompt[:message] if prompt[:message] + text << prompt[:preamble] if prompt[:preamble] + + text + end + end + end + end +end diff --git a/spec/lib/completions/endpoints/cohere_spec.rb b/spec/lib/completions/endpoints/cohere_spec.rb new file mode 100644 index 00000000..52a73721 --- /dev/null +++ b/spec/lib/completions/endpoints/cohere_spec.rb @@ -0,0 +1,248 @@ +# frozen_string_literal: true +require_relative "endpoint_compliance" + +RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do + let(:llm) { DiscourseAi::Completions::Llm.proxy("cohere:command-r-plus") } + fab!(:user) + + let(:prompt) do + DiscourseAi::Completions::Prompt.new( + "You are hello bot", + messages: [ + { type: :user, id: "user1", content: "hello" }, + { type: :model, content: "hi user" }, + { type: :user, id: "user1", content: "thanks" }, + ], + ) + end + + let(:weather_tool) do + { + name: "weather", + description: "lookup weather in a city", + parameters: [{ name: "city", type: "string", description: "city name", required: true }], + } + end + + let(:prompt_with_tools) do + prompt = + DiscourseAi::Completions::Prompt.new( + "You are weather bot", + messages: [ + { type: :user, id: "user1", content: "what is the weather in sydney and melbourne?" }, + ], + ) + + prompt.tools = [weather_tool] + prompt + end + + let(:prompt_with_tool_results) do + prompt = + DiscourseAi::Completions::Prompt.new( + "You are weather bot", + messages: [ + { type: :user, id: "user1", content: "what is the weather in sydney and melbourne?" }, + { + type: :tool_call, + id: "tool_call_1", + name: "weather", + content: { arguments: [%w[city Sydney]] }.to_json, + }, + { type: :tool, id: "tool_call_1", name: "weather", content: { weather: "22c" }.to_json }, + ], + ) + + prompt.tools = [weather_tool] + prompt + end + + before { SiteSetting.ai_cohere_api_key = "ABC" } + + it "is able to run tools" do + body = { + response_id: "0a90275b-273d-4690-abce-8018edcec7d0", + text: "Sydney is 22c", + generation_id: "cc2742f7-622c-4e42-8fd4-d95b21012e52", + chat_history: [], + finish_reason: "COMPLETE", + token_count: { + prompt_tokens: 29, + response_tokens: 11, + total_tokens: 40, + billed_tokens: 25, + }, + meta: { + api_version: { + version: "1", + }, + billed_units: { + input_tokens: 17, + output_tokens: 22, + }, + }, + }.to_json + + parsed_body = nil + stub_request(:post, "https://api.cohere.ai/v1/chat").with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + headers: { + "Content-Type" => "application/json", + "Authorization" => "Bearer ABC", + }, + ).to_return(status: 200, body: body) + + result = llm.generate(prompt_with_tool_results, user: user) + + expect(parsed_body[:preamble]).to include("You are weather bot") + expect(parsed_body[:preamble]).to include("") + + expected_message = <<~MESSAGE + + + weather + + {"weather":"22c"} + + + + MESSAGE + + expect(parsed_body[:message].strip).to eq(expected_message.strip) + + expect(result).to eq("Sydney is 22c") + audit = AiApiAuditLog.order("id desc").first + + # billing should be picked + expect(audit.request_tokens).to eq(17) + expect(audit.response_tokens).to eq(22) + end + + it "is able to perform streaming completions" do + body = <<~TEXT + {"is_finished":false,"event_type":"stream-start","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf"} + {"is_finished":false,"event_type":"text-generation","text":"You"} + {"is_finished":false,"event_type":"text-generation","text":"'re"} + {"is_finished":false,"event_type":"text-generation","text":" welcome"} + {"is_finished":false,"event_type":"text-generation","text":"!"} + {"is_finished":false,"event_type":"text-generation","text":" Is"} + {"is_finished":false,"event_type":"text-generation","text":" there"} + {"is_finished":false,"event_type":"text-generation","text":" anything"}| + {"is_finished":false,"event_type":"text-generation","text":" else"} + {"is_finished":false,"event_type":"text-generation","text":" I"} + {"is_finished":false,"event_type":"text-generation","text":" can"} + {"is_finished":false,"event_type":"text-generation","text":" help"}| + {"is_finished":false,"event_type":"text-generation","text":" you"} + {"is_finished":false,"event_type":"text-generation","text":" with"} + {"is_finished":false,"event_type":"text-generation","text":"?"}| + {"is_finished":true,"event_type":"stream-end","response":{"response_id":"d235db17-8555-493b-8d91-e601f76de3f9","text":"You're welcome! Is there anything else I can help you with?","generation_id":"eb889b0f-c27d-45ea-98cf-567bdb7fc8bf","chat_history":[{"role":"USER","message":"user1: hello"},{"role":"CHATBOT","message":"hi user"},{"role":"USER","message":"user1: thanks"},{"role":"CHATBOT","message":"You're welcome! Is there anything else I can help you with?"}],"token_count":{"prompt_tokens":29,"response_tokens":14,"total_tokens":43,"billed_tokens":28},"meta":{"api_version":{"version":"1"},"billed_units":{"input_tokens":14,"output_tokens":14}}},"finish_reason":"COMPLETE"} + TEXT + + parsed_body = nil + result = +"" + + EndpointMock.with_chunk_array_support do + stub_request(:post, "https://api.cohere.ai/v1/chat").with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + headers: { + "Content-Type" => "application/json", + "Authorization" => "Bearer ABC", + }, + ).to_return(status: 200, body: body.split("|")) + + result = llm.generate(prompt, user: user) { |partial, cancel| result << partial } + end + + expect(parsed_body[:preamble]).to eq("You are hello bot") + expect(parsed_body[:chat_history]).to eq( + [{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }], + ) + expect(parsed_body[:message]).to eq("user1: thanks") + + expect(result).to eq("You're welcome! Is there anything else I can help you with?") + audit = AiApiAuditLog.order("id desc").first + + # billing should be picked + expect(audit.request_tokens).to eq(14) + expect(audit.response_tokens).to eq(14) + end + + it "is able to perform non streaming completions" do + body = { + response_id: "0a90275b-273d-4690-abce-8018edcec7d0", + text: "You're welcome! How can I help you today?", + generation_id: "cc2742f7-622c-4e42-8fd4-d95b21012e52", + chat_history: [ + { role: "USER", message: "user1: hello" }, + { role: "CHATBOT", message: "hi user" }, + { role: "USER", message: "user1: thanks" }, + { role: "CHATBOT", message: "You're welcome! How can I help you today?" }, + ], + finish_reason: "COMPLETE", + token_count: { + prompt_tokens: 29, + response_tokens: 11, + total_tokens: 40, + billed_tokens: 25, + }, + meta: { + api_version: { + version: "1", + }, + billed_units: { + input_tokens: 14, + output_tokens: 11, + }, + }, + }.to_json + + parsed_body = nil + stub_request(:post, "https://api.cohere.ai/v1/chat").with( + body: + proc do |req_body| + parsed_body = JSON.parse(req_body, symbolize_names: true) + true + end, + headers: { + "Content-Type" => "application/json", + "Authorization" => "Bearer ABC", + }, + ).to_return(status: 200, body: body) + + result = + llm.generate( + prompt, + user: user, + temperature: 0.1, + top_p: 0.5, + max_tokens: 100, + stop_sequences: ["stop"], + ) + + expect(parsed_body[:temperature]).to eq(0.1) + expect(parsed_body[:p]).to eq(0.5) + expect(parsed_body[:max_tokens]).to eq(100) + expect(parsed_body[:stop_sequences]).to eq(["stop"]) + + expect(parsed_body[:preamble]).to eq("You are hello bot") + expect(parsed_body[:chat_history]).to eq( + [{ role: "USER", message: "user1: hello" }, { role: "CHATBOT", message: "hi user" }], + ) + expect(parsed_body[:message]).to eq("user1: thanks") + + expect(result).to eq("You're welcome! How can I help you today?") + audit = AiApiAuditLog.order("id desc").first + + # billing should be picked + expect(audit.request_tokens).to eq(14) + expect(audit.response_tokens).to eq(11) + end +end