FEATURE: Add support for Mistral models (#919)

Adds support for mistral models (pixtral and mistral large now have presets)

Also corrects token accounting in AWS bedrock models
This commit is contained in:
Sam 2024-11-19 17:28:09 +11:00 committed by GitHub
parent 0d7f353284
commit 755b63f31f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
12 changed files with 224 additions and 22 deletions

View File

@ -13,6 +13,7 @@ class AiApiAuditLog < ActiveRecord::Base
Cohere = 6 Cohere = 6
Ollama = 7 Ollama = 7
SambaNova = 8 SambaNova = 8
Mistral = 9
end end
def next_log_id def next_log_id

View File

@ -28,6 +28,9 @@ class LlmModel < ActiveRecord::Base
organization: :text, organization: :text,
disable_native_tools: :checkbox, disable_native_tools: :checkbox,
}, },
mistral: {
disable_native_tools: :checkbox,
},
google: { google: {
disable_native_tools: :checkbox, disable_native_tools: :checkbox,
}, },

View File

@ -290,6 +290,8 @@ en:
open_ai-o1-preview: "Open AI's most capabale reasoning model" open_ai-o1-preview: "Open AI's most capabale reasoning model"
samba_nova-Meta-Llama-3-1-8B-Instruct: "Efficient lightweight multilingual model" samba_nova-Meta-Llama-3-1-8B-Instruct: "Efficient lightweight multilingual model"
samba_nova-Meta-Llama-3-1-70B-Instruct": "Powerful multipurpose model" samba_nova-Meta-Llama-3-1-70B-Instruct": "Powerful multipurpose model"
mistral-mistral-large-latest: "Mistral's most powerful model"
mistral-pixtral-large-latest: "Mistral's most powerful vision capable model"
configured: configured:
title: "Configured LLMs" title: "Configured LLMs"
@ -325,6 +327,7 @@ en:
ollama: "Ollama" ollama: "Ollama"
CDCK: "CDCK" CDCK: "CDCK"
samba_nova: "SambaNova" samba_nova: "SambaNova"
mistral: "Mistral"
fake: "Custom" fake: "Custom"
provider_fields: provider_fields:

View File

@ -16,6 +16,7 @@ module DiscourseAi
DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Command, DiscourseAi::Completions::Dialects::Command,
DiscourseAi::Completions::Dialects::Ollama, DiscourseAi::Completions::Dialects::Ollama,
DiscourseAi::Completions::Dialects::Mistral,
DiscourseAi::Completions::Dialects::OpenAiCompatible, DiscourseAi::Completions::Dialects::OpenAiCompatible,
] ]
end end

View File

@ -0,0 +1,44 @@
# frozen_string_literal: true
# basically the same as Open AI, except for no support for user names
module DiscourseAi
module Completions
module Dialects
class Mistral < ChatGpt
class << self
def can_translate?(model_provider)
model_provider == "mistral"
end
end
def translate
corrected = super
corrected.each do |msg|
msg[:content] = "" if msg[:tool_calls] && msg[:role] == "assistant"
end
corrected
end
private
def user_msg(msg)
mapped = super
if name = mapped.delete(:name)
if mapped[:content].is_a?(String)
mapped[:content] = "#{name}: #{mapped[:content]}"
else
mapped[:content].each do |inner|
if inner[:text]
inner[:text] = "#{name}: #{inner[:text]}"
break
end
end
end
end
mapped
end
end
end
end
end

View File

