diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb
index 82595225..dda3e8a2 100644
--- a/lib/completions/dialects/chat_gpt.rb
+++ b/lib/completions/dialects/chat_gpt.rb
@@ -3,31 +3,99 @@
module DiscourseAi
module Completions
module Dialects
- class ChatGpt
- def self.can_translate?(model_name)
- %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
+ class ChatGpt < Dialect
+ class << self
+ def can_translate?(model_name)
+ %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
+ end
+
+ def tokenizer
+ DiscourseAi::Tokenizer::OpenAiTokenizer
+ end
end
- def translate(generic_prompt)
+ def translate
open_ai_prompt = [
- {
- role: "system",
- content: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
- },
+ { role: "system", content: [prompt[:insts], prompt[:post_insts].to_s].join("\n") },
]
- if generic_prompt[:examples]
- generic_prompt[:examples].each do |example_pair|
+ if prompt[:examples]
+ prompt[:examples].each do |example_pair|
open_ai_prompt << { role: "user", content: example_pair.first }
open_ai_prompt << { role: "assistant", content: example_pair.second }
end
end
- open_ai_prompt << { role: "user", content: generic_prompt[:input] }
+ open_ai_prompt.concat!(conversation_context) if prompt[:conversation_context]
+
+ open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input]
+
+ open_ai_prompt
end
- def tokenizer
- DiscourseAi::Tokenizer::OpenAiTokenizer
+ def tools
+ return if prompt[:tools].blank?
+
+ prompt[:tools].map { |t| { type: "function", tool: t } }
+ end
+
+ def conversation_context
+ return [] if prompt[:conversation_context].blank?
+
+ trimmed_context = trim_context(prompt[:conversation_context])
+
+ trimmed_context.reverse.map do |context|
+ translated = context.slice(:content)
+ translated[:role] = context[:type]
+
+ if context[:name]
+ if translated[:role] == "tool"
+ translated[:tool_call_id] = context[:name]
+ else
+ translated[:name] = context[:name]
+ end
+ end
+
+ translated
+ end
+ end
+
+ def max_prompt_tokens
+ # provide a buffer of 120 tokens - our function counting is not
+ # 100% accurate and getting numbers to align exactly is very hard
+ buffer = (opts[:max_tokens_to_sample] || 2500) + 50
+
+ if tools.present?
+ # note this is about 100 tokens over, OpenAI have a more optimal representation
+ @function_size ||= self.class.tokenizer.size(tools.to_json.to_s)
+ buffer += @function_size
+ end
+
+ model_max_tokens - buffer
+ end
+
+ private
+
+ def per_message_overhead
+ # open ai defines about 4 tokens per message of overhead
+ 4
+ end
+
+ def calculate_message_token(context)
+ self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
+ end
+
+ def model_max_tokens
+ case model_name
+ when "gpt-3.5-turbo", "gpt-3.5-turbo-16k"
+ 16_384
+ when "gpt-4"
+ 8192
+ when "gpt-4-32k"
+ 32_768
+ else
+ 8192
+ end
end
end
end
diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb
index 07438985..f10c4c49 100644
--- a/lib/completions/dialects/claude.rb
+++ b/lib/completions/dialects/claude.rb
@@ -3,25 +3,65 @@
module DiscourseAi
module Completions
module Dialects
- class Claude
- def self.can_translate?(model_name)
- %w[claude-instant-1 claude-2].include?(model_name)
+ class Claude < Dialect
+ class << self
+ def can_translate?(model_name)
+ %w[claude-instant-1 claude-2].include?(model_name)
+ end
+
+ def tokenizer
+ DiscourseAi::Tokenizer::AnthropicTokenizer
+ end
end
- def translate(generic_prompt)
- claude_prompt = +"Human: #{generic_prompt[:insts]}\n"
+ def translate
+ claude_prompt = +"Human: #{prompt[:insts]}\n"
- claude_prompt << build_examples(generic_prompt[:examples]) if generic_prompt[:examples]
+ claude_prompt << build_tools_prompt if prompt[:tools]
- claude_prompt << "#{generic_prompt[:input]}\n"
+ claude_prompt << build_examples(prompt[:examples]) if prompt[:examples]
- claude_prompt << "#{generic_prompt[:post_insts]}\n" if generic_prompt[:post_insts]
+ claude_prompt << conversation_context if prompt[:conversation_context]
+
+ claude_prompt << "#{prompt[:input]}\n"
+
+ claude_prompt << "#{prompt[:post_insts]}\n" if prompt[:post_insts]
claude_prompt << "Assistant:\n"
end
- def tokenizer
- DiscourseAi::Tokenizer::AnthropicTokenizer
+ def max_prompt_tokens
+ 50_000
+ end
+
+ def conversation_context
+ return "" if prompt[:conversation_context].blank?
+
+ trimmed_context = trim_context(prompt[:conversation_context])
+
+ trimmed_context
+ .reverse
+ .reduce(+"") do |memo, context|
+ memo << (context[:type] == "user" ? "Human:" : "Assistant:")
+
+ if context[:type] == "tool"
+ memo << <<~TEXT
+
+
+
+ #{context[:name]}
+
+ #{context[:content]}
+
+
+
+ TEXT
+ else
+ memo << " " << context[:content] << "\n"
+ end
+
+ memo
+ end
end
private
diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb
new file mode 100644
index 00000000..5e6d6b97
--- /dev/null
+++ b/lib/completions/dialects/dialect.rb
@@ -0,0 +1,160 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ module Dialects
+ class Dialect
+ class << self
+ def can_translate?(_model_name)
+ raise NotImplemented
+ end
+
+ def dialect_for(model_name)
+ dialects = [
+ DiscourseAi::Completions::Dialects::Claude,
+ DiscourseAi::Completions::Dialects::Llama2Classic,
+ DiscourseAi::Completions::Dialects::ChatGpt,
+ DiscourseAi::Completions::Dialects::OrcaStyle,
+ DiscourseAi::Completions::Dialects::Gemini,
+ ]
+ dialects.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |d|
+ d.can_translate?(model_name)
+ end
+ end
+
+ def tokenizer
+ raise NotImplemented
+ end
+ end
+
+ def initialize(generic_prompt, model_name, opts: {})
+ @prompt = generic_prompt
+ @model_name = model_name
+ @opts = opts
+ end
+
+ def translate
+ raise NotImplemented
+ end
+
+ def tools
+ tools = +""
+
+ prompt[:tools].each do |function|
+ parameters = +""
+ if function[:parameters].present?
+ function[:parameters].each do |parameter|
+ parameters << <<~PARAMETER
+
+ #{parameter[:name]}
+ #{parameter[:type]}
+ #{parameter[:description]}
+ #{parameter[:required]}
+ PARAMETER
+ if parameter[:enum]
+ parameters << "#{parameter[:enum].join(",")}\n"
+ end
+ parameters << "\n"
+ end
+ end
+
+ tools << <<~TOOLS
+
+ #{function[:name]}
+ #{function[:description]}
+
+ #{parameters}
+
+ TOOLS
+ end
+
+ tools
+ end
+
+ def conversation_context
+ raise NotImplemented
+ end
+
+ def max_prompt_tokens
+ raise NotImplemented
+ end
+
+ private
+
+ attr_reader :prompt, :model_name, :opts
+
+ def trim_context(conversation_context)
+ prompt_limit = max_prompt_tokens
+ current_token_count = calculate_token_count_without_context
+
+ conversation_context.reduce([]) do |memo, context|
+ break(memo) if current_token_count >= prompt_limit
+
+ dupped_context = context.dup
+
+ message_tokens = calculate_message_token(dupped_context)
+
+ # Trimming content to make sure we respect token limit.
+ while dupped_context[:content].present? &&
+ message_tokens + current_token_count + per_message_overhead > prompt_limit
+ dupped_context[:content] = dupped_context[:content][0..-100] || ""
+ message_tokens = calculate_message_token(dupped_context)
+ end
+
+ next(memo) if dupped_context[:content].blank?
+
+ current_token_count += message_tokens + per_message_overhead
+
+ memo << dupped_context
+ end
+ end
+
+ def calculate_token_count_without_context
+ tokenizer = self.class.tokenizer
+
+ examples_count =
+ prompt[:examples].to_a.sum do |pair|
+ tokenizer.size(pair.join) + (per_message_overhead * 2)
+ end
+ input_count = tokenizer.size(prompt[:input].to_s) + per_message_overhead
+
+ examples_count + input_count +
+ prompt
+ .except(:conversation_context, :tools, :examples, :input)
+ .sum { |_, v| tokenizer.size(v) + per_message_overhead }
+ end
+
+ def per_message_overhead
+ 0
+ end
+
+ def calculate_message_token(context)
+ self.class.tokenizer.size(context[:content].to_s)
+ end
+
+ def build_tools_prompt
+ return "" if prompt[:tools].blank?
+
+ <<~TEXT
+ In this environment you have access to a set of tools you can use to answer the user's question.
+ You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
+
+
+ $TOOL_NAME
+
+ <$PARAMETER_NAME>$PARAMETER_VALUE$PARAMETER_NAME>
+ ...
+
+
+
+
+ Here are the tools available:
+
+
+ #{tools}
+ TEXT
+ end
+ end
+ end
+ end
+end
diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb
index 49fa716e..b6b938bf 100644
--- a/lib/completions/dialects/gemini.rb
+++ b/lib/completions/dialects/gemini.rb
@@ -3,34 +3,91 @@
module DiscourseAi
module Completions
module Dialects
- class Gemini
- def self.can_translate?(model_name)
- %w[gemini-pro].include?(model_name)
+ class Gemini < Dialect
+ class << self
+ def can_translate?(model_name)
+ %w[gemini-pro].include?(model_name)
+ end
+
+ def tokenizer
+ DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
+ end
end
- def translate(generic_prompt)
+ def translate
gemini_prompt = [
{
role: "user",
parts: {
- text: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
+ text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"),
},
},
{ role: "model", parts: { text: "Ok." } },
]
- if generic_prompt[:examples]
- generic_prompt[:examples].each do |example_pair|
+ if prompt[:examples]
+ prompt[:examples].each do |example_pair|
gemini_prompt << { role: "user", parts: { text: example_pair.first } }
gemini_prompt << { role: "model", parts: { text: example_pair.second } }
end
end
- gemini_prompt << { role: "user", parts: { text: generic_prompt[:input] } }
+ gemini_prompt.concat!(conversation_context) if prompt[:conversation_context]
+
+ gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
end
- def tokenizer
- DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
+ def tools
+ return if prompt[:tools].blank?
+
+ translated_tools =
+ prompt[:tools].map do |t|
+ required_fields = []
+ tool = t.dup
+
+ tool[:parameters] = t[:parameters].map do |p|
+ required_fields << p[:name] if p[:required]
+
+ p.except(:required)
+ end
+
+ tool.merge(required: required_fields)
+ end
+
+ [{ function_declarations: translated_tools }]
+ end
+
+ def conversation_context
+ return [] if prompt[:conversation_context].blank?
+
+ trimmed_context = trim_context(prompt[:conversation_context])
+
+ trimmed_context.reverse.map do |context|
+ translated = {}
+ translated[:role] = (context[:type] == "user" ? "user" : "model")
+
+ part = {}
+
+ if context[:type] == "tool"
+ part["functionResponse"] = { name: context[:name], content: context[:content] }
+ else
+ part[:text] = context[:content]
+ end
+
+ translated[:parts] = [part]
+
+ translated
+ end
+ end
+
+ def max_prompt_tokens
+ 16_384 # 50% of model tokens
+ end
+
+ protected
+
+ def calculate_message_token(context)
+ self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
end
end
diff --git a/lib/completions/dialects/llama2_classic.rb b/lib/completions/dialects/llama2_classic.rb
index 542e5f57..0470a61c 100644
--- a/lib/completions/dialects/llama2_classic.rb
+++ b/lib/completions/dialects/llama2_classic.rb
@@ -3,27 +3,72 @@
module DiscourseAi
module Completions
module Dialects
- class Llama2Classic
- def self.can_translate?(model_name)
- %w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
+ class Llama2Classic < Dialect
+ class << self
+ def can_translate?(model_name)
+ %w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
+ end
+
+ def tokenizer
+ DiscourseAi::Tokenizer::Llama2Tokenizer
+ end
end
- def translate(generic_prompt)
- llama2_prompt =
- +"[INST]<>#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}<>[/INST]\n"
+ def translate
+ llama2_prompt = +<<~TEXT
+ [INST]
+ <>
+ #{prompt[:insts]}
+ #{build_tools_prompt}#{prompt[:post_insts]}
+ <>
+ [/INST]
+ TEXT
- if generic_prompt[:examples]
- generic_prompt[:examples].each do |example_pair|
+ if prompt[:examples]
+ prompt[:examples].each do |example_pair|
llama2_prompt << "[INST]#{example_pair.first}[/INST]\n"
llama2_prompt << "#{example_pair.second}\n"
end
end
- llama2_prompt << "[INST]#{generic_prompt[:input]}[/INST]\n"
+ llama2_prompt << conversation_context if prompt[:conversation_context].present?
+
+ llama2_prompt << "[INST]#{prompt[:input]}[/INST]\n"
end
- def tokenizer
- DiscourseAi::Tokenizer::Llama2Tokenizer
+ def conversation_context
+ return "" if prompt[:conversation_context].blank?
+
+ trimmed_context = trim_context(prompt[:conversation_context])
+
+ trimmed_context
+ .reverse
+ .reduce(+"") do |memo, context|
+ if context[:type] == "tool"
+ memo << <<~TEXT
+ [INST]
+
+
+ #{context[:name]}
+
+ #{context[:content]}
+
+
+
+ [/INST]
+ TEXT
+ elsif context[:type] == "assistant"
+ memo << "[INST]" << context[:content] << "[/INST]\n"
+ else
+ memo << context[:content] << "\n"
+ end
+
+ memo
+ end
+ end
+
+ def max_prompt_tokens
+ SiteSetting.ai_hugging_face_token_limit
end
end
end
diff --git a/lib/completions/dialects/orca_style.rb b/lib/completions/dialects/orca_style.rb
index 3aa11609..fd76f3b5 100644
--- a/lib/completions/dialects/orca_style.rb
+++ b/lib/completions/dialects/orca_style.rb
@@ -3,29 +3,68 @@
module DiscourseAi
module Completions
module Dialects
- class OrcaStyle
- def self.can_translate?(model_name)
- %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name)
+ class OrcaStyle < Dialect
+ class << self
+ def can_translate?(model_name)
+ %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name)
+ end
+
+ def tokenizer
+ DiscourseAi::Tokenizer::Llama2Tokenizer
+ end
end
- def translate(generic_prompt)
- orca_style_prompt =
- +"### System:\n#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}\n"
+ def translate
+ orca_style_prompt = +<<~TEXT
+ ### System:
+ #{prompt[:insts]}
+ #{build_tools_prompt}#{prompt[:post_insts]}
+ TEXT
- if generic_prompt[:examples]
- generic_prompt[:examples].each do |example_pair|
+ if prompt[:examples]
+ prompt[:examples].each do |example_pair|
orca_style_prompt << "### User:\n#{example_pair.first}\n"
orca_style_prompt << "### Assistant:\n#{example_pair.second}\n"
end
end
- orca_style_prompt << "### User:\n#{generic_prompt[:input]}\n"
+ orca_style_prompt << "### User:\n#{prompt[:input]}\n"
orca_style_prompt << "### Assistant:\n"
end
- def tokenizer
- DiscourseAi::Tokenizer::Llama2Tokenizer
+ def conversation_context
+ return "" if prompt[:conversation_context].blank?
+
+ trimmed_context = trim_context(prompt[:conversation_context])
+
+ trimmed_context
+ .reverse
+ .reduce(+"") do |memo, context|
+ memo << (context[:type] == "user" ? "### User:" : "### Assistant:")
+
+ if context[:type] == "tool"
+ memo << <<~TEXT
+
+
+
+ #{context[:name]}
+
+ #{context[:content]}
+
+
+
+ TEXT
+ else
+ memo << " " << context[:content] << "\n"
+ end
+
+ memo
+ end
+ end
+
+ def max_prompt_tokens
+ SiteSetting.ai_hugging_face_token_limit
end
end
end
diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb
index 5216d4e7..3846d7e4 100644
--- a/lib/completions/endpoints/anthropic.rb
+++ b/lib/completions/endpoints/anthropic.rb
@@ -22,7 +22,7 @@ module DiscourseAi
@uri ||= URI("https://api.anthropic.com/v1/complete")
end
- def prepare_payload(prompt, model_params)
+ def prepare_payload(prompt, model_params, _dialect)
default_options
.merge(model_params)
.merge(prompt: prompt)
diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb
index 0e8cf68f..902a375c 100644
--- a/lib/completions/endpoints/aws_bedrock.rb
+++ b/lib/completions/endpoints/aws_bedrock.rb
@@ -37,7 +37,7 @@ module DiscourseAi
URI(api_url)
end
- def prepare_payload(prompt, model_params)
+ def prepare_payload(prompt, model_params, _dialect)
default_options.merge(prompt: prompt).merge(model_params)
end
diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index 61846e23..ef92a162 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -30,9 +30,11 @@ module DiscourseAi
@tokenizer = tokenizer
end
- def perform_completion!(prompt, user, model_params = {})
+ def perform_completion!(dialect, user, model_params = {})
@streaming_mode = block_given?
+ prompt = dialect.translate
+
Net::HTTP.start(
model_uri.host,
model_uri.port,
@@ -43,7 +45,10 @@ module DiscourseAi
) do |http|
response_data = +""
response_raw = +""
- request_body = prepare_payload(prompt, model_params).to_json
+
+ # Needed to response token calculations. Cannot rely on response_data due to function buffering.
+ partials_raw = +""
+ request_body = prepare_payload(prompt, model_params, dialect).to_json
request = prepare_request(request_body)
@@ -66,6 +71,15 @@ module DiscourseAi
if !@streaming_mode
response_raw = response.read_body
response_data = extract_completion_from(response_raw)
+ partials_raw = response_data.to_s
+
+ if has_tool?("", response_data)
+ function_buffer = build_buffer # Nokogiri document
+ function_buffer = add_to_buffer(function_buffer, "", response_data)
+
+ response_data = +function_buffer.at("function_calls").to_s
+ response_data << "\n"
+ end
return response_data
end
@@ -75,6 +89,7 @@ module DiscourseAi
cancel = lambda { cancelled = true }
leftover = ""
+ function_buffer = build_buffer # Nokogiri document
response.read_body do |chunk|
if cancelled
@@ -85,6 +100,12 @@ module DiscourseAi
decoded_chunk = decode(chunk)
response_raw << decoded_chunk
+ # Buffering for extremely slow streaming.
+ if (leftover + decoded_chunk).length < "data: [DONE]".length
+ leftover += decoded_chunk
+ next
+ end
+
partials_from(leftover + decoded_chunk).each do |raw_partial|
next if cancelled
next if raw_partial.blank?
@@ -93,11 +114,27 @@ module DiscourseAi
partial = extract_completion_from(raw_partial)
next if partial.nil?
leftover = ""
- response_data << partial
- yield partial, cancel if partial
+ if has_tool?(response_data, partial)
+ function_buffer = add_to_buffer(function_buffer, response_data, partial)
+
+ if buffering_finished?(dialect.tools, function_buffer)
+ invocation = +function_buffer.at("function_calls").to_s
+ invocation << "\n"
+
+ partials_raw << partial.to_s
+ response_data << invocation
+
+ yield invocation, cancel
+ end
+ else
+ partials_raw << partial
+ response_data << partial
+
+ yield partial, cancel if partial
+ end
rescue JSON::ParserError
- leftover = raw_partial
+ leftover += decoded_chunk
end
end
end
@@ -109,7 +146,7 @@ module DiscourseAi
ensure
if log
log.raw_response_payload = response_raw
- log.response_tokens = tokenizer.size(response_data)
+ log.response_tokens = tokenizer.size(partials_raw)
log.save!
if Rails.env.development?
@@ -165,6 +202,40 @@ module DiscourseAi
def extract_prompt_for_tokenizer(prompt)
prompt
end
+
+ def build_buffer
+ Nokogiri::HTML5.fragment(<<~TEXT)
+
+
+
+
+
+
+
+ TEXT
+ end
+
+ def has_tool?(response, partial)
+ (response + partial).include?("")
+ end
+
+ def add_to_buffer(function_buffer, response_data, partial)
+ new_buffer = Nokogiri::HTML5.fragment(response_data + partial)
+ if tool_name = new_buffer.at("tool_name").text
+ if new_buffer.at("tool_id").nil?
+ tool_id_node =
+ Nokogiri::HTML5::DocumentFragment.parse("\n#{tool_name}")
+
+ new_buffer.at("invoke").children[1].add_next_sibling(tool_id_node)
+ end
+ end
+
+ new_buffer
+ end
+
+ def buffering_finished?(_available_functions, buffer)
+ buffer.to_s.include?("")
+ end
end
end
end
diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb
index 83300fcc..56d2b913 100644
--- a/lib/completions/endpoints/canned_response.rb
+++ b/lib/completions/endpoints/canned_response.rb
@@ -31,9 +31,14 @@ module DiscourseAi
cancelled = false
cancel_fn = lambda { cancelled = true }
- response.each_char do |char|
- break if cancelled
- yield(char, cancel_fn)
+ # We buffer and return tool invocations in one go.
+ if is_tool?(response)
+ yield(response, cancel_fn)
+ else
+ response.each_char do |char|
+ break if cancelled
+ yield(char, cancel_fn)
+ end
end
else
response
@@ -43,6 +48,12 @@ module DiscourseAi
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
+
+ private
+
+ def is_tool?(response)
+ Nokogiri::HTML5.fragment(response).at("function_calls").present?
+ end
end
end
end
diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb
index 7738085c..8fa8a2d3 100644
--- a/lib/completions/endpoints/gemini.rb
+++ b/lib/completions/endpoints/gemini.rb
@@ -25,8 +25,11 @@ module DiscourseAi
URI(url)
end
- def prepare_payload(prompt, model_params)
- default_options.merge(model_params).merge(contents: prompt)
+ def prepare_payload(prompt, model_params, dialect)
+ default_options
+ .merge(model_params)
+ .merge(contents: prompt)
+ .tap { |payload| payload[:tools] = dialect.tools if dialect.tools.present? }
end
def prepare_request(payload)
@@ -36,25 +39,72 @@ module DiscourseAi
end
def extract_completion_from(response_raw)
- if @streaming_mode
- parsed = response_raw
- else
- parsed = JSON.parse(response_raw, symbolize_names: true)
- end
+ parsed = JSON.parse(response_raw, symbolize_names: true)
- completion = dig_text(parsed).to_s
+ response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
+
+ has_function_call = response_h.dig(:functionCall).present?
+ has_function_call ? response_h[:functionCall] : response_h.dig(:text)
end
def partials_from(decoded_chunk)
- JSON.parse(decoded_chunk, symbolize_names: true)
+ decoded_chunk
+ .split("\n")
+ .map do |line|
+ if line == ","
+ nil
+ elsif line.starts_with?("[")
+ line[1..-1]
+ elsif line.ends_with?("]")
+ line[0..-1]
+ else
+ line
+ end
+ end
+ .compact_blank
end
def extract_prompt_for_tokenizer(prompt)
prompt.to_s
end
- def dig_text(response)
- response.dig(:candidates, 0, :content, :parts, 0, :text)
+ def has_tool?(_response_data, partial)
+ partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
+ end
+
+ def add_to_buffer(function_buffer, _response_data, partial)
+ if partial[:name].present?
+ function_buffer.at("tool_name").content = partial[:name]
+ function_buffer.at("tool_id").content = partial[:name]
+ end
+
+ if partial[:args]
+ argument_fragments =
+ partial[:args].reduce(+"") do |memo, (arg_name, value)|
+ memo << "\n<#{arg_name}>#{value}#{arg_name}>"
+ end
+ argument_fragments << "\n"
+
+ function_buffer.at("parameters").children =
+ Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
+ end
+
+ function_buffer
+ end
+
+ def buffering_finished?(available_functions, buffer)
+ tool_name = buffer.at("tool_name")&.text
+ return false if tool_name.blank?
+
+ signature =
+ available_functions.dig(0, :function_declarations).find { |f| f[:name] == tool_name }
+
+ signature[:parameters].reduce(true) do |memo, param|
+ param_present = buffer.at(param[:name]).present?
+ next(memo) if param_present || !signature[:required].include?(param[:name])
+
+ memo && param_present
+ end
end
end
end
diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb
index d5d10cea..819cec47 100644
--- a/lib/completions/endpoints/hugging_face.rb
+++ b/lib/completions/endpoints/hugging_face.rb
@@ -11,7 +11,7 @@ module DiscourseAi
end
def default_options
- { parameters: { repetition_penalty: 1.1, temperature: 0.7 } }
+ { parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } }
end
def provider_id
@@ -24,7 +24,7 @@ module DiscourseAi
URI(SiteSetting.ai_hugging_face_api_url)
end
- def prepare_payload(prompt, model_params)
+ def prepare_payload(prompt, model_params, _dialect)
default_options
.merge(inputs: prompt)
.tap do |payload|
@@ -33,7 +33,6 @@ module DiscourseAi
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
- payload[:parameters][:return_full_text] = false
payload[:stream] = true if @streaming_mode
end
diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb
index 65b01314..b0664cd0 100644
--- a/lib/completions/endpoints/open_ai.rb
+++ b/lib/completions/endpoints/open_ai.rb
@@ -37,11 +37,14 @@ module DiscourseAi
URI(url)
end
- def prepare_payload(prompt, model_params)
+ def prepare_payload(prompt, model_params, dialect)
default_options
.merge(model_params)
.merge(messages: prompt)
- .tap { |payload| payload[:stream] = true if @streaming_mode }
+ .tap do |payload|
+ payload[:stream] = true if @streaming_mode
+ payload[:tools] = dialect.tools if dialect.tools.present?
+ end
end
def prepare_request(payload)
@@ -62,15 +65,12 @@ module DiscourseAi
end
def extract_completion_from(response_raw)
- parsed = JSON.parse(response_raw, symbolize_names: true)
+ parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
- (
- if @streaming_mode
- parsed.dig(:choices, 0, :delta, :content)
- else
- parsed.dig(:choices, 0, :message, :content)
- end
- ).to_s
+ response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
+
+ has_function_call = response_h.dig(:tool_calls).present?
+ has_function_call ? response_h.dig(:tool_calls, 0, :function) : response_h.dig(:content)
end
def partials_from(decoded_chunk)
@@ -86,6 +86,42 @@ module DiscourseAi
def extract_prompt_for_tokenizer(prompt)
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
+
+ def has_tool?(_response_data, partial)
+ partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
+ end
+
+ def add_to_buffer(function_buffer, _response_data, partial)
+ function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
+ function_buffer.at("tool_id").content = partial[:id] if partial[:id].present?
+
+ if partial[:arguments]
+ argument_fragments =
+ partial[:arguments].reduce(+"") do |memo, (arg_name, value)|
+ memo << "\n<#{arg_name}>#{value}#{arg_name}>"
+ end
+ argument_fragments << "\n"
+
+ function_buffer.at("parameters").children =
+ Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
+ end
+
+ function_buffer
+ end
+
+ def buffering_finished?(available_functions, buffer)
+ tool_name = buffer.at("tool_name")&.text
+ return false if tool_name.blank?
+
+ signature = available_functions.find { |f| f.dig(:tool, :name) == tool_name }[:tool]
+
+ signature[:parameters].reduce(true) do |memo, param|
+ param_present = buffer.at(param[:name]).present?
+ next(memo) if param_present && !param[:required]
+
+ memo && param_present
+ end
+ end
end
end
end
diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb
index f2cbc0e2..16e1c96b 100644
--- a/lib/completions/llm.rb
+++ b/lib/completions/llm.rb
@@ -24,35 +24,26 @@ module DiscourseAi
end
def self.proxy(model_name)
- dialects = [
- DiscourseAi::Completions::Dialects::Claude,
- DiscourseAi::Completions::Dialects::Llama2Classic,
- DiscourseAi::Completions::Dialects::ChatGpt,
- DiscourseAi::Completions::Dialects::OrcaStyle,
- DiscourseAi::Completions::Dialects::Gemini,
- ]
+ dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
- dialect =
- dialects.detect(-> { raise UNKNOWN_MODEL }) { |d| d.can_translate?(model_name) }.new
-
- return new(dialect, @canned_response, model_name) if @canned_response
+ return new(dialect_klass, @canned_response, model_name) if @canned_response
gateway =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new(
model_name,
- dialect.tokenizer,
+ dialect_klass.tokenizer,
)
- new(dialect, gateway, model_name)
+ new(dialect_klass, gateway, model_name)
end
- def initialize(dialect, gateway, model_name)
- @dialect = dialect
+ def initialize(dialect_klass, gateway, model_name)
+ @dialect_klass = dialect_klass
@gateway = gateway
@model_name = model_name
end
- delegate :tokenizer, to: :dialect
+ delegate :tokenizer, to: :dialect_klass
# @param generic_prompt { Hash } - Prompt using our generic format.
# We use the following keys from the hash:
@@ -60,23 +51,64 @@ module DiscourseAi
# - input: String containing user input
# - examples (optional): Array of arrays with examples of input and responses. Each array is a input/response pair like [[example1, response1], [example2, response2]].
# - post_insts (optional): Additional instructions for the LLM. Some dialects like Claude add these at the end of the prompt.
+ # - conversation_context (optional): Array of hashes to provide context about an ongoing conversation with the model.
+ # We translate the array in reverse order, meaning the first element would be the most recent message in the conversation.
+ # Example:
+ #
+ # [
+ # { type: "user", name: "user1", content: "This is a new message by a user" },
+ # { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
+ # { type: "tool", name: "tool_id", content: "I'm a tool result" },
+ # ]
+ #
+ # - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example:
+ #
+ # {
+ # name: "get_weather",
+ # description: "Get the weather in a city",
+ # parameters: [
+ # { name: "location", type: "string", description: "the city name", required: true },
+ # {
+ # name: "unit",
+ # type: "string",
+ # description: "the unit of measurement celcius c or fahrenheit f",
+ # enum: %w[c f],
+ # required: true,
+ # },
+ # ],
+ # }
#
# @param user { User } - User requesting the summary.
#
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
#
# @returns { String } - Completion result.
+ #
+ # When the model invokes a tool, we'll wait until the endpoint finishes replying and feed you a fully-formed tool,
+ # even if you passed a partial_read_blk block. Invocations are strings that look like this:
+ #
+ #
+ #
+ # get_weather
+ # get_weather
+ #
+ # Sydney
+ # c
+ #
+ #
+ #
+ #
def completion!(generic_prompt, user, &partial_read_blk)
- prompt = dialect.translate(generic_prompt)
-
model_params = generic_prompt.dig(:params, model_name) || {}
- gateway.perform_completion!(prompt, user, model_params, &partial_read_blk)
+ dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params)
+
+ gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
end
private
- attr_reader :dialect, :gateway, :model_name
+ attr_reader :dialect_klass, :gateway, :model_name
end
end
end
diff --git a/spec/lib/completions/dialects/chat_gpt_spec.rb b/spec/lib/completions/dialects/chat_gpt_spec.rb
index 27b56b49..2792dbf6 100644
--- a/spec/lib/completions/dialects/chat_gpt_spec.rb
+++ b/spec/lib/completions/dialects/chat_gpt_spec.rb
@@ -1,7 +1,24 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
- subject(:dialect) { described_class.new }
+ subject(:dialect) { described_class.new(prompt, "gpt-4") }
+
+ let(:tool) do
+ {
+ name: "get_weather",
+ description: "Get the weather in a city",
+ parameters: [
+ { name: "location", type: "string", description: "the city name", required: true },
+ {
+ name: "unit",
+ type: "string",
+ description: "the unit of measurement celcius c or fahrenheit f",
+ enum: %w[c f],
+ required: true,
+ },
+ ],
+ }
+ end
let(:prompt) do
{
@@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
TEXT
post_insts:
"Please put the translation between tags and separate each title with a comma.",
+ tools: [tool],
}
end
@@ -35,7 +53,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
{ role: "user", content: prompt[:input] },
]
- translated = dialect.translate(prompt)
+ translated = dialect.translate
expect(translated).to contain_exactly(*open_ai_version)
end
@@ -55,9 +73,51 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
{ role: "user", content: prompt[:input] },
]
- translated = dialect.translate(prompt)
+ translated = dialect.translate
expect(translated).to contain_exactly(*open_ai_version)
end
end
+
+ describe "#conversation_context" do
+ let(:context) do
+ [
+ { type: "user", name: "user1", content: "This is a new message by a user" },
+ { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
+ { type: "tool", name: "tool_id", content: "I'm a tool result" },
+ ]
+ end
+
+ it "adds conversation in reverse order (first == newer)" do
+ prompt[:conversation_context] = context
+
+ translated_context = dialect.conversation_context
+
+ expect(translated_context).to eq(
+ [
+ { role: "tool", content: context.last[:content], tool_call_id: context.last[:name] },
+ { role: "assistant", content: context.second[:content] },
+ { role: "user", content: context.first[:content], name: context.first[:name] },
+ ],
+ )
+ end
+
+ it "trims content if it's getting too long" do
+ context.last[:content] = context.last[:content] * 1000
+
+ prompt[:conversation_context] = context
+
+ translated_context = dialect.conversation_context
+
+ expect(translated_context.last[:content].length).to be < context.last[:content].length
+ end
+ end
+
+ describe "#tools" do
+ it "returns a list of available tools" do
+ open_ai_tool_f = { type: "function", tool: tool }
+
+ expect(subject.tools).to contain_exactly(open_ai_tool_f)
+ end
+ end
end
diff --git a/spec/lib/completions/dialects/claude_spec.rb b/spec/lib/completions/dialects/claude_spec.rb
index d26dd570..1107814b 100644
--- a/spec/lib/completions/dialects/claude_spec.rb
+++ b/spec/lib/completions/dialects/claude_spec.rb
@@ -1,7 +1,24 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
- subject(:dialect) { described_class.new }
+ subject(:dialect) { described_class.new(prompt, "claude-2") }
+
+ let(:tool) do
+ {
+ name: "get_weather",
+ description: "Get the weather in a city",
+ parameters: [
+ { name: "location", type: "string", description: "the city name", required: true },
+ {
+ name: "unit",
+ type: "string",
+ description: "the unit of measurement celcius c or fahrenheit f",
+ enum: %w[c f],
+ required: true,
+ },
+ ],
+ }
+ end
let(:prompt) do
{
@@ -37,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
Assistant:
TEXT
- translated = dialect.translate(prompt)
+ translated = dialect.translate
expect(translated).to eq(anthropic_version)
end
@@ -60,9 +77,111 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
Assistant:
TEXT
- translated = dialect.translate(prompt)
+ translated = dialect.translate
+
+ expect(translated).to eq(anthropic_version)
+ end
+
+ it "include tools inside the prompt" do
+ prompt[:tools] = [tool]
+
+ anthropic_version = <<~TEXT
+ Human: #{prompt[:insts]}
+ In this environment you have access to a set of tools you can use to answer the user's question.
+ You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
+
+
+ $TOOL_NAME
+
+ <$PARAMETER_NAME>$PARAMETER_VALUE$PARAMETER_NAME>
+ ...
+
+
+
+
+ Here are the tools available:
+
+
+ #{dialect.tools}
+ #{prompt[:input]}
+ #{prompt[:post_insts]}
+ Assistant:
+ TEXT
+
+ translated = dialect.translate
expect(translated).to eq(anthropic_version)
end
end
+
+ describe "#conversation_context" do
+ let(:context) do
+ [
+ { type: "user", name: "user1", content: "This is a new message by a user" },
+ { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
+ { type: "tool", name: "tool_id", content: "I'm a tool result" },
+ ]
+ end
+
+ it "adds conversation in reverse order (first == newer)" do
+ prompt[:conversation_context] = context
+
+ expected = <<~TEXT
+ Assistant:
+
+
+ tool_id
+
+ #{context.last[:content]}
+
+
+
+ Assistant: #{context.second[:content]}
+ Human: #{context.first[:content]}
+ TEXT
+
+ translated_context = dialect.conversation_context
+
+ expect(translated_context).to eq(expected)
+ end
+
+ it "trims content if it's getting too long" do
+ context.last[:content] = context.last[:content] * 10_000
+ prompt[:conversation_context] = context
+
+ translated_context = dialect.conversation_context
+
+ expect(translated_context.length).to be < context.last[:content].length
+ end
+ end
+
+ describe "#tools" do
+ it "translates tools to the tool syntax" do
+ prompt[:tools] = [tool]
+
+ translated_tool = <<~TEXT
+
+ get_weather
+ Get the weather in a city
+
+
+ location
+ string
+ the city name
+ true
+
+
+ unit
+ string
+ the unit of measurement celcius c or fahrenheit f
+ true
+ c,f
+
+
+
+ TEXT
+
+ expect(dialect.tools).to eq(translated_tool)
+ end
+ end
end
diff --git a/spec/lib/completions/dialects/gemini_spec.rb b/spec/lib/completions/dialects/gemini_spec.rb
index 84aec2d3..0a5266b7 100644
--- a/spec/lib/completions/dialects/gemini_spec.rb
+++ b/spec/lib/completions/dialects/gemini_spec.rb
@@ -1,7 +1,24 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
- subject(:dialect) { described_class.new }
+ subject(:dialect) { described_class.new(prompt, "gemini-pro") }
+
+ let(:tool) do
+ {
+ name: "get_weather",
+ description: "Get the weather in a city",
+ parameters: [
+ { name: "location", type: "string", description: "the city name", required: true },
+ {
+ name: "unit",
+ type: "string",
+ description: "the unit of measurement celcius c or fahrenheit f",
+ enum: %w[c f],
+ required: true,
+ },
+ ],
+ }
+ end
let(:prompt) do
{
@@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
TEXT
post_insts:
"Please put the translation between tags and separate each title with a comma.",
+ tools: [tool],
}
end
@@ -36,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
{ role: "user", parts: { text: prompt[:input] } },
]
- translated = dialect.translate(prompt)
+ translated = dialect.translate
expect(translated).to eq(gemini_version)
end
@@ -57,9 +75,79 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
{ role: "user", parts: { text: prompt[:input] } },
]
- translated = dialect.translate(prompt)
+ translated = dialect.translate
expect(translated).to contain_exactly(*gemini_version)
end
end
+
+ describe "#conversation_context" do
+ let(:context) do
+ [
+ { type: "user", name: "user1", content: "This is a new message by a user" },
+ { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
+ { type: "tool", name: "tool_id", content: "I'm a tool result" },
+ ]
+ end
+
+ it "adds conversation in reverse order (first == newer)" do
+ prompt[:conversation_context] = context
+
+ translated_context = dialect.conversation_context
+
+ expect(translated_context).to eq(
+ [
+ {
+ role: "model",
+ parts: [
+ {
+ "functionResponse" => {
+ name: context.last[:name],
+ content: context.last[:content],
+ },
+ },
+ ],
+ },
+ { role: "model", parts: [{ text: context.second[:content] }] },
+ { role: "user", parts: [{ text: context.first[:content] }] },
+ ],
+ )
+ end
+
+ it "trims content if it's getting too long" do
+ context.last[:content] = context.last[:content] * 1000
+
+ prompt[:conversation_context] = context
+
+ translated_context = dialect.conversation_context
+
+ expect(translated_context.last.dig(:parts, 0, :text).length).to be <
+ context.last[:content].length
+ end
+ end
+
+ describe "#tools" do
+ it "returns a list of available tools" do
+ gemini_tools = {
+ function_declarations: [
+ {
+ name: "get_weather",
+ description: "Get the weather in a city",
+ parameters: [
+ { name: "location", type: "string", description: "the city name" },
+ {
+ name: "unit",
+ type: "string",
+ description: "the unit of measurement celcius c or fahrenheit f",
+ enum: %w[c f],
+ },
+ ],
+ required: %w[location unit],
+ },
+ ],
+ }
+
+ expect(subject.tools).to contain_exactly(gemini_tools)
+ end
+ end
end
diff --git a/spec/lib/completions/dialects/llama2_classic_spec.rb b/spec/lib/completions/dialects/llama2_classic_spec.rb
index 2b1d93a2..81d088e6 100644
--- a/spec/lib/completions/dialects/llama2_classic_spec.rb
+++ b/spec/lib/completions/dialects/llama2_classic_spec.rb
@@ -1,7 +1,24 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
- subject(:dialect) { described_class.new }
+ subject(:dialect) { described_class.new(prompt, "Llama2-chat-hf") }
+
+ let(:tool) do
+ {
+ name: "get_weather",
+ description: "Get the weather in a city",
+ parameters: [
+ { name: "location", type: "string", description: "the city name", required: true },
+ {
+ name: "unit",
+ type: "string",
+ description: "the unit of measurement celcius c or fahrenheit f",
+ enum: %w[c f],
+ required: true,
+ },
+ ],
+ }
+ end
let(:prompt) do
{
@@ -31,11 +48,16 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
describe "#translate" do
it "translates a prompt written in our generic format to the Llama2 format" do
llama2_classic_version = <<~TEXT
- [INST]<>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<>[/INST]
+ [INST]
+ <>
+ #{prompt[:insts]}
+ #{prompt[:post_insts]}
+ <>
+ [/INST]
[INST]#{prompt[:input]}[/INST]
TEXT
- translated = dialect.translate(prompt)
+ translated = dialect.translate
expect(translated).to eq(llama2_classic_version)
end
@@ -49,15 +71,126 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
]
llama2_classic_version = <<~TEXT
- [INST]<>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<>[/INST]
+ [INST]
+ <>
+ #{prompt[:insts]}
+ #{prompt[:post_insts]}
+ <>
+ [/INST]
[INST]#{prompt[:examples][0][0]}[/INST]
#{prompt[:examples][0][1]}
[INST]#{prompt[:input]}[/INST]
TEXT
- translated = dialect.translate(prompt)
+ translated = dialect.translate
+
+ expect(translated).to eq(llama2_classic_version)
+ end
+
+ it "include tools inside the prompt" do
+ prompt[:tools] = [tool]
+
+ llama2_classic_version = <<~TEXT
+ [INST]
+ <>
+ #{prompt[:insts]}
+ In this environment you have access to a set of tools you can use to answer the user's question.
+ You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
+
+
+ $TOOL_NAME
+
+ <$PARAMETER_NAME>$PARAMETER_VALUE$PARAMETER_NAME>
+ ...
+
+
+
+
+ Here are the tools available:
+
+
+ #{dialect.tools}
+ #{prompt[:post_insts]}
+ <>
+ [/INST]
+ [INST]#{prompt[:input]}[/INST]
+ TEXT
+
+ translated = dialect.translate
expect(translated).to eq(llama2_classic_version)
end
end
+
+ describe "#conversation_context" do
+ let(:context) do
+ [
+ { type: "user", name: "user1", content: "This is a new message by a user" },
+ { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
+ { type: "tool", name: "tool_id", content: "I'm a tool result" },
+ ]
+ end
+
+ it "adds conversation in reverse order (first == newer)" do
+ prompt[:conversation_context] = context
+
+ expected = <<~TEXT
+ [INST]
+
+
+ tool_id
+
+ #{context.last[:content]}
+
+
+
+ [/INST]
+ [INST]#{context.second[:content]}[/INST]
+ #{context.first[:content]}
+ TEXT
+
+ translated_context = dialect.conversation_context
+
+ expect(translated_context).to eq(expected)
+ end
+
+ it "trims content if it's getting too long" do
+ context.last[:content] = context.last[:content] * 1_000
+ prompt[:conversation_context] = context
+
+ translated_context = dialect.conversation_context
+
+ expect(translated_context.length).to be < context.last[:content].length
+ end
+ end
+
+ describe "#tools" do
+ it "translates functions to the tool syntax" do
+ prompt[:tools] = [tool]
+
+ translated_tool = <<~TEXT
+
+ get_weather
+ Get the weather in a city
+
+
+ location
+ string
+ the city name
+ true
+
+
+ unit
+ string
+ the unit of measurement celcius c or fahrenheit f
+ true
+ c,f
+
+
+
+ TEXT
+
+ expect(dialect.tools).to eq(translated_tool)
+ end
+ end
end
diff --git a/spec/lib/completions/dialects/orca_style_spec.rb b/spec/lib/completions/dialects/orca_style_spec.rb
index 411a84a8..d27dc9d3 100644
--- a/spec/lib/completions/dialects/orca_style_spec.rb
+++ b/spec/lib/completions/dialects/orca_style_spec.rb
@@ -1,44 +1,62 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
- subject(:dialect) { described_class.new }
+ subject(:dialect) { described_class.new(prompt, "StableBeluga2") }
+
+ let(:tool) do
+ {
+ name: "get_weather",
+ description: "Get the weather in a city",
+ parameters: [
+ { name: "location", type: "string", description: "the city name", required: true },
+ {
+ name: "unit",
+ type: "string",
+ description: "the unit of measurement celcius c or fahrenheit f",
+ enum: %w[c f],
+ required: true,
+ },
+ ],
+ }
+ end
+
+ let(:prompt) do
+ {
+ insts: <<~TEXT,
+ I want you to act as a title generator for written pieces. I will provide you with a text,
+ and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,
+ and ensure that the meaning is maintained. Replies will utilize the language type of the topic.
+ TEXT
+ input: <<~TEXT,
+ Here is the text, inside XML tags:
+
+ To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
+ discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer
+ defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry.
+
+ Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires,
+ a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and
+ slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he
+ dies so that a scene may be repeated.
+
+ TEXT
+ post_insts:
+ "Please put the translation between tags and separate each title with a comma.",
+ }
+ end
describe "#translate" do
- let(:prompt) do
- {
- insts: <<~TEXT,
- I want you to act as a title generator for written pieces. I will provide you with a text,
- and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,
- and ensure that the meaning is maintained. Replies will utilize the language type of the topic.
- TEXT
- input: <<~TEXT,
- Here is the text, inside XML tags:
-
- To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
- discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer
- defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry.
-
- Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires,
- a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and
- slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he
- dies so that a scene may be repeated.
-
- TEXT
- post_insts:
- "Please put the translation between tags and separate each title with a comma.",
- }
- end
-
it "translates a prompt written in our generic format to the Open AI format" do
orca_style_version = <<~TEXT
### System:
- #{[prompt[:insts], prompt[:post_insts]].join("\n")}
+ #{prompt[:insts]}
+ #{prompt[:post_insts]}
### User:
#{prompt[:input]}
### Assistant:
TEXT
- translated = dialect.translate(prompt)
+ translated = dialect.translate
expect(translated).to eq(orca_style_version)
end
@@ -53,7 +71,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
orca_style_version = <<~TEXT
### System:
- #{[prompt[:insts], prompt[:post_insts]].join("\n")}
+ #{prompt[:insts]}
+ #{prompt[:post_insts]}
### User:
#{prompt[:examples][0][0]}
### Assistant:
@@ -63,9 +82,113 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
### Assistant:
TEXT
- translated = dialect.translate(prompt)
+ translated = dialect.translate
+
+ expect(translated).to eq(orca_style_version)
+ end
+
+ it "include tools inside the prompt" do
+ prompt[:tools] = [tool]
+
+ orca_style_version = <<~TEXT
+ ### System:
+ #{prompt[:insts]}
+ In this environment you have access to a set of tools you can use to answer the user's question.
+ You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
+
+
+ $TOOL_NAME
+
+ <$PARAMETER_NAME>$PARAMETER_VALUE$PARAMETER_NAME>
+ ...
+
+
+
+
+ Here are the tools available:
+
+
+ #{dialect.tools}
+ #{prompt[:post_insts]}
+ ### User:
+ #{prompt[:input]}
+ ### Assistant:
+ TEXT
+
+ translated = dialect.translate
expect(translated).to eq(orca_style_version)
end
end
+
+ describe "#conversation_context" do
+ let(:context) do
+ [
+ { type: "user", name: "user1", content: "This is a new message by a user" },
+ { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
+ { type: "tool", name: "tool_id", content: "I'm a tool result" },
+ ]
+ end
+
+ it "adds conversation in reverse order (first == newer)" do
+ prompt[:conversation_context] = context
+
+ expected = <<~TEXT
+ ### Assistant:
+
+
+ tool_id
+
+ #{context.last[:content]}
+
+
+
+ ### Assistant: #{context.second[:content]}
+ ### User: #{context.first[:content]}
+ TEXT
+
+ translated_context = dialect.conversation_context
+
+ expect(translated_context).to eq(expected)
+ end
+
+ it "trims content if it's getting too long" do
+ context.last[:content] = context.last[:content] * 1_000
+ prompt[:conversation_context] = context
+
+ translated_context = dialect.conversation_context
+
+ expect(translated_context.length).to be < context.last[:content].length
+ end
+ end
+
+ describe "#tools" do
+ it "translates tools to the tool syntax" do
+ prompt[:tools] = [tool]
+
+ translated_tool = <<~TEXT
+
+ get_weather
+ Get the weather in a city
+
+
+ location
+ string
+ the city name
+ true
+
+
+ unit
+ string
+ the unit of measurement celcius c or fahrenheit f
+ true
+ c,f
+
+
+
+ TEXT
+
+ expect(dialect.tools).to eq(translated_tool)
+ end
+ end
end
diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb
index d0309e2f..0a57ad29 100644
--- a/spec/lib/completions/endpoints/anthropic_spec.rb
+++ b/spec/lib/completions/endpoints/anthropic_spec.rb
@@ -6,7 +6,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
let(:model_name) { "claude-2" }
- let(:prompt) { "Human: write 3 words\n\n" }
+ let(:generic_prompt) { { insts: "write 3 words" } }
+ let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
+ let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
@@ -23,10 +25,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
}
end
- def stub_response(prompt, response_text)
+ def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete")
- .with(body: model.default_options.merge(prompt: prompt).to_json)
+ .with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
@@ -42,7 +44,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
}.to_json
end
- def stub_streamed_response(prompt, deltas)
+ def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
@@ -52,13 +54,27 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end
end
- chunks = chunks.join("\n\n")
+ chunks = chunks.join("\n\n").split("")
WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete")
- .with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
+ .with(body: stream_request_body)
.to_return(status: 200, body: chunks)
end
+ let(:tool_deltas) { ["
+
+ get_weather
+
+ Sydney
+ c
+
+
+
+ REPLY
+
+ let(:tool_call) { invocation }
+
it_behaves_like "an endpoint that can communicate with a completion service"
end
diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb
index 5c0cb8cc..2c866898 100644
--- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb
+++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb
@@ -9,10 +9,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
let(:model_name) { "claude-2" }
let(:bedrock_name) { "claude-v2" }
- let(:prompt) { "Human: write 3 words\n\n" }
+ let(:generic_prompt) { { insts: "write 3 words" } }
+ let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
+ let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
- let(:stream_request_body) { model.default_options.merge(prompt: prompt).to_json }
+ let(:stream_request_body) { request_body }
before do
SiteSetting.ai_bedrock_access_key_id = "123456"
@@ -20,39 +22,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
SiteSetting.ai_bedrock_region = "us-east-1"
end
- # Copied from https://github.com/bblimke/webmock/issues/629
- # Workaround for stubbing a streamed response
- before do
- mocked_http =
- Class.new(Net::HTTP) do
- def request(*)
- super do |response|
- response.instance_eval do
- def read_body(*, &block)
- if block_given?
- @body.each(&block)
- else
- super
- end
- end
- end
-
- yield response if block_given?
-
- response
- end
- end
- end
-
- @original_net_http = Net.send(:remove_const, :HTTP)
- Net.send(:const_set, :HTTP, mocked_http)
- end
-
- after do
- Net.send(:remove_const, :HTTP)
- Net.send(:const_set, :HTTP, @original_net_http)
- end
-
def response(content)
{
completion: content,
@@ -65,7 +34,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
}
end
- def stub_response(prompt, response_text)
+ def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(
:post,
@@ -102,7 +71,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
encoder.encode(message)
end
- def stub_streamed_response(prompt, deltas)
+ def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
@@ -121,5 +90,19 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
.to_return(status: 200, body: chunks)
end
+ let(:tool_deltas) { ["
+
+ get_weather
+
+ Sydney
+ c
+
+
+
+ REPLY
+
+ let(:tool_call) { invocation }
+
it_behaves_like "an endpoint that can communicate with a completion service"
end
diff --git a/spec/lib/completions/endpoints/endpoint_examples.rb b/spec/lib/completions/endpoints/endpoint_examples.rb
index 6ca86070..0de7ef73 100644
--- a/spec/lib/completions/endpoints/endpoint_examples.rb
+++ b/spec/lib/completions/endpoints/endpoint_examples.rb
@@ -1,70 +1,177 @@
# frozen_string_literal: true
RSpec.shared_examples "an endpoint that can communicate with a completion service" do
+ # Copied from https://github.com/bblimke/webmock/issues/629
+ # Workaround for stubbing a streamed response
+ before do
+ mocked_http =
+ Class.new(Net::HTTP) do
+ def request(*)
+ super do |response|
+ response.instance_eval do
+ def read_body(*, &block)
+ if block_given?
+ @body.each(&block)
+ else
+ super
+ end
+ end
+ end
+
+ yield response if block_given?
+
+ response
+ end
+ end
+ end
+
+ @original_net_http = Net.send(:remove_const, :HTTP)
+ Net.send(:const_set, :HTTP, mocked_http)
+ end
+
+ after do
+ Net.send(:remove_const, :HTTP)
+ Net.send(:const_set, :HTTP, @original_net_http)
+ end
+
describe "#perform_completion!" do
fab!(:user) { Fabricate(:user) }
- let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
+ let(:tool) do
+ {
+ name: "get_weather",
+ description: "Get the weather in a city",
+ parameters: [
+ { name: "location", type: "string", description: "the city name", required: true },
+ {
+ name: "unit",
+ type: "string",
+ description: "the unit of measurement celcius c or fahrenheit f",
+ enum: %w[c f],
+ required: true,
+ },
+ ],
+ }
+ end
+
+ let(:invocation) { <<~TEXT }
+
+
+ get_weather
+ get_weather
+
+ Sydney
+ c
+
+
+
+ TEXT
context "when using regular mode" do
- before { stub_response(prompt, response_text) }
+ context "with simple prompts" do
+ let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
- it "can complete a trivial prompt" do
- completion_response = model.perform_completion!(prompt, user)
+ before { stub_response(prompt, response_text) }
- expect(completion_response).to eq(response_text)
+ it "can complete a trivial prompt" do
+ completion_response = model.perform_completion!(dialect, user)
+
+ expect(completion_response).to eq(response_text)
+ end
+
+ it "creates an audit log for the request" do
+ model.perform_completion!(dialect, user)
+
+ expect(AiApiAuditLog.count).to eq(1)
+ log = AiApiAuditLog.first
+
+ response_body = response(response_text).to_json
+
+ expect(log.provider_id).to eq(model.provider_id)
+ expect(log.user_id).to eq(user.id)
+ expect(log.raw_request_payload).to eq(request_body)
+ expect(log.raw_response_payload).to eq(response_body)
+ expect(log.request_tokens).to eq(model.prompt_size(prompt))
+ expect(log.response_tokens).to eq(model.tokenizer.size(response_text))
+ end
end
- it "creates an audit log for the request" do
- model.perform_completion!(prompt, user)
+ context "with functions" do
+ let(:generic_prompt) do
+ {
+ insts: "You can tell me the weather",
+ input: "Return the weather in Sydney",
+ tools: [tool],
+ }
+ end
- expect(AiApiAuditLog.count).to eq(1)
- log = AiApiAuditLog.first
+ before { stub_response(prompt, tool_call, tool_call: true) }
- response_body = response(response_text).to_json
+ it "returns a function invocation" do
+ completion_response = model.perform_completion!(dialect, user)
- expect(log.provider_id).to eq(model.provider_id)
- expect(log.user_id).to eq(user.id)
- expect(log.raw_request_payload).to eq(request_body)
- expect(log.raw_response_payload).to eq(response_body)
- expect(log.request_tokens).to eq(model.prompt_size(prompt))
- expect(log.response_tokens).to eq(model.tokenizer.size(response_text))
+ expect(completion_response).to eq(invocation)
+ end
end
end
context "when using stream mode" do
- let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] }
+ context "with simple prompts" do
+ let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] }
- before { stub_streamed_response(prompt, deltas) }
+ before { stub_streamed_response(prompt, deltas) }
- it "can complete a trivial prompt" do
- completion_response = +""
+ it "can complete a trivial prompt" do
+ completion_response = +""
- model.perform_completion!(prompt, user) do |partial, cancel|
- completion_response << partial
- cancel.call if completion_response.split(" ").length == 2
+ model.perform_completion!(dialect, user) do |partial, cancel|
+ completion_response << partial
+ cancel.call if completion_response.split(" ").length == 2
+ end
+
+ expect(completion_response).to eq(deltas[0...-1].join)
end
- expect(completion_response).to eq(deltas[0...-1].join)
+ it "creates an audit log and updates is on each read." do
+ completion_response = +""
+
+ model.perform_completion!(dialect, user) do |partial, cancel|
+ completion_response << partial
+ cancel.call if completion_response.split(" ").length == 2
+ end
+
+ expect(AiApiAuditLog.count).to eq(1)
+ log = AiApiAuditLog.first
+
+ expect(log.provider_id).to eq(model.provider_id)
+ expect(log.user_id).to eq(user.id)
+ expect(log.raw_request_payload).to eq(stream_request_body)
+ expect(log.raw_response_payload).to be_present
+ expect(log.request_tokens).to eq(model.prompt_size(prompt))
+ expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join))
+ end
end
- it "creates an audit log and updates is on each read." do
- completion_response = +""
-
- model.perform_completion!(prompt, user) do |partial, cancel|
- completion_response << partial
- cancel.call if completion_response.split(" ").length == 2
+ context "with functions" do
+ let(:generic_prompt) do
+ {
+ insts: "You can tell me the weather",
+ input: "Return the weather in Sydney",
+ tools: [tool],
+ }
end
- expect(AiApiAuditLog.count).to eq(1)
- log = AiApiAuditLog.first
+ before { stub_streamed_response(prompt, tool_deltas, tool_call: true) }
- expect(log.provider_id).to eq(model.provider_id)
- expect(log.user_id).to eq(user.id)
- expect(log.raw_request_payload).to eq(stream_request_body)
- expect(log.raw_response_payload).to be_present
- expect(log.request_tokens).to eq(model.prompt_size(prompt))
- expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join))
+ it "waits for the invocation to finish before calling the partial" do
+ buffered_partial = ""
+
+ model.perform_completion!(dialect, user) do |partial, cancel|
+ buffered_partial = partial if partial.include?("")
+ end
+
+ expect(buffered_partial).to eq(invocation)
+ end
end
end
end
diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb
index a9431dde..df855e73 100644
--- a/spec/lib/completions/endpoints/gemini_spec.rb
+++ b/spec/lib/completions/endpoints/gemini_spec.rb
@@ -6,22 +6,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
let(:model_name) { "gemini-pro" }
- let(:prompt) do
+ let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
+ let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
+ let(:prompt) { dialect.translate }
+
+ let(:tool_payload) do
+ {
+ name: "get_weather",
+ description: "Get the weather in a city",
+ parameters: [
+ { name: "location", type: "string", description: "the city name" },
+ {
+ name: "unit",
+ type: "string",
+ description: "the unit of measurement celcius c or fahrenheit f",
+ enum: %w[c f],
+ },
+ ],
+ required: %w[location unit],
+ }
+ end
+
+ let(:request_body) do
+ model
+ .default_options
+ .merge(contents: prompt)
+ .tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] }
+ .to_json
+ end
+ let(:stream_request_body) do
+ model
+ .default_options
+ .merge(contents: prompt)
+ .tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] }
+ .to_json
+ end
+
+ let(:tool_deltas) do
[
- { role: "system", content: "You are a helpful bot." },
- { role: "user", content: "Write 3 words" },
+ { "functionCall" => { name: "get_weather", args: {} } },
+ { "functionCall" => { name: "get_weather", args: { location: "" } } },
+ { "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
]
end
- let(:request_body) { model.default_options.merge(contents: prompt).to_json }
- let(:stream_request_body) { model.default_options.merge(contents: prompt).to_json }
+ let(:tool_call) do
+ { "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
+ end
- def response(content)
+ def response(content, tool_call: false)
{
candidates: [
{
content: {
- parts: [{ text: content }],
+ parts: [(tool_call ? content : { text: content })],
role: "model",
},
finishReason: "STOP",
@@ -45,22 +83,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
}
end
- def stub_response(prompt, response_text)
+ def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(
:post,
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
)
- .with(body: { contents: prompt })
- .to_return(status: 200, body: JSON.dump(response(response_text)))
+ .with(body: request_body)
+ .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
end
- def stream_line(delta, finish_reason: nil)
+ def stream_line(delta, finish_reason: nil, tool_call: false)
{
candidates: [
{
content: {
- parts: [{ text: delta }],
+ parts: [(tool_call ? delta : { text: delta })],
role: "model",
},
finishReason: finish_reason,
@@ -76,24 +114,24 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
}.to_json
end
- def stub_streamed_response(prompt, deltas)
+ def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
- stream_line(deltas[index], finish_reason: "STOP")
+ stream_line(deltas[index], finish_reason: "STOP", tool_call: tool_call)
else
- stream_line(deltas[index])
+ stream_line(deltas[index], tool_call: tool_call)
end
end
- chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]")
+ chunks = chunks.join("\n,\n").prepend("[").concat("\n]").split("")
WebMock
.stub_request(
:post,
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
)
- .with(body: model.default_options.merge(contents: prompt).to_json)
+ .with(body: stream_request_body)
.to_return(status: 200, body: chunks)
end
diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb
index ba29893a..087ca1fc 100644
--- a/spec/lib/completions/endpoints/hugging_face_spec.rb
+++ b/spec/lib/completions/endpoints/hugging_face_spec.rb
@@ -6,10 +6,11 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) }
let(:model_name) { "Llama2-*-chat-hf" }
- let(:prompt) { <<~TEXT }
- [INST]<>You are a helpful bot.<>[/INST]
- [INST]Write 3 words[/INST]
- TEXT
+ let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
+ let(:dialect) do
+ DiscourseAi::Completions::Dialects::Llama2Classic.new(generic_prompt, model_name)
+ end
+ let(:prompt) { dialect.translate }
let(:request_body) do
model
@@ -18,7 +19,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
.tap do |payload|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
- payload[:parameters][:return_full_text] = false
end
.to_json
end
@@ -30,7 +30,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
payload[:stream] = true
- payload[:parameters][:return_full_text] = false
end
.to_json
end
@@ -41,14 +40,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
[{ generated_text: content }]
end
- def stub_response(prompt, response_text)
+ def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
- def stream_line(delta, finish_reason: nil)
+ def stream_line(delta, deltas, finish_reason: nil)
+"data: " << {
token: {
id: 29_889,
@@ -56,22 +55,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
logprob: -0.08319092,
special: !!finish_reason,
},
- generated_text: finish_reason ? response_text : nil,
+ generated_text: finish_reason ? deltas.join : nil,
details: nil,
}.to_json
end
- def stub_streamed_response(prompt, deltas)
+ def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
- stream_line(deltas[index], finish_reason: true)
+ stream_line(deltas[index], deltas, finish_reason: true)
else
- stream_line(deltas[index])
+ stream_line(deltas[index], deltas)
end
end
- chunks = chunks.join("\n\n")
+ chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
@@ -79,5 +78,29 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
.to_return(status: 200, body: chunks)
end
+ let(:tool_deltas) { ["
+
+ get_weather
+
+ Sydney
+ c
+
+
+
+ REPLY
+
+
+ get_weather
+
+ Sydney
+ c
+
+
+
+ REPLY
+
+ let(:tool_call) { invocation }
+
it_behaves_like "an endpoint that can communicate with a completion service"
end
diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb
index fa22d461..929c087e 100644
--- a/spec/lib/completions/endpoints/open_ai_spec.rb
+++ b/spec/lib/completions/endpoints/open_ai_spec.rb
@@ -6,17 +6,53 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
let(:model_name) { "gpt-3.5-turbo" }
- let(:prompt) do
+ let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
+ let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
+ let(:prompt) { dialect.translate }
+
+ let(:tool_deltas) do
[
- { role: "system", content: "You are a helpful bot." },
- { role: "user", content: "Write 3 words" },
+ { id: "get_weather", name: "get_weather", arguments: {} },
+ { id: "get_weather", name: "get_weather", arguments: { location: "" } },
+ { id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } },
]
end
- let(:request_body) { model.default_options.merge(messages: prompt).to_json }
- let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
+ let(:tool_call) do
+ { id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } }
+ end
+
+ let(:request_body) do
+ model
+ .default_options
+ .merge(messages: prompt)
+ .tap do |b|
+ b[:tools] = generic_prompt[:tools].map do |t|
+ { type: "function", tool: t }
+ end if generic_prompt[:tools]
+ end
+ .to_json
+ end
+ let(:stream_request_body) do
+ model
+ .default_options
+ .merge(messages: prompt, stream: true)
+ .tap do |b|
+ b[:tools] = generic_prompt[:tools].map do |t|
+ { type: "function", tool: t }
+ end if generic_prompt[:tools]
+ end
+ .to_json
+ end
+
+ def response(content, tool_call: false)
+ message_content =
+ if tool_call
+ { tool_calls: [{ function: content }] }
+ else
+ { content: content }
+ end
- def response(content)
{
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
object: "chat.completion",
@@ -28,45 +64,52 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
total_tokens: 499,
},
choices: [
- { message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 },
+ { message: { role: "assistant" }.merge(message_content), finish_reason: "stop", index: 0 },
],
}
end
- def stub_response(prompt, response_text)
+ def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
- .with(body: { model: model_name, messages: prompt })
- .to_return(status: 200, body: JSON.dump(response(response_text)))
+ .with(body: request_body)
+ .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
end
- def stream_line(delta, finish_reason: nil)
+ def stream_line(delta, finish_reason: nil, tool_call: false)
+ message_content =
+ if tool_call
+ { tool_calls: [{ function: delta }] }
+ else
+ { content: delta }
+ end
+
+"data: " << {
id: "chatcmpl-#{SecureRandom.hex}",
object: "chat.completion.chunk",
created: 1_681_283_881,
model: "gpt-3.5-turbo-0301",
- choices: [{ delta: { content: delta } }],
+ choices: [{ delta: message_content }],
finish_reason: finish_reason,
index: 0,
}.to_json
end
- def stub_streamed_response(prompt, deltas)
+ def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
- stream_line(deltas[index], finish_reason: "stop_sequence")
+ stream_line(deltas[index], finish_reason: "stop_sequence", tool_call: tool_call)
else
- stream_line(deltas[index])
+ stream_line(deltas[index], tool_call: tool_call)
end
end
- chunks = chunks.join("\n\n")
+ chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
- .with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
+ .with(body: stream_request_body)
.to_return(status: 200, body: chunks)
end
diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb
index 66f53060..3a67ff22 100644
--- a/spec/lib/completions/llm_spec.rb
+++ b/spec/lib/completions/llm_spec.rb
@@ -3,7 +3,7 @@
RSpec.describe DiscourseAi::Completions::Llm do
subject(:llm) do
described_class.new(
- DiscourseAi::Completions::Dialects::OrcaStyle.new,
+ DiscourseAi::Completions::Dialects::OrcaStyle,
canned_response,
"Upstage-Llama-2-*-instruct-v2",
)
diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb
index 78939334..8b115a27 100644
--- a/spec/requests/ai_helper/assistant_controller_spec.rb
+++ b/spec/requests/ai_helper/assistant_controller_spec.rb
@@ -103,7 +103,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
expect(response.status).to eq(200)
expect(response.parsed_body["suggestions"].first).to eq(translated_text)
expect(response.parsed_body["diff"]).to eq(expected_diff)
- expect(spy.prompt.last[:content]).to eq(expected_input)
+ expect(spy.prompt.translate.last[:content]).to eq(expected_input)
end
end
end