REFACTOR: Migrate Vllm/TGI-served models to the OpenAI format. (#588)

Both endpoints provide OpenAI-compatible servers. The only difference is that Vllm doesn't support passing tools as a separate parameter. Even if the tool param is supported, it ultimately relies on the model's ability to handle native functions, which is not the case with the models we have today.

As a part of this change, we are dropping support for StableBeluga/Llama2 models. They don't have a chat_template, meaning the new API can translate them.

These changes let us remove some of our existing dialects and are a first step in our plan to support any LLM by defining them as data-driven concepts.

 I rewrote the "translate" method to use a template method and extracted the tool support strategies into its classes to simplify the code.

Finally, these changes bring support for Ollama when running in dev mode. It only works with Mistral for now, but it will change soon..
This commit is contained in:
Roman Rizzi 2024-05-07 10:02:16 -03:00 committed by GitHub
parent dacc1b9f28
commit 4f1a3effe0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 665 additions and 845 deletions

View File

@ -11,6 +11,7 @@ class AiApiAuditLog < ActiveRecord::Base
Gemini = 4
Vllm = 5
Cohere = 6
Ollama = 7
end
end

View File

@ -185,6 +185,9 @@ discourse_ai:
ai_strict_token_counting:
default: false
hidden: true
ai_ollama_endpoint:
hidden: true
default: ""
composer_ai_helper_enabled:
default: false

View File

@ -164,12 +164,15 @@ module DiscourseAi
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
"open_ai:gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(
"mistralai/Mixtral-8x7B-Instruct-v0.1",
)
"vllm:mistralai/Mixtral-8x7B-Instruct-v0.1"
mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model)
"vllm:#{mixtral_model}"
elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?(
mixtral_model,
)
"hugging_face:#{mixtral_model}"
else
"hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1"
"ollama:mistral"
end
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"google:gemini-pro"

View File

@ -40,8 +40,10 @@ module DiscourseAi
if model.start_with?("mistral")
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(model)
return "vllm:#{model}"
elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?(model)
"hugging_face:#{model}"
else
return "hugging_face:#{model}"
"ollama:mistral"
end
end

View File

@ -6,14 +6,7 @@ module DiscourseAi
class ChatGpt < Dialect
class << self
def can_translate?(model_name)
%w[
gpt-3.5-turbo
gpt-4
gpt-3.5-turbo-16k
gpt-4-32k
gpt-4-turbo
gpt-4-vision-preview
].include?(model_name)
model_name.starts_with?("gpt-")
end
def tokenizer
@ -23,72 +16,17 @@ module DiscourseAi
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def native_tool_support?
true
end
def translate
messages = prompt.messages
# ChatGPT doesn't use an assistant msg to improve long-context responses.
if messages.last[:type] == :model
messages = messages.dup
messages.pop
end
trimmed_messages = trim_messages(messages)
embed_user_ids =
trimmed_messages.any? do |m|
@embed_user_ids =
prompt.messages.any? do |m|
m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX)
end
trimmed_messages.map do |msg|
if msg[:type] == :system
{ role: "system", content: msg[:content] }
elsif msg[:type] == :model
{ role: "assistant", content: msg[:content] }
elsif msg[:type] == :tool_call
call_details = JSON.parse(msg[:content], symbolize_names: true)
call_details[:arguments] = call_details[:arguments].to_json
call_details[:name] = msg[:name]
{
role: "assistant",
content: nil,
tool_calls: [{ type: "function", function: call_details, id: msg[:id] }],
}
elsif msg[:type] == :tool
{ role: "tool", tool_call_id: msg[:id], content: msg[:content], name: msg[:name] }
else
user_message = { role: "user", content: msg[:content] }
if msg[:id]
if embed_user_ids
user_message[:content] = "#{msg[:id]}: #{msg[:content]}"
else
user_message[:name] = msg[:id]
end
end
user_message[:content] = inline_images(user_message[:content], msg)
user_message
end
end
end
def tools
prompt.tools.map do |t|
tool = t.dup
tool[:parameters] = t[:parameters]
.to_a
.reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
name = p[:name]
memo[:required] << name if p[:required]
memo[:properties][name] = p.except(:name, :required, :item_type)
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
memo
end
{ type: "function", function: tool }
end
super
end
def max_prompt_tokens
@ -107,6 +45,41 @@ module DiscourseAi
private
def tools_dialect
@tools_dialect ||= DiscourseAi::Completions::Dialects::OpenAiTools.new(prompt.tools)
end
def system_msg(msg)
{ role: "system", content: msg[:content] }
end
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end
def tool_call_msg(msg)
tools_dialect.from_raw_tool_call(msg)
end
def tool_msg(msg)
tools_dialect.from_raw_tool(msg)
end
def user_msg(msg)
user_message = { role: "user", content: msg[:content] }
if msg[:id]
if @embed_user_ids
user_message[:content] = "#{msg[:id]}: #{msg[:content]}"
else
user_message[:name] = msg[:id]
end
end
user_message[:content] = inline_images(user_message[:content], msg)
user_message
end
def inline_images(content, message)
if model_name.include?("gpt-4-vision") || model_name == "gpt-4-turbo"
content = message[:content]

