REFACTOR: Migrate Vllm/TGI-served models to the OpenAI format. (#588)

Both endpoints provide OpenAI-compatible servers. The only difference is that Vllm doesn't support passing tools as a separate parameter. Even if the tool param is supported, it ultimately relies on the model's ability to handle native functions, which is not the case with the models we have today.

As a part of this change, we are dropping support for StableBeluga/Llama2 models. They don't have a chat_template, meaning the new API can translate them.

These changes let us remove some of our existing dialects and are a first step in our plan to support any LLM by defining them as data-driven concepts.

 I rewrote the "translate" method to use a template method and extracted the tool support strategies into its classes to simplify the code.

Finally, these changes bring support for Ollama when running in dev mode. It only works with Mistral for now, but it will change soon..
This commit is contained in:
Roman Rizzi 2024-05-07 10:02:16 -03:00 committed by GitHub
parent dacc1b9f28
commit 4f1a3effe0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 665 additions and 845 deletions

View File

@ -11,6 +11,7 @@ class AiApiAuditLog < ActiveRecord::Base
Gemini = 4 Gemini = 4
Vllm = 5 Vllm = 5
Cohere = 6 Cohere = 6
Ollama = 7
end end
end end

View File

@ -185,6 +185,9 @@ discourse_ai:
ai_strict_token_counting: ai_strict_token_counting:
default: false default: false
hidden: true hidden: true
ai_ollama_endpoint:
hidden: true
default: ""
composer_ai_helper_enabled: composer_ai_helper_enabled:
default: false default: false

View File

@ -164,12 +164,15 @@ module DiscourseAi
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
"open_ai:gpt-3.5-turbo-16k" "open_ai:gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?( mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
"mistralai/Mixtral-8x7B-Instruct-v0.1", if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model)
) "vllm:#{mixtral_model}"
"vllm:mistralai/Mixtral-8x7B-Instruct-v0.1" elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?(
mixtral_model,
)
"hugging_face:#{mixtral_model}"
else else
"hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1" "ollama:mistral"
end end
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"google:gemini-pro" "google:gemini-pro"

View File

@ -40,8 +40,10 @@ module DiscourseAi
if model.start_with?("mistral") if model.start_with?("mistral")
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(model) if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(model)
return "vllm:#{model}" return "vllm:#{model}"
elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?(model)
"hugging_face:#{model}"
else else
return "hugging_face:#{model}" "ollama:mistral"
end end
end end

View File

@ -6,14 +6,7 @@ module DiscourseAi
class ChatGpt < Dialect class ChatGpt < Dialect
class << self class << self
def can_translate?(model_name) def can_translate?(model_name)
%w[ model_name.starts_with?("gpt-")
gpt-3.5-turbo
gpt-4
gpt-3.5-turbo-16k
gpt-4-32k
gpt-4-turbo
gpt-4-vision-preview
].include?(model_name)
end end
def tokenizer def tokenizer
@ -23,72 +16,17 @@ module DiscourseAi
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def native_tool_support?
true
end
def translate def translate
messages = prompt.messages @embed_user_ids =
prompt.messages.any? do |m|
# ChatGPT doesn't use an assistant msg to improve long-context responses.
if messages.last[:type] == :model
messages = messages.dup
messages.pop
end
trimmed_messages = trim_messages(messages)
embed_user_ids =
trimmed_messages.any? do |m|
m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX) m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX)
end end
trimmed_messages.map do |msg| super
if msg[:type] == :system
{ role: "system", content: msg[:content] }
elsif msg[:type] == :model
{ role: "assistant", content: msg[:content] }
elsif msg[:type] == :tool_call
call_details = JSON.parse(msg[:content], symbolize_names: true)
call_details[:arguments] = call_details[:arguments].to_json
call_details[:name] = msg[:name]
{
role: "assistant",
content: nil,
tool_calls: [{ type: "function", function: call_details, id: msg[:id] }],
}
elsif msg[:type] == :tool
{ role: "tool", tool_call_id: msg[:id], content: msg[:content], name: msg[:name] }
else
user_message = { role: "user", content: msg[:content] }
if msg[:id]
if embed_user_ids
user_message[:content] = "#{msg[:id]}: #{msg[:content]}"
else
user_message[:name] = msg[:id]
end
end
user_message[:content] = inline_images(user_message[:content], msg)
user_message
end
end
end
def tools
prompt.tools.map do |t|
tool = t.dup
tool[:parameters] = t[:parameters]
.to_a
.reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
name = p[:name]
memo[:required] << name if p[:required]
memo[:properties][name] = p.except(:name, :required, :item_type)
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
memo
end
{ type: "function", function: tool }
end
end end
def max_prompt_tokens def max_prompt_tokens
@ -107,6 +45,41 @@ module DiscourseAi
private private
def tools_dialect
@tools_dialect ||= DiscourseAi::Completions::Dialects::OpenAiTools.new(prompt.tools)
end
def system_msg(msg)
{ role: "system", content: msg[:content] }
end
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end
def tool_call_msg(msg)
tools_dialect.from_raw_tool_call(msg)
end
def tool_msg(msg)
tools_dialect.from_raw_tool(msg)
end
def user_msg(msg)
user_message = { role: "user", content: msg[:content] }
if msg[:id]
if @embed_user_ids
user_message[:content] = "#{msg[:id]}: #{msg[:content]}"
else
user_message[:name] = msg[:id]
end
end
user_message[:content] = inline_images(user_message[:content], msg)
user_message
end
def inline_images(content, message) def inline_images(content, message)
if model_name.include?("gpt-4-vision") || model_name == "gpt-4-turbo" if model_name.include?("gpt-4-vision") || model_name == "gpt-4-turbo"
content = message[:content] content = message[:content]

View File