@ -19,7 +19,10 @@ module DiscourseAi
end end
def default_options(dialect) def default_options(dialect)
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" } max_tokens = 4096
max_tokens = 8192 if bedrock_model_id.match?(/3.5/)
options = { max_tokens: max_tokens, anthropic_version: "bedrock-2023-05-31" }
options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? && options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? &&
dialect.prompt.has_tools? dialect.prompt.has_tools?
@ -40,15 +43,7 @@ module DiscourseAi
private private
def prompt_size(prompt) def bedrock_model_id
# approximation
tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
end
def model_uri
region = llm_model.lookup_custom_param("region")
bedrock_model_id =
case llm_model.name case llm_model.name
when "claude-2" when "claude-2"
"anthropic.claude-v2:1" "anthropic.claude-v2:1"
@ -62,9 +57,20 @@ module DiscourseAi
"anthropic.claude-3-opus-20240229-v1:0" "anthropic.claude-3-opus-20240229-v1:0"
when "claude-3-5-sonnet" when "claude-3-5-sonnet"
"anthropic.claude-3-5-sonnet-20241022-v2:0" "anthropic.claude-3-5-sonnet-20241022-v2:0"
when "claude-3-5-haiku"
"anthropic.claude-3-5-haiku-20241022-v1:0"
else else
llm_model.name llm_model.name
end end
end
def prompt_size(prompt)
# approximation
tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
end
def model_uri
region = llm_model.lookup_custom_param("region")
if region.blank? || bedrock_model_id.blank? if region.blank? || bedrock_model_id.blank?
raise CompletionFailed.new(I18n.t("discourse_ai.llm_models.bedrock_invalid_url")) raise CompletionFailed.new(I18n.t("discourse_ai.llm_models.bedrock_invalid_url"))

View File

@ -20,6 +20,7 @@ module DiscourseAi
DiscourseAi::Completions::Endpoints::Anthropic, DiscourseAi::Completions::Endpoints::Anthropic,
DiscourseAi::Completions::Endpoints::Cohere, DiscourseAi::Completions::Endpoints::Cohere,
DiscourseAi::Completions::Endpoints::SambaNova, DiscourseAi::Completions::Endpoints::SambaNova,
DiscourseAi::Completions::Endpoints::Mistral,
] ]
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development? endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?

View File

@ -0,0 +1,17 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class Mistral < OpenAi
def self.can_contact?(model_provider)
model_provider == "mistral"
end
def provider_id
AiApiAuditLog::Provider::Mistral
end
end
end
end
end

View File

@ -90,6 +90,24 @@ module DiscourseAi
endpoint: "https://api.sambanova.ai/v1/chat/completions", endpoint: "https://api.sambanova.ai/v1/chat/completions",
provider: "samba_nova", provider: "samba_nova",
}, },
{
id: "mistral",
models: [
{
name: "mistral-large-latest",
tokens: 128_000,
display_name: "Mistral Large",
},
{
name: "pixtral-large-latest",
tokens: 128_000,
display_name: "Pixtral Large",
},
],
tokenizer: DiscourseAi::Tokenizer::MixtralTokenizer,
endpoint: "https://api.mistral.ai/v1/chat/completions",
provider: "mistral",
},
] ]
end end
end end
@ -105,6 +123,7 @@ module DiscourseAi
google google
azure azure
samba_nova samba_nova
mistral
] ]
if !Rails.env.production? if !Rails.env.production?
providers << "fake" providers << "fake"

View File

@ -90,6 +90,16 @@ Fabricator(:ollama_model, from: :llm_model) do
provider_params { { enable_native_tool: true } } provider_params { { enable_native_tool: true } }
end end
Fabricator(:mistral_model, from: :llm_model) do
display_name "Mistral Large"
name "mistral-large-latest"
provider "mistral"
api_key "ABC"
tokenizer "DiscourseAi::Tokenizer::MixtralTokenizer"
url "https://api.mistral.ai/v1/chat/completions"
provider_params { { disable_native_tools: false } }
end
Fabricator(:seeded_model, from: :llm_model) do Fabricator(:seeded_model, from: :llm_model) do
id "-2" id "-2"
display_name "CDCK Hosted Model" display_name "CDCK Hosted Model"

View File