View File

@ -27,41 +27,13 @@ module DiscourseAi
end
def translate
messages = prompt.messages
system_prompt = +""
messages = super
messages =
trim_messages(messages)
.map do |msg|
case msg[:type]
when :system
system_prompt << msg[:content]
nil
when :tool_call
{ role: "assistant", content: tool_call_to_xml(msg) }
when :tool
{ role: "user", content: tool_result_to_xml(msg) }
when :model
{ role: "assistant", content: msg[:content] }
when :user
content = +""
content << "#{msg[:id]}: " if msg[:id]
content << msg[:content]
content = inline_images(content, msg)
{ role: "user", content: content }
end
end
.compact
if prompt.tools.present?
system_prompt << "\n\n"
system_prompt << build_tools_prompt
end
system_prompt = messages.shift[:content] if messages.first[:role] == "system"
interleving_messages = []
previous_message = nil
messages.each do |message|
if previous_message
if previous_message[:role] == "user" && message[:role] == "user"
@ -84,6 +56,29 @@ module DiscourseAi
private
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end
def system_msg(msg)
msg = { role: "system", content: msg[:content] }
if tools_dialect.instructions.present?
msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
end
msg
end
def user_msg(msg)
content = +""
content << "#{msg[:id]}: " if msg[:id]
content << msg[:content]
content = inline_images(content, msg)
{ role: "user", content: content }
end
def inline_images(content, message)
if model_name.include?("claude-3")
encoded_uploads = prompt.encoded_uploads(message)

View File

@ -19,57 +19,17 @@ module DiscourseAi
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def translate
messages = prompt.messages
messages = super
# ChatGPT doesn't use an assistant msg to improve long-context responses.
if messages.last[:type] == :model
messages = messages.dup
messages.pop
end
system_message = messages.shift[:message] if messages.first[:role] == "SYSTEM"
trimmed_messages = trim_messages(messages)
prompt = { preamble: +"#{system_message}" }
prompt[:chat_history] = messages if messages.present?
chat_history = []
system_message = nil
prompt = {}
trimmed_messages.each do |msg|
case msg[:type]
when :system
if system_message
chat_history << { role: "SYSTEM", message: msg[:content] }
else
system_message = msg[:content]
end
when :model
chat_history << { role: "CHATBOT", message: msg[:content] }
when :tool_call
chat_history << { role: "CHATBOT", message: tool_call_to_xml(msg) }
when :tool
chat_history << { role: "USER", message: tool_result_to_xml(msg) }
when :user
user_message = { role: "USER", message: msg[:content] }
user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id]
chat_history << user_message
end
end
tools_prompt = build_tools_prompt
prompt[:preamble] = +"#{system_message}"
if tools_prompt.present?
prompt[:preamble] << "\n#{tools_prompt}"
prompt[
:preamble
] << "\nNEVER attempt to run tools using JSON, always use XML. Lives depend on it."
end
prompt[:chat_history] = chat_history if chat_history.present?
chat_history.reverse_each do |msg|
messages.reverse_each do |msg|
if msg[:role] == "USER"
prompt[:message] = msg[:message]
chat_history.delete(msg)
messages.delete(msg)
break
end
end
@ -101,6 +61,43 @@ module DiscourseAi
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
def tools_dialect
@tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools)
end
def system_msg(msg)
cmd_msg = { role: "SYSTEM", message: msg[:content] }
if tools_dialect.instructions.present?
cmd_msg[:message] = [
msg[:content],
tools_dialect.instructions,
"NEVER attempt to run tools using JSON, always use XML. Lives depend on it.",
].join("\n")
end
cmd_msg
end
def model_msg(msg)
{ role: "CHATBOT", message: msg[:content] }
end
def tool_call_msg(msg)
{ role: "CHATBOT", message: tools_dialect.from_raw_tool_call(msg) }
end
def tool_msg(msg)
{ role: "USER", message: tools_dialect.from_raw_tool(msg) }
end
def user_msg(msg)
user_message = { role: "USER", message: msg[:content] }
user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id]
user_message
end
end
end
end

View File