@ -27,41 +27,13 @@ module DiscourseAi
end end
def translate def translate
messages = prompt.messages messages = super
system_prompt = +""
messages = system_prompt = messages.shift[:content] if messages.first[:role] == "system"
trim_messages(messages)
.map do |msg|
case msg[:type]
when :system
system_prompt << msg[:content]
nil
when :tool_call
{ role: "assistant", content: tool_call_to_xml(msg) }
when :tool
{ role: "user", content: tool_result_to_xml(msg) }
when :model
{ role: "assistant", content: msg[:content] }
when :user
content = +""
content << "#{msg[:id]}: " if msg[:id]
content << msg[:content]
content = inline_images(content, msg)
{ role: "user", content: content }
end
end
.compact
if prompt.tools.present?
system_prompt << "\n\n"
system_prompt << build_tools_prompt
end
interleving_messages = [] interleving_messages = []
previous_message = nil previous_message = nil
messages.each do |message| messages.each do |message|
if previous_message if previous_message
if previous_message[:role] == "user" && message[:role] == "user" if previous_message[:role] == "user" && message[:role] == "user"
@ -84,6 +56,29 @@ module DiscourseAi
private private
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end
def system_msg(msg)
msg = { role: "system", content: msg[:content] }
if tools_dialect.instructions.present?
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
end
msg
end
def user_msg(msg)
content = +""
content << "#{msg[:id]}: " if msg[:id]
content << msg[:content]
content = inline_images(content, msg)
{ role: "user", content: content }
end
def inline_images(content, message) def inline_images(content, message)
if model_name.include?("claude-3") if model_name.include?("claude-3")
encoded_uploads = prompt.encoded_uploads(message) encoded_uploads = prompt.encoded_uploads(message)

View File

@ -19,57 +19,17 @@ module DiscourseAi
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def translate def translate
messages = prompt.messages messages = super
# ChatGPT doesn't use an assistant msg to improve long-context responses. system_message = messages.shift[:message] if messages.first[:role] == "SYSTEM"
if messages.last[:type] == :model
messages = messages.dup
messages.pop
end
trimmed_messages = trim_messages(messages) prompt = { preamble: +"#{system_message}" }
prompt[:chat_history] = messages if messages.present?
chat_history = [] messages.reverse_each do |msg|
system_message = nil
prompt = {}
trimmed_messages.each do |msg|
case msg[:type]
when :system
if system_message
chat_history << { role: "SYSTEM", message: msg[:content] }
else
system_message = msg[:content]
end
when :model
chat_history << { role: "CHATBOT", message: msg[:content] }
when :tool_call
chat_history << { role: "CHATBOT", message: tool_call_to_xml(msg) }
when :tool
chat_history << { role: "USER", message: tool_result_to_xml(msg) }
when :user
user_message = { role: "USER", message: msg[:content] }
user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id]
chat_history << user_message
end
end
tools_prompt = build_tools_prompt
prompt[:preamble] = +"#{system_message}"
if tools_prompt.present?
prompt[:preamble] << "\n#{tools_prompt}"
prompt[
:preamble
] << "\nNEVER attempt to run tools using JSON, always use XML. Lives depend on it."
end
prompt[:chat_history] = chat_history if chat_history.present?
chat_history.reverse_each do |msg|
if msg[:role] == "USER" if msg[:role] == "USER"
prompt[:message] = msg[:message] prompt[:message] = msg[:message]
chat_history.delete(msg) messages.delete(msg)
break break
end end
end end
@ -101,6 +61,43 @@ module DiscourseAi
def calculate_message_token(context) def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end end
def tools_dialect
@tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools)
end
def system_msg(msg)
cmd_msg = { role: "SYSTEM", message: msg[:content] }
if tools_dialect.instructions.present?
cmd_msg[:message] = [
msg[:content],
tools_dialect.instructions,
"NEVER attempt to run tools using JSON, always use XML. Lives depend on it.",
].join("\n")
end
cmd_msg
end
def model_msg(msg)
{ role: "CHATBOT", message: msg[:content] }
end
def tool_call_msg(msg)
{ role: "CHATBOT", message: tools_dialect.from_raw_tool_call(msg) }
end
def tool_msg(msg)
{ role: "USER", message: tools_dialect.from_raw_tool(msg) }
end
def user_msg(msg)
user_message = { role: "USER", message: msg[:content] }
user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id]
user_message
end
end end
end end
end end

View File

