FEATURE: Add support for Mistral models (#919)
Adds support for mistral models (pixtral and mistral large now have presets) Also corrects token accounting in AWS bedrock models
This commit is contained in:
parent
0d7f353284
commit
755b63f31f
|
@ -13,6 +13,7 @@ class AiApiAuditLog < ActiveRecord::Base
|
|||
Cohere = 6
|
||||
Ollama = 7
|
||||
SambaNova = 8
|
||||
Mistral = 9
|
||||
end
|
||||
|
||||
def next_log_id
|
||||
|
|
|
@ -28,6 +28,9 @@ class LlmModel < ActiveRecord::Base
|
|||
organization: :text,
|
||||
disable_native_tools: :checkbox,
|
||||
},
|
||||
mistral: {
|
||||
disable_native_tools: :checkbox,
|
||||
},
|
||||
google: {
|
||||
disable_native_tools: :checkbox,
|
||||
},
|
||||
|
|
|
@ -290,6 +290,8 @@ en:
|
|||
open_ai-o1-preview: "Open AI's most capabale reasoning model"
|
||||
samba_nova-Meta-Llama-3-1-8B-Instruct: "Efficient lightweight multilingual model"
|
||||
samba_nova-Meta-Llama-3-1-70B-Instruct": "Powerful multipurpose model"
|
||||
mistral-mistral-large-latest: "Mistral's most powerful model"
|
||||
mistral-pixtral-large-latest: "Mistral's most powerful vision capable model"
|
||||
|
||||
configured:
|
||||
title: "Configured LLMs"
|
||||
|
@ -325,6 +327,7 @@ en:
|
|||
ollama: "Ollama"
|
||||
CDCK: "CDCK"
|
||||
samba_nova: "SambaNova"
|
||||
mistral: "Mistral"
|
||||
fake: "Custom"
|
||||
|
||||
provider_fields:
|
||||
|
|
|
@ -16,6 +16,7 @@ module DiscourseAi
|
|||
DiscourseAi::Completions::Dialects::Claude,
|
||||
DiscourseAi::Completions::Dialects::Command,
|
||||
DiscourseAi::Completions::Dialects::Ollama,
|
||||
DiscourseAi::Completions::Dialects::Mistral,
|
||||
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
||||
]
|
||||
end
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
# basically the same as Open AI, except for no support for user names
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class Mistral < ChatGpt
|
||||
class << self
|
||||
def can_translate?(model_provider)
|
||||
model_provider == "mistral"
|
||||
end
|
||||
end
|
||||
|
||||
def translate
|
||||
corrected = super
|
||||
corrected.each do |msg|
|
||||
msg[:content] = "" if msg[:tool_calls] && msg[:role] == "assistant"
|
||||
end
|
||||
corrected
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def user_msg(msg)
|
||||
mapped = super
|
||||
if name = mapped.delete(:name)
|
||||
if mapped[:content].is_a?(String)
|
||||
mapped[:content] = "#{name}: #{mapped[:content]}"
|
||||
else
|
||||
mapped[:content].each do |inner|
|
||||
if inner[:text]
|
||||
inner[:text] = "#{name}: #{inner[:text]}"
|
||||
break
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
mapped
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -19,7 +19,10 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def default_options(dialect)
|
||||
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
|
||||
max_tokens = 4096
|
||||
max_tokens = 8192 if bedrock_model_id.match?(/3.5/)
|
||||
|
||||
options = { max_tokens: max_tokens, anthropic_version: "bedrock-2023-05-31" }
|
||||
|
||||
options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? &&
|
||||
dialect.prompt.has_tools?
|
||||
|
@ -40,15 +43,7 @@ module DiscourseAi
|
|||
|
||||
private
|
||||
|
||||
def prompt_size(prompt)
|
||||
# approximation
|
||||
tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
|
||||
end
|
||||
|
||||
def model_uri
|
||||
region = llm_model.lookup_custom_param("region")
|
||||
|
||||
bedrock_model_id =
|
||||
def bedrock_model_id
|
||||
case llm_model.name
|
||||
when "claude-2"
|
||||
"anthropic.claude-v2:1"
|
||||
|
@ -62,9 +57,20 @@ module DiscourseAi
|
|||
"anthropic.claude-3-opus-20240229-v1:0"
|
||||
when "claude-3-5-sonnet"
|
||||
"anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||
when "claude-3-5-haiku"
|
||||
"anthropic.claude-3-5-haiku-20241022-v1:0"
|
||||
else
|
||||
llm_model.name
|
||||
end
|
||||
end
|
||||
|
||||
def prompt_size(prompt)
|
||||
# approximation
|
||||
tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
|
||||
end
|
||||
|
||||
def model_uri
|
||||
region = llm_model.lookup_custom_param("region")
|
||||
|
||||
if region.blank? || bedrock_model_id.blank?
|
||||
raise CompletionFailed.new(I18n.t("discourse_ai.llm_models.bedrock_invalid_url"))
|
||||
|
|
|
@ -20,6 +20,7 @@ module DiscourseAi
|
|||
DiscourseAi::Completions::Endpoints::Anthropic,
|
||||
DiscourseAi::Completions::Endpoints::Cohere,
|
||||
DiscourseAi::Completions::Endpoints::SambaNova,
|
||||
DiscourseAi::Completions::Endpoints::Mistral,
|
||||
]
|
||||
|
||||
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Endpoints
|
||||
class Mistral < OpenAi
|
||||
def self.can_contact?(model_provider)
|
||||
model_provider == "mistral"
|
||||
end
|
||||
|
||||
def provider_id
|
||||
AiApiAuditLog::Provider::Mistral
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -90,6 +90,24 @@ module DiscourseAi
|
|||
endpoint: "https://api.sambanova.ai/v1/chat/completions",
|
||||
provider: "samba_nova",
|
||||
},
|
||||
{
|
||||
id: "mistral",
|
||||
models: [
|
||||
{
|
||||
name: "mistral-large-latest",
|
||||
tokens: 128_000,
|
||||
display_name: "Mistral Large",
|
||||
},
|
||||
{
|
||||
name: "pixtral-large-latest",
|
||||
tokens: 128_000,
|
||||
display_name: "Pixtral Large",
|
||||
},
|
||||
],
|
||||
tokenizer: DiscourseAi::Tokenizer::MixtralTokenizer,
|
||||
endpoint: "https://api.mistral.ai/v1/chat/completions",
|
||||
provider: "mistral",
|
||||
},
|
||||
]
|
||||
end
|
||||
end
|
||||
|
@ -105,6 +123,7 @@ module DiscourseAi
|
|||
google
|
||||
azure
|
||||
samba_nova
|
||||
mistral
|
||||
]
|
||||
if !Rails.env.production?
|
||||
providers << "fake"
|
||||
|
|
|
@ -90,6 +90,16 @@ Fabricator(:ollama_model, from: :llm_model) do
|
|||
provider_params { { enable_native_tool: true } }
|
||||
end
|
||||
|
||||
Fabricator(:mistral_model, from: :llm_model) do
|
||||
display_name "Mistral Large"
|
||||
name "mistral-large-latest"
|
||||
provider "mistral"
|
||||
api_key "ABC"
|
||||
tokenizer "DiscourseAi::Tokenizer::MixtralTokenizer"
|
||||
url "https://api.mistral.ai/v1/chat/completions"
|
||||
provider_params { { disable_native_tools: false } }
|
||||
end
|
||||
|
||||
Fabricator(:seeded_model, from: :llm_model) do
|
||||
id "-2"
|
||||
display_name "CDCK Hosted Model"
|
||||
|
|
|
@ -0,0 +1,81 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
require "rails_helper"
|
||||
require_relative "dialect_context"
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Mistral do
|
||||
fab!(:model) { Fabricate(:mistral_model) }
|
||||
let(:context) { DialectContext.new(described_class, model) }
|
||||
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
|
||||
let(:upload100x100) do
|
||||
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
|
||||
end
|
||||
|
||||
it "does not include user names" do
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
messages: [type: :user, content: "Hello, I am Bob", id: "bob"],
|
||||
)
|
||||
|
||||
dialect = described_class.new(prompt, model)
|
||||
|
||||
# mistral has no support for name
|
||||
expect(dialect.translate).to eq([{ role: "user", content: "bob: Hello, I am Bob" }])
|
||||
end
|
||||
|
||||
it "can properly encode images" do
|
||||
model.update!(vision_enabled: true)
|
||||
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are image bot",
|
||||
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
|
||||
)
|
||||
|
||||
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||
|
||||
image = "data:image/jpeg;base64,#{encoded[0][:base64]}"
|
||||
|
||||
dialect = described_class.new(prompt, model)
|
||||
|
||||
content = dialect.translate[1][:content]
|
||||
|
||||
expect(content).to eq(
|
||||
[{ type: "image_url", image_url: { url: image } }, { type: "text", text: "user1: hello" }],
|
||||
)
|
||||
end
|
||||
|
||||
it "can properly map tool calls to mistral format" do
|
||||
result = [
|
||||
{
|
||||
role: "system",
|
||||
content:
|
||||
"I want you to act as a title generator for written pieces. I will provide you with a text,\nand you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,\nand ensure that the meaning is maintained. Replies will utilize the language type of the topic.\n",
|
||||
},
|
||||
{ role: "user", content: "user1: This is a message by a user" },
|
||||
{ role: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||
{ role: "user", content: "user1: This is a new message by a user" },
|
||||
{
|
||||
role: "assistant",
|
||||
content: "",
|
||||
tool_calls: [
|
||||
{
|
||||
type: "function",
|
||||
function: {
|
||||
arguments: "{\"location\":\"Sydney\",\"unit\":\"c\"}",
|
||||
name: "get_weather",
|
||||
},
|
||||
id: "tool_id",
|
||||
},
|
||||
],
|
||||
},
|
||||
{
|
||||
role: "tool",
|
||||
tool_call_id: "tool_id",
|
||||
content: "\"I'm a tool result\"",
|
||||
name: "get_weather",
|
||||
},
|
||||
]
|
||||
expect(context.multi_turn_scenario).to eq(result)
|
||||
end
|
||||
end
|
|
@ -26,6 +26,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
Aws::EventStream::Encoder.new.encode(aws_message)
|
||||
end
|
||||
|
||||
it "should provide accurate max token count" do
|
||||
prompt = DiscourseAi::Completions::Prompt.new("hello")
|
||||
dialect = DiscourseAi::Completions::Dialects::Claude.new(prompt, model)
|
||||
endpoint = DiscourseAi::Completions::Endpoints::AwsBedrock.new(model)
|
||||
|
||||
model.name = "claude-2"
|
||||
expect(endpoint.default_options(dialect)[:max_tokens]).to eq(4096)
|
||||
|
||||
model.name = "claude-3-5-sonnet"
|
||||
expect(endpoint.default_options(dialect)[:max_tokens]).to eq(8192)
|
||||
|
||||
model.name = "claude-3-5-haiku"
|
||||
options = endpoint.default_options(dialect)
|
||||
expect(options[:max_tokens]).to eq(8192)
|
||||
end
|
||||
|
||||
describe "function calling" do
|
||||
it "supports old school xml function calls" do
|
||||
model.provider_params["disable_native_tools"] = true
|
||||
|
@ -246,7 +262,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
expect(response).to eq(expected_response)
|
||||
|
||||
expected = {
|
||||
"max_tokens" => 3000,
|
||||
"max_tokens" => 4096,
|
||||
"anthropic_version" => "bedrock-2023-05-31",
|
||||
"messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }],
|
||||
"tools" => [
|
||||
|
@ -305,7 +321,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
||||
|
||||
expected = {
|
||||
"max_tokens" => 3000,
|
||||
"max_tokens" => 4096,
|
||||
"anthropic_version" => "bedrock-2023-05-31",
|
||||
"messages" => [{ "role" => "user", "content" => "hello world" }],
|
||||
"system" => "You are a helpful bot",
|
||||
|
@ -354,7 +370,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
||||
|
||||
expected = {
|
||||
"max_tokens" => 3000,
|
||||
"max_tokens" => 4096,
|
||||
"anthropic_version" => "bedrock-2023-05-31",
|
||||
"messages" => [{ "role" => "user", "content" => "hello world" }],
|
||||
"system" => "You are a helpful bot",
|
||||
|
|
Loading…
Reference in New Issue