@ -0,0 +1,81 @@
# frozen_string_literal: true
require "rails_helper"
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::Mistral do
fab!(:model) { Fabricate(:mistral_model) }
let(:context) { DialectContext.new(described_class, model) }
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
let(:upload100x100) do
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
end
it "does not include user names" do
prompt =
DiscourseAi::Completions::Prompt.new(
messages: [type: :user, content: "Hello, I am Bob", id: "bob"],
)
dialect = described_class.new(prompt, model)
# mistral has no support for name
expect(dialect.translate).to eq([{ role: "user", content: "bob: Hello, I am Bob" }])
end
it "can properly encode images" do
model.update!(vision_enabled: true)
prompt =
DiscourseAi::Completions::Prompt.new(
"You are image bot",
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
)
encoded = prompt.encoded_uploads(prompt.messages.last)
image = "data:image/jpeg;base64,#{encoded[0][:base64]}"
dialect = described_class.new(prompt, model)
content = dialect.translate[1][:content]
expect(content).to eq(
[{ type: "image_url", image_url: { url: image } }, { type: "text", text: "user1: hello" }],
)
end
it "can properly map tool calls to mistral format" do
result = [
{
role: "system",
content:
"I want you to act as a title generator for written pieces. I will provide you with a text,\nand you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,\nand ensure that the meaning is maintained. Replies will utilize the language type of the topic.\n",
},
{ role: "user", content: "user1: This is a message by a user" },
{ role: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
{ role: "user", content: "user1: This is a new message by a user" },
{
role: "assistant",
content: "",
tool_calls: [
{
type: "function",
function: {
arguments: "{\"location\":\"Sydney\",\"unit\":\"c\"}",
name: "get_weather",
},
id: "tool_id",
},
],
},
{
role: "tool",
tool_call_id: "tool_id",
content: "\"I'm a tool result\"",
name: "get_weather",
},
]
expect(context.multi_turn_scenario).to eq(result)
end
end

View File

@ -26,6 +26,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
Aws::EventStream::Encoder.new.encode(aws_message) Aws::EventStream::Encoder.new.encode(aws_message)
end end
it "should provide accurate max token count" do
prompt = DiscourseAi::Completions::Prompt.new("hello")
dialect = DiscourseAi::Completions::Dialects::Claude.new(prompt, model)
endpoint = DiscourseAi::Completions::Endpoints::AwsBedrock.new(model)
model.name = "claude-2"
expect(endpoint.default_options(dialect)[:max_tokens]).to eq(4096)
model.name = "claude-3-5-sonnet"
expect(endpoint.default_options(dialect)[:max_tokens]).to eq(8192)
model.name = "claude-3-5-haiku"
options = endpoint.default_options(dialect)
expect(options[:max_tokens]).to eq(8192)
end
describe "function calling" do describe "function calling" do
it "supports old school xml function calls" do it "supports old school xml function calls" do
model.provider_params["disable_native_tools"] = true model.provider_params["disable_native_tools"] = true
@ -246,7 +262,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
expect(response).to eq(expected_response) expect(response).to eq(expected_response)
expected = { expected = {
"max_tokens" => 3000, "max_tokens" => 4096,
"anthropic_version" => "bedrock-2023-05-31", "anthropic_version" => "bedrock-2023-05-31",
"messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }], "messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }],
"tools" => [ "tools" => [
@ -305,7 +321,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
expect(request.headers["X-Amz-Content-Sha256"]).to be_present expect(request.headers["X-Amz-Content-Sha256"]).to be_present
expected = { expected = {
"max_tokens" => 3000, "max_tokens" => 4096,
"anthropic_version" => "bedrock-2023-05-31", "anthropic_version" => "bedrock-2023-05-31",
"messages" => [{ "role" => "user", "content" => "hello world" }], "messages" => [{ "role" => "user", "content" => "hello world" }],
"system" => "You are a helpful bot", "system" => "You are a helpful bot",
@ -354,7 +370,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
expect(request.headers["X-Amz-Content-Sha256"]).to be_present expect(request.headers["X-Amz-Content-Sha256"]).to be_present
expected = { expected = {
"max_tokens" => 3000, "max_tokens" => 4096,
"anthropic_version" => "bedrock-2023-05-31", "anthropic_version" => "bedrock-2023-05-31",
"messages" => [{ "role" => "user", "content" => "hello world" }], "messages" => [{ "role" => "user", "content" => "hello world" }],
"system" => "You are a helpful bot", "system" => "You are a helpful bot",