FEATURE: Add support for Mistral models (#919)
Adds support for mistral models (pixtral and mistral large now have presets) Also corrects token accounting in AWS bedrock models
This commit is contained in:
parent
0d7f353284
commit
755b63f31f
|
@ -13,6 +13,7 @@ class AiApiAuditLog < ActiveRecord::Base
|
||||||
Cohere = 6
|
Cohere = 6
|
||||||
Ollama = 7
|
Ollama = 7
|
||||||
SambaNova = 8
|
SambaNova = 8
|
||||||
|
Mistral = 9
|
||||||
end
|
end
|
||||||
|
|
||||||
def next_log_id
|
def next_log_id
|
||||||
|
|
|
@ -28,6 +28,9 @@ class LlmModel < ActiveRecord::Base
|
||||||
organization: :text,
|
organization: :text,
|
||||||
disable_native_tools: :checkbox,
|
disable_native_tools: :checkbox,
|
||||||
},
|
},
|
||||||
|
mistral: {
|
||||||
|
disable_native_tools: :checkbox,
|
||||||
|
},
|
||||||
google: {
|
google: {
|
||||||
disable_native_tools: :checkbox,
|
disable_native_tools: :checkbox,
|
||||||
},
|
},
|
||||||
|
|
|
@ -290,6 +290,8 @@ en:
|
||||||
open_ai-o1-preview: "Open AI's most capabale reasoning model"
|
open_ai-o1-preview: "Open AI's most capabale reasoning model"
|
||||||
samba_nova-Meta-Llama-3-1-8B-Instruct: "Efficient lightweight multilingual model"
|
samba_nova-Meta-Llama-3-1-8B-Instruct: "Efficient lightweight multilingual model"
|
||||||
samba_nova-Meta-Llama-3-1-70B-Instruct": "Powerful multipurpose model"
|
samba_nova-Meta-Llama-3-1-70B-Instruct": "Powerful multipurpose model"
|
||||||
|
mistral-mistral-large-latest: "Mistral's most powerful model"
|
||||||
|
mistral-pixtral-large-latest: "Mistral's most powerful vision capable model"
|
||||||
|
|
||||||
configured:
|
configured:
|
||||||
title: "Configured LLMs"
|
title: "Configured LLMs"
|
||||||
|
@ -325,6 +327,7 @@ en:
|
||||||
ollama: "Ollama"
|
ollama: "Ollama"
|
||||||
CDCK: "CDCK"
|
CDCK: "CDCK"
|
||||||
samba_nova: "SambaNova"
|
samba_nova: "SambaNova"
|
||||||
|
mistral: "Mistral"
|
||||||
fake: "Custom"
|
fake: "Custom"
|
||||||
|
|
||||||
provider_fields:
|
provider_fields:
|
||||||
|
|
|
@ -16,6 +16,7 @@ module DiscourseAi
|
||||||
DiscourseAi::Completions::Dialects::Claude,
|
DiscourseAi::Completions::Dialects::Claude,
|
||||||
DiscourseAi::Completions::Dialects::Command,
|
DiscourseAi::Completions::Dialects::Command,
|
||||||
DiscourseAi::Completions::Dialects::Ollama,
|
DiscourseAi::Completions::Dialects::Ollama,
|
||||||
|
DiscourseAi::Completions::Dialects::Mistral,
|
||||||
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
DiscourseAi::Completions::Dialects::OpenAiCompatible,
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,44 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
# basically the same as Open AI, except for no support for user names
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Dialects
|
||||||
|
class Mistral < ChatGpt
|
||||||
|
class << self
|
||||||
|
def can_translate?(model_provider)
|
||||||
|
model_provider == "mistral"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def translate
|
||||||
|
corrected = super
|
||||||
|
corrected.each do |msg|
|
||||||
|
msg[:content] = "" if msg[:tool_calls] && msg[:role] == "assistant"
|
||||||
|
end
|
||||||
|
corrected
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def user_msg(msg)
|
||||||
|
mapped = super
|
||||||
|
if name = mapped.delete(:name)
|
||||||
|
if mapped[:content].is_a?(String)
|
||||||
|
mapped[:content] = "#{name}: #{mapped[:content]}"
|
||||||
|
else
|
||||||
|
mapped[:content].each do |inner|
|
||||||
|
if inner[:text]
|
||||||
|
inner[:text] = "#{name}: #{inner[:text]}"
|
||||||
|
break
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
mapped
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -19,7 +19,10 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options(dialect)
|
def default_options(dialect)
|
||||||
options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" }
|
max_tokens = 4096
|
||||||
|
max_tokens = 8192 if bedrock_model_id.match?(/3.5/)
|
||||||
|
|
||||||
|
options = { max_tokens: max_tokens, anthropic_version: "bedrock-2023-05-31" }
|
||||||
|
|
||||||
options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? &&
|
options[:stop_sequences] = ["</function_calls>"] if !dialect.native_tool_support? &&
|
||||||
dialect.prompt.has_tools?
|
dialect.prompt.has_tools?
|
||||||
|
@ -40,15 +43,7 @@ module DiscourseAi
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
def prompt_size(prompt)
|
def bedrock_model_id
|
||||||
# approximation
|
|
||||||
tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
|
|
||||||
end
|
|
||||||
|
|
||||||
def model_uri
|
|
||||||
region = llm_model.lookup_custom_param("region")
|
|
||||||
|
|
||||||
bedrock_model_id =
|
|
||||||
case llm_model.name
|
case llm_model.name
|
||||||
when "claude-2"
|
when "claude-2"
|
||||||
"anthropic.claude-v2:1"
|
"anthropic.claude-v2:1"
|
||||||
|
@ -62,9 +57,20 @@ module DiscourseAi
|
||||||
"anthropic.claude-3-opus-20240229-v1:0"
|
"anthropic.claude-3-opus-20240229-v1:0"
|
||||||
when "claude-3-5-sonnet"
|
when "claude-3-5-sonnet"
|
||||||
"anthropic.claude-3-5-sonnet-20241022-v2:0"
|
"anthropic.claude-3-5-sonnet-20241022-v2:0"
|
||||||
|
when "claude-3-5-haiku"
|
||||||
|
"anthropic.claude-3-5-haiku-20241022-v1:0"
|
||||||
else
|
else
|
||||||
llm_model.name
|
llm_model.name
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def prompt_size(prompt)
|
||||||
|
# approximation
|
||||||
|
tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
|
||||||
|
end
|
||||||
|
|
||||||
|
def model_uri
|
||||||
|
region = llm_model.lookup_custom_param("region")
|
||||||
|
|
||||||
if region.blank? || bedrock_model_id.blank?
|
if region.blank? || bedrock_model_id.blank?
|
||||||
raise CompletionFailed.new(I18n.t("discourse_ai.llm_models.bedrock_invalid_url"))
|
raise CompletionFailed.new(I18n.t("discourse_ai.llm_models.bedrock_invalid_url"))
|
||||||
|
|
|
@ -20,6 +20,7 @@ module DiscourseAi
|
||||||
DiscourseAi::Completions::Endpoints::Anthropic,
|
DiscourseAi::Completions::Endpoints::Anthropic,
|
||||||
DiscourseAi::Completions::Endpoints::Cohere,
|
DiscourseAi::Completions::Endpoints::Cohere,
|
||||||
DiscourseAi::Completions::Endpoints::SambaNova,
|
DiscourseAi::Completions::Endpoints::SambaNova,
|
||||||
|
DiscourseAi::Completions::Endpoints::Mistral,
|
||||||
]
|
]
|
||||||
|
|
||||||
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
|
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Endpoints
|
||||||
|
class Mistral < OpenAi
|
||||||
|
def self.can_contact?(model_provider)
|
||||||
|
model_provider == "mistral"
|
||||||
|
end
|
||||||
|
|
||||||
|
def provider_id
|
||||||
|
AiApiAuditLog::Provider::Mistral
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -90,6 +90,24 @@ module DiscourseAi
|
||||||
endpoint: "https://api.sambanova.ai/v1/chat/completions",
|
endpoint: "https://api.sambanova.ai/v1/chat/completions",
|
||||||
provider: "samba_nova",
|
provider: "samba_nova",
|
||||||
},
|
},
|
||||||
|
{
|
||||||
|
id: "mistral",
|
||||||
|
models: [
|
||||||
|
{
|
||||||
|
name: "mistral-large-latest",
|
||||||
|
tokens: 128_000,
|
||||||
|
display_name: "Mistral Large",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "pixtral-large-latest",
|
||||||
|
tokens: 128_000,
|
||||||
|
display_name: "Pixtral Large",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
tokenizer: DiscourseAi::Tokenizer::MixtralTokenizer,
|
||||||
|
endpoint: "https://api.mistral.ai/v1/chat/completions",
|
||||||
|
provider: "mistral",
|
||||||
|
},
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -105,6 +123,7 @@ module DiscourseAi
|
||||||
google
|
google
|
||||||
azure
|
azure
|
||||||
samba_nova
|
samba_nova
|
||||||
|
mistral
|
||||||
]
|
]
|
||||||
if !Rails.env.production?
|
if !Rails.env.production?
|
||||||
providers << "fake"
|
providers << "fake"
|
||||||
|
|
|
@ -90,6 +90,16 @@ Fabricator(:ollama_model, from: :llm_model) do
|
||||||
provider_params { { enable_native_tool: true } }
|
provider_params { { enable_native_tool: true } }
|
||||||
end
|
end
|
||||||
|
|
||||||
|
Fabricator(:mistral_model, from: :llm_model) do
|
||||||
|
display_name "Mistral Large"
|
||||||
|
name "mistral-large-latest"
|
||||||
|
provider "mistral"
|
||||||
|
api_key "ABC"
|
||||||
|
tokenizer "DiscourseAi::Tokenizer::MixtralTokenizer"
|
||||||
|
url "https://api.mistral.ai/v1/chat/completions"
|
||||||
|
provider_params { { disable_native_tools: false } }
|
||||||
|
end
|
||||||
|
|
||||||
Fabricator(:seeded_model, from: :llm_model) do
|
Fabricator(:seeded_model, from: :llm_model) do
|
||||||
id "-2"
|
id "-2"
|
||||||
display_name "CDCK Hosted Model"
|
display_name "CDCK Hosted Model"
|
||||||
|
|
|
@ -0,0 +1,81 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require "rails_helper"
|
||||||
|
require_relative "dialect_context"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Completions::Dialects::Mistral do
|
||||||
|
fab!(:model) { Fabricate(:mistral_model) }
|
||||||
|
let(:context) { DialectContext.new(described_class, model) }
|
||||||
|
let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") }
|
||||||
|
let(:upload100x100) do
|
||||||
|
UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "does not include user names" do
|
||||||
|
prompt =
|
||||||
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
messages: [type: :user, content: "Hello, I am Bob", id: "bob"],
|
||||||
|
)
|
||||||
|
|
||||||
|
dialect = described_class.new(prompt, model)
|
||||||
|
|
||||||
|
# mistral has no support for name
|
||||||
|
expect(dialect.translate).to eq([{ role: "user", content: "bob: Hello, I am Bob" }])
|
||||||
|
end
|
||||||
|
|
||||||
|
it "can properly encode images" do
|
||||||
|
model.update!(vision_enabled: true)
|
||||||
|
|
||||||
|
prompt =
|
||||||
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
"You are image bot",
|
||||||
|
messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]],
|
||||||
|
)
|
||||||
|
|
||||||
|
encoded = prompt.encoded_uploads(prompt.messages.last)
|
||||||
|
|
||||||
|
image = "data:image/jpeg;base64,#{encoded[0][:base64]}"
|
||||||
|
|
||||||
|
dialect = described_class.new(prompt, model)
|
||||||
|
|
||||||
|
content = dialect.translate[1][:content]
|
||||||
|
|
||||||
|
expect(content).to eq(
|
||||||
|
[{ type: "image_url", image_url: { url: image } }, { type: "text", text: "user1: hello" }],
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "can properly map tool calls to mistral format" do
|
||||||
|
result = [
|
||||||
|
{
|
||||||
|
role: "system",
|
||||||
|
content:
|
||||||
|
"I want you to act as a title generator for written pieces. I will provide you with a text,\nand you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,\nand ensure that the meaning is maintained. Replies will utilize the language type of the topic.\n",
|
||||||
|
},
|
||||||
|
{ role: "user", content: "user1: This is a message by a user" },
|
||||||
|
{ role: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||||
|
{ role: "user", content: "user1: This is a new message by a user" },
|
||||||
|
{
|
||||||
|
role: "assistant",
|
||||||
|
content: "",
|
||||||
|
tool_calls: [
|
||||||
|
{
|
||||||
|
type: "function",
|
||||||
|
function: {
|
||||||
|
arguments: "{\"location\":\"Sydney\",\"unit\":\"c\"}",
|
||||||
|
name: "get_weather",
|
||||||
|
},
|
||||||
|
id: "tool_id",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{
|
||||||
|
role: "tool",
|
||||||
|
tool_call_id: "tool_id",
|
||||||
|
content: "\"I'm a tool result\"",
|
||||||
|
name: "get_weather",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
expect(context.multi_turn_scenario).to eq(result)
|
||||||
|
end
|
||||||
|
end
|
|
@ -26,6 +26,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
Aws::EventStream::Encoder.new.encode(aws_message)
|
Aws::EventStream::Encoder.new.encode(aws_message)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "should provide accurate max token count" do
|
||||||
|
prompt = DiscourseAi::Completions::Prompt.new("hello")
|
||||||
|
dialect = DiscourseAi::Completions::Dialects::Claude.new(prompt, model)
|
||||||
|
endpoint = DiscourseAi::Completions::Endpoints::AwsBedrock.new(model)
|
||||||
|
|
||||||
|
model.name = "claude-2"
|
||||||
|
expect(endpoint.default_options(dialect)[:max_tokens]).to eq(4096)
|
||||||
|
|
||||||
|
model.name = "claude-3-5-sonnet"
|
||||||
|
expect(endpoint.default_options(dialect)[:max_tokens]).to eq(8192)
|
||||||
|
|
||||||
|
model.name = "claude-3-5-haiku"
|
||||||
|
options = endpoint.default_options(dialect)
|
||||||
|
expect(options[:max_tokens]).to eq(8192)
|
||||||
|
end
|
||||||
|
|
||||||
describe "function calling" do
|
describe "function calling" do
|
||||||
it "supports old school xml function calls" do
|
it "supports old school xml function calls" do
|
||||||
model.provider_params["disable_native_tools"] = true
|
model.provider_params["disable_native_tools"] = true
|
||||||
|
@ -246,7 +262,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
expect(response).to eq(expected_response)
|
expect(response).to eq(expected_response)
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
"max_tokens" => 3000,
|
"max_tokens" => 4096,
|
||||||
"anthropic_version" => "bedrock-2023-05-31",
|
"anthropic_version" => "bedrock-2023-05-31",
|
||||||
"messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }],
|
"messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }],
|
||||||
"tools" => [
|
"tools" => [
|
||||||
|
@ -305,7 +321,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
"max_tokens" => 3000,
|
"max_tokens" => 4096,
|
||||||
"anthropic_version" => "bedrock-2023-05-31",
|
"anthropic_version" => "bedrock-2023-05-31",
|
||||||
"messages" => [{ "role" => "user", "content" => "hello world" }],
|
"messages" => [{ "role" => "user", "content" => "hello world" }],
|
||||||
"system" => "You are a helpful bot",
|
"system" => "You are a helpful bot",
|
||||||
|
@ -354,7 +370,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
"max_tokens" => 3000,
|
"max_tokens" => 4096,
|
||||||
"anthropic_version" => "bedrock-2023-05-31",
|
"anthropic_version" => "bedrock-2023-05-31",
|
||||||
"messages" => [{ "role" => "user", "content" => "hello world" }],
|
"messages" => [{ "role" => "user", "content" => "hello world" }],
|
||||||
"system" => "You are a helpful bot",
|
"system" => "You are a helpful bot",
|
||||||
|
|
Loading…
Reference in New Issue