@ -11,11 +11,9 @@ module DiscourseAi
def dialect_for(model_name) def dialect_for(model_name)
dialects = [ dialects = [
DiscourseAi::Completions::Dialects::Llama2Classic,
DiscourseAi::Completions::Dialects::ChatGpt, DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::OrcaStyle,
DiscourseAi::Completions::Dialects::Gemini, DiscourseAi::Completions::Dialects::Gemini,
DiscourseAi::Completions::Dialects::Mixtral, DiscourseAi::Completions::Dialects::Mistral,
DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Command, DiscourseAi::Completions::Dialects::Command,
] ]
@ -32,40 +30,6 @@ module DiscourseAi
def tokenizer def tokenizer
raise NotImplemented raise NotImplemented
end end
def tool_preamble(include_array_tip: true)
array_tip =
if include_array_tip
<<~TEXT
If a parameter type is an array, return an array of values. For example:
<$PARAMETER_NAME>["one","two","three"]</$PARAMETER_NAME>
TEXT
else
""
end
<<~TEXT
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this.
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
#{array_tip}
If you wish to call multiple function in one reply, wrap multiple <invoke>
block in a single <function_calls> block.
Always prefer to lead with tool calls, if you need to execute any.
Avoid all niceties prior to tool calls, Eg: "Let me look this up for you.." etc.
Here are the complete list of tools available:
TEXT
end
end end
def initialize(generic_prompt, model_name, opts: {}) def initialize(generic_prompt, model_name, opts: {})
@ -74,74 +38,30 @@ module DiscourseAi
@opts = opts @opts = opts
end end
def translate VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
raise NotImplemented
def can_end_with_assistant_msg?
false
end end
def tool_result_to_xml(message) def native_tool_support?
(<<~TEXT).strip false
<function_results>
<result>
<tool_name>#{message[:name] || message[:id]}</tool_name>
<json>
#{message[:content]}
</json>
</result>
</function_results>
TEXT
end
def tool_call_to_xml(message)
parsed = JSON.parse(message[:content], symbolize_names: true)
parameters = +""
if parsed[:arguments]
parameters << "<parameters>\n"
parsed[:arguments].each { |k, v| parameters << "<#{k}>#{v}</#{k}>\n" }
parameters << "</parameters>\n"
end
(<<~TEXT).strip
<function_calls>
<invoke>
<tool_name>#{message[:name] || parsed[:name]}</tool_name>
#{parameters}</invoke>
</function_calls>
TEXT
end end
def tools def tools
tools = +"" @tools ||= tools_dialect.translated_tools
end
prompt.tools.each do |function| def translate
parameters = +"" messages = prompt.messages
if function[:parameters].present?
function[:parameters].each do |parameter|
parameters << <<~PARAMETER
<parameter>
<name>#{parameter[:name]}</name>
<type>#{parameter[:type]}</type>
<description>#{parameter[:description]}</description>
<required>#{parameter[:required]}</required>
PARAMETER
if parameter[:enum]
parameters << "<options>#{parameter[:enum].join(",")}</options>\n"
end
parameters << "</parameter>\n"
end
end
tools << <<~TOOLS # Some models use an assistant msg to improve long-context responses.
<tool_description> if messages.last[:type] == :model && can_end_with_assistant_msg?
<tool_name>#{function[:name]}</tool_name> messages = messages.dup
<description>#{function[:description]}</description> messages.pop
<parameters>
#{parameters}</parameters>
</tool_description>
TOOLS
end end
tools trim_messages(messages).map { |msg| send("#{msg[:type]}_msg", msg) }.compact
end end
def conversation_context def conversation_context
@ -154,19 +74,6 @@ module DiscourseAi
attr_reader :prompt attr_reader :prompt
def build_tools_prompt
return "" if prompt.tools.blank?
has_arrays =
prompt.tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } }
(<<~TEXT).strip
#{self.class.tool_preamble(include_array_tip: has_arrays)}
<tools>
#{tools}</tools>
TEXT
end
private private
attr_reader :model_name, :opts attr_reader :model_name, :opts
@ -230,6 +137,30 @@ module DiscourseAi
def calculate_message_token(msg) def calculate_message_token(msg)
self.class.tokenizer.size(msg[:content].to_s) self.class.tokenizer.size(msg[:content].to_s)
end end
def tools_dialect
@tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools)
end
def system_msg(msg)
raise NotImplemented
end
def assistant_msg(msg)
raise NotImplemented
end
def user_msg(msg)
raise NotImplemented
end
def tool_call_msg(msg)
{ role: "assistant", content: tools_dialect.from_raw_tool_call(msg) }
end
def tool_msg(msg)
{ role: "user", content: tools_dialect.from_raw_tool(msg) }
end
end end
end end
end end

View File

@ -9,14 +9,14 @@ module DiscourseAi
model_name == "fake" model_name == "fake"
end end
def translate
""
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer DiscourseAi::Tokenizer::OpenAiTokenizer
end end
end end
def translate
""
end
end end
end end
end end

View File

@ -14,59 +14,30 @@ module DiscourseAi
end end
end end
def native_tool_support?
true
end
def translate def translate
# Gemini complains if we don't alternate model/user roles. # Gemini complains if we don't alternate model/user roles.
noop_model_response = { role: "model", parts: { text: "Ok." } } noop_model_response = { role: "model", parts: { text: "Ok." } }
messages = super
messages = prompt.messages interleving_messages = []
previous_message = nil
# Gemini doesn't use an assistant msg to improve long-context responses. messages.each do |message|
messages.pop if messages.last[:type] == :model if previous_message
if (previous_message[:role] == "user" || previous_message[:role] == "function") &&
memo = [] message[:role] == "user"
interleving_messages << noop_model_response.dup
trim_messages(messages).each do |msg|
if msg[:type] == :system
memo << { role: "user", parts: { text: msg[:content] } }
memo << noop_model_response.dup
elsif msg[:type] == :model
memo << { role: "model", parts: { text: msg[:content] } }
elsif msg[:type] == :tool_call
call_details = JSON.parse(msg[:content], symbolize_names: true)
memo << {
role: "model",
parts: {
functionCall: {
name: msg[:name] || call_details[:name],
args: call_details[:arguments],
},
},
}
elsif msg[:type] == :tool
memo << {
role: "function",
parts: {
functionResponse: {
name: msg[:name] || msg[:id],
response: {
content: msg[:content],
},
},
},
}
else
# Gemini quirk. Doesn't accept tool -> user or user -> user msgs.
previous_msg_role = memo.last&.dig(:role)
if previous_msg_role == "user" || previous_msg_role == "function"
memo << noop_model_response.dup
end end
memo << { role: "user", parts: { text: msg[:content] } }
end end
interleving_messages << message
previous_message = message
end end
memo interleving_messages
end end
def tools def tools
@ -110,6 +81,46 @@ module DiscourseAi
def calculate_message_token(context) def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end end
def system_msg(msg)
{ role: "user", parts: { text: msg[:content] } }
end
def model_msg(msg)
{ role: "model", parts: { text: msg[:content] } }
end
def user_msg(msg)
{ role: "user", parts: { text: msg[:content] } }
end
def tool_call_msg(msg)
call_details = JSON.parse(msg[:content], symbolize_names: true)
{
role: "model",
parts: {
functionCall: {
name: msg[:name] || call_details[:name],
args: call_details[:arguments],
},
},
}
end
def tool_msg(msg)
{
role: "function",
parts: {
functionResponse: {
name: msg[:name] || msg[:id],
response: {
content: msg[:content],
},
},
},
}
end
end end
end end
end end