@ -11,11 +11,9 @@ module DiscourseAi
def dialect_for(model_name)
dialects = [
DiscourseAi::Completions::Dialects::Llama2Classic,
DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::OrcaStyle,
DiscourseAi::Completions::Dialects::Gemini,
DiscourseAi::Completions::Dialects::Mixtral,
DiscourseAi::Completions::Dialects::Mistral,
DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Command,
]
@ -32,40 +30,6 @@ module DiscourseAi
def tokenizer
raise NotImplemented
end
def tool_preamble(include_array_tip: true)
array_tip =
if include_array_tip
<<~TEXT
If a parameter type is an array, return an array of values. For example:
<$PARAMETER_NAME>["one","two","three"]</$PARAMETER_NAME>
TEXT
else
""
end
<<~TEXT
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this.
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
#{array_tip}
If you wish to call multiple function in one reply, wrap multiple <invoke>
block in a single <function_calls> block.
Always prefer to lead with tool calls, if you need to execute any.
Avoid all niceties prior to tool calls, Eg: "Let me look this up for you.." etc.
Here are the complete list of tools available:
TEXT
end
end
def initialize(generic_prompt, model_name, opts: {})
@ -74,74 +38,30 @@ module DiscourseAi
@opts = opts
end
def translate
raise NotImplemented
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def can_end_with_assistant_msg?
false
end
def tool_result_to_xml(message)
(<<~TEXT).strip
<function_results>
<result>
<tool_name>#{message[:name] || message[:id]}</tool_name>
<json>
#{message[:content]}
</json>
</result>
</function_results>
TEXT
end
def tool_call_to_xml(message)
parsed = JSON.parse(message[:content], symbolize_names: true)
parameters = +""
if parsed[:arguments]
parameters << "<parameters>\n"
parsed[:arguments].each { |k, v| parameters << "<#{k}>#{v}</#{k}>\n" }
parameters << "</parameters>\n"
end
(<<~TEXT).strip
<function_calls>
<invoke>
<tool_name>#{message[:name] || parsed[:name]}</tool_name>
#{parameters}</invoke>
</function_calls>
TEXT
def native_tool_support?
false
end
def tools
tools = +""
@tools ||= tools_dialect.translated_tools
end
prompt.tools.each do |function|
parameters = +""
if function[:parameters].present?
function[:parameters].each do |parameter|
parameters << <<~PARAMETER
<parameter>
<name>#{parameter[:name]}</name>
<type>#{parameter[:type]}</type>
<description>#{parameter[:description]}</description>
<required>#{parameter[:required]}</required>
PARAMETER
if parameter[:enum]
parameters << "<options>#{parameter[:enum].join(",")}</options>\n"
end
parameters << "</parameter>\n"
end
end
def translate
messages = prompt.messages
tools << <<~TOOLS
<tool_description>
<tool_name>#{function[:name]}</tool_name>
<description>#{function[:description]}</description>
<parameters>
#{parameters}</parameters>
</tool_description>
TOOLS
# Some models use an assistant msg to improve long-context responses.
if messages.last[:type] == :model && can_end_with_assistant_msg?
messages = messages.dup
messages.pop
end
tools
trim_messages(messages).map { |msg| send("#{msg[:type]}_msg", msg) }.compact
end
def conversation_context
@ -154,19 +74,6 @@ module DiscourseAi
attr_reader :prompt
def build_tools_prompt
return "" if prompt.tools.blank?
has_arrays =
prompt.tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } }
(<<~TEXT).strip
#{self.class.tool_preamble(include_array_tip: has_arrays)}
<tools>
#{tools}</tools>
TEXT
end
private
attr_reader :model_name, :opts
@ -230,6 +137,30 @@ module DiscourseAi
def calculate_message_token(msg)
self.class.tokenizer.size(msg[:content].to_s)
end
def tools_dialect
@tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools)
end
def system_msg(msg)
raise NotImplemented
end
def assistant_msg(msg)
raise NotImplemented
end
def user_msg(msg)
raise NotImplemented
end
def tool_call_msg(msg)
{ role: "assistant", content: tools_dialect.from_raw_tool_call(msg) }
end
def tool_msg(msg)
{ role: "user", content: tools_dialect.from_raw_tool(msg) }
end
end
end
end

View File

@ -9,14 +9,14 @@ module DiscourseAi
model_name == "fake"
end
def translate
""
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
def translate
""
end
end
end
end

View File

@ -14,59 +14,30 @@ module DiscourseAi
end
end
def native_tool_support?
true
end
def translate
# Gemini complains if we don't alternate model/user roles.
noop_model_response = { role: "model", parts: { text: "Ok." } }
messages = super
messages = prompt.messages
interleving_messages = []
previous_message = nil
# Gemini doesn't use an assistant msg to improve long-context responses.
messages.pop if messages.last[:type] == :model
memo = []
trim_messages(messages).each do |msg|
if msg[:type] == :system
memo << { role: "user", parts: { text: msg[:content] } }
memo << noop_model_response.dup
elsif msg[:type] == :model
memo << { role: "model", parts: { text: msg[:content] } }
elsif msg[:type] == :tool_call
call_details = JSON.parse(msg[:content], symbolize_names: true)
memo << {
role: "model",
parts: {
functionCall: {
name: msg[:name] || call_details[:name],
args: call_details[:arguments],
},
},
}
elsif msg[:type] == :tool
memo << {
role: "function",
parts: {
functionResponse: {
name: msg[:name] || msg[:id],
response: {
content: msg[:content],
},
},
},
}
else
# Gemini quirk. Doesn't accept tool -> user or user -> user msgs.
previous_msg_role = memo.last&.dig(:role)
if previous_msg_role == "user" || previous_msg_role == "function"
memo << noop_model_response.dup
messages.each do |message|
if previous_message
if (previous_message[:role] == "user" || previous_message[:role] == "function") &&
message[:role] == "user"
interleving_messages << noop_model_response.dup
end
memo << { role: "user", parts: { text: msg[:content] } }
end
interleving_messages << message
previous_message = message
end
memo
interleving_messages
end
def tools
@ -110,6 +81,46 @@ module DiscourseAi
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
def system_msg(msg)
{ role: "user", parts: { text: msg[:content] } }
end
def model_msg(msg)
{ role: "model", parts: { text: msg[:content] } }
end
def user_msg(msg)
{ role: "user", parts: { text: msg[:content] } }
end
def tool_call_msg(msg)
call_details = JSON.parse(msg[:content], symbolize_names: true)
{
role: "model",
parts: {
functionCall: {
name: msg[:name] || call_details[:name],
args: call_details[:arguments],
},
},
}
end
def tool_msg(msg)
{
role: "function",
parts: {
functionResponse: {
name: msg[:name] || msg[:id],
response: {
content: msg[:content],
},
},
},
}
end
end
end
end

