diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb
index 7001bc08..f426925a 100644
--- a/app/models/ai_api_audit_log.rb
+++ b/app/models/ai_api_audit_log.rb
@@ -11,6 +11,7 @@ class AiApiAuditLog < ActiveRecord::Base
Gemini = 4
Vllm = 5
Cohere = 6
+ Ollama = 7
end
end
diff --git a/config/settings.yml b/config/settings.yml
index 0dbcf624..cff41ce5 100644
--- a/config/settings.yml
+++ b/config/settings.yml
@@ -185,6 +185,9 @@ discourse_ai:
ai_strict_token_counting:
default: false
hidden: true
+ ai_ollama_endpoint:
+ hidden: true
+ default: ""
composer_ai_helper_enabled:
default: false
diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb
index d935b336..3e3c932e 100644
--- a/lib/ai_bot/bot.rb
+++ b/lib/ai_bot/bot.rb
@@ -164,12 +164,15 @@ module DiscourseAi
when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID
"open_ai:gpt-3.5-turbo-16k"
when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID
- if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(
- "mistralai/Mixtral-8x7B-Instruct-v0.1",
- )
- "vllm:mistralai/Mixtral-8x7B-Instruct-v0.1"
+ mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1"
+ if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model)
+ "vllm:#{mixtral_model}"
+ elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?(
+ mixtral_model,
+ )
+ "hugging_face:#{mixtral_model}"
else
- "hugging_face:mistralai/Mixtral-8x7B-Instruct-v0.1"
+ "ollama:mistral"
end
when DiscourseAi::AiBot::EntryPoint::GEMINI_ID
"google:gemini-pro"
diff --git a/lib/automation.rb b/lib/automation.rb
index d1604fa6..b755f1db 100644
--- a/lib/automation.rb
+++ b/lib/automation.rb
@@ -40,8 +40,10 @@ module DiscourseAi
if model.start_with?("mistral")
if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(model)
return "vllm:#{model}"
+ elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?(model)
+ "hugging_face:#{model}"
else
- return "hugging_face:#{model}"
+ "ollama:mistral"
end
end
diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb
index 7368deff..9383019c 100644
--- a/lib/completions/dialects/chat_gpt.rb
+++ b/lib/completions/dialects/chat_gpt.rb
@@ -6,14 +6,7 @@ module DiscourseAi
class ChatGpt < Dialect
class << self
def can_translate?(model_name)
- %w[
- gpt-3.5-turbo
- gpt-4
- gpt-3.5-turbo-16k
- gpt-4-32k
- gpt-4-turbo
- gpt-4-vision-preview
- ].include?(model_name)
+ model_name.starts_with?("gpt-")
end
def tokenizer
@@ -23,72 +16,17 @@ module DiscourseAi
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
+ def native_tool_support?
+ true
+ end
+
def translate
- messages = prompt.messages
-
- # ChatGPT doesn't use an assistant msg to improve long-context responses.
- if messages.last[:type] == :model
- messages = messages.dup
- messages.pop
- end
-
- trimmed_messages = trim_messages(messages)
-
- embed_user_ids =
- trimmed_messages.any? do |m|
+ @embed_user_ids =
+ prompt.messages.any? do |m|
m[:id] && m[:type] == :user && !m[:id].to_s.match?(VALID_ID_REGEX)
end
- trimmed_messages.map do |msg|
- if msg[:type] == :system
- { role: "system", content: msg[:content] }
- elsif msg[:type] == :model
- { role: "assistant", content: msg[:content] }
- elsif msg[:type] == :tool_call
- call_details = JSON.parse(msg[:content], symbolize_names: true)
- call_details[:arguments] = call_details[:arguments].to_json
- call_details[:name] = msg[:name]
-
- {
- role: "assistant",
- content: nil,
- tool_calls: [{ type: "function", function: call_details, id: msg[:id] }],
- }
- elsif msg[:type] == :tool
- { role: "tool", tool_call_id: msg[:id], content: msg[:content], name: msg[:name] }
- else
- user_message = { role: "user", content: msg[:content] }
- if msg[:id]
- if embed_user_ids
- user_message[:content] = "#{msg[:id]}: #{msg[:content]}"
- else
- user_message[:name] = msg[:id]
- end
- end
- user_message[:content] = inline_images(user_message[:content], msg)
- user_message
- end
- end
- end
-
- def tools
- prompt.tools.map do |t|
- tool = t.dup
-
- tool[:parameters] = t[:parameters]
- .to_a
- .reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
- name = p[:name]
- memo[:required] << name if p[:required]
-
- memo[:properties][name] = p.except(:name, :required, :item_type)
-
- memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
- memo
- end
-
- { type: "function", function: tool }
- end
+ super
end
def max_prompt_tokens
@@ -107,6 +45,41 @@ module DiscourseAi
private
+ def tools_dialect
+ @tools_dialect ||= DiscourseAi::Completions::Dialects::OpenAiTools.new(prompt.tools)
+ end
+
+ def system_msg(msg)
+ { role: "system", content: msg[:content] }
+ end
+
+ def model_msg(msg)
+ { role: "assistant", content: msg[:content] }
+ end
+
+ def tool_call_msg(msg)
+ tools_dialect.from_raw_tool_call(msg)
+ end
+
+ def tool_msg(msg)
+ tools_dialect.from_raw_tool(msg)
+ end
+
+ def user_msg(msg)
+ user_message = { role: "user", content: msg[:content] }
+
+ if msg[:id]
+ if @embed_user_ids
+ user_message[:content] = "#{msg[:id]}: #{msg[:content]}"
+ else
+ user_message[:name] = msg[:id]
+ end
+ end
+
+ user_message[:content] = inline_images(user_message[:content], msg)
+ user_message
+ end
+
def inline_images(content, message)
if model_name.include?("gpt-4-vision") || model_name == "gpt-4-turbo"
content = message[:content]
diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb
index e3c93a59..9a15b293 100644
--- a/lib/completions/dialects/claude.rb
+++ b/lib/completions/dialects/claude.rb
@@ -27,41 +27,13 @@ module DiscourseAi
end
def translate
- messages = prompt.messages
- system_prompt = +""
+ messages = super
- messages =
- trim_messages(messages)
- .map do |msg|
- case msg[:type]
- when :system
- system_prompt << msg[:content]
- nil
- when :tool_call
- { role: "assistant", content: tool_call_to_xml(msg) }
- when :tool
- { role: "user", content: tool_result_to_xml(msg) }
- when :model
- { role: "assistant", content: msg[:content] }
- when :user
- content = +""
- content << "#{msg[:id]}: " if msg[:id]
- content << msg[:content]
- content = inline_images(content, msg)
-
- { role: "user", content: content }
- end
- end
- .compact
-
- if prompt.tools.present?
- system_prompt << "\n\n"
- system_prompt << build_tools_prompt
- end
+ system_prompt = messages.shift[:content] if messages.first[:role] == "system"
interleving_messages = []
-
previous_message = nil
+
messages.each do |message|
if previous_message
if previous_message[:role] == "user" && message[:role] == "user"
@@ -84,6 +56,29 @@ module DiscourseAi
private
+ def model_msg(msg)
+ { role: "assistant", content: msg[:content] }
+ end
+
+ def system_msg(msg)
+ msg = { role: "system", content: msg[:content] }
+
+ if tools_dialect.instructions.present?
+ msg[:content] = msg[:content].dup << "\n\n#{tools_dialect.instructions}"
+ end
+
+ msg
+ end
+
+ def user_msg(msg)
+ content = +""
+ content << "#{msg[:id]}: " if msg[:id]
+ content << msg[:content]
+ content = inline_images(content, msg)
+
+ { role: "user", content: content }
+ end
+
def inline_images(content, message)
if model_name.include?("claude-3")
encoded_uploads = prompt.encoded_uploads(message)
diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb
index 8b4bf67d..f119aba8 100644
--- a/lib/completions/dialects/command.rb
+++ b/lib/completions/dialects/command.rb
@@ -19,57 +19,17 @@ module DiscourseAi
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def translate
- messages = prompt.messages
+ messages = super
- # ChatGPT doesn't use an assistant msg to improve long-context responses.
- if messages.last[:type] == :model
- messages = messages.dup
- messages.pop
- end
+ system_message = messages.shift[:message] if messages.first[:role] == "SYSTEM"
- trimmed_messages = trim_messages(messages)
+ prompt = { preamble: +"#{system_message}" }
+ prompt[:chat_history] = messages if messages.present?
- chat_history = []
- system_message = nil
-
- prompt = {}
-
- trimmed_messages.each do |msg|
- case msg[:type]
- when :system
- if system_message
- chat_history << { role: "SYSTEM", message: msg[:content] }
- else
- system_message = msg[:content]
- end
- when :model
- chat_history << { role: "CHATBOT", message: msg[:content] }
- when :tool_call
- chat_history << { role: "CHATBOT", message: tool_call_to_xml(msg) }
- when :tool
- chat_history << { role: "USER", message: tool_result_to_xml(msg) }
- when :user
- user_message = { role: "USER", message: msg[:content] }
- user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id]
- chat_history << user_message
- end
- end
-
- tools_prompt = build_tools_prompt
- prompt[:preamble] = +"#{system_message}"
- if tools_prompt.present?
- prompt[:preamble] << "\n#{tools_prompt}"
- prompt[
- :preamble
- ] << "\nNEVER attempt to run tools using JSON, always use XML. Lives depend on it."
- end
-
- prompt[:chat_history] = chat_history if chat_history.present?
-
- chat_history.reverse_each do |msg|
+ messages.reverse_each do |msg|
if msg[:role] == "USER"
prompt[:message] = msg[:message]
- chat_history.delete(msg)
+ messages.delete(msg)
break
end
end
@@ -101,6 +61,43 @@ module DiscourseAi
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
+
+ def tools_dialect
+ @tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools)
+ end
+
+ def system_msg(msg)
+ cmd_msg = { role: "SYSTEM", message: msg[:content] }
+
+ if tools_dialect.instructions.present?
+ cmd_msg[:message] = [
+ msg[:content],
+ tools_dialect.instructions,
+ "NEVER attempt to run tools using JSON, always use XML. Lives depend on it.",
+ ].join("\n")
+ end
+
+ cmd_msg
+ end
+
+ def model_msg(msg)
+ { role: "CHATBOT", message: msg[:content] }
+ end
+
+ def tool_call_msg(msg)
+ { role: "CHATBOT", message: tools_dialect.from_raw_tool_call(msg) }
+ end
+
+ def tool_msg(msg)
+ { role: "USER", message: tools_dialect.from_raw_tool(msg) }
+ end
+
+ def user_msg(msg)
+ user_message = { role: "USER", message: msg[:content] }
+ user_message[:message] = "#{msg[:id]}: #{msg[:content]}" if msg[:id]
+
+ user_message
+ end
end
end
end
diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb
index 465884e2..865b5509 100644
--- a/lib/completions/dialects/dialect.rb
+++ b/lib/completions/dialects/dialect.rb
@@ -11,11 +11,9 @@ module DiscourseAi
def dialect_for(model_name)
dialects = [
- DiscourseAi::Completions::Dialects::Llama2Classic,
DiscourseAi::Completions::Dialects::ChatGpt,
- DiscourseAi::Completions::Dialects::OrcaStyle,
DiscourseAi::Completions::Dialects::Gemini,
- DiscourseAi::Completions::Dialects::Mixtral,
+ DiscourseAi::Completions::Dialects::Mistral,
DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Command,
]
@@ -32,40 +30,6 @@ module DiscourseAi
def tokenizer
raise NotImplemented
end
-
- def tool_preamble(include_array_tip: true)
- array_tip =
- if include_array_tip
- <<~TEXT
- If a parameter type is an array, return an array of values. For example:
- <$PARAMETER_NAME>["one","two","three"]$PARAMETER_NAME>
- TEXT
- else
- ""
- end
-
- <<~TEXT
- In this environment you have access to a set of tools you can use to answer the user's question.
- You may call them like this.
-
-
-
- $TOOL_NAME
-
- <$PARAMETER_NAME>$PARAMETER_VALUE$PARAMETER_NAME>
- ...
-
-
-
- #{array_tip}
- If you wish to call multiple function in one reply, wrap multiple
- block in a single block.
-
- Always prefer to lead with tool calls, if you need to execute any.
- Avoid all niceties prior to tool calls, Eg: "Let me look this up for you.." etc.
- Here are the complete list of tools available:
- TEXT
- end
end
def initialize(generic_prompt, model_name, opts: {})
@@ -74,74 +38,30 @@ module DiscourseAi
@opts = opts
end
- def translate
- raise NotImplemented
+ VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
+
+ def can_end_with_assistant_msg?
+ false
end
- def tool_result_to_xml(message)
- (<<~TEXT).strip
-
-
- #{message[:name] || message[:id]}
-
- #{message[:content]}
-
-
-
- TEXT
- end
-
- def tool_call_to_xml(message)
- parsed = JSON.parse(message[:content], symbolize_names: true)
- parameters = +""
-
- if parsed[:arguments]
- parameters << "\n"
- parsed[:arguments].each { |k, v| parameters << "<#{k}>#{v}#{k}>\n" }
- parameters << "\n"
- end
-
- (<<~TEXT).strip
-
-
- #{message[:name] || parsed[:name]}
- #{parameters}
-
- TEXT
+ def native_tool_support?
+ false
end
def tools
- tools = +""
+ @tools ||= tools_dialect.translated_tools
+ end
- prompt.tools.each do |function|
- parameters = +""
- if function[:parameters].present?
- function[:parameters].each do |parameter|
- parameters << <<~PARAMETER
-
- #{parameter[:name]}
- #{parameter[:type]}
- #{parameter[:description]}
- #{parameter[:required]}
- PARAMETER
- if parameter[:enum]
- parameters << "#{parameter[:enum].join(",")}\n"
- end
- parameters << "\n"
- end
- end
+ def translate
+ messages = prompt.messages
- tools << <<~TOOLS
-
- #{function[:name]}
- #{function[:description]}
-
- #{parameters}
-
- TOOLS
+ # Some models use an assistant msg to improve long-context responses.
+ if messages.last[:type] == :model && can_end_with_assistant_msg?
+ messages = messages.dup
+ messages.pop
end
- tools
+ trim_messages(messages).map { |msg| send("#{msg[:type]}_msg", msg) }.compact
end
def conversation_context
@@ -154,19 +74,6 @@ module DiscourseAi
attr_reader :prompt
- def build_tools_prompt
- return "" if prompt.tools.blank?
-
- has_arrays =
- prompt.tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } }
-
- (<<~TEXT).strip
- #{self.class.tool_preamble(include_array_tip: has_arrays)}
-
- #{tools}
- TEXT
- end
-
private
attr_reader :model_name, :opts
@@ -230,6 +137,30 @@ module DiscourseAi
def calculate_message_token(msg)
self.class.tokenizer.size(msg[:content].to_s)
end
+
+ def tools_dialect
+ @tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools)
+ end
+
+ def system_msg(msg)
+ raise NotImplemented
+ end
+
+ def assistant_msg(msg)
+ raise NotImplemented
+ end
+
+ def user_msg(msg)
+ raise NotImplemented
+ end
+
+ def tool_call_msg(msg)
+ { role: "assistant", content: tools_dialect.from_raw_tool_call(msg) }
+ end
+
+ def tool_msg(msg)
+ { role: "user", content: tools_dialect.from_raw_tool(msg) }
+ end
end
end
end
diff --git a/lib/completions/dialects/fake.rb b/lib/completions/dialects/fake.rb
index c569ee28..898f3364 100644
--- a/lib/completions/dialects/fake.rb
+++ b/lib/completions/dialects/fake.rb
@@ -9,14 +9,14 @@ module DiscourseAi
model_name == "fake"
end
- def translate
- ""
- end
-
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
+
+ def translate
+ ""
+ end
end
end
end
diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb
index a425d4f5..678dc0cd 100644
--- a/lib/completions/dialects/gemini.rb
+++ b/lib/completions/dialects/gemini.rb
@@ -14,59 +14,30 @@ module DiscourseAi
end
end
+ def native_tool_support?
+ true
+ end
+
def translate
# Gemini complains if we don't alternate model/user roles.
noop_model_response = { role: "model", parts: { text: "Ok." } }
+ messages = super
- messages = prompt.messages
+ interleving_messages = []
+ previous_message = nil
- # Gemini doesn't use an assistant msg to improve long-context responses.
- messages.pop if messages.last[:type] == :model
-
- memo = []
-
- trim_messages(messages).each do |msg|
- if msg[:type] == :system
- memo << { role: "user", parts: { text: msg[:content] } }
- memo << noop_model_response.dup
- elsif msg[:type] == :model
- memo << { role: "model", parts: { text: msg[:content] } }
- elsif msg[:type] == :tool_call
- call_details = JSON.parse(msg[:content], symbolize_names: true)
-
- memo << {
- role: "model",
- parts: {
- functionCall: {
- name: msg[:name] || call_details[:name],
- args: call_details[:arguments],
- },
- },
- }
- elsif msg[:type] == :tool
- memo << {
- role: "function",
- parts: {
- functionResponse: {
- name: msg[:name] || msg[:id],
- response: {
- content: msg[:content],
- },
- },
- },
- }
- else
- # Gemini quirk. Doesn't accept tool -> user or user -> user msgs.
- previous_msg_role = memo.last&.dig(:role)
- if previous_msg_role == "user" || previous_msg_role == "function"
- memo << noop_model_response.dup
+ messages.each do |message|
+ if previous_message
+ if (previous_message[:role] == "user" || previous_message[:role] == "function") &&
+ message[:role] == "user"
+ interleving_messages << noop_model_response.dup
end
-
- memo << { role: "user", parts: { text: msg[:content] } }
end
+ interleving_messages << message
+ previous_message = message
end
- memo
+ interleving_messages
end
def tools
@@ -110,6 +81,46 @@ module DiscourseAi
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
+
+ def system_msg(msg)
+ { role: "user", parts: { text: msg[:content] } }
+ end
+
+ def model_msg(msg)
+ { role: "model", parts: { text: msg[:content] } }
+ end
+
+ def user_msg(msg)
+ { role: "user", parts: { text: msg[:content] } }
+ end
+
+ def tool_call_msg(msg)
+ call_details = JSON.parse(msg[:content], symbolize_names: true)
+
+ {
+ role: "model",
+ parts: {
+ functionCall: {
+ name: msg[:name] || call_details[:name],
+ args: call_details[:arguments],
+ },
+ },
+ }
+ end
+
+ def tool_msg(msg)
+ {
+ role: "function",
+ parts: {
+ functionResponse: {
+ name: msg[:name] || msg[:id],
+ response: {
+ content: msg[:content],
+ },
+ },
+ },
+ }
+ end
end
end
end
diff --git a/lib/completions/dialects/llama2_classic.rb b/lib/completions/dialects/llama2_classic.rb
deleted file mode 100644
index 3b5675f1..00000000
--- a/lib/completions/dialects/llama2_classic.rb
+++ /dev/null
@@ -1,68 +0,0 @@
-# frozen_string_literal: true
-
-module DiscourseAi
- module Completions
- module Dialects
- class Llama2Classic < Dialect
- class << self
- def can_translate?(model_name)
- %w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
- end
-
- def tokenizer
- DiscourseAi::Tokenizer::Llama2Tokenizer
- end
- end
-
- def translate
- messages = prompt.messages
-
- llama2_prompt =
- trim_messages(messages).reduce(+"") do |memo, msg|
- next(memo) if msg[:type] == :tool_call
-
- if msg[:type] == :system
- memo << (<<~TEXT).strip
- [INST]
- <>
- #{msg[:content]}
- #{build_tools_prompt}
- <>
- [/INST]
- TEXT
- elsif msg[:type] == :model
- memo << "\n#{msg[:content]}"
- elsif msg[:type] == :tool
- JSON.parse(msg[:content], symbolize_names: true)
- memo << "\n[INST]\n"
-
- memo << (<<~TEXT).strip
-
-
- #{msg[:id]}
-
- #{msg[:content]}
-
-
-
- [/INST]
- TEXT
- else
- memo << "\n[INST]#{msg[:content]}[/INST]"
- end
-
- memo
- end
-
- llama2_prompt << "\n" if llama2_prompt.ends_with?("[/INST]")
-
- llama2_prompt
- end
-
- def max_prompt_tokens
- SiteSetting.ai_hugging_face_token_limit
- end
- end
- end
- end
-end
diff --git a/lib/completions/dialects/mistral.rb b/lib/completions/dialects/mistral.rb
new file mode 100644
index 00000000..7752a876
--- /dev/null
+++ b/lib/completions/dialects/mistral.rb
@@ -0,0 +1,57 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ module Dialects
+ class Mistral < Dialect
+ class << self
+ def can_translate?(model_name)
+ %w[
+ mistralai/Mixtral-8x7B-Instruct-v0.1
+ mistralai/Mistral-7B-Instruct-v0.2
+ mistral
+ ].include?(model_name)
+ end
+
+ def tokenizer
+ DiscourseAi::Tokenizer::MixtralTokenizer
+ end
+ end
+
+ def tools
+ @tools ||= tools_dialect.translated_tools
+ end
+
+ def max_prompt_tokens
+ 32_000
+ end
+
+ private
+
+ def system_msg(msg)
+ { role: "assistant", content: "#{msg[:content]}" }
+ end
+
+ def model_msg(msg)
+ { role: "assistant", content: msg[:content] }
+ end
+
+ def tool_call_msg(msg)
+ tools_dialect.from_raw_tool_call(msg)
+ end
+
+ def tool_msg(msg)
+ tools_dialect.from_raw_tool(msg)
+ end
+
+ def user_msg(msg)
+ content = +""
+ content << "#{msg[:id]}: " if msg[:id]
+ content << msg[:content]
+
+ { role: "user", content: content }
+ end
+ end
+ end
+ end
+end
diff --git a/lib/completions/dialects/mixtral.rb b/lib/completions/dialects/mixtral.rb
deleted file mode 100644
index 425d741e..00000000
--- a/lib/completions/dialects/mixtral.rb
+++ /dev/null
@@ -1,57 +0,0 @@
-# frozen_string_literal: true
-
-module DiscourseAi
- module Completions
- module Dialects
- class Mixtral < Dialect
- class << self
- def can_translate?(model_name)
- %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
- model_name,
- )
- end
-
- def tokenizer
- DiscourseAi::Tokenizer::MixtralTokenizer
- end
- end
-
- def translate
- messages = prompt.messages
-
- mixtral_prompt =
- trim_messages(messages).reduce(+"") do |memo, msg|
- if msg[:type] == :tool_call
- memo << "\n"
- memo << tool_call_to_xml(msg)
- elsif msg[:type] == :system
- memo << (<<~TEXT).strip
- [INST]
- #{msg[:content]}
- #{build_tools_prompt}
- [/INST] Ok
- TEXT
- elsif msg[:type] == :model
- memo << "\n#{msg[:content]}"
- elsif msg[:type] == :tool
- memo << "\n"
- memo << tool_result_to_xml(msg)
- else
- memo << "\n[INST]#{msg[:content]}[/INST]"
- end
-
- memo
- end
-
- mixtral_prompt << "\n" if mixtral_prompt.ends_with?("[/INST]")
-
- mixtral_prompt
- end
-
- def max_prompt_tokens
- 32_000
- end
- end
- end
- end
-end
diff --git a/lib/completions/dialects/open_ai_tools.rb b/lib/completions/dialects/open_ai_tools.rb
new file mode 100644
index 00000000..b990e379
--- /dev/null
+++ b/lib/completions/dialects/open_ai_tools.rb
@@ -0,0 +1,62 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ module Dialects
+ class OpenAiTools
+ def initialize(tools)
+ @raw_tools = tools
+ end
+
+ def translated_tools
+ raw_tools.map do |t|
+ tool = t.dup
+
+ tool[:parameters] = t[:parameters]
+ .to_a
+ .reduce({ type: "object", properties: {}, required: [] }) do |memo, p|
+ name = p[:name]
+ memo[:required] << name if p[:required]
+
+ memo[:properties][name] = p.except(:name, :required, :item_type)
+
+ memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
+ memo
+ end
+
+ { type: "function", function: tool }
+ end
+ end
+
+ def instructions
+ "" # Noop. Tools are listed separate.
+ end
+
+ def from_raw_tool_call(raw_message)
+ call_details = JSON.parse(raw_message[:content], symbolize_names: true)
+ call_details[:arguments] = call_details[:arguments].to_json
+ call_details[:name] = raw_message[:name]
+
+ {
+ role: "assistant",
+ content: nil,
+ tool_calls: [{ type: "function", function: call_details, id: raw_message[:id] }],
+ }
+ end
+
+ def from_raw_tool(raw_message)
+ {
+ role: "tool",
+ tool_call_id: raw_message[:id],
+ content: raw_message[:content],
+ name: raw_message[:name],
+ }
+ end
+
+ private
+
+ attr_reader :raw_tools
+ end
+ end
+ end
+end
diff --git a/lib/completions/dialects/orca_style.rb b/lib/completions/dialects/orca_style.rb
deleted file mode 100644
index 4ec42a36..00000000
--- a/lib/completions/dialects/orca_style.rb
+++ /dev/null
@@ -1,59 +0,0 @@
-# frozen_string_literal: true
-
-module DiscourseAi
- module Completions
- module Dialects
- class OrcaStyle < Dialect
- class << self
- def can_translate?(model_name)
- %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name)
- end
-
- def tokenizer
- DiscourseAi::Tokenizer::Llama2Tokenizer
- end
- end
-
- def translate
- messages = prompt.messages
- trimmed_messages = trim_messages(messages)
-
- # Need to include this differently
- last_message = trimmed_messages.last[:type] == :assistant ? trimmed_messages.pop : nil
-
- llama2_prompt =
- trimmed_messages.reduce(+"") do |memo, msg|
- if msg[:type] == :tool_call
- memo << "\n### Assistant:\n"
- memo << tool_call_to_xml(msg)
- elsif msg[:type] == :system
- memo << (<<~TEXT).strip
- ### System:
- #{msg[:content]}
- #{build_tools_prompt}
- TEXT
- elsif msg[:type] == :model
- memo << "\n### Assistant:\n#{msg[:content]}"
- elsif msg[:type] == :tool
- memo << "\n### User:\n"
- memo << tool_result_to_xml(msg)
- else
- memo << "\n### User:\n#{msg[:content]}"
- end
-
- memo
- end
-
- llama2_prompt << "\n### Assistant:\n"
- llama2_prompt << "#{last_message[:content]}:" if last_message
-
- llama2_prompt
- end
-
- def max_prompt_tokens
- SiteSetting.ai_hugging_face_token_limit
- end
- end
- end
- end
-end
diff --git a/lib/completions/dialects/xml_tools.rb b/lib/completions/dialects/xml_tools.rb
new file mode 100644
index 00000000..47988a71
--- /dev/null
+++ b/lib/completions/dialects/xml_tools.rb
@@ -0,0 +1,125 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ module Dialects
+ class XmlTools
+ def initialize(tools)
+ @raw_tools = tools
+ end
+
+ def translated_tools
+ raw_tools.reduce(+"") do |tools, function|
+ parameters = +""
+ if function[:parameters].present?
+ function[:parameters].each do |parameter|
+ parameters << <<~PARAMETER
+
+ #{parameter[:name]}
+ #{parameter[:type]}
+ #{parameter[:description]}
+ #{parameter[:required]}
+ PARAMETER
+ if parameter[:enum]
+ parameters << "#{parameter[:enum].join(",")}\n"
+ end
+ parameters << "\n"
+ end
+ end
+
+ tools << <<~TOOLS
+
+ #{function[:name]}
+ #{function[:description]}
+
+ #{parameters}
+
+ TOOLS
+ end
+ end
+
+ def instructions
+ return "" if raw_tools.blank?
+
+ has_arrays = raw_tools.any? { |tool| tool[:parameters]&.any? { |p| p[:type] == "array" } }
+
+ (<<~TEXT).strip
+ #{tool_preamble(include_array_tip: has_arrays)}
+
+ #{translated_tools}
+ TEXT
+ end
+
+ def from_raw_tool(raw_message)
+ (<<~TEXT).strip
+
+
+ #{raw_message[:name] || raw_message[:id]}
+
+ #{raw_message[:content]}
+
+
+
+ TEXT
+ end
+
+ def from_raw_tool_call(raw_message)
+ parsed = JSON.parse(raw_message[:content], symbolize_names: true)
+ parameters = +""
+
+ if parsed[:arguments]
+ parameters << "\n"
+ parsed[:arguments].each { |k, v| parameters << "<#{k}>#{v}#{k}>\n" }
+ parameters << "\n"
+ end
+
+ (<<~TEXT).strip
+
+
+ #{raw_message[:name] || parsed[:name]}
+ #{parameters}
+
+ TEXT
+ end
+
+ private
+
+ attr_reader :raw_tools
+
+ def tool_preamble(include_array_tip: true)
+ array_tip =
+ if include_array_tip
+ <<~TEXT
+ If a parameter type is an array, return an array of values. For example:
+ <$PARAMETER_NAME>["one","two","three"]$PARAMETER_NAME>
+ TEXT
+ else
+ ""
+ end
+
+ <<~TEXT
+ In this environment you have access to a set of tools you can use to answer the user's question.
+ You may call them like this.
+
+
+
+ $TOOL_NAME
+
+ <$PARAMETER_NAME>$PARAMETER_VALUE$PARAMETER_NAME>
+ ...
+
+
+
+ #{array_tip}
+ If you wish to call multiple function in one reply, wrap multiple
+ block in a single block.
+
+ Always prefer to lead with tool calls, if you need to execute any.
+ Avoid all niceties prior to tool calls, Eg: "Let me look this up for you.." etc.
+ Here are the complete list of tools available:
+ TEXT
+ end
+ end
+ end
+ end
+end
diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb
index 8c27a269..ee9d4f17 100644
--- a/lib/completions/endpoints/anthropic.rb
+++ b/lib/completions/endpoints/anthropic.rb
@@ -62,7 +62,7 @@ module DiscourseAi
# this is an approximation, we will update it later if request goes through
def prompt_size(prompt)
- super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
+ tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
end
def model_uri
diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb
index 5f17cad1..b7f3e8bf 100644
--- a/lib/completions/endpoints/aws_bedrock.rb
+++ b/lib/completions/endpoints/aws_bedrock.rb
@@ -51,7 +51,7 @@ module DiscourseAi
def prompt_size(prompt)
# approximation
- super(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
+ tokenizer.size(prompt.system_prompt.to_s + " " + prompt.messages.to_s)
end
def model_uri
diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index 0766a9a4..7a914a1e 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -19,6 +19,8 @@ module DiscourseAi
DiscourseAi::Completions::Endpoints::Cohere,
]
+ endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
+
if Rails.env.test? || Rails.env.development?
endpoints << DiscourseAi::Completions::Endpoints::Fake
end
@@ -67,6 +69,10 @@ module DiscourseAi
false
end
+ def use_ssl?
+ true
+ end
+
def perform_completion!(dialect, user, model_params = {}, &blk)
allow_tools = dialect.prompt.has_tools?
model_params = normalize_model_params(model_params)
@@ -78,7 +84,7 @@ module DiscourseAi
FinalDestination::HTTP.start(
model_uri.host,
model_uri.port,
- use_ssl: true,
+ use_ssl: use_ssl?,
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
@@ -315,7 +321,7 @@ module DiscourseAi
end
def extract_prompt_for_tokenizer(prompt)
- prompt
+ prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
def build_buffer
diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb
index 5542c73f..d6237c05 100644
--- a/lib/completions/endpoints/hugging_face.rb
+++ b/lib/completions/endpoints/hugging_face.rb
@@ -8,14 +8,9 @@ module DiscourseAi
def can_contact?(endpoint_name, model_name)
return false unless endpoint_name == "hugging_face"
- %w[
- StableBeluga2
- Upstage-Llama-2-*-instruct-v2
- Llama2-*-chat-hf
- Llama2-chat-hf
- mistralai/Mixtral-8x7B-Instruct-v0.1
- mistralai/Mistral-7B-Instruct-v0.2
- ].include?(model_name)
+ %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
+ model_name,
+ )
end
def dependant_setting_names
@@ -31,24 +26,21 @@ module DiscourseAi
end
end
- def default_options
- { parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } }
- end
-
def normalize_model_params(model_params)
model_params = model_params.dup
+ # max_tokens, temperature are already supported
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
- if model_params[:max_tokens]
- model_params[:max_new_tokens] = model_params.delete(:max_tokens)
- end
-
model_params
end
+ def default_options
+ { model: model, temperature: 0.7 }
+ end
+
def provider_id
AiApiAuditLog::Provider::HuggingFaceTextGeneration
end
@@ -61,13 +53,14 @@ module DiscourseAi
def prepare_payload(prompt, model_params, _dialect)
default_options
- .merge(inputs: prompt)
+ .merge(model_params)
+ .merge(messages: prompt)
.tap do |payload|
- payload[:parameters].merge!(model_params)
+ if !payload[:max_tokens]
+ token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
- token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
-
- payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
+ payload[:max_tokens] = token_limit - prompt_size(prompt)
+ end
payload[:stream] = true if @streaming_mode
end
@@ -85,16 +78,13 @@ module DiscourseAi
end
def extract_completion_from(response_raw)
- parsed = JSON.parse(response_raw, symbolize_names: true)
+ parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
+ # half a line sent here
+ return if !parsed
- if @streaming_mode
- # Last chunk contains full response, which we already yielded.
- return if parsed.dig(:token, :special)
+ response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
- parsed.dig(:token, :text).to_s
- else
- parsed[0][:generated_text].to_s
- end
+ response_h.dig(:content)
end
def partials_from(decoded_chunk)
diff --git a/lib/completions/endpoints/ollama.rb b/lib/completions/endpoints/ollama.rb
new file mode 100644
index 00000000..0fd748d4
--- /dev/null
+++ b/lib/completions/endpoints/ollama.rb
@@ -0,0 +1,89 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ module Endpoints
+ class Ollama < Base
+ class << self
+ def can_contact?(endpoint_name, model_name)
+ endpoint_name == "ollama" && %w[mistral].include?(model_name)
+ end
+
+ def dependant_setting_names
+ %w[ai_ollama_endpoint]
+ end
+
+ def correctly_configured?(_model_name)
+ SiteSetting.ai_ollama_endpoint.present?
+ end
+
+ def endpoint_name(model_name)
+ "Ollama - #{model_name}"
+ end
+ end
+
+ def normalize_model_params(model_params)
+ model_params = model_params.dup
+
+ # max_tokens, temperature are already supported
+ if model_params[:stop_sequences]
+ model_params[:stop] = model_params.delete(:stop_sequences)
+ end
+
+ model_params
+ end
+
+ def default_options
+ { max_tokens: 2000, model: model }
+ end
+
+ def provider_id
+ AiApiAuditLog::Provider::Ollama
+ end
+
+ def use_ssl?
+ false
+ end
+
+ private
+
+ def model_uri
+ URI("#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
+ end
+
+ def prepare_payload(prompt, model_params, _dialect)
+ default_options
+ .merge(model_params)
+ .merge(messages: prompt)
+ .tap { |payload| payload[:stream] = true if @streaming_mode }
+ end
+
+ def prepare_request(payload)
+ headers = { "Content-Type" => "application/json" }
+
+ Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
+ end
+
+ def partials_from(decoded_chunk)
+ decoded_chunk
+ .split("\n")
+ .map do |line|
+ data = line.split("data: ", 2)[1]
+ data == "[DONE]" ? nil : data
+ end
+ .compact
+ end
+
+ def extract_completion_from(response_raw)
+ parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
+ # half a line sent here
+ return if !parsed
+
+ response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
+
+ response_h.dig(:content)
+ end
+ end
+ end
+ end
+end
diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb
index 16e0b886..2ccd817e 100644
--- a/lib/completions/endpoints/open_ai.rb
+++ b/lib/completions/endpoints/open_ai.rb
@@ -153,10 +153,6 @@ module DiscourseAi
.compact
end
- def extract_prompt_for_tokenizer(prompt)
- prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
- end
-
def has_tool?(_response_data)
@has_function_call
end
diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb
index 310bcdc9..7db1452d 100644
--- a/lib/completions/endpoints/vllm.rb
+++ b/lib/completions/endpoints/vllm.rb
@@ -7,14 +7,9 @@ module DiscourseAi
class << self
def can_contact?(endpoint_name, model_name)
endpoint_name == "vllm" &&
- %w[
- mistralai/Mixtral-8x7B-Instruct-v0.1
- mistralai/Mistral-7B-Instruct-v0.2
- StableBeluga2
- Upstage-Llama-2-*-instruct-v2
- Llama2-*-chat-hf
- Llama2-chat-hf
- ].include?(model_name)
+ %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2].include?(
+ model_name,
+ )
end
def dependant_setting_names
@@ -54,9 +49,9 @@ module DiscourseAi
def model_uri
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
if service.present?
- api_endpoint = "https://#{service.target}:#{service.port}/v1/completions"
+ api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
else
- api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/completions"
+ api_endpoint = "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions"
end
@uri ||= URI(api_endpoint)
end
@@ -64,7 +59,7 @@ module DiscourseAi
def prepare_payload(prompt, model_params, _dialect)
default_options
.merge(model_params)
- .merge(prompt: prompt)
+ .merge(messages: prompt)
.tap { |payload| payload[:stream] = true if @streaming_mode }
end
@@ -76,15 +71,6 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
- def extract_completion_from(response_raw)
- parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
-
- # half a line sent here
- return if !parsed
-
- parsed.dig(:text)
- end
-
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
@@ -94,6 +80,16 @@ module DiscourseAi
end
.compact
end
+
+ def extract_completion_from(response_raw)
+ parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
+ # half a line sent here
+ return if !parsed
+
+ response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
+
+ response_h.dig(:content)
+ end
end
end
end
diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb
index c54c0dd4..47190644 100644
--- a/lib/completions/llm.rb
+++ b/lib/completions/llm.rb
@@ -31,21 +31,10 @@ module DiscourseAi
claude-3-opus
],
anthropic: %w[claude-instant-1 claude-2 claude-3-haiku claude-3-sonnet claude-3-opus],
- vllm: %w[
- mistralai/Mixtral-8x7B-Instruct-v0.1
- mistralai/Mistral-7B-Instruct-v0.2
- StableBeluga2
- Upstage-Llama-2-*-instruct-v2
- Llama2-*-chat-hf
- Llama2-chat-hf
- ],
+ vllm: %w[mistralai/Mixtral-8x7B-Instruct-v0.1 mistralai/Mistral-7B-Instruct-v0.2],
hugging_face: %w[
mistralai/Mixtral-8x7B-Instruct-v0.1
mistralai/Mistral-7B-Instruct-v0.2
- StableBeluga2
- Upstage-Llama-2-*-instruct-v2
- Llama2-*-chat-hf
- Llama2-chat-hf
],
cohere: %w[command-light command command-r command-r-plus],
open_ai: %w[
@@ -57,7 +46,10 @@ module DiscourseAi
gpt-4-vision-preview
],
google: %w[gemini-pro gemini-1.5-pro],
- }.tap { |h| h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development? }
+ }.tap do |h|
+ h[:ollama] = ["mistral"] if Rails.env.development?
+ h[:fake] = ["fake"] if Rails.env.test? || Rails.env.development?
+ end
end
def valid_provider_models
@@ -120,8 +112,6 @@ module DiscourseAi
@gateway = gateway
end
- delegate :tokenizer, to: :dialect_klass
-
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
# @param user { User } - User requesting the summary.
#
@@ -184,6 +174,8 @@ module DiscourseAi
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
end
+ delegate :tokenizer, to: :dialect_klass
+
attr_reader :model_name
private
diff --git a/lib/summarization/entry_point.rb b/lib/summarization/entry_point.rb
index 9e60eb67..37c72725 100644
--- a/lib/summarization/entry_point.rb
+++ b/lib/summarization/entry_point.rb
@@ -10,14 +10,6 @@ module DiscourseAi
Models::OpenAi.new("open_ai:gpt-4-turbo", max_tokens: 100_000),
Models::OpenAi.new("open_ai:gpt-3.5-turbo", max_tokens: 4096),
Models::OpenAi.new("open_ai:gpt-3.5-turbo-16k", max_tokens: 16_384),
- Models::Llama2.new(
- "hugging_face:Llama2-chat-hf",
- max_tokens: SiteSetting.ai_hugging_face_token_limit,
- ),
- Models::Llama2FineTunedOrcaStyle.new(
- "hugging_face:StableBeluga2",
- max_tokens: SiteSetting.ai_hugging_face_token_limit,
- ),
Models::Gemini.new("google:gemini-pro", max_tokens: 32_768),
Models::Gemini.new("google:gemini-1.5-pro", max_tokens: 800_000),
]
diff --git a/spec/lib/completions/dialects/dialect_spec.rb b/spec/lib/completions/dialects/dialect_spec.rb
index 2410f415..c54d1838 100644
--- a/spec/lib/completions/dialects/dialect_spec.rb
+++ b/spec/lib/completions/dialects/dialect_spec.rb
@@ -17,47 +17,6 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect
end
RSpec.describe DiscourseAi::Completions::Dialects::Dialect do
- describe "#build_tools_prompt" do
- it "can exclude array instructions" do
- prompt = DiscourseAi::Completions::Prompt.new("12345")
- prompt.tools = [
- {
- name: "weather",
- description: "lookup weather in a city",
- parameters: [{ name: "city", type: "string", description: "city name", required: true }],
- },
- ]
-
- dialect = TestDialect.new(prompt, "test")
-
- expect(dialect.build_tools_prompt).not_to include("array")
- end
-
- it "can include array instructions" do
- prompt = DiscourseAi::Completions::Prompt.new("12345")
- prompt.tools = [
- {
- name: "weather",
- description: "lookup weather in a city",
- parameters: [{ name: "city", type: "array", description: "city names", required: true }],
- },
- ]
-
- dialect = TestDialect.new(prompt, "test")
-
- expect(dialect.build_tools_prompt).to include("array")
- end
-
- it "does not break if there are no params" do
- prompt = DiscourseAi::Completions::Prompt.new("12345")
- prompt.tools = [{ name: "categories", description: "lookup all categories" }]
-
- dialect = TestDialect.new(prompt, "test")
-
- expect(dialect.build_tools_prompt).not_to include("array")
- end
- end
-
describe "#trim_messages" do
it "should trim tool messages if tool_calls are trimmed" do
prompt = DiscourseAi::Completions::Prompt.new("12345")
diff --git a/spec/lib/completions/dialects/llama2_classic_spec.rb b/spec/lib/completions/dialects/llama2_classic_spec.rb
deleted file mode 100644
index 0242ebf6..00000000
--- a/spec/lib/completions/dialects/llama2_classic_spec.rb
+++ /dev/null
@@ -1,62 +0,0 @@
-# frozen_string_literal: true
-
-require_relative "dialect_context"
-
-RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
- let(:model_name) { "Llama2-chat-hf" }
- let(:context) { DialectContext.new(described_class, model_name) }
-
- describe "#translate" do
- it "translates a prompt written in our generic format to the Llama2 format" do
- llama2_classic_version = <<~TEXT
- [INST]
- <>
- #{context.system_insts}
- #{described_class.tool_preamble(include_array_tip: false)}
-
- #{context.dialect_tools}
- <>
- [/INST]
- [INST]#{context.simple_user_input}[/INST]
- TEXT
-
- translated = context.system_user_scenario
-
- expect(translated).to eq(llama2_classic_version)
- end
-
- it "translates tool messages" do
- expected = +(<<~TEXT)
- [INST]
- <>
- #{context.system_insts}
- #{described_class.tool_preamble(include_array_tip: false)}
-
- #{context.dialect_tools}
- <>
- [/INST]
- [INST]This is a message by a user[/INST]
- I'm a previous bot reply, that's why there's no user
- [INST]This is a new message by a user[/INST]
- [INST]
-
-
- tool_id
-
- "I'm a tool result"
-
-
-
- [/INST]
- TEXT
-
- expect(context.multi_turn_scenario).to eq(expected)
- end
-
- it "trims content if it's getting too long" do
- translated = context.long_user_input_scenario
-
- expect(translated.length).to be < context.long_message_text.length
- end
- end
-end
diff --git a/spec/lib/completions/dialects/mixtral_spec.rb b/spec/lib/completions/dialects/mixtral_spec.rb
deleted file mode 100644
index 499dad73..00000000
--- a/spec/lib/completions/dialects/mixtral_spec.rb
+++ /dev/null
@@ -1,66 +0,0 @@
-# frozen_string_literal: true
-
-require_relative "dialect_context"
-
-RSpec.describe DiscourseAi::Completions::Dialects::Mixtral do
- let(:model_name) { "mistralai/Mixtral-8x7B-Instruct-v0.1" }
- let(:context) { DialectContext.new(described_class, model_name) }
-
- describe "#translate" do
- it "translates a prompt written in our generic format to the Llama2 format" do
- llama2_classic_version = <<~TEXT
- [INST]
- #{context.system_insts}
- #{described_class.tool_preamble(include_array_tip: false)}
-
- #{context.dialect_tools}
- [/INST] Ok
- [INST]#{context.simple_user_input}[/INST]
- TEXT
-
- translated = context.system_user_scenario
-
- expect(translated).to eq(llama2_classic_version)
- end
-
- it "translates tool messages" do
- expected = +(<<~TEXT).strip
- [INST]
- #{context.system_insts}
- #{described_class.tool_preamble(include_array_tip: false)}
-
- #{context.dialect_tools}
- [/INST] Ok
- [INST]This is a message by a user[/INST]
- I'm a previous bot reply, that's why there's no user
- [INST]This is a new message by a user[/INST]
-
-
- get_weather
-
- Sydney
- c
-
-
-
-
-
- get_weather
-
- "I'm a tool result"
-
-
-
- TEXT
-
- expect(context.multi_turn_scenario).to eq(expected)
- end
-
- it "trims content if it's getting too long" do
- length = 6_000
- translated = context.long_user_input_scenario(length: length)
-
- expect(translated.length).to be < context.long_message_text(length: length).length
- end
- end
-end
diff --git a/spec/lib/completions/dialects/orca_style_spec.rb b/spec/lib/completions/dialects/orca_style_spec.rb
deleted file mode 100644
index 63b414b8..00000000
--- a/spec/lib/completions/dialects/orca_style_spec.rb
+++ /dev/null
@@ -1,71 +0,0 @@
-# frozen_string_literal: true
-
-require_relative "dialect_context"
-
-RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
- let(:model_name) { "StableBeluga2" }
- let(:context) { DialectContext.new(described_class, model_name) }
-
- describe "#translate" do
- it "translates a prompt written in our generic format to the Llama2 format" do
- llama2_classic_version = <<~TEXT
- ### System:
- #{context.system_insts}
- #{described_class.tool_preamble(include_array_tip: false)}
-
- #{context.dialect_tools}
- ### User:
- #{context.simple_user_input}
- ### Assistant:
- TEXT
-
- translated = context.system_user_scenario
-
- expect(translated).to eq(llama2_classic_version)
- end
-
- it "translates tool messages" do
- expected = +(<<~TEXT)
- ### System:
- #{context.system_insts}
- #{described_class.tool_preamble(include_array_tip: false)}
-
- #{context.dialect_tools}
- ### User:
- This is a message by a user
- ### Assistant:
- I'm a previous bot reply, that's why there's no user
- ### User:
- This is a new message by a user
- ### Assistant:
-
-
- get_weather
-
- Sydney
- c
-
-
-
- ### User:
-
-
- get_weather
-
- "I'm a tool result"
-
-
-
- ### Assistant:
- TEXT
-
- expect(context.multi_turn_scenario).to eq(expected)
- end
-
- it "trims content if it's getting too long" do
- translated = context.long_user_input_scenario
-
- expect(translated.length).to be < context.long_message_text.length
- end
- end
-end
diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb
index bfdfa74e..5b9bd9f5 100644
--- a/spec/lib/completions/endpoints/hugging_face_spec.rb
+++ b/spec/lib/completions/endpoints/hugging_face_spec.rb
@@ -4,7 +4,20 @@ require_relative "endpoint_compliance"
class HuggingFaceMock < EndpointMock
def response(content)
- [{ generated_text: content }]
+ {
+ id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
+ object: "chat.completion",
+ created: 1_678_464_820,
+ model: "Llama2-*-chat-hf",
+ usage: {
+ prompt_tokens: 337,
+ completion_tokens: 162,
+ total_tokens: 499,
+ },
+ choices: [
+ { message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 },
+ ],
+ }
end
def stub_response(prompt, response_text, tool_call: false)
@@ -14,26 +27,32 @@ class HuggingFaceMock < EndpointMock
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
- def stream_line(delta, deltas, finish_reason: nil)
+ def stream_line(delta, finish_reason: nil)
+"data: " << {
- token: {
- id: 29_889,
- text: delta,
- logprob: -0.08319092,
- special: !!finish_reason,
- },
- generated_text: finish_reason ? deltas.join : nil,
- details: nil,
+ id: "chatcmpl-#{SecureRandom.hex}",
+ object: "chat.completion.chunk",
+ created: 1_681_283_881,
+ model: "Llama2-*-chat-hf",
+ choices: [{ delta: { content: delta } }],
+ finish_reason: finish_reason,
+ index: 0,
}.to_json
end
+ def stub_raw(chunks)
+ WebMock.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}").to_return(
+ status: 200,
+ body: chunks,
+ )
+ end
+
def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
- stream_line(deltas[index], deltas, finish_reason: true)
+ stream_line(deltas[index], finish_reason: "stop_sequence")
else
- stream_line(deltas[index], deltas)
+ stream_line(deltas[index])
end
end
@@ -43,16 +62,18 @@ class HuggingFaceMock < EndpointMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.with(body: request_body(prompt, stream: true))
.to_return(status: 200, body: chunks)
+
+ yield if block_given?
end
- def request_body(prompt, stream: false)
+ def request_body(prompt, stream: false, tool_call: false)
model
.default_options
- .merge(inputs: prompt)
- .tap do |payload|
- payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
+ .merge(messages: prompt)
+ .tap do |b|
+ b[:max_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
- payload[:stream] = true if stream
+ b[:stream] = true if stream
end
.to_json
end
@@ -70,7 +91,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
let(:hf_mock) { HuggingFaceMock.new(endpoint) }
let(:compliance) do
- EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Llama2Classic, user)
+ EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mistral, user)
end
describe "#perform_completion!" do
diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb
index 52d87007..d879ad09 100644
--- a/spec/lib/completions/endpoints/vllm_spec.rb
+++ b/spec/lib/completions/endpoints/vllm_spec.rb
@@ -6,7 +6,7 @@ class VllmMock < EndpointMock
def response(content)
{
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
- object: "text_completion",
+ object: "chat.completion",
created: 1_678_464_820,
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
usage: {
@@ -14,14 +14,16 @@ class VllmMock < EndpointMock
completion_tokens: 162,
total_tokens: 499,
},
- choices: [{ text: content, finish_reason: "stop", index: 0 }],
+ choices: [
+ { message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 },
+ ],
}
end
def stub_response(prompt, response_text, tool_call: false)
WebMock
- .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
- .with(body: model.default_options.merge(prompt: prompt).to_json)
+ .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions")
+ .with(body: model.default_options.merge(messages: prompt).to_json)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
@@ -30,7 +32,7 @@ class VllmMock < EndpointMock
id: "cmpl-#{SecureRandom.hex}",
created: 1_681_283_881,
model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
- choices: [{ text: delta, finish_reason: finish_reason, index: 0 }],
+ choices: [{ delta: { content: delta } }],
index: 0,
}.to_json
end
@@ -48,8 +50,8 @@ class VllmMock < EndpointMock
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock
- .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/completions")
- .with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
+ .stub_request(:post, "#{SiteSetting.ai_vllm_endpoint}/v1/chat/completions")
+ .with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
.to_return(status: 200, body: chunks)
end
end
@@ -67,14 +69,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
let(:anthropic_mock) { VllmMock.new(endpoint) }
let(:compliance) do
- EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mixtral, user)
+ EndpointsCompliance.new(self, endpoint, DiscourseAi::Completions::Dialects::Mistral, user)
end
- let(:dialect) { DiscourseAi::Completions::Dialects::Mixtral.new(generic_prompt, model_name) }
+ let(:dialect) { DiscourseAi::Completions::Dialects::Mistral.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
- let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
- let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
+ let(:request_body) { model.default_options.merge(messages: prompt).to_json }
+ let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb
index 3aeeb04a..e91a4a60 100644
--- a/spec/lib/completions/llm_spec.rb
+++ b/spec/lib/completions/llm_spec.rb
@@ -3,8 +3,8 @@
RSpec.describe DiscourseAi::Completions::Llm do
subject(:llm) do
described_class.new(
- DiscourseAi::Completions::Dialects::OrcaStyle,
- nil,
+ DiscourseAi::Completions::Dialects::Mistral,
+ canned_response,
"hugging_face:Upstage-Llama-2-*-instruct-v2",
gateway: canned_response,
)