View File

@ -1,68 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Llama2Classic < Dialect
class << self
def can_translate?(model_name)
%w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
end
def tokenizer
DiscourseAi::Tokenizer::Llama2Tokenizer
end
end
def translate
messages = prompt.messages
llama2_prompt =
trim_messages(messages).reduce(+"") do |memo, msg|
next(memo) if msg[:type] == :tool_call
if msg[:type] == :system
memo << (<<~TEXT).strip
[INST]
<<SYS>>
#{msg[:content]}
#{build_tools_prompt}
<</SYS>>
[/INST]
TEXT
elsif msg[:type] == :model
memo << "\n#{msg[:content]}"
elsif msg[:type] == :tool
JSON.parse(msg[:content], symbolize_names: true)
memo << "\n[INST]\n"
memo << (<<~TEXT).strip
<function_results>
<result>
<tool_name>#{msg[:id]}</tool_name>
<json>
#{msg[:content]}
</json>
</result>
</function_results>
[/INST]
TEXT
else
memo << "\n[INST]#{msg[:content]}[/INST]"
end
memo
end
llama2_prompt << "\n" if llama2_prompt.ends_with?("[/INST]")
llama2_prompt
end
def max_prompt_tokens
SiteSetting.ai_hugging_face_token_limit
end
end
end
end
end

View File

@ -0,0 +1,57 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Mistral < Dialect
class << self
def can_translate?(model_name)
%w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
mistral
].include?(model_name)
end
def tokenizer
DiscourseAi::Tokenizer::MixtralTokenizer
end
end
def tools
@tools ||= tools_dialect.translated_tools
end
def max_prompt_tokens
32_000
end
private
def system_msg(msg)
{ role: "assistant", content: "<s>#{msg[:content]}</s>" }
end
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end
def tool_call_msg(msg)
tools_dialect.from_raw_tool_call(msg)
end
def tool_msg(msg)
tools_dialect.from_raw_tool(msg)
end
def user_msg(msg)
content = +""
content << "#{msg[:id]}: " if msg[:id]
content << msg[:content]
{ role: "user", content: content }
end
end
end
end
end

View File

@ -1,57 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Mixtral < Dialect
class << self
def can_translate?(model_name)
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
model_name,
)
end
def tokenizer
DiscourseAi::Tokenizer::MixtralTokenizer
end
end
def translate
messages = prompt.messages
mixtral_prompt =
trim_messages(messages).reduce(+"") do |memo, msg|
if msg[:type] == :tool_call
memo << "\n"
memo << tool_call_to_xml(msg)
elsif msg[:type] == :system
memo << (<<~TEXT).strip
<s> [INST]
#{msg[:content]}
#{build_tools_prompt}
[/INST] Ok </s>
TEXT
elsif msg[:type] == :model
memo << "\n#{msg[:content]}</s>"
elsif msg[:type] == :tool
memo << "\n"
memo << tool_result_to_xml(msg)
else
memo << "\n[INST]#{msg[:content]}[/INST]"
end
memo
end
mixtral_prompt << "\n" if mixtral_prompt.ends_with?("[/INST]")
mixtral_prompt
end
def max_prompt_tokens
32_000
end
end
end
end
end

View File

@ -0,0 +1,62 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class OpenAiTools
def initialize(tools)
@raw_tools = tools
end
def translated_tools
raw_tools.map do |t|
tool = t.dup
tool[:parameters] = t[:parameters]
.to_a
.reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
name = p[:name]
memo[:required] << name if p[:required]
memo[:properties][name] = p.except(:name, :required, :item_type)
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
memo
end
{ type: "function", function: tool }
end
end
def instructions
"" # Noop. Tools are listed separate.
end
def from_raw_tool_call(raw_message)
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
call_details[:arguments] = call_details[:arguments].to_json
call_details[:name] = raw_message[:name]
{
role: "assistant",
content: nil,
tool_calls: [{ type: "function", function: call_details, id: raw_message[:id] }],
}
end
def from_raw_tool(raw_message)
{
role: "tool",
tool_call_id: raw_message[:id],
content: raw_message[:content],
name: raw_message[:name],
}
end
private
attr_reader :raw_tools
end
end
end
end

View File

@ -1,59 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class OrcaStyle < Dialect
class << self
def can_translate?(model_name)
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name)
end
def tokenizer
DiscourseAi::Tokenizer::Llama2Tokenizer
end
end
def translate
messages = prompt.messages
trimmed_messages = trim_messages(messages)
# Need to include this differently
last_message = trimmed_messages.last[:type] == :assistant ? trimmed_messages.pop : nil
llama2_prompt =
trimmed_messages.reduce(+"") do |memo, msg|
if msg[:type] == :tool_call
memo << "\n### Assistant:\n"
memo << tool_call_to_xml(msg)
elsif msg[:type] == :system
memo << (<<~TEXT).strip
### System:
#{msg[:content]}
#{build_tools_prompt}
TEXT
elsif msg[:type] == :model
memo << "\n### Assistant:\n#{msg[:content]}"
elsif msg[:type] == :tool
memo << "\n### User:\n"
memo << tool_result_to_xml(msg)
else
memo << "\n### User:\n#{msg[:content]}"
end
memo
end
llama2_prompt << "\n### Assistant:\n"
llama2_prompt << "#{last_message[:content]}:" if last_message
llama2_prompt
end
def max_prompt_tokens
SiteSetting.ai_hugging_face_token_limit
end
end
end
end
end

View File