View File

@ -1,68 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Llama2Classic < Dialect
class << self
def can_translate?(model_name)
%w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
end
def tokenizer
DiscourseAi::Tokenizer::Llama2Tokenizer
end
end
def translate
messages = prompt.messages
llama2_prompt =
trim_messages(messages).reduce(+"") do |memo, msg|
next(memo) if msg[:type] == :tool_call
if msg[:type] == :system
memo << (<<~TEXT).strip
[INST]
<<SYS>>
#{msg[:content]}
#{build_tools_prompt}
<</SYS>>
[/INST]
TEXT
elsif msg[:type] == :model
memo << "\n#{msg[:content]}"
elsif msg[:type] == :tool
JSON.parse(msg[:content], symbolize_names: true)
memo << "\n[INST]\n"
memo << (<<~TEXT).strip
<function_results>
<result>
<tool_name>#{msg[:id]}</tool_name>
<json>
#{msg[:content]}
</json>
</result>
</function_results>
[/INST]
TEXT
else
memo << "\n[INST]#{msg[:content]}[/INST]"
end
memo
end
llama2_prompt << "\n" if llama2_prompt.ends_with?("[/INST]")
llama2_prompt
end
def max_prompt_tokens
SiteSetting.ai_hugging_face_token_limit
end
end
end
end
end

View File

@ -0,0 +1,57 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Mistral < Dialect
class << self
def can_translate?(model_name)
%w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
mistral
].include?(model_name)
end
def tokenizer
DiscourseAi::Tokenizer::MixtralTokenizer
end
end
def tools
@tools ||= tools_dialect.translated_tools
end
def max_prompt_tokens
32_000
end
private
def system_msg(msg)
{ role: "assistant", content: "<s>#{msg[:content]}</s>" }
end
def model_msg(msg)
{ role: "assistant", content: msg[:content] }
end
def tool_call_msg(msg)
tools_dialect.from_raw_tool_call(msg)
end
def tool_msg(msg)
tools_dialect.from_raw_tool(msg)
end
def user_msg(msg)
content = +""
content << "#{msg[:id]}: " if msg[:id]
content << msg[:content]
{ role: "user", content: content }
end
end
end
end
end

View File

@ -1,57 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Mixtral < Dialect
class << self
def can_translate?(model_name)
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
model_name,
)
end
def tokenizer
DiscourseAi::Tokenizer::MixtralTokenizer
end
end
def translate
messages = prompt.messages
mixtral_prompt =
trim_messages(messages).reduce(+"") do |memo, msg|
if msg[:type] == :tool_call
memo << "\n"
memo << tool_call_to_xml(msg)
elsif msg[:type] == :system
memo << (<<~TEXT).strip
<s> [INST]
#{msg[:content]}
#{build_tools_prompt}
[/INST] Ok </s>
TEXT
elsif msg[:type] == :model
memo << "\n#{msg[:content]}</s>"
elsif msg[:type] == :tool
memo << "\n"
memo << tool_result_to_xml(msg)
else
memo << "\n[INST]#{msg[:content]}[/INST]"
end
memo
end
mixtral_prompt << "\n" if mixtral_prompt.ends_with?("[/INST]")
mixtral_prompt
end
def max_prompt_tokens
32_000
end
end
end
end
end

View File

@ -0,0 +1,62 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class OpenAiTools
def initialize(tools)
@raw_tools = tools
end
def translated_tools
raw_tools.map do |t|
tool = t.dup
tool[:parameters] = t[:parameters]
.to_a
.reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
name = p[:name]
memo[:required] << name if p[:required]
memo[:properties][name] = p.except(:name, :required, :item_type)
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
memo
end
{ type: "function", function: tool }
end
end
def instructions
"" # Noop. Tools are listed separate.
end
def from_raw_tool_call(raw_message)
call_details = JSON.parse(raw_message[:content], symbolize_names: true)
call_details[:arguments] = call_details[:arguments].to_json
call_details[:name] = raw_message[:name]
{
role: "assistant",
content: nil,
tool_calls: [{ type: "function", function: call_details, id: raw_message[:id] }],
}
end
def from_raw_tool(raw_message)
{
role: "tool",
tool_call_id: raw_message[:id],
content: raw_message[:content],
name: raw_message[:name],
}
end
private
attr_reader :raw_tools
end
end
end
end

