Add both Mistral and Mixtral support. Also includes vLLM-openAI inference support.

Co-authored-by: Roman Rizzi <rizziromanalejandro@gmail.com>
This commit is contained in:
Rafael dos Santos Silva 2023-12-26 14:49:55 -03:00 committed by GitHub
parent cb325bb883
commit 5db7bf6e68
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 91631 additions and 14 deletions

View File

@ -6,6 +6,7 @@ class AiApiAuditLog < ActiveRecord::Base
Anthropic = 2
HuggingFaceTextGeneration = 3
Gemini = 4
Vllm = 5
end
end

View File

@ -153,6 +153,11 @@ discourse_ai:
ai_gemini_api_key:
default: ""
secret: true
ai_vllm_endpoint:
default: ""
ai_vllm_endpoint_srv:
default: ""
hidden: true
composer_ai_helper_enabled:
default: false
@ -177,6 +182,8 @@ discourse_ai:
- stable-beluga-2
- Llama2-chat-hf
- gemini-pro
- mistralai/Mixtral-8x7B-Instruct-v0.1
- mistralai/Mistral-7B-Instruct-v0.2
ai_helper_custom_prompts_allowed_groups:
client: true
type: group_list
@ -241,6 +248,8 @@ discourse_ai:
- StableBeluga2
- Upstage-Llama-2-*-instruct-v2
- gemini-pro
- mistralai/Mixtral-8x7B-Instruct-v0.1
- mistralai/Mistral-7B-Instruct-v0.2
ai_summarization_discourse_service_api_endpoint: ""
ai_summarization_discourse_service_api_key:

View File

@ -16,6 +16,7 @@ module DiscourseAi
DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::OrcaStyle,
DiscourseAi::Completions::Dialects::Gemini,
DiscourseAi::Completions::Dialects::Mixtral,
]
dialect = dialects.find { |d| d.can_translate?(model_name) }
@ -87,6 +88,7 @@ module DiscourseAi
def trim_context(conversation_context)
prompt_limit = max_prompt_tokens
current_token_count = calculate_token_count_without_context
message_step_size = (max_prompt_tokens / 25).to_i * -1
conversation_context.reduce([]) do |memo, context|
break(memo) if current_token_count >= prompt_limit
@ -98,7 +100,7 @@ module DiscourseAi
# Trimming content to make sure we respect token limit.
while dupped_context[:content].present? &&
message_tokens + current_token_count + per_message_overhead > prompt_limit
dupped_context[:content] = dupped_context[:content][0..-100] || ""
dupped_context[:content] = dupped_context[:content][0..message_step_size] || ""
message_tokens = calculate_message_token(dupped_context)
end

View File

@ -0,0 +1,76 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Mixtral < Dialect
class << self
def can_translate?(model_name)
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
model_name,
)
end
def tokenizer
DiscourseAi::Tokenizer::MixtralTokenizer
end
end
def translate
mixtral_prompt = +<<~TEXT
<s> [INST]
#{prompt[:insts]}
#{build_tools_prompt}#{prompt[:post_insts]}
[/INST] Ok </s>
TEXT
if prompt[:examples]
prompt[:examples].each do |example_pair|
mixtral_prompt << "[INST] #{example_pair.first} [/INST]\n"
mixtral_prompt << "#{example_pair.second}\n"
end
end
mixtral_prompt << conversation_context if prompt[:conversation_context].present?
mixtral_prompt << "[INST] #{prompt[:input]} [/INST]\n"
end
def conversation_context
return "" if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context
.reverse
.reduce(+"") do |memo, context|
memo << "[INST] " if context[:type] == "user"
if context[:type] == "tool"
memo << <<~TEXT
<function_results>
<result>
<tool_name>#{context[:name]}</tool_name>
<json>
#{context[:content]}
</json>
</result>
</function_results>
TEXT
else
memo << context[:content] << "\n"
memo << "[/INST]" if context[:type] == "user"
end
memo
end
end
def max_prompt_tokens
32_000
end
end
end
end
end

View File