@ -0,0 +1,125 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class XmlTools
def initialize(tools)
@raw_tools = tools
end
def translated_tools
raw_tools.reduce(+"") do |tools, function|
parameters = +""
if function[:parameters].present?
function[:parameters].each do |parameter|
parameters << <<~PARAMETER
<parameter>
<name>#{parameter[:name]}</name>
<type>#{parameter[:type]}</type>
<description>#{parameter[:description]}</description>
<required>#{parameter[:required]}</required>
PARAMETER
if parameter[:enum]
parameters << "<options>#{parameter[:enum].join(",")}</options>\n"
end
parameters << "</parameter>\n"
end
end
tools << <<~TOOLS
<tool_description>
<tool_name>#{function[:name]}</tool_name>
<description>#{function[:description]}</description>
<parameters>
#{parameters}</parameters>
</tool_description>
TOOLS
end
end
def instructions
return "" if raw_tools.blank?
has_arrays = raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } }
(<<~TEXT).strip
#{tool_preamble(include_array_tip: has_arrays)}
<tools>
#{translated_tools}</tools>
TEXT
end
def from_raw_tool(raw_message)
(<<~TEXT).strip
<function_results>
<result>
<tool_name>#{raw_message[:name] || raw_message[:id]}</tool_name>
<json>
#{raw_message[:content]}
</json>
</result>
</function_results>
TEXT
end
def from_raw_tool_call(raw_message)
parsed = JSON.parse(raw_message[:content], symbolize_names: true)
parameters = +""
if parsed[:arguments]
parameters << "<parameters>\n"
parsed[:arguments].each { |k, v| parameters << "<#{k}>#{v}</#{k}>\n" }
parameters << "</parameters>\n"
end
(<<~TEXT).strip
<function_calls>
<invoke>
<tool_name>#{raw_message[:name] || parsed[:name]}</tool_name>
#{parameters}</invoke>
</function_calls>
TEXT
end
private
attr_reader :raw_tools
def tool_preamble(include_array_tip: true)
array_tip =
if include_array_tip
<<~TEXT
If a parameter type is an array, return an array of values. For example:
<$PARAMETER_NAME>["one","two","three"]</$PARAMETER_NAME>
TEXT
else
""
end
<<~TEXT
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this.
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
#{array_tip}
If you wish to call multiple function in one reply, wrap multiple <invoke>
block in a single <function_calls> block.
Always prefer to lead with tool calls, if you need to execute any.
Avoid all niceties prior to tool calls, Eg: "Let me look this up for you.." etc.
Here are the complete list of tools available:
TEXT
end
end
end
end
end

View File

@ -62,7 +62,7 @@ module DiscourseAi
# this is an approximation, we will update it later if request goes through # this is an approximation, we will update it later if request goes through
def prompt_size(prompt) def prompt_size(prompt)
super(prompt.system_prompt.to_s + " " + prompt.messages.to_s) tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
end end
def model_uri def model_uri

View File

@ -51,7 +51,7 @@ module DiscourseAi
def prompt_size(prompt) def prompt_size(prompt)
# approximation # approximation
super(prompt.system_prompt.to_s + " " + prompt.messages.to_s) tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
end end
def model_uri def model_uri

View File

@ -19,6 +19,8 @@ module DiscourseAi
DiscourseAi::Completions::Endpoints::Cohere, DiscourseAi::Completions::Endpoints::Cohere,
] ]
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
if Rails.env.test? || Rails.env.development? if Rails.env.test? || Rails.env.development?
endpoints << DiscourseAi::Completions::Endpoints::Fake endpoints << DiscourseAi::Completions::Endpoints::Fake
end end
@ -67,6 +69,10 @@ module DiscourseAi
false false
end end
def use_ssl?
true
end
def perform_completion!(dialect, user, model_params = {}, &blk) def perform_completion!(dialect, user, model_params = {}, &blk)
allow_tools = dialect.prompt.has_tools? allow_tools = dialect.prompt.has_tools?
model_params = normalize_model_params(model_params) model_params = normalize_model_params(model_params)
@ -78,7 +84,7 @@ module DiscourseAi
FinalDestination::HTTP.start( FinalDestination::HTTP.start(
model_uri.host, model_uri.host,
model_uri.port, model_uri.port,
use_ssl: true, use_ssl: use_ssl?,
read_timeout: TIMEOUT, read_timeout: TIMEOUT,
open_timeout: TIMEOUT, open_timeout: TIMEOUT,
write_timeout: TIMEOUT, write_timeout: TIMEOUT,
@ -315,7 +321,7 @@ module DiscourseAi
end end
def extract_prompt_for_tokenizer(prompt) def extract_prompt_for_tokenizer(prompt)
prompt prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end end
def build_buffer def build_buffer

View File

@ -8,14 +8,9 @@ module DiscourseAi
def can_contact?(endpoint_name, model_name) def can_contact?(endpoint_name, model_name)
return false unless endpoint_name == "hugging_face" return false unless endpoint_name == "hugging_face"
%w[ %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
StableBeluga2 model_name,
Upstage-Llama-2-*-instruct-v2 )
Llama2-*-chat-hf
Llama2-chat-hf
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
].include?(model_name)
end end
def dependant_setting_names def dependant_setting_names
@ -31,24 +26,21 @@ module DiscourseAi
end end
end end
def default_options
{ parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } }
end
def normalize_model_params(model_params) def normalize_model_params(model_params)
model_params = model_params.dup model_params = model_params.dup
# max_tokens, temperature are already supported
if model_params[:stop_sequences] if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences) model_params[:stop] = model_params.delete(:stop_sequences)
end end
if model_params[:max_tokens]
model_params[:max_new_tokens] = model_params.delete(:max_tokens)
end
model_params model_params
end end
def default_options
{ model: model, temperature: 0.7 }
end
def provider_id def provider_id
AiApiAuditLog::Provider::HuggingFaceTextGeneration AiApiAuditLog::Provider::HuggingFaceTextGeneration
end end
@ -61,13 +53,14 @@ module DiscourseAi
def prepare_payload(prompt, model_params, _dialect) def prepare_payload(prompt, model_params, _dialect)
default_options default_options
.merge(inputs: prompt) .merge(model_params)
.merge(messages: prompt)
.tap do |payload| .tap do |payload|
payload[:parameters].merge!(model_params) if !payload[:max_tokens]
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000 payload[:max_tokens] = token_limit - prompt_size(prompt)
end
payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
payload[:stream] = true if @streaming_mode payload[:stream] = true if @streaming_mode
end end
@ -85,16 +78,13 @@ module DiscourseAi
end end
def extract_completion_from(response_raw) def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true) parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
if @streaming_mode response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
# Last chunk contains full response, which we already yielded.
return if parsed.dig(:token, :special)
parsed.dig(:token, :text).to_s response_h.dig(:content)
else
parsed[0][:generated_text].to_s
end
end end
def partials_from(decoded_chunk) def partials_from(decoded_chunk)