View File

@ -1,59 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class OrcaStyle < Dialect
class << self
def can_translate?(model_name)
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name)
end
def tokenizer
DiscourseAi::Tokenizer::Llama2Tokenizer
end
end
def translate
messages = prompt.messages
trimmed_messages = trim_messages(messages)
# Need to include this differently
last_message = trimmed_messages.last[:type] == :assistant ? trimmed_messages.pop : nil
llama2_prompt =
trimmed_messages.reduce(+"") do |memo, msg|
if msg[:type] == :tool_call
memo << "\n### Assistant:\n"
memo << tool_call_to_xml(msg)
elsif msg[:type] == :system
memo << (<<~TEXT).strip
### System:
#{msg[:content]}
#{build_tools_prompt}
TEXT
elsif msg[:type] == :model
memo << "\n### Assistant:\n#{msg[:content]}"
elsif msg[:type] == :tool
memo << "\n### User:\n"
memo << tool_result_to_xml(msg)
else
memo << "\n### User:\n#{msg[:content]}"
end
memo
end
llama2_prompt << "\n### Assistant:\n"
llama2_prompt << "#{last_message[:content]}:" if last_message
llama2_prompt
end
def max_prompt_tokens
SiteSetting.ai_hugging_face_token_limit
end
end
end
end
end

View File

@ -0,0 +1,125 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class XmlTools
def initialize(tools)
@raw_tools = tools
end
def translated_tools
raw_tools.reduce(+"") do |tools, function|
parameters = +""
if function[:parameters].present?
function[:parameters].each do |parameter|
parameters << <<~PARAMETER
<parameter>
<name>#{parameter[:name]}</name>
<type>#{parameter[:type]}</type>
<description>#{parameter[:description]}</description>
<required>#{parameter[:required]}</required>
PARAMETER
if parameter[:enum]
parameters << "<options>#{parameter[:enum].join(",")}</options>\n"
end
parameters << "</parameter>\n"
end
end
tools << <<~TOOLS
<tool_description>
<tool_name>#{function[:name]}</tool_name>
<description>#{function[:description]}</description>
<parameters>
#{parameters}</parameters>
</tool_description>
TOOLS
end
end
def instructions
return "" if raw_tools.blank?
has_arrays = raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } }
(<<~TEXT).strip
#{tool_preamble(include_array_tip: has_arrays)}
<tools>
#{translated_tools}</tools>
TEXT
end
def from_raw_tool(raw_message)
(<<~TEXT).strip
<function_results>
<result>
<tool_name>#{raw_message[:name] || raw_message[:id]}</tool_name>
<json>
#{raw_message[:content]}
</json>
</result>
</function_results>
TEXT
end
def from_raw_tool_call(raw_message)
parsed = JSON.parse(raw_message[:content], symbolize_names: true)
parameters = +""
if parsed[:arguments]
parameters << "<parameters>\n"
parsed[:arguments].each { |k, v| parameters << "<#{k}>#{v}</#{k}>\n" }
parameters << "</parameters>\n"
end
(<<~TEXT).strip
<function_calls>
<invoke>
<tool_name>#{raw_message[:name] || parsed[:name]}</tool_name>
#{parameters}</invoke>
</function_calls>
TEXT
end
private
attr_reader :raw_tools
def tool_preamble(include_array_tip: true)
array_tip =
if include_array_tip
<<~TEXT
If a parameter type is an array, return an array of values. For example:
<$PARAMETER_NAME>["one","two","three"]</$PARAMETER_NAME>
TEXT
else
""
end
<<~TEXT
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this.
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
#{array_tip}
If you wish to call multiple function in one reply, wrap multiple <invoke>
block in a single <function_calls> block.
Always prefer to lead with tool calls, if you need to execute any.
Avoid all niceties prior to tool calls, Eg: "Let me look this up for you.." etc.
Here are the complete list of tools available:
TEXT
end
end
end
end
end

View File

@ -62,7 +62,7 @@ module DiscourseAi
# this is an approximation, we will update it later if request goes through
def prompt_size(prompt)
super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
end
def model_uri

View File

@ -51,7 +51,7 @@ module DiscourseAi
def prompt_size(prompt)
# approximation
super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
end
def model_uri

View File