@ -16,6 +16,7 @@ module DiscourseAi
DiscourseAi::Completions::Endpoints::OpenAi,
DiscourseAi::Completions::Endpoints::HuggingFace,
DiscourseAi::Completions::Endpoints::Gemini,
DiscourseAi::Completions::Endpoints::Vllm,
].detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
ek.can_contact?(model_name)
end
@ -228,7 +229,8 @@ module DiscourseAi
<invoke>
<tool_name></tool_name>
<tool_id></tool_id>
<parameters></parameters>
<parameters>
</parameters>
</invoke>
</function_calls>
TEXT
@ -239,17 +241,28 @@ module DiscourseAi
end
def add_to_buffer(function_buffer, response_data, partial)
new_buffer = Nokogiri::HTML5.fragment(response_data + partial)
if tool_name = new_buffer.at("tool_name").text
if new_buffer.at("tool_id").nil?
tool_id_node =
Nokogiri::HTML5::DocumentFragment.parse("\n<tool_id>#{tool_name}</tool_id>")
read_function = Nokogiri::HTML5.fragment(response_data + partial)
new_buffer.at("invoke").children[1].add_next_sibling(tool_id_node)
end
if tool_name = read_function.at("tool_name").text
function_buffer.at("tool_name").inner_html = tool_name
function_buffer.at("tool_id").inner_html = tool_name
end
new_buffer
read_parameters =
read_function
.at("parameters")
.elements
.each do |elem|
if paramenter = function_buffer.at(elem.name)&.text
function_buffer.at(elem.name).inner_html = paramenter
else
param_node = read_function.at(elem.name)
function_buffer.at("parameters").add_child(param_node)
function_buffer.at("parameters").add_child("\n")
end
end
function_buffer
end
def buffering_finished?(_available_functions, buffer)

View File

@ -5,9 +5,14 @@ module DiscourseAi
module Endpoints
class HuggingFace < Base
def self.can_contact?(model_name)
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2 Llama2-*-chat-hf Llama2-chat-hf].include?(
model_name,
)
%w[
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
].include?(model_name)
end
def default_options

View File

@ -0,0 +1,67 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class Vllm < Base
def self.can_contact?(model_name)
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
model_name,
)
end
def default_options
{ max_tokens_to_sample: 2000, model: model }
end
def provider_id
AiApiAuditLog::Provider::Vllm
end
private
def model_uri
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
if service.present?
api_endpoint = "https://#{service.target}:#{service.port}/v1/completions"
else
api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/completions"
end
@uri ||= URI(api_endpoint)
end
def prepare_payload(prompt, model_params, _dialect)
default_options
.merge(model_params)
.merge(prompt: prompt)
.tap { |payload| payload[:stream] = true if @streaming_mode }
end
def prepare_request(payload)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
parsed.dig(:text)
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data: ", 2)[1]
data == "[DONE]" ? nil : data
end
.compact
end
end
end
end
end

View File

@ -0,0 +1,25 @@
# frozen_string_literal: true
module DiscourseAi
module Summarization
module Models
class Mixtral < Base
def display_name
"MistralAI's #{model}"
end
def correctly_configured?
SiteSetting.ai_hugging_face_api_url.present? || SiteSetting.ai_vllm_endpoint_srv.present?
end
def configuration_hint
I18n.t(
"discourse_ai.summarization.configuration_hint",
count: 1,
settings: %w[ai_hugging_face_api_url ai_vllm_endpoint_srv],
)
end
end
end
end
end

View File

@ -0,0 +1,11 @@
# frozen_string_literal: true
module DiscourseAi
module Tokenizer
class MixtralTokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||= Tokenizers.from_file("./plugins/discourse-ai/tokenizers/mixtral.json")
end
end
end
end

View File

@ -8,7 +8,7 @@
# url: https://meta.discourse.org/t/discourse-ai/259214
# required_version: 2.7.0
gem "tokenizers", "0.3.3"
gem "tokenizers", "0.4.2"
gem "tiktoken_ruby", "0.0.5"
enabled_site_setting :discourse_ai_enabled

View File