View File

@ -0,0 +1,89 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class Ollama < Base
class << self
def can_contact?(endpoint_name, model_name)
endpoint_name == "ollama" && %w[mistral].include?(model_name)
end
def dependant_setting_names
%w[ai_ollama_endpoint]
end
def correctly_configured?(_model_name)
SiteSetting.ai_ollama_endpoint.present?
end
def endpoint_name(model_name)
"Ollama - #{model_name}"
end
end
def normalize_model_params(model_params)
model_params = model_params.dup
# max_tokens, temperature are already supported
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
model_params
end
def default_options
{ max_tokens: 2000, model: model }
end
def provider_id
AiApiAuditLog::Provider::Ollama
end
def use_ssl?
false
end
private
def model_uri
URI("#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
end
def prepare_payload(prompt, model_params, _dialect)
default_options
.merge(model_params)
.merge(messages: prompt)
.tap { |payload| payload[:stream] = true if @streaming_mode }
end
def prepare_request(payload)
headers = { "Content-Type" => "application/json" }
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data: ", 2)[1]
data == "[DONE]" ? nil : data
end
.compact
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
response_h.dig(:content)
end
end
end
end
end

View File

@ -153,10 +153,6 @@ module DiscourseAi
.compact .compact
end end
def extract_prompt_for_tokenizer(prompt)
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
def has_tool?(_response_data) def has_tool?(_response_data)
@has_function_call @has_function_call
end end

View File

@ -7,14 +7,9 @@ module DiscourseAi
class << self class << self
def can_contact?(endpoint_name, model_name) def can_contact?(endpoint_name, model_name)
endpoint_name == "vllm" && endpoint_name == "vllm" &&
%w[ %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
mistralai/Mixtral-8x7B-Instruct-v0.1 model_name,
mistralai/Mistral-7B-Instruct-v0.2 )
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
].include?(model_name)
end end
def dependant_setting_names def dependant_setting_names
@ -54,9 +49,9 @@ module DiscourseAi
def model_uri def model_uri
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv) service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
if service.present? if service.present?
api_endpoint = "https://#{service.target}:#{service.port}/v1/completions" api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
else else
api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/completions" api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions"
end end
@uri ||= URI(api_endpoint) @uri ||= URI(api_endpoint)
end end
@ -64,7 +59,7 @@ module DiscourseAi
def prepare_payload(prompt, model_params, _dialect) def prepare_payload(prompt, model_params, _dialect)
default_options default_options
.merge(model_params) .merge(model_params)
.merge(prompt: prompt) .merge(messages: prompt)
.tap { |payload| payload[:stream] = true if @streaming_mode } .tap { |payload| payload[:stream] = true if @streaming_mode }
end end
@ -76,15 +71,6 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
parsed.dig(:text)
end
def partials_from(decoded_chunk) def partials_from(decoded_chunk)
decoded_chunk decoded_chunk
.split("\n") .split("\n")
@ -94,6 +80,16 @@ module DiscourseAi
end end
.compact .compact
end end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
response_h.dig(:content)
end
end end
end end
end end

View File

@ -31,21 +31,10 @@ module DiscourseAi
claude-3-opus claude-3-opus
], ],
anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus], anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus],
vllm: %w[ vllm: %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2],
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
],
hugging_face: %w[ hugging_face: %w[
mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2 mistralai/Mistral-7B-Instruct-v0.2
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
], ],
cohere: %w[command-light command command-r command-r-plus], cohere: %w[command-light command command-r command-r-plus],
open_ai: %w[ open_ai: %w[
@ -57,7 +46,10 @@ module DiscourseAi
gpt-4-vision-preview gpt-4-vision-preview
], ],
google: %w[gemini-pro gemini-1.5-pro], google: %w[gemini-pro gemini-1.5-pro],
}.tap { |h| h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? } }.tap do |h|
h[:ollama] = ["mistral"] if Rails.env.development?
h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development?
end
end end
def valid_provider_models def valid_provider_models
@ -120,8 +112,6 @@ module DiscourseAi
@gateway = gateway @gateway = gateway
end end
delegate :tokenizer, to: :dialect_klass
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object # @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
# @param user { User } - User requesting the summary. # @param user { User } - User requesting the summary.
# #
@ -184,6 +174,8 @@ module DiscourseAi
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
end end
delegate :tokenizer, to: :dialect_klass
attr_reader :model_name attr_reader :model_name
private private

View File

@ -10,14 +10,6 @@ module DiscourseAi
Models::OpenAi.new("open_ai:gpt-4-turbo", max_tokens: 100_000), Models::OpenAi.new("open_ai:gpt-4-turbo", max_tokens: 100_000),
Models::OpenAi.new("open_ai:gpt-3.5-turbo", max_tokens: 4096), Models::OpenAi.new("open_ai:gpt-3.5-turbo", max_tokens: 4096),
Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384), Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384),
Models::Llama2.new(
"hugging_face:Llama2-chat-hf",
max_tokens: SiteSetting.ai_hugging_face_token_limit,
),
Models::Llama2FineTunedOrcaStyle.new(
"hugging_face:StableBeluga2",
max_tokens: SiteSetting.ai_hugging_face_token_limit,
),
Models::Gemini.new("google:gemini-pro", max_tokens: 32_768), Models::Gemini.new("google:gemini-pro", max_tokens: 32_768),
Models::Gemini.new("google:gemini-1.5-pro", max_tokens: 800_000), Models::Gemini.new("google:gemini-1.5-pro", max_tokens: 800_000),
] ]

