diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index 2fa0a214..c8d02c51 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -13,6 +13,7 @@ class AiApiAuditLog < ActiveRecord::Base Cohere = 6 Ollama = 7 SambaNova = 8 + Mistral = 9 end def next_log_id diff --git a/app/models/llm_model.rb b/app/models/llm_model.rb index 877c7534..e8c21f21 100644 --- a/app/models/llm_model.rb +++ b/app/models/llm_model.rb @@ -28,6 +28,9 @@ class LlmModel < ActiveRecord::Base organization: :text, disable_native_tools: :checkbox, }, + mistral: { + disable_native_tools: :checkbox, + }, google: { disable_native_tools: :checkbox, }, diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 6d5bb721..483c3fd4 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -290,6 +290,8 @@ en: open_ai-o1-preview: "Open AI's most capabale reasoning model" samba_nova-Meta-Llama-3-1-8B-Instruct: "Efficient lightweight multilingual model" samba_nova-Meta-Llama-3-1-70B-Instruct": "Powerful multipurpose model" + mistral-mistral-large-latest: "Mistral's most powerful model" + mistral-pixtral-large-latest: "Mistral's most powerful vision capable model" configured: title: "Configured LLMs" @@ -325,6 +327,7 @@ en: ollama: "Ollama" CDCK: "CDCK" samba_nova: "SambaNova" + mistral: "Mistral" fake: "Custom" provider_fields: diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 53505214..f97da195 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -16,6 +16,7 @@ module DiscourseAi DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Command, DiscourseAi::Completions::Dialects::Ollama, + DiscourseAi::Completions::Dialects::Mistral, DiscourseAi::Completions::Dialects::OpenAiCompatible, ] end diff --git a/lib/completions/dialects/mistral.rb b/lib/completions/dialects/mistral.rb new file mode 100644 index 00000000..d6968e82 --- /dev/null +++ b/lib/completions/dialects/mistral.rb @@ -0,0 +1,44 @@ +# frozen_string_literal: true + +# basically the same as Open AI, except for no support for user names + +module DiscourseAi + module Completions + module Dialects + class Mistral < ChatGpt + class << self + def can_translate?(model_provider) + model_provider == "mistral" + end + end + + def translate + corrected = super + corrected.each do |msg| + msg[:content] = "" if msg[:tool_calls] && msg[:role] == "assistant" + end + corrected + end + + private + + def user_msg(msg) + mapped = super + if name = mapped.delete(:name) + if mapped[:content].is_a?(String) + mapped[:content] = "#{name}: #{mapped[:content]}" + else + mapped[:content].each do |inner| + if inner[:text] + inner[:text] = "#{name}: #{inner[:text]}" + break + end + end + end + end + mapped + end + end + end + end +end diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index c17a051f..1c6f67f1 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -19,7 +19,10 @@ module DiscourseAi end def default_options(dialect) - options = { max_tokens: 3_000, anthropic_version: "bedrock-2023-05-31" } + max_tokens = 4096 + max_tokens = 8192 if bedrock_model_id.match?(/3.5/) + + options = { max_tokens: max_tokens, anthropic_version: "bedrock-2023-05-31" } options[:stop_sequences] = [""] if !dialect.native_tool_support? && dialect.prompt.has_tools? @@ -40,6 +43,27 @@ module DiscourseAi private + def bedrock_model_id + case llm_model.name + when "claude-2" + "anthropic.claude-v2:1" + when "claude-3-haiku" + "anthropic.claude-3-haiku-20240307-v1:0" + when "claude-3-sonnet" + "anthropic.claude-3-sonnet-20240229-v1:0" + when "claude-instant-1" + "anthropic.claude-instant-v1" + when "claude-3-opus" + "anthropic.claude-3-opus-20240229-v1:0" + when "claude-3-5-sonnet" + "anthropic.claude-3-5-sonnet-20241022-v2:0" + when "claude-3-5-haiku" + "anthropic.claude-3-5-haiku-20241022-v1:0" + else + llm_model.name + end + end + def prompt_size(prompt) # approximation tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s) @@ -48,24 +72,6 @@ module DiscourseAi def model_uri region = llm_model.lookup_custom_param("region") - bedrock_model_id = - case llm_model.name - when "claude-2" - "anthropic.claude-v2:1" - when "claude-3-haiku" - "anthropic.claude-3-haiku-20240307-v1:0" - when "claude-3-sonnet" - "anthropic.claude-3-sonnet-20240229-v1:0" - when "claude-instant-1" - "anthropic.claude-instant-v1" - when "claude-3-opus" - "anthropic.claude-3-opus-20240229-v1:0" - when "claude-3-5-sonnet" - "anthropic.claude-3-5-sonnet-20241022-v2:0" - else - llm_model.name - end - if region.blank? || bedrock_model_id.blank? raise CompletionFailed.new(I18n.t("discourse_ai.llm_models.bedrock_invalid_url")) end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 6ad24fbc..8c9711e6 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -20,6 +20,7 @@ module DiscourseAi DiscourseAi::Completions::Endpoints::Anthropic, DiscourseAi::Completions::Endpoints::Cohere, DiscourseAi::Completions::Endpoints::SambaNova, + DiscourseAi::Completions::Endpoints::Mistral, ] endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development? diff --git a/lib/completions/endpoints/mistral.rb b/lib/completions/endpoints/mistral.rb new file mode 100644 index 00000000..5414b3df --- /dev/null +++ b/lib/completions/endpoints/mistral.rb @@ -0,0 +1,17 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + module Endpoints + class Mistral < OpenAi + def self.can_contact?(model_provider) + model_provider == "mistral" + end + + def provider_id + AiApiAuditLog::Provider::Mistral + end + end + end + end +end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 95b94ad1..1707460c 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -90,6 +90,24 @@ module DiscourseAi endpoint: "https://api.sambanova.ai/v1/chat/completions", provider: "samba_nova", }, + { + id: "mistral", + models: [ + { + name: "mistral-large-latest", + tokens: 128_000, + display_name: "Mistral Large", + }, + { + name: "pixtral-large-latest", + tokens: 128_000, + display_name: "Pixtral Large", + }, + ], + tokenizer: DiscourseAi::Tokenizer::MixtralTokenizer, + endpoint: "https://api.mistral.ai/v1/chat/completions", + provider: "mistral", + }, ] end end @@ -105,6 +123,7 @@ module DiscourseAi google azure samba_nova + mistral ] if !Rails.env.production? providers << "fake" diff --git a/spec/fabricators/llm_model_fabricator.rb b/spec/fabricators/llm_model_fabricator.rb index 421c2a6c..3195b3f5 100644 --- a/spec/fabricators/llm_model_fabricator.rb +++ b/spec/fabricators/llm_model_fabricator.rb @@ -90,6 +90,16 @@ Fabricator(:ollama_model, from: :llm_model) do provider_params { { enable_native_tool: true } } end +Fabricator(:mistral_model, from: :llm_model) do + display_name "Mistral Large" + name "mistral-large-latest" + provider "mistral" + api_key "ABC" + tokenizer "DiscourseAi::Tokenizer::MixtralTokenizer" + url "https://api.mistral.ai/v1/chat/completions" + provider_params { { disable_native_tools: false } } +end + Fabricator(:seeded_model, from: :llm_model) do id "-2" display_name "CDCK Hosted Model" diff --git a/spec/lib/completions/dialects/mistral_spec.rb b/spec/lib/completions/dialects/mistral_spec.rb new file mode 100644 index 00000000..2e373bc5 --- /dev/null +++ b/spec/lib/completions/dialects/mistral_spec.rb @@ -0,0 +1,81 @@ +# frozen_string_literal: true + +require "rails_helper" +require_relative "dialect_context" + +RSpec.describe DiscourseAi::Completions::Dialects::Mistral do + fab!(:model) { Fabricate(:mistral_model) } + let(:context) { DialectContext.new(described_class, model) } + let(:image100x100) { plugin_file_from_fixtures("100x100.jpg") } + let(:upload100x100) do + UploadCreator.new(image100x100, "image.jpg").create_for(Discourse.system_user.id) + end + + it "does not include user names" do + prompt = + DiscourseAi::Completions::Prompt.new( + messages: [type: :user, content: "Hello, I am Bob", id: "bob"], + ) + + dialect = described_class.new(prompt, model) + + # mistral has no support for name + expect(dialect.translate).to eq([{ role: "user", content: "bob: Hello, I am Bob" }]) + end + + it "can properly encode images" do + model.update!(vision_enabled: true) + + prompt = + DiscourseAi::Completions::Prompt.new( + "You are image bot", + messages: [type: :user, id: "user1", content: "hello", upload_ids: [upload100x100.id]], + ) + + encoded = prompt.encoded_uploads(prompt.messages.last) + + image = "data:image/jpeg;base64,#{encoded[0][:base64]}" + + dialect = described_class.new(prompt, model) + + content = dialect.translate[1][:content] + + expect(content).to eq( + [{ type: "image_url", image_url: { url: image } }, { type: "text", text: "user1: hello" }], + ) + end + + it "can properly map tool calls to mistral format" do + result = [ + { + role: "system", + content: + "I want you to act as a title generator for written pieces. I will provide you with a text,\nand you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,\nand ensure that the meaning is maintained. Replies will utilize the language type of the topic.\n", + }, + { role: "user", content: "user1: This is a message by a user" }, + { role: "assistant", content: "I'm a previous bot reply, that's why there's no user" }, + { role: "user", content: "user1: This is a new message by a user" }, + { + role: "assistant", + content: "", + tool_calls: [ + { + type: "function", + function: { + arguments: "{\"location\":\"Sydney\",\"unit\":\"c\"}", + name: "get_weather", + }, + id: "tool_id", + }, + ], + }, + { + role: "tool", + tool_call_id: "tool_id", + content: "\"I'm a tool result\"", + name: "get_weather", + }, + ] + expect(context.multi_turn_scenario).to eq(result) + end +end diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index 2a9cc77f..f5329d3d 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -26,6 +26,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do Aws::EventStream::Encoder.new.encode(aws_message) end + it "should provide accurate max token count" do + prompt = DiscourseAi::Completions::Prompt.new("hello") + dialect = DiscourseAi::Completions::Dialects::Claude.new(prompt, model) + endpoint = DiscourseAi::Completions::Endpoints::AwsBedrock.new(model) + + model.name = "claude-2" + expect(endpoint.default_options(dialect)[:max_tokens]).to eq(4096) + + model.name = "claude-3-5-sonnet" + expect(endpoint.default_options(dialect)[:max_tokens]).to eq(8192) + + model.name = "claude-3-5-haiku" + options = endpoint.default_options(dialect) + expect(options[:max_tokens]).to eq(8192) + end + describe "function calling" do it "supports old school xml function calls" do model.provider_params["disable_native_tools"] = true @@ -246,7 +262,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do expect(response).to eq(expected_response) expected = { - "max_tokens" => 3000, + "max_tokens" => 4096, "anthropic_version" => "bedrock-2023-05-31", "messages" => [{ "role" => "user", "content" => "what is the weather in sydney" }], "tools" => [ @@ -305,7 +321,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do expect(request.headers["X-Amz-Content-Sha256"]).to be_present expected = { - "max_tokens" => 3000, + "max_tokens" => 4096, "anthropic_version" => "bedrock-2023-05-31", "messages" => [{ "role" => "user", "content" => "hello world" }], "system" => "You are a helpful bot", @@ -354,7 +370,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do expect(request.headers["X-Amz-Content-Sha256"]).to be_present expected = { - "max_tokens" => 3000, + "max_tokens" => 4096, "anthropic_version" => "bedrock-2023-05-31", "messages" => [{ "role" => "user", "content" => "hello world" }], "system" => "You are a helpful bot",