@ -0,0 +1,189 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do
subject(:dialect) { described_class.new(prompt, "mistralai/Mixtral-8x7B-Instruct-v0.1") }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:prompt) do
{
insts: <<~TEXT,
I want you to act as a title generator for written pieces. I will provide you with a text,
and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,
and ensure that the meaning is maintained. Replies will utilize the language type of the topic.
TEXT
input: <<~TEXT,
Here is the text, inside <input></input> XML tags:
<input>
To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer
defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry.
Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires,
a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and
slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he
dies so that a scene may be repeated.
</input>
TEXT
post_insts:
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
}
end
describe "#translate" do
it "translates a prompt written in our generic format to the Open AI format" do
orca_style_version = <<~TEXT
<s> [INST]
#{prompt[:insts]}
#{prompt[:post_insts]}
[/INST] Ok </s>
[INST] #{prompt[:input]} [/INST]
TEXT
translated = dialect.translate
expect(translated).to eq(orca_style_version)
end
it "include examples in the translated prompt" do
prompt[:examples] = [
[
"<input>In the labyrinth of time, a solitary horse, etched in gold by the setting sun, embarked on an infinite journey.</input>",
"<ai>The solitary horse.,The horse etched in gold.,A horse's infinite journey.,A horse lost in time.,A horse's last ride.</ai>",
],
]
orca_style_version = <<~TEXT
<s> [INST]
#{prompt[:insts]}
#{prompt[:post_insts]}
[/INST] Ok </s>
[INST] #{prompt[:examples][0][0]} [/INST]
#{prompt[:examples][0][1]}
[INST] #{prompt[:input]} [/INST]
TEXT
translated = dialect.translate
expect(translated).to eq(orca_style_version)
end
it "include tools inside the prompt" do
prompt[:tools] = [tool]
orca_style_version = <<~TEXT
<s> [INST]
#{prompt[:insts]}
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
#{dialect.tools}</tools>
#{prompt[:post_insts]}
[/INST] Ok </s>
[INST] #{prompt[:input]} [/INST]
TEXT
translated = dialect.translate
expect(translated).to eq(orca_style_version)
end
end
describe "#conversation_context" do
let(:context) do
[
{ type: "user", name: "user1", content: "This is a new message by a user" },
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
]
end
it "adds conversation in reverse order (first == newer)" do
prompt[:conversation_context] = context
expected = <<~TEXT
<function_results>
<result>
<tool_name>tool_id</tool_name>
<json>
#{context.last[:content]}
</json>
</result>
</function_results>
#{context.second[:content]}
[INST] #{context.first[:content]}
[/INST]
TEXT
translated_context = dialect.conversation_context
expect(translated_context.strip).to eq(expected.strip)
end
it "trims content if it's getting too long" do
context.last[:content] = context.last[:content] * 6_000
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context.length).to be < context.last[:content].length
end
end
describe "#tools" do
it "translates tools to the tool syntax" do
prompt[:tools] = [tool]
translated_tool = <<~TEXT
<tool_description>
<tool_name>get_weather</tool_name>
<description>Get the weather in a city</description>
<parameters>
<parameter>
<name>location</name>
<type>string</type>
<description>the city name</description>
<required>true</required>
</parameter>
<parameter>
<name>unit</name>
<type>string</type>
<description>the unit of measurement celcius c or fahrenheit f</description>
<required>true</required>
<options>c,f</options>
</parameter>
</parameters>
</tool_description>
TEXT
expect(dialect.tools).to eq(translated_tool)
end
end
end

View File

@ -0,0 +1,93 @@
# frozen_string_literal: true
require_relative "endpoint_examples"
RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::MixtralTokenizer) }
let(:model_name) { "mistralai/Mixtral-8x7B-Instruct-v0.1" }
let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
def response(content)
{
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
object: "text_completion",
created: 1_678_464_820,
model: model_name,
usage: {
prompt_tokens: 337,
completion_tokens: 162,
total_tokens: 499,
},
choices: [{ text: content, finish_reason: "stop", index: 0 }],
}
end
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
.with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
def stream_line(delta, finish_reason: nil)
+"data: " << {
id: "cmpl-#{SecureRandom.hex}",
created: 1_681_283_881,
model: model_name,
choices: [{ text: delta, finish_reason: finish_reason, index: 0 }],
index: 0,
}.to_json
end
def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
stream_line(deltas[index], finish_reason: "stop_sequence")
else
stream_line(deltas[index])
end
end
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
.with(body: stream_request_body)
.to_return(status: 200, body: chunks)
end
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
let(:tool_call) { invocation }
it_behaves_like "an endpoint that can communicate with a completion service"
end

View File

@ -21,3 +21,7 @@ Licensed under MIT License
## bge-large-en
Licensed under MIT License
## mixtral
Licensed under Apache 2.0 License

91122
tokenizers/mixtral.json Normal file

File diff suppressed because it is too large Load Diff