@ -19,6 +19,8 @@ module DiscourseAi
DiscourseAi::Completions::Endpoints::Cohere,
]
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
if Rails.env.test? || Rails.env.development?
endpoints << DiscourseAi::Completions::Endpoints::Fake
end
@ -67,6 +69,10 @@ module DiscourseAi
false
end
def use_ssl?
true
end
def perform_completion!(dialect, user, model_params = {}, &blk)
allow_tools = dialect.prompt.has_tools?
model_params = normalize_model_params(model_params)
@ -78,7 +84,7 @@ module DiscourseAi
FinalDestination::HTTP.start(
model_uri.host,
model_uri.port,
use_ssl: true,
use_ssl: use_ssl?,
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
@ -315,7 +321,7 @@ module DiscourseAi
end
def extract_prompt_for_tokenizer(prompt)
prompt
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
def build_buffer

View File

@ -8,14 +8,9 @@ module DiscourseAi
def can_contact?(endpoint_name, model_name)
return false unless endpoint_name == "hugging_face"
%w[
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
].include?(model_name)
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
model_name,
)
end
def dependant_setting_names
@ -31,24 +26,21 @@ module DiscourseAi
end
end
def default_options
{ parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } }
end
def normalize_model_params(model_params)
model_params = model_params.dup
# max_tokens, temperature are already supported
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
if model_params[:max_tokens]
model_params[:max_new_tokens] = model_params.delete(:max_tokens)
end
model_params
end
def default_options
{ model: model, temperature: 0.7 }
end
def provider_id
AiApiAuditLog::Provider::HuggingFaceTextGeneration
end
@ -61,13 +53,14 @@ module DiscourseAi
def prepare_payload(prompt, model_params, _dialect)
default_options
.merge(inputs: prompt)
.merge(model_params)
.merge(messages: prompt)
.tap do |payload|
payload[:parameters].merge!(model_params)
if !payload[:max_tokens]
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
payload[:max_tokens] = token_limit - prompt_size(prompt)
end
payload[:stream] = true if @streaming_mode
end
@ -85,16 +78,13 @@ module DiscourseAi
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
if @streaming_mode
# Last chunk contains full response, which we already yielded.
return if parsed.dig(:token, :special)
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
parsed.dig(:token, :text).to_s
else
parsed[0][:generated_text].to_s
end
response_h.dig(:content)
end
def partials_from(decoded_chunk)

View File

@ -0,0 +1,89 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class Ollama < Base
class << self
def can_contact?(endpoint_name, model_name)
endpoint_name == "ollama" && %w[mistral].include?(model_name)
end
def dependant_setting_names
%w[ai_ollama_endpoint]
end
def correctly_configured?(_model_name)
SiteSetting.ai_ollama_endpoint.present?
end
def endpoint_name(model_name)
"Ollama - #{model_name}"
end
end
def normalize_model_params(model_params)
model_params = model_params.dup
# max_tokens, temperature are already supported
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
model_params
end
def default_options
{ max_tokens: 2000, model: model }
end
def provider_id
AiApiAuditLog::Provider::Ollama
end
def use_ssl?
false
end
private
def model_uri
URI("#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
end
def prepare_payload(prompt, model_params, _dialect)
default_options
.merge(model_params)
.merge(messages: prompt)
.tap { |payload| payload[:stream] = true if @streaming_mode }
end
def prepare_request(payload)
headers = { "Content-Type" => "application/json" }
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data: ", 2)[1]
data == "[DONE]" ? nil : data
end
.compact
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
response_h.dig(:content)
end
end
end
end
end

View File

@ -153,10 +153,6 @@ module DiscourseAi
.compact
end
def extract_prompt_for_tokenizer(prompt)
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
def has_tool?(_response_data)
@has_function_call
end

View File

@ -7,14 +7,9 @@ module DiscourseAi
class << self
def can_contact?(endpoint_name, model_name)
endpoint_name == "vllm" &&
%w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
].include?(model_name)
%w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
model_name,
)
end
def dependant_setting_names
@ -54,9 +49,9 @@ module DiscourseAi
def model_uri
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
if service.present?
api_endpoint = "https://#{service.target}:#{service.port}/v1/completions"
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
else
api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/completions"
api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions"
end
@uri ||= URI(api_endpoint)
end
@ -64,7 +59,7 @@ module DiscourseAi
def prepare_payload(prompt, model_params, _dialect)
default_options
.merge(model_params)
.merge(prompt: prompt)
.merge(messages: prompt)
.tap { |payload| payload[:stream] = true if @streaming_mode }
end
@ -76,15 +71,6 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
parsed.dig(:text)
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
@ -94,6 +80,16 @@ module DiscourseAi
end
.compact
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
response_h.dig(:content)
end
end
end
end

View File

@ -31,21 +31,10 @@ module DiscourseAi
claude-3-opus
],
anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus],
vllm: %w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
],
vllm: %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2],
hugging_face: %w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
StableBeluga2
Upstage-Llama-2-*-instruct-v2
Llama2-*-chat-hf
Llama2-chat-hf
],
cohere: %w[command-light command command-r command-r-plus],
open_ai: %w[
@ -57,7 +46,10 @@ module DiscourseAi
gpt-4-vision-preview
],
google: %w[gemini-pro gemini-1.5-pro],
}.tap { |h| h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? }
}.tap do |h|
h[:ollama] = ["mistral"] if Rails.env.development?
h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development?
end
end
def valid_provider_models
@ -120,8 +112,6 @@ module DiscourseAi
@gateway = gateway
end
delegate :tokenizer, to: :dialect_klass
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
# @param user { User } - User requesting the summary.
#
@ -184,6 +174,8 @@ module DiscourseAi
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
end
delegate :tokenizer, to: :dialect_klass
attr_reader :model_name
private

View File

