Sam 7f16d3ad43
FEATURE: Cohere Command R support (#558)
- Added Cohere Command models (Command, Command Light, Command R, Command R Plus) to the available model list
- Added a new site setting `ai_cohere_api_key` for configuring the Cohere API key
- Implemented a new `DiscourseAi::Completions::Endpoints::Cohere` class to handle interactions with the Cohere API, including:
   - Translating request parameters to the Cohere API format
   - Parsing Cohere API responses 
   - Supporting streaming and non-streaming completions
   - Supporting "tools" which allow the model to call back to discourse to lookup additional information
- Implemented a new `DiscourseAi::Completions::Dialects::Command` class to translate between the generic Discourse AI prompt format and the Cohere Command format
- Added specs covering the new Cohere endpoint and dialect classes
- Updated `DiscourseAi::AiBot::Bot.guess_model` to map the new Cohere model to the appropriate bot user

In summary, this PR adds support for using the Cohere Command family of models with the Discourse AI plugin. It handles configuring API keys, making requests to the Cohere API, and translating between Discourse's generic prompt format and Cohere's specific format. Thorough test coverage was added for the new functionality.
2024-04-11 07:24:17 +10:00

108 lines
2.9 KiB
Ruby

# frozen_string_literal: true
# see: https://docs.cohere.com/reference/chat
#
module DiscourseAi
module Completions
module Dialects
class Command < Dialect
class << self
def can_translate?(model_name)
%w[command-light command command-r command-r-plus].include?(model_name)
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def translate
messages = prompt.messages
# ChatGPT doesn't use an assistant msg to improve long-context responses.
if messages.last[:type] == :model
messages = messages.dup
messages.pop
end
trimmed_messages = trim_messages(messages)
chat_history = []
system_message = nil
prompt = {}
trimmed_messages.each do |msg|
case msg[:type]
when :system
if system_message
chat_history << { role: "SYSTEM", message: msg[:content] }
else
system_message = msg[:content]
end
when :model
chat_history << { role: "CHATBOT", message: msg[:content] }
when :tool_call
chat_history << { role: "CHATBOT", message: tool_call_to_xml(msg) }
when :tool
chat_history << { role: "USER", message: tool_result_to_xml(msg) }
when :user
user_message = { role: "USER", message: msg[:content] }
user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id]
chat_history << user_message
end
end
tools_prompt = build_tools_prompt
prompt[:preamble] = +"#{system_message}"
if tools_prompt.present?
prompt[:preamble] << "\n#{tools_prompt}"
prompt[
:preamble
] << "\nNEVER attempt to run tools using JSON, always use XML. Lives depend on it."
end
prompt[:chat_history] = chat_history if chat_history.present?
chat_history.reverse_each do |msg|
if msg[:role] == "USER"
prompt[:message] = msg[:message]
chat_history.delete(msg)
break
end
end
prompt
end
def max_prompt_tokens
case model_name
when "command-light"
4096
when "command"
8192
when "command-r"
131_072
when "command-r-plus"
131_072
else
8192
end
end
private
def per_message_overhead
0
end
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
end
end
end
end