FIX: Mixtral models have system role support. (#703)

Using assistant role for system produces an error because
they expect alternating roles like user/assistant/user and so on.
Prompts cannot start with the assistant role.
This commit is contained in:
Roman Rizzi 2024-07-04 13:23:03 -03:00 committed by GitHub
parent eab2f74b58
commit 442681a3d3
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 16 additions and 64 deletions

View File

@ -13,7 +13,6 @@ module DiscourseAi
[ [
DiscourseAi::Completions::Dialects::ChatGpt, DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::Gemini, DiscourseAi::Completions::Dialects::Gemini,
DiscourseAi::Completions::Dialects::Mistral,
DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Command, DiscourseAi::Completions::Dialects::Command,
DiscourseAi::Completions::Dialects::OpenAiCompatible, DiscourseAi::Completions::Dialects::OpenAiCompatible,

View File

@ -1,59 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Mistral < Dialect
class << self
def can_translate?(model_name)
%w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
mistral
].include?(model_name)
end
end
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::MixtralTokenizer
end
def tools
@tools ||= tools_dialect.translated_tools
end
def max_prompt_tokens
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
32_000
end
private
def system_msg(msg)
{ role: "assistant", content: "<s>#{msg[:content]}</s>" }
end
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end
def tool_call_msg(msg)
tools_dialect.from_raw_tool_call(msg)
end
def tool_msg(msg)
tools_dialect.from_raw_tool(msg)
end
def user_msg(msg)
content = +""
content << "#{msg[:id]}: " if msg[:id]
content << msg[:content]
{ role: "user", content: content }
end
end
end
end
end

View File

@ -94,7 +94,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
let(:hf_mock) { HuggingFaceMock.new(endpoint) } let(:hf_mock) { HuggingFaceMock.new(endpoint) }
let(:compliance) do let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mistral, user) EndpointsCompliance.new(
self,
endpoint,
DiscourseAi::Completions::Dialects::OpenAiCompatible,
user,
)
end end
describe "#perform_completion!" do describe "#perform_completion!" do

View File

@ -69,10 +69,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
let(:anthropic_mock) { VllmMock.new(endpoint) } let(:anthropic_mock) { VllmMock.new(endpoint) }
let(:compliance) do let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mistral, user) EndpointsCompliance.new(
self,
endpoint,
DiscourseAi::Completions::Dialects::OpenAiCompatible,
user,
)
end end
let(:dialect) { DiscourseAi::Completions::Dialects::Mistral.new(generic_prompt, model_name) } let(:dialect) do
DiscourseAi::Completions::Dialects::OpenAiCompatible.new(generic_prompt, model_name)
end
let(:prompt) { dialect.translate } let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(messages: prompt).to_json } let(:request_body) { model.default_options.merge(messages: prompt).to_json }

View File

@ -3,7 +3,7 @@
RSpec.describe DiscourseAi::Completions::Llm do RSpec.describe DiscourseAi::Completions::Llm do
subject(:llm) do subject(:llm) do
described_class.new( described_class.new(
DiscourseAi::Completions::Dialects::Mistral, DiscourseAi::Completions::Dialects::OpenAiCompatible,
canned_response, canned_response,
"hugging_face:Upstage-Llama-2-*-instruct-v2", "hugging_face:Upstage-Llama-2-*-instruct-v2",
gateway: canned_response, gateway: canned_response,