@ -10,14 +10,6 @@ module DiscourseAi
Models::OpenAi.new("open_ai:gpt-4-turbo", max_tokens: 100_000),
Models::OpenAi.new("open_ai:gpt-3.5-turbo", max_tokens: 4096),
Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384),
Models::Llama2.new(
"hugging_face:Llama2-chat-hf",
max_tokens: SiteSetting.ai_hugging_face_token_limit,
),
Models::Llama2FineTunedOrcaStyle.new(
"hugging_face:StableBeluga2",
max_tokens: SiteSetting.ai_hugging_face_token_limit,
),
Models::Gemini.new("google:gemini-pro", max_tokens: 32_768),
Models::Gemini.new("google:gemini-1.5-pro", max_tokens: 800_000),
]

View File

@ -17,47 +17,6 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect
end
RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
describe "#build_tools_prompt" do
it "can exclude array instructions" do
prompt = DiscourseAi::Completions::Prompt.new("12345")
prompt.tools = [
{
name: "weather",
description: "lookup weather in a city",
parameters: [{ name: "city", type: "string", description: "city name", required: true }],
},
]
dialect = TestDialect.new(prompt, "test")
expect(dialect.build_tools_prompt).not_to include("array")
end
it "can include array instructions" do
prompt = DiscourseAi::Completions::Prompt.new("12345")
prompt.tools = [
{
name: "weather",
description: "lookup weather in a city",
parameters: [{ name: "city", type: "array", description: "city names", required: true }],
},
]
dialect = TestDialect.new(prompt, "test")
expect(dialect.build_tools_prompt).to include("array")
end
it "does not break if there are no params" do
prompt = DiscourseAi::Completions::Prompt.new("12345")
prompt.tools = [{ name: "categories", description: "lookup all categories" }]
dialect = TestDialect.new(prompt, "test")
expect(dialect.build_tools_prompt).not_to include("array")
end
end
describe "#trim_messages" do
it "should trim tool messages if tool_calls are trimmed" do
prompt = DiscourseAi::Completions::Prompt.new("12345")

View File

@ -1,62 +0,0 @@
# frozen_string_literal: true
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
let(:model_name) { "Llama2-chat-hf" }
let(:context) { DialectContext.new(described_class, model_name) }
describe "#translate" do
it "translates a prompt written in our generic format to the Llama2 format" do
llama2_classic_version = <<~TEXT
[INST]
<<SYS>>
#{context.system_insts}
#{described_class.tool_preamble(include_array_tip: false)}
<tools>
#{context.dialect_tools}</tools>
<</SYS>>
[/INST]
[INST]#{context.simple_user_input}[/INST]
TEXT
translated = context.system_user_scenario
expect(translated).to eq(llama2_classic_version)
end
it "translates tool messages" do
expected = +(<<~TEXT)
[INST]
<<SYS>>
#{context.system_insts}
#{described_class.tool_preamble(include_array_tip: false)}
<tools>
#{context.dialect_tools}</tools>
<</SYS>>
[/INST]
[INST]This is a message by a user[/INST]
I'm a previous bot reply, that's why there's no user
[INST]This is a new message by a user[/INST]
[INST]
<function_results>
<result>
<tool_name>tool_id</tool_name>
<json>
"I'm a tool result"
</json>
</result>
</function_results>
[/INST]
TEXT
expect(context.multi_turn_scenario).to eq(expected)
end
it "trims content if it's getting too long" do
translated = context.long_user_input_scenario
expect(translated.length).to be < context.long_message_text.length
end
end
end

View File

@ -1,66 +0,0 @@
# frozen_string_literal: true
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do
let(:model_name) { "mistralai/Mixtral-8x7B-Instruct-v0.1" }
let(:context) { DialectContext.new(described_class, model_name) }
describe "#translate" do
it "translates a prompt written in our generic format to the Llama2 format" do
llama2_classic_version = <<~TEXT
<s> [INST]
#{context.system_insts}
#{described_class.tool_preamble(include_array_tip: false)}
<tools>
#{context.dialect_tools}</tools>
[/INST] Ok </s>
[INST]#{context.simple_user_input}[/INST]
TEXT
translated = context.system_user_scenario
expect(translated).to eq(llama2_classic_version)
end
it "translates tool messages" do
expected = +(<<~TEXT).strip
<s> [INST]
#{context.system_insts}
#{described_class.tool_preamble(include_array_tip: false)}
<tools>
#{context.dialect_tools}</tools>
[/INST] Ok </s>
[INST]This is a message by a user[/INST]
I'm a previous bot reply, that's why there's no user</s>
[INST]This is a new message by a user[/INST]
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
<function_results>
<result>
<tool_name>get_weather</tool_name>
<json>
"I'm a tool result"
</json>
</result>
</function_results>
TEXT
expect(context.multi_turn_scenario).to eq(expected)
end
it "trims content if it's getting too long" do
length = 6_000
translated = context.long_user_input_scenario(length: length)
expect(translated.length).to be < context.long_message_text(length: length).length
end
end
end

View File

