Mixtral (#376)
Add both Mistral and Mixtral support. Also includes vLLM-openAI inference support. Co-authored-by: Roman Rizzi <rizziromanalejandro@gmail.com>
This commit is contained in:
parent
cb325bb883
commit
5db7bf6e68
|
@ -6,6 +6,7 @@ class AiApiAuditLog < ActiveRecord::Base
|
||||||
Anthropic = 2
|
Anthropic = 2
|
||||||
HuggingFaceTextGeneration = 3
|
HuggingFaceTextGeneration = 3
|
||||||
Gemini = 4
|
Gemini = 4
|
||||||
|
Vllm = 5
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -153,6 +153,11 @@ discourse_ai:
|
||||||
ai_gemini_api_key:
|
ai_gemini_api_key:
|
||||||
default: ""
|
default: ""
|
||||||
secret: true
|
secret: true
|
||||||
|
ai_vllm_endpoint:
|
||||||
|
default: ""
|
||||||
|
ai_vllm_endpoint_srv:
|
||||||
|
default: ""
|
||||||
|
hidden: true
|
||||||
|
|
||||||
composer_ai_helper_enabled:
|
composer_ai_helper_enabled:
|
||||||
default: false
|
default: false
|
||||||
|
@ -177,6 +182,8 @@ discourse_ai:
|
||||||
- stable-beluga-2
|
- stable-beluga-2
|
||||||
- Llama2-chat-hf
|
- Llama2-chat-hf
|
||||||
- gemini-pro
|
- gemini-pro
|
||||||
|
- mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||||
|
- mistralai/Mistral-7B-Instruct-v0.2
|
||||||
ai_helper_custom_prompts_allowed_groups:
|
ai_helper_custom_prompts_allowed_groups:
|
||||||
client: true
|
client: true
|
||||||
type: group_list
|
type: group_list
|
||||||
|
@ -241,6 +248,8 @@ discourse_ai:
|
||||||
- StableBeluga2
|
- StableBeluga2
|
||||||
- Upstage-Llama-2-*-instruct-v2
|
- Upstage-Llama-2-*-instruct-v2
|
||||||
- gemini-pro
|
- gemini-pro
|
||||||
|
- mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||||
|
- mistralai/Mistral-7B-Instruct-v0.2
|
||||||
|
|
||||||
ai_summarization_discourse_service_api_endpoint: ""
|
ai_summarization_discourse_service_api_endpoint: ""
|
||||||
ai_summarization_discourse_service_api_key:
|
ai_summarization_discourse_service_api_key:
|
||||||
|
|
|
@ -16,6 +16,7 @@ module DiscourseAi
|
||||||
DiscourseAi::Completions::Dialects::ChatGpt,
|
DiscourseAi::Completions::Dialects::ChatGpt,
|
||||||
DiscourseAi::Completions::Dialects::OrcaStyle,
|
DiscourseAi::Completions::Dialects::OrcaStyle,
|
||||||
DiscourseAi::Completions::Dialects::Gemini,
|
DiscourseAi::Completions::Dialects::Gemini,
|
||||||
|
DiscourseAi::Completions::Dialects::Mixtral,
|
||||||
]
|
]
|
||||||
|
|
||||||
dialect = dialects.find { |d| d.can_translate?(model_name) }
|
dialect = dialects.find { |d| d.can_translate?(model_name) }
|
||||||
|
@ -87,6 +88,7 @@ module DiscourseAi
|
||||||
def trim_context(conversation_context)
|
def trim_context(conversation_context)
|
||||||
prompt_limit = max_prompt_tokens
|
prompt_limit = max_prompt_tokens
|
||||||
current_token_count = calculate_token_count_without_context
|
current_token_count = calculate_token_count_without_context
|
||||||
|
message_step_size = (max_prompt_tokens / 25).to_i * -1
|
||||||
|
|
||||||
conversation_context.reduce([]) do |memo, context|
|
conversation_context.reduce([]) do |memo, context|
|
||||||
break(memo) if current_token_count >= prompt_limit
|
break(memo) if current_token_count >= prompt_limit
|
||||||
|
@ -98,7 +100,7 @@ module DiscourseAi
|
||||||
# Trimming content to make sure we respect token limit.
|
# Trimming content to make sure we respect token limit.
|
||||||
while dupped_context[:content].present? &&
|
while dupped_context[:content].present? &&
|
||||||
message_tokens + current_token_count + per_message_overhead > prompt_limit
|
message_tokens + current_token_count + per_message_overhead > prompt_limit
|
||||||
dupped_context[:content] = dupped_context[:content][0..-100] || ""
|
dupped_context[:content] = dupped_context[:content][0..message_step_size] || ""
|
||||||
message_tokens = calculate_message_token(dupped_context)
|
message_tokens = calculate_message_token(dupped_context)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,76 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Dialects
|
||||||
|
class Mixtral < Dialect
|
||||||
|
class << self
|
||||||
|
def can_translate?(model_name)
|
||||||
|
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
DiscourseAi::Tokenizer::MixtralTokenizer
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def translate
|
||||||
|
mixtral_prompt = +<<~TEXT
|
||||||
|
<s> [INST]
|
||||||
|
#{prompt[:insts]}
|
||||||
|
#{build_tools_prompt}#{prompt[:post_insts]}
|
||||||
|
[/INST] Ok </s>
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
if prompt[:examples]
|
||||||
|
prompt[:examples].each do |example_pair|
|
||||||
|
mixtral_prompt << "[INST] #{example_pair.first} [/INST]\n"
|
||||||
|
mixtral_prompt << "#{example_pair.second}\n"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
mixtral_prompt << conversation_context if prompt[:conversation_context].present?
|
||||||
|
|
||||||
|
mixtral_prompt << "[INST] #{prompt[:input]} [/INST]\n"
|
||||||
|
end
|
||||||
|
|
||||||
|
def conversation_context
|
||||||
|
return "" if prompt[:conversation_context].blank?
|
||||||
|
|
||||||
|
trimmed_context = trim_context(prompt[:conversation_context])
|
||||||
|
|
||||||
|
trimmed_context
|
||||||
|
.reverse
|
||||||
|
.reduce(+"") do |memo, context|
|
||||||
|
memo << "[INST] " if context[:type] == "user"
|
||||||
|
|
||||||
|
if context[:type] == "tool"
|
||||||
|
memo << <<~TEXT
|
||||||
|
|
||||||
|
<function_results>
|
||||||
|
<result>
|
||||||
|
<tool_name>#{context[:name]}</tool_name>
|
||||||
|
<json>
|
||||||
|
#{context[:content]}
|
||||||
|
</json>
|
||||||
|
</result>
|
||||||
|
</function_results>
|
||||||
|
TEXT
|
||||||
|
else
|
||||||
|
memo << context[:content] << "\n"
|
||||||
|
memo << "[/INST]" if context[:type] == "user"
|
||||||
|
end
|
||||||
|
|
||||||
|
memo
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_prompt_tokens
|
||||||
|
32_000
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -16,6 +16,7 @@ module DiscourseAi
|
||||||
DiscourseAi::Completions::Endpoints::OpenAi,
|
DiscourseAi::Completions::Endpoints::OpenAi,
|
||||||
DiscourseAi::Completions::Endpoints::HuggingFace,
|
DiscourseAi::Completions::Endpoints::HuggingFace,
|
||||||
DiscourseAi::Completions::Endpoints::Gemini,
|
DiscourseAi::Completions::Endpoints::Gemini,
|
||||||
|
DiscourseAi::Completions::Endpoints::Vllm,
|
||||||
].detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
].detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
||||||
ek.can_contact?(model_name)
|
ek.can_contact?(model_name)
|
||||||
end
|
end
|
||||||
|
@ -228,7 +229,8 @@ module DiscourseAi
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name></tool_name>
|
<tool_name></tool_name>
|
||||||
<tool_id></tool_id>
|
<tool_id></tool_id>
|
||||||
<parameters></parameters>
|
<parameters>
|
||||||
|
</parameters>
|
||||||
</invoke>
|
</invoke>
|
||||||
</function_calls>
|
</function_calls>
|
||||||
TEXT
|
TEXT
|
||||||
|
@ -239,17 +241,28 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_to_buffer(function_buffer, response_data, partial)
|
def add_to_buffer(function_buffer, response_data, partial)
|
||||||
new_buffer = Nokogiri::HTML5.fragment(response_data + partial)
|
read_function = Nokogiri::HTML5.fragment(response_data + partial)
|
||||||
if tool_name = new_buffer.at("tool_name").text
|
|
||||||
if new_buffer.at("tool_id").nil?
|
|
||||||
tool_id_node =
|
|
||||||
Nokogiri::HTML5::DocumentFragment.parse("\n<tool_id>#{tool_name}</tool_id>")
|
|
||||||
|
|
||||||
new_buffer.at("invoke").children[1].add_next_sibling(tool_id_node)
|
if tool_name = read_function.at("tool_name").text
|
||||||
|
function_buffer.at("tool_name").inner_html = tool_name
|
||||||
|
function_buffer.at("tool_id").inner_html = tool_name
|
||||||
|
end
|
||||||
|
|
||||||
|
read_parameters =
|
||||||
|
read_function
|
||||||
|
.at("parameters")
|
||||||
|
.elements
|
||||||
|
.each do |elem|
|
||||||
|
if paramenter = function_buffer.at(elem.name)&.text
|
||||||
|
function_buffer.at(elem.name).inner_html = paramenter
|
||||||
|
else
|
||||||
|
param_node = read_function.at(elem.name)
|
||||||
|
function_buffer.at("parameters").add_child(param_node)
|
||||||
|
function_buffer.at("parameters").add_child("\n")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
new_buffer
|
function_buffer
|
||||||
end
|
end
|
||||||
|
|
||||||
def buffering_finished?(_available_functions, buffer)
|
def buffering_finished?(_available_functions, buffer)
|
||||||
|
|
|
@ -5,9 +5,14 @@ module DiscourseAi
|
||||||
module Endpoints
|
module Endpoints
|
||||||
class HuggingFace < Base
|
class HuggingFace < Base
|
||||||
def self.can_contact?(model_name)
|
def self.can_contact?(model_name)
|
||||||
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf Llama2-chat-hf].include?(
|
%w[
|
||||||
model_name,
|
StableBeluga2
|
||||||
)
|
Upstage-Llama-2-*-instruct-v2
|
||||||
|
Llama2-*-chat-hf
|
||||||
|
Llama2-chat-hf
|
||||||
|
mistralai/Mixtral-8x7B-Instruct-v0.1
|
||||||
|
mistralai/Mistral-7B-Instruct-v0.2
|
||||||
|
].include?(model_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options
|
def default_options
|
||||||
|
|
|
@ -0,0 +1,67 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Endpoints
|
||||||
|
class Vllm < Base
|
||||||
|
def self.can_contact?(model_name)
|
||||||
|
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
|
||||||
|
model_name,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def default_options
|
||||||
|
{ max_tokens_to_sample: 2000, model: model }
|
||||||
|
end
|
||||||
|
|
||||||
|
def provider_id
|
||||||
|
AiApiAuditLog::Provider::Vllm
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def model_uri
|
||||||
|
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
|
||||||
|
if service.present?
|
||||||
|
api_endpoint = "https://#{service.target}:#{service.port}/v1/completions"
|
||||||
|
else
|
||||||
|
api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/completions"
|
||||||
|
end
|
||||||
|
@uri ||= URI(api_endpoint)
|
||||||
|
end
|
||||||
|
|
||||||
|
def prepare_payload(prompt, model_params, _dialect)
|
||||||
|
default_options
|
||||||
|
.merge(model_params)
|
||||||
|
.merge(prompt: prompt)
|
||||||
|
.tap { |payload| payload[:stream] = true if @streaming_mode }
|
||||||
|
end
|
||||||
|
|
||||||
|
def prepare_request(payload)
|
||||||
|
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||||
|
|
||||||
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
|
end
|
||||||
|
|
||||||
|
def extract_completion_from(response_raw)
|
||||||
|
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
|
||||||
|
|
||||||
|
# half a line sent here
|
||||||
|
return if !parsed
|
||||||
|
|
||||||
|
parsed.dig(:text)
|
||||||
|
end
|
||||||
|
|
||||||
|
def partials_from(decoded_chunk)
|
||||||
|
decoded_chunk
|
||||||
|
.split("\n")
|
||||||
|
.map do |line|
|
||||||
|
data = line.split("data: ", 2)[1]
|
||||||
|
data == "[DONE]" ? nil : data
|
||||||
|
end
|
||||||
|
.compact
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,25 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Summarization
|
||||||
|
module Models
|
||||||
|
class Mixtral < Base
|
||||||
|
def display_name
|
||||||
|
"MistralAI's #{model}"
|
||||||
|
end
|
||||||
|
|
||||||
|
def correctly_configured?
|
||||||
|
SiteSetting.ai_hugging_face_api_url.present? || SiteSetting.ai_vllm_endpoint_srv.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
def configuration_hint
|
||||||
|
I18n.t(
|
||||||
|
"discourse_ai.summarization.configuration_hint",
|
||||||
|
count: 1,
|
||||||
|
settings: %w[ai_hugging_face_api_url ai_vllm_endpoint_srv],
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,11 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Tokenizer
|
||||||
|
class MixtralTokenizer < BasicTokenizer
|
||||||
|
def self.tokenizer
|
||||||
|
@@tokenizer ||= Tokenizers.from_file("./plugins/discourse-ai/tokenizers/mixtral.json")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -8,7 +8,7 @@
|
||||||
# url: https://meta.discourse.org/t/discourse-ai/259214
|
# url: https://meta.discourse.org/t/discourse-ai/259214
|
||||||
# required_version: 2.7.0
|
# required_version: 2.7.0
|
||||||
|
|
||||||
gem "tokenizers", "0.3.3"
|
gem "tokenizers", "0.4.2"
|
||||||
gem "tiktoken_ruby", "0.0.5"
|
gem "tiktoken_ruby", "0.0.5"
|
||||||
|
|
||||||
enabled_site_setting :discourse_ai_enabled
|
enabled_site_setting :discourse_ai_enabled
|
||||||
|
|
|
@ -0,0 +1,189 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do
|
||||||
|
subject(:dialect) { described_class.new(prompt, "mistralai/Mixtral-8x7B-Instruct-v0.1") }
|
||||||
|
|
||||||
|
let(:tool) do
|
||||||
|
{
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: [
|
||||||
|
{ name: "location", type: "string", description: "the city name", required: true },
|
||||||
|
{
|
||||||
|
name: "unit",
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:prompt) do
|
||||||
|
{
|
||||||
|
insts: <<~TEXT,
|
||||||
|
I want you to act as a title generator for written pieces. I will provide you with a text,
|
||||||
|
and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,
|
||||||
|
and ensure that the meaning is maintained. Replies will utilize the language type of the topic.
|
||||||
|
TEXT
|
||||||
|
input: <<~TEXT,
|
||||||
|
Here is the text, inside <input></input> XML tags:
|
||||||
|
<input>
|
||||||
|
To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
|
||||||
|
discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer
|
||||||
|
defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry.
|
||||||
|
|
||||||
|
Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires,
|
||||||
|
a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and
|
||||||
|
slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he
|
||||||
|
dies so that a scene may be repeated.
|
||||||
|
</input>
|
||||||
|
TEXT
|
||||||
|
post_insts:
|
||||||
|
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#translate" do
|
||||||
|
it "translates a prompt written in our generic format to the Open AI format" do
|
||||||
|
orca_style_version = <<~TEXT
|
||||||
|
<s> [INST]
|
||||||
|
#{prompt[:insts]}
|
||||||
|
#{prompt[:post_insts]}
|
||||||
|
[/INST] Ok </s>
|
||||||
|
[INST] #{prompt[:input]} [/INST]
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
translated = dialect.translate
|
||||||
|
|
||||||
|
expect(translated).to eq(orca_style_version)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "include examples in the translated prompt" do
|
||||||
|
prompt[:examples] = [
|
||||||
|
[
|
||||||
|
"<input>In the labyrinth of time, a solitary horse, etched in gold by the setting sun, embarked on an infinite journey.</input>",
|
||||||
|
"<ai>The solitary horse.,The horse etched in gold.,A horse's infinite journey.,A horse lost in time.,A horse's last ride.</ai>",
|
||||||
|
],
|
||||||
|
]
|
||||||
|
|
||||||
|
orca_style_version = <<~TEXT
|
||||||
|
<s> [INST]
|
||||||
|
#{prompt[:insts]}
|
||||||
|
#{prompt[:post_insts]}
|
||||||
|
[/INST] Ok </s>
|
||||||
|
[INST] #{prompt[:examples][0][0]} [/INST]
|
||||||
|
#{prompt[:examples][0][1]}
|
||||||
|
[INST] #{prompt[:input]} [/INST]
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
translated = dialect.translate
|
||||||
|
|
||||||
|
expect(translated).to eq(orca_style_version)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "include tools inside the prompt" do
|
||||||
|
prompt[:tools] = [tool]
|
||||||
|
|
||||||
|
orca_style_version = <<~TEXT
|
||||||
|
<s> [INST]
|
||||||
|
#{prompt[:insts]}
|
||||||
|
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||||
|
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>$TOOL_NAME</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||||
|
...
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
|
||||||
|
Here are the tools available:
|
||||||
|
|
||||||
|
<tools>
|
||||||
|
#{dialect.tools}</tools>
|
||||||
|
#{prompt[:post_insts]}
|
||||||
|
[/INST] Ok </s>
|
||||||
|
[INST] #{prompt[:input]} [/INST]
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
translated = dialect.translate
|
||||||
|
|
||||||
|
expect(translated).to eq(orca_style_version)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#conversation_context" do
|
||||||
|
let(:context) do
|
||||||
|
[
|
||||||
|
{ type: "user", name: "user1", content: "This is a new message by a user" },
|
||||||
|
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||||
|
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
it "adds conversation in reverse order (first == newer)" do
|
||||||
|
prompt[:conversation_context] = context
|
||||||
|
|
||||||
|
expected = <<~TEXT
|
||||||
|
<function_results>
|
||||||
|
<result>
|
||||||
|
<tool_name>tool_id</tool_name>
|
||||||
|
<json>
|
||||||
|
#{context.last[:content]}
|
||||||
|
</json>
|
||||||
|
</result>
|
||||||
|
</function_results>
|
||||||
|
#{context.second[:content]}
|
||||||
|
[INST] #{context.first[:content]}
|
||||||
|
[/INST]
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
translated_context = dialect.conversation_context
|
||||||
|
|
||||||
|
expect(translated_context.strip).to eq(expected.strip)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "trims content if it's getting too long" do
|
||||||
|
context.last[:content] = context.last[:content] * 6_000
|
||||||
|
prompt[:conversation_context] = context
|
||||||
|
|
||||||
|
translated_context = dialect.conversation_context
|
||||||
|
|
||||||
|
expect(translated_context.length).to be < context.last[:content].length
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#tools" do
|
||||||
|
it "translates tools to the tool syntax" do
|
||||||
|
prompt[:tools] = [tool]
|
||||||
|
|
||||||
|
translated_tool = <<~TEXT
|
||||||
|
<tool_description>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<description>Get the weather in a city</description>
|
||||||
|
<parameters>
|
||||||
|
<parameter>
|
||||||
|
<name>location</name>
|
||||||
|
<type>string</type>
|
||||||
|
<description>the city name</description>
|
||||||
|
<required>true</required>
|
||||||
|
</parameter>
|
||||||
|
<parameter>
|
||||||
|
<name>unit</name>
|
||||||
|
<type>string</type>
|
||||||
|
<description>the unit of measurement celcius c or fahrenheit f</description>
|
||||||
|
<required>true</required>
|
||||||
|
<options>c,f</options>
|
||||||
|
</parameter>
|
||||||
|
</parameters>
|
||||||
|
</tool_description>
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
expect(dialect.tools).to eq(translated_tool)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,93 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
require_relative "endpoint_examples"
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
||||||
|
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::MixtralTokenizer) }
|
||||||
|
|
||||||
|
let(:model_name) { "mistralai/Mixtral-8x7B-Instruct-v0.1" }
|
||||||
|
let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
|
||||||
|
let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) }
|
||||||
|
let(:prompt) { dialect.translate }
|
||||||
|
|
||||||
|
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||||
|
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
|
||||||
|
|
||||||
|
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
|
||||||
|
|
||||||
|
def response(content)
|
||||||
|
{
|
||||||
|
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
|
||||||
|
object: "text_completion",
|
||||||
|
created: 1_678_464_820,
|
||||||
|
model: model_name,
|
||||||
|
usage: {
|
||||||
|
prompt_tokens: 337,
|
||||||
|
completion_tokens: 162,
|
||||||
|
total_tokens: 499,
|
||||||
|
},
|
||||||
|
choices: [{ text: content, finish_reason: "stop", index: 0 }],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
|
WebMock
|
||||||
|
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
|
||||||
|
.with(body: request_body)
|
||||||
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||||
|
end
|
||||||
|
|
||||||
|
def stream_line(delta, finish_reason: nil)
|
||||||
|
+"data: " << {
|
||||||
|
id: "cmpl-#{SecureRandom.hex}",
|
||||||
|
created: 1_681_283_881,
|
||||||
|
model: model_name,
|
||||||
|
choices: [{ text: delta, finish_reason: finish_reason, index: 0 }],
|
||||||
|
index: 0,
|
||||||
|
}.to_json
|
||||||
|
end
|
||||||
|
|
||||||
|
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||||
|
chunks =
|
||||||
|
deltas.each_with_index.map do |_, index|
|
||||||
|
if index == (deltas.length - 1)
|
||||||
|
stream_line(deltas[index], finish_reason: "stop_sequence")
|
||||||
|
else
|
||||||
|
stream_line(deltas[index])
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
|
||||||
|
|
||||||
|
WebMock
|
||||||
|
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
|
||||||
|
.with(body: stream_request_body)
|
||||||
|
.to_return(status: 200, body: chunks)
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
|
||||||
|
_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<location>Sydney</location>
|
||||||
|
<unit>c</unit>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
REPLY
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<location>Sydney</location>
|
||||||
|
<unit>c</unit>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
REPLY
|
||||||
|
|
||||||
|
let(:tool_call) { invocation }
|
||||||
|
|
||||||
|
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||||
|
end
|
|
@ -21,3 +21,7 @@ Licensed under MIT License
|
||||||
## bge-large-en
|
## bge-large-en
|
||||||
|
|
||||||
Licensed under MIT License
|
Licensed under MIT License
|
||||||
|
|
||||||
|
## mixtral
|
||||||
|
|
||||||
|
Licensed under Apache 2.0 License
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue