2023-12-26 12:49:55 -05:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module DiscourseAi
|
|
|
|
module Completions
|
|
|
|
module Dialects
|
|
|
|
class Mixtral < Dialect
|
|
|
|
class << self
|
|
|
|
def can_translate?(model_name)
|
|
|
|
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
|
|
|
|
model_name,
|
|
|
|
)
|
|
|
|
end
|
|
|
|
|
|
|
|
def tokenizer
|
|
|
|
DiscourseAi::Tokenizer::MixtralTokenizer
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
def translate
|
2024-01-12 12:36:44 -05:00
|
|
|
messages = prompt.messages
|
2023-12-26 12:49:55 -05:00
|
|
|
|
2024-01-12 12:36:44 -05:00
|
|
|
mixtral_prompt =
|
|
|
|
trim_messages(messages).reduce(+"") do |memo, msg|
|
|
|
|
next(memo) if msg[:type] == :tool_call
|
2023-12-26 12:49:55 -05:00
|
|
|
|
2024-01-12 12:36:44 -05:00
|
|
|
if msg[:type] == :system
|
|
|
|
memo << (<<~TEXT).strip
|
|
|
|
<s> [INST]
|
|
|
|
#{msg[:content]}
|
|
|
|
#{build_tools_prompt}
|
|
|
|
[/INST] Ok </s>
|
|
|
|
TEXT
|
|
|
|
elsif msg[:type] == :model
|
|
|
|
memo << "\n#{msg[:content]}</s>"
|
|
|
|
elsif msg[:type] == :tool
|
|
|
|
memo << "\n"
|
2023-12-26 12:49:55 -05:00
|
|
|
|
2024-01-12 12:36:44 -05:00
|
|
|
memo << (<<~TEXT).strip
|
2023-12-26 12:49:55 -05:00
|
|
|
<function_results>
|
|
|
|
<result>
|
2024-01-12 12:36:44 -05:00
|
|
|
<tool_name>#{msg[:id]}</tool_name>
|
2023-12-26 12:49:55 -05:00
|
|
|
<json>
|
2024-01-12 12:36:44 -05:00
|
|
|
#{msg[:content]}
|
2023-12-26 12:49:55 -05:00
|
|
|
</json>
|
|
|
|
</result>
|
|
|
|
</function_results>
|
|
|
|
TEXT
|
|
|
|
else
|
2024-01-12 12:36:44 -05:00
|
|
|
memo << "\n[INST]#{msg[:content]}[/INST]"
|
2023-12-26 12:49:55 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
memo
|
|
|
|
end
|
2024-01-12 12:36:44 -05:00
|
|
|
|
|
|
|
mixtral_prompt << "\n" if mixtral_prompt.ends_with?("[/INST]")
|
|
|
|
|
|
|
|
mixtral_prompt
|
2023-12-26 12:49:55 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
def max_prompt_tokens
|
|
|
|
32_000
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|