View File

@ -17,47 +17,6 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect
end end
RSpec.describe DiscourseAi::Completions::Dialects::Dialect do RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
describe "#build_tools_prompt" do
it "can exclude array instructions" do
prompt = DiscourseAi::Completions::Prompt.new("12345")
prompt.tools = [
{
name: "weather",
description: "lookup weather in a city",
parameters: [{ name: "city", type: "string", description: "city name", required: true }],
},
]
dialect = TestDialect.new(prompt, "test")
expect(dialect.build_tools_prompt).not_to include("array")
end
it "can include array instructions" do
prompt = DiscourseAi::Completions::Prompt.new("12345")
prompt.tools = [
{
name: "weather",
description: "lookup weather in a city",
parameters: [{ name: "city", type: "array", description: "city names", required: true }],
},
]
dialect = TestDialect.new(prompt, "test")
expect(dialect.build_tools_prompt).to include("array")
end
it "does not break if there are no params" do
prompt = DiscourseAi::Completions::Prompt.new("12345")
prompt.tools = [{ name: "categories", description: "lookup all categories" }]
dialect = TestDialect.new(prompt, "test")
expect(dialect.build_tools_prompt).not_to include("array")
end
end
describe "#trim_messages" do describe "#trim_messages" do
it "should trim tool messages if tool_calls are trimmed" do it "should trim tool messages if tool_calls are trimmed" do
prompt = DiscourseAi::Completions::Prompt.new("12345") prompt = DiscourseAi::Completions::Prompt.new("12345")

View File

@ -1,62 +0,0 @@
# frozen_string_literal: true
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
let(:model_name) { "Llama2-chat-hf" }
let(:context) { DialectContext.new(described_class, model_name) }
describe "#translate" do
it "translates a prompt written in our generic format to the Llama2 format" do
llama2_classic_version = <<~TEXT
[INST]
<<SYS>>
#{context.system_insts}
#{described_class.tool_preamble(include_array_tip: false)}
<tools>
#{context.dialect_tools}</tools>
<</SYS>>
[/INST]
[INST]#{context.simple_user_input}[/INST]
TEXT
translated = context.system_user_scenario
expect(translated).to eq(llama2_classic_version)
end
it "translates tool messages" do
expected = +(<<~TEXT)
[INST]
<<SYS>>
#{context.system_insts}
#{described_class.tool_preamble(include_array_tip: false)}
<tools>
#{context.dialect_tools}</tools>
<</SYS>>
[/INST]
[INST]This is a message by a user[/INST]
I'm a previous bot reply, that's why there's no user
[INST]This is a new message by a user[/INST]
[INST]
<function_results>
<result>
<tool_name>tool_id</tool_name>
<json>
"I'm a tool result"
</json>
</result>
</function_results>
[/INST]
TEXT
expect(context.multi_turn_scenario).to eq(expected)
end
it "trims content if it's getting too long" do
translated = context.long_user_input_scenario
expect(translated.length).to be < context.long_message_text.length
end
end
end

View File

@ -1,66 +0,0 @@
# frozen_string_literal: true
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do
let(:model_name) { "mistralai/Mixtral-8x7B-Instruct-v0.1" }
let(:context) { DialectContext.new(described_class, model_name) }
describe "#translate" do
it "translates a prompt written in our generic format to the Llama2 format" do
llama2_classic_version = <<~TEXT
<s> [INST]
#{context.system_insts}
#{described_class.tool_preamble(include_array_tip: false)}
<tools>
#{context.dialect_tools}</tools>
[/INST] Ok </s>
[INST]#{context.simple_user_input}[/INST]
TEXT
translated = context.system_user_scenario
expect(translated).to eq(llama2_classic_version)
end
it "translates tool messages" do
expected = +(<<~TEXT).strip
<s> [INST]
#{context.system_insts}
#{described_class.tool_preamble(include_array_tip: false)}
<tools>
#{context.dialect_tools}</tools>
[/INST] Ok </s>
[INST]This is a message by a user[/INST]
I'm a previous bot reply, that's why there's no user</s>
[INST]This is a new message by a user[/INST]
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
<function_results>
<result>
<tool_name>get_weather</tool_name>
<json>
"I'm a tool result"
</json>
</result>
</function_results>
TEXT
expect(context.multi_turn_scenario).to eq(expected)
end
it "trims content if it's getting too long" do
length = 6_000
translated = context.long_user_input_scenario(length: length)
expect(translated.length).to be < context.long_message_text(length: length).length
end
end
end

View File

@ -1,71 +0,0 @@
# frozen_string_literal: true
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
let(:model_name) { "StableBeluga2" }
let(:context) { DialectContext.new(described_class, model_name) }
describe "#translate" do
it "translates a prompt written in our generic format to the Llama2 format" do
llama2_classic_version = <<~TEXT
### System:
#{context.system_insts}
#{described_class.tool_preamble(include_array_tip: false)}
<tools>
#{context.dialect_tools}</tools>
### User:
#{context.simple_user_input}
### Assistant:
TEXT
translated = context.system_user_scenario
expect(translated).to eq(llama2_classic_version)
end
it "translates tool messages" do
expected = +(<<~TEXT)
### System:
#{context.system_insts}
#{described_class.tool_preamble(include_array_tip: false)}
<tools>
#{context.dialect_tools}</tools>
### User:
This is a message by a user
### Assistant:
I'm a previous bot reply, that's why there's no user
### User:
This is a new message by a user
### Assistant:
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
### User:
<function_results>
<result>
<tool_name>get_weather</tool_name>
<json>
"I'm a tool result"
</json>
</result>
</function_results>
### Assistant:
TEXT
expect(context.multi_turn_scenario).to eq(expected)
end
it "trims content if it's getting too long" do
translated = context.long_user_input_scenario
expect(translated.length).to be < context.long_message_text.length
end
end
end

