mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-02-17 00:45:12 +00:00
- Added Cohere Command models (Command, Command Light, Command R, Command R Plus) to the available model list - Added a new site setting `ai_cohere_api_key` for configuring the Cohere API key - Implemented a new `DiscourseAi::Completions::Endpoints::Cohere` class to handle interactions with the Cohere API, including: - Translating request parameters to the Cohere API format - Parsing Cohere API responses - Supporting streaming and non-streaming completions - Supporting "tools" which allow the model to call back to discourse to lookup additional information - Implemented a new `DiscourseAi::Completions::Dialects::Command` class to translate between the generic Discourse AI prompt format and the Cohere Command format - Added specs covering the new Cohere endpoint and dialect classes - Updated `DiscourseAi::AiBot::Bot.guess_model` to map the new Cohere model to the appropriate bot user In summary, this PR adds support for using the Cohere Command family of models with the Discourse AI plugin. It handles configuring API keys, making requests to the Cohere API, and translating between Discourse's generic prompt format and Cohere's specific format. Thorough test coverage was added for the new functionality.
108 lines
2.9 KiB
Ruby
108 lines
2.9 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
# see: https://docs.cohere.com/reference/chat
|
|
#
|
|
module DiscourseAi
|
|
module Completions
|
|
module Dialects
|
|
class Command < Dialect
|
|
class << self
|
|
def can_translate?(model_name)
|
|
%w[command-light command command-r command-r-plus].include?(model_name)
|
|
end
|
|
|
|
def tokenizer
|
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
|
end
|
|
end
|
|
|
|
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
|
|
|
def translate
|
|
messages = prompt.messages
|
|
|
|
# ChatGPT doesn't use an assistant msg to improve long-context responses.
|
|
if messages.last[:type] == :model
|
|
messages = messages.dup
|
|
messages.pop
|
|
end
|
|
|
|
trimmed_messages = trim_messages(messages)
|
|
|
|
chat_history = []
|
|
system_message = nil
|
|
|
|
prompt = {}
|
|
|
|
trimmed_messages.each do |msg|
|
|
case msg[:type]
|
|
when :system
|
|
if system_message
|
|
chat_history << { role: "SYSTEM", message: msg[:content] }
|
|
else
|
|
system_message = msg[:content]
|
|
end
|
|
when :model
|
|
chat_history << { role: "CHATBOT", message: msg[:content] }
|
|
when :tool_call
|
|
chat_history << { role: "CHATBOT", message: tool_call_to_xml(msg) }
|
|
when :tool
|
|
chat_history << { role: "USER", message: tool_result_to_xml(msg) }
|
|
when :user
|
|
user_message = { role: "USER", message: msg[:content] }
|
|
user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id]
|
|
chat_history << user_message
|
|
end
|
|
end
|
|
|
|
tools_prompt = build_tools_prompt
|
|
prompt[:preamble] = +"#{system_message}"
|
|
if tools_prompt.present?
|
|
prompt[:preamble] << "\n#{tools_prompt}"
|
|
prompt[
|
|
:preamble
|
|
] << "\nNEVER attempt to run tools using JSON, always use XML. Lives depend on it."
|
|
end
|
|
|
|
prompt[:chat_history] = chat_history if chat_history.present?
|
|
|
|
chat_history.reverse_each do |msg|
|
|
if msg[:role] == "USER"
|
|
prompt[:message] = msg[:message]
|
|
chat_history.delete(msg)
|
|
break
|
|
end
|
|
end
|
|
|
|
prompt
|
|
end
|
|
|
|
def max_prompt_tokens
|
|
case model_name
|
|
when "command-light"
|
|
4096
|
|
when "command"
|
|
8192
|
|
when "command-r"
|
|
131_072
|
|
when "command-r-plus"
|
|
131_072
|
|
else
|
|
8192
|
|
end
|
|
end
|
|
|
|
private
|
|
|
|
def per_message_overhead
|
|
0
|
|
end
|
|
|
|
def calculate_message_token(context)
|
|
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|