@ -1,71 +0,0 @@
# frozen_string_literal: true
require_relative "dialect_context"
RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
let(:model_name) { "StableBeluga2" }
let(:context) { DialectContext.new(described_class, model_name) }
describe "#translate" do
it "translates a prompt written in our generic format to the Llama2 format" do
llama2_classic_version = <<~TEXT
### System:
#{context.system_insts}
#{described_class.tool_preamble(include_array_tip: false)}
<tools>
#{context.dialect_tools}</tools>
### User:
#{context.simple_user_input}
### Assistant:
TEXT
translated = context.system_user_scenario
expect(translated).to eq(llama2_classic_version)
end
it "translates tool messages" do
expected = +(<<~TEXT)
### System:
#{context.system_insts}
#{described_class.tool_preamble(include_array_tip: false)}
<tools>
#{context.dialect_tools}</tools>
### User:
This is a message by a user
### Assistant:
I'm a previous bot reply, that's why there's no user
### User:
This is a new message by a user
### Assistant:
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
### User:
<function_results>
<result>
<tool_name>get_weather</tool_name>
<json>
"I'm a tool result"
</json>
</result>
</function_results>
### Assistant:
TEXT
expect(context.multi_turn_scenario).to eq(expected)
end
it "trims content if it's getting too long" do
translated = context.long_user_input_scenario
expect(translated.length).to be < context.long_message_text.length
end
end
end

View File

@ -4,7 +4,20 @@ require_relative "endpoint_compliance"
class HuggingFaceMock < EndpointMock
def response(content)
[{ generated_text: content }]
{
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
object: "chat.completion",
created: 1_678_464_820,
model: "Llama2-*-chat-hf",
usage: {
prompt_tokens: 337,
completion_tokens: 162,
total_tokens: 499,
},
choices: [
{ message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 },
],
}
end
def stub_response(prompt, response_text, tool_call: false)
@ -14,26 +27,32 @@ class HuggingFaceMock < EndpointMock
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
def stream_line(delta, deltas, finish_reason: nil)
def stream_line(delta, finish_reason: nil)
+"data: " << {
token: {
id: 29_889,
text: delta,
logprob: -0.08319092,
special: !!finish_reason,
},
generated_text: finish_reason ? deltas.join : nil,
details: nil,
id: "chatcmpl-#{SecureRandom.hex}",
object: "chat.completion.chunk",
created: 1_681_283_881,
model: "Llama2-*-chat-hf",
choices: [{ delta: { content: delta } }],
finish_reason: finish_reason,
index: 0,
}.to_json
end
def stub_raw(chunks)
WebMock.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}").to_return(
status: 200,
body: chunks,
)
end
def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
stream_line(deltas[index], deltas, finish_reason: true)
stream_line(deltas[index], finish_reason: "stop_sequence")
else
stream_line(deltas[index], deltas)
stream_line(deltas[index])
end
end
@ -43,16 +62,18 @@ class HuggingFaceMock < EndpointMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.with(body: request_body(prompt, stream: true))
.to_return(status: 200, body: chunks)
yield if block_given?
end
def request_body(prompt, stream: false)
def request_body(prompt, stream: false, tool_call: false)
model
.default_options
.merge(inputs: prompt)
.tap do |payload|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
.merge(messages: prompt)
.tap do |b|
b[:max_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
payload[:stream] = true if stream
b[:stream] = true if stream
end
.to_json
end
@ -70,7 +91,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
let(:hf_mock) { HuggingFaceMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Llama2Classic, user)
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mistral, user)
end
describe "#perform_completion!" do

View File

@ -6,7 +6,7 @@ class VllmMock < EndpointMock
def response(content)
{
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
object: "text_completion",
object: "chat.completion",
created: 1_678_464_820,
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
usage: {
@ -14,14 +14,16 @@ class VllmMock < EndpointMock
completion_tokens: 162,
total_tokens: 499,
},
choices: [{ text: content, finish_reason: "stop", index: 0 }],
choices: [
{ message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 },
],
}
end
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
.with(body: model.default_options.merge(prompt: prompt).to_json)
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions")
.with(body: model.default_options.merge(messages: prompt).to_json)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
@ -30,7 +32,7 @@ class VllmMock < EndpointMock
id: "cmpl-#{SecureRandom.hex}",
created: 1_681_283_881,
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
choices: [{ text: delta, finish_reason: finish_reason, index: 0 }],
choices: [{ delta: { content: delta } }],
index: 0,
}.to_json
end
@ -48,8 +50,8 @@ class VllmMock < EndpointMock
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
.stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions")
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
.to_return(status: 200, body: chunks)
end
end
@ -67,14 +69,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
let(:anthropic_mock) { VllmMock.new(endpoint) }
let(:compliance) do
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mixtral, user)
EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mistral, user)
end
let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) }
let(:dialect) { DiscourseAi::Completions::Dialects::Mistral.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
let(:request_body) { model.default_options.merge(messages: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }

View File

@ -3,8 +3,8 @@
RSpec.describe DiscourseAi::Completions::Llm do
subject(:llm) do
described_class.new(
DiscourseAi::Completions::Dialects::OrcaStyle,
nil,
DiscourseAi::Completions::Dialects::Mistral,
canned_response,
"hugging_face:Upstage-Llama-2-*-instruct-v2",
gateway: canned_response,
)