View File

@ -4,7 +4,20 @@ require_relative "endpoint_compliance"
class HuggingFaceMock < EndpointMock class HuggingFaceMock < EndpointMock
def response(content) def response(content)
[{ generated_text: content }] {
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
object: "chat.completion",
created: 1_678_464_820,
model: "Llama2-*-chat-hf",
usage: {
prompt_tokens: 337,
completion_tokens: 162,
total_tokens: 499,
},
choices: [
{ message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 },
],
}
end end
def stub_response(prompt, response_text, tool_call: false) def stub_response(prompt, response_text, tool_call: false)
@ -14,26 +27,32 @@ class HuggingFaceMock < EndpointMock
.to_return(status: 200, body: JSON.dump(response(response_text))) .to_return(status: 200, body: JSON.dump(response(response_text)))
end end
def stream_line(delta, deltas, finish_reason: nil) def stream_line(delta, finish_reason: nil)
+"data: " << { +"data: " << {
token: { id: "chatcmpl-#{SecureRandom.hex}",
id: 29_889, object: "chat.completion.chunk",
text: delta, created: 1_681_283_881,
logprob: -0.08319092, model: "Llama2-*-chat-hf",
special: !!finish_reason, choices: [{ delta: { content: delta } }],
}, finish_reason: finish_reason,
generated_text: finish_reason ? deltas.join : nil, index: 0,
details: nil,
}.to_json }.to_json
end end
def stub_raw(chunks)
WebMock.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}").to_return(
status: 200,
body: chunks,
)
end
def stub_streamed_response(prompt, deltas, tool_call: false) def stub_streamed_response(prompt, deltas, tool_call: false)
chunks = chunks =
deltas.each_with_index.map do |_, index| deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1) if index == (deltas.length - 1)
stream_line(deltas[index], deltas, finish_reason: true) stream_line(deltas[index], finish_reason: "stop_sequence")
else else
stream_line(deltas[index], deltas) stream_line(deltas[index])
end end
end end
@ -43,16 +62,18 @@ class HuggingFaceMock < EndpointMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.with(body: request_body(prompt, stream: true)) .with(body: request_body(prompt, stream: true))
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
yield if block_given?
end end
def request_body(prompt, stream: false) def request_body(prompt, stream: false, tool_call: false)
model model
.default_options .default_options
.merge(inputs: prompt) .merge(messages: prompt)
.tap do |payload| .tap do |b|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - b[:max_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt) model.prompt_size(prompt)
payload[:stream] = true if stream b[:stream] = true if stream
end end
.to_json .to_json
end end
@ -70,7 +91,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
let(:hf_mock) { HuggingFaceMock.new(endpoint) } let(:hf_mock) { HuggingFaceMock.new(endpoint) }
let(:compliance) do let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Llama2Classic, user) EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mistral, user)
end end
describe "#perform_completion!" do describe "#perform_completion!" do

View File

@ -6,7 +6,7 @@ class VllmMock < EndpointMock
def response(content) def response(content)
{ {
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
object: "text_completion", object: "chat.completion",
created: 1_678_464_820, created: 1_678_464_820,
model: "mistralai/Mixtral-8x7B-Instruct-v0.1", model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
usage: { usage: {
@ -14,14 +14,16 @@ class VllmMock < EndpointMock
completion_tokens: 162, completion_tokens: 162,
total_tokens: 499, total_tokens: 499,
}, },
choices: [{ text: content, finish_reason: "stop", index: 0 }], choices: [
{ message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 },
],
} }
end end
def stub_response(prompt, response_text, tool_call: false) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions") .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions")
.with(body: model.default_options.merge(prompt: prompt).to_json) .with(body: model.default_options.merge(messages: prompt).to_json)
.to_return(status: 200, body: JSON.dump(response(response_text))) .to_return(status: 200, body: JSON.dump(response(response_text)))
end end
@ -30,7 +32,7 @@ class VllmMock < EndpointMock
id: "cmpl-#{SecureRandom.hex}", id: "cmpl-#{SecureRandom.hex}",
created: 1_681_283_881, created: 1_681_283_881,
model: "mistralai/Mixtral-8x7B-Instruct-v0.1", model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
choices: [{ text: delta, finish_reason: finish_reason, index: 0 }], choices: [{ delta: { content: delta } }],
index: 0, index: 0,
}.to_json }.to_json
end end
@ -48,8 +50,8 @@ class VllmMock < EndpointMock
chunks = (chunks.join("\n\n") << "data: [DONE]").split("") chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions") .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions")
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json) .with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
end end
@ -67,14 +69,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
let(:anthropic_mock) { VllmMock.new(endpoint) } let(:anthropic_mock) { VllmMock.new(endpoint) }
let(:compliance) do let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mixtral, user) EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mistral, user)
end end
let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) } let(:dialect) { DiscourseAi::Completions::Dialects::Mistral.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate } let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json } let(:request_body) { model.default_options.merge(messages: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json } let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" } before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }

View File

@ -3,8 +3,8 @@
RSpec.describe DiscourseAi::Completions::Llm do RSpec.describe DiscourseAi::Completions::Llm do
subject(:llm) do subject(:llm) do
described_class.new( described_class.new(
DiscourseAi::Completions::Dialects::OrcaStyle, DiscourseAi::Completions::Dialects::Mistral,
nil, canned_response,
"hugging_face:Upstage-Llama-2-*-instruct-v2", "hugging_face:Upstage-Llama-2-*-instruct-v2",
gateway: canned_response, gateway: canned_response,
) )