DEV: Tool support for the LLM service. (#366)

This PR adds tool support to available LLMs. We'll buffer tool invocations and return them instead of making users of this service parse the response.

It also adds support for conversation context in the generic prompt. It includes bot messages, user messages, and tool invocations, which we'll trim to make sure it doesn't exceed the prompt limit, then translate them to the correct dialect.

Finally, It adds some buffering when reading chunks to handle cases when streaming is extremely slow.:M
This commit is contained in:
Roman Rizzi 2023-12-18 18:06:01 -03:00 committed by GitHub
parent 203906be65
commit e0bf6adb5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1625 additions and 284 deletions

View File

@ -3,33 +3,101 @@
module DiscourseAi module DiscourseAi
module Completions module Completions
module Dialects module Dialects
class ChatGpt class ChatGpt < Dialect
def self.can_translate?(model_name) class << self
def can_translate?(model_name)
%w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name) %w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
end end
def translate(generic_prompt)
open_ai_prompt = [
{
role: "system",
content: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
},
]
if generic_prompt[:examples]
generic_prompt[:examples].each do |example_pair|
open_ai_prompt << { role: "user", content: example_pair.first }
open_ai_prompt << { role: "assistant", content: example_pair.second }
end
end
open_ai_prompt << { role: "user", content: generic_prompt[:input] }
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer DiscourseAi::Tokenizer::OpenAiTokenizer
end end
end end
def translate
open_ai_prompt = [
{ role: "system", content: [prompt[:insts], prompt[:post_insts].to_s].join("\n") },
]
if prompt[:examples]
prompt[:examples].each do |example_pair|
open_ai_prompt << { role: "user", content: example_pair.first }
open_ai_prompt << { role: "assistant", content: example_pair.second }
end
end
open_ai_prompt.concat!(conversation_context) if prompt[:conversation_context]
open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input]
open_ai_prompt
end
def tools
return if prompt[:tools].blank?
prompt[:tools].map { |t| { type: "function", tool: t } }
end
def conversation_context
return [] if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context.reverse.map do |context|
translated = context.slice(:content)
translated[:role] = context[:type]
if context[:name]
if translated[:role] == "tool"
translated[:tool_call_id] = context[:name]
else
translated[:name] = context[:name]
end
end
translated
end
end
def max_prompt_tokens
# provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard
buffer = (opts[:max_tokens_to_sample] || 2500) + 50
if tools.present?
# note this is about 100 tokens over, OpenAI have a more optimal representation
@function_size ||= self.class.tokenizer.size(tools.to_json.to_s)
buffer += @function_size
end
model_max_tokens - buffer
end
private
def per_message_overhead
# open ai defines about 4 tokens per message of overhead
4
end
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
def model_max_tokens
case model_name
when "gpt-3.5-turbo", "gpt-3.5-turbo-16k"
16_384
when "gpt-4"
8192
when "gpt-4-32k"
32_768
else
8192
end
end
end
end end
end end
end end

View File

@ -3,26 +3,66 @@
module DiscourseAi module DiscourseAi
module Completions module Completions
module Dialects module Dialects
class Claude class Claude < Dialect
def self.can_translate?(model_name) class << self
def can_translate?(model_name)
%w[claude-instant-1 claude-2].include?(model_name) %w[claude-instant-1 claude-2].include?(model_name)
end end
def translate(generic_prompt)
claude_prompt = +"Human: #{generic_prompt[:insts]}\n"
claude_prompt << build_examples(generic_prompt[:examples]) if generic_prompt[:examples]
claude_prompt << "#{generic_prompt[:input]}\n"
claude_prompt << "#{generic_prompt[:post_insts]}\n" if generic_prompt[:post_insts]
claude_prompt << "Assistant:\n"
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::AnthropicTokenizer DiscourseAi::Tokenizer::AnthropicTokenizer
end end
end
def translate
claude_prompt = +"Human: #{prompt[:insts]}\n"
claude_prompt << build_tools_prompt if prompt[:tools]
claude_prompt << build_examples(prompt[:examples]) if prompt[:examples]
claude_prompt << conversation_context if prompt[:conversation_context]
claude_prompt << "#{prompt[:input]}\n"
claude_prompt << "#{prompt[:post_insts]}\n" if prompt[:post_insts]
claude_prompt << "Assistant:\n"
end
def max_prompt_tokens
50_000
end
def conversation_context
return "" if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context
.reverse
.reduce(+"") do |memo, context|
memo << (context[:type] == "user" ? "Human:" : "Assistant:")
if context[:type] == "tool"
memo << <<~TEXT
<function_results>
<result>
<tool_name>#{context[:name]}</tool_name>
<json>
#{context[:content]}
</json>
</result>
</function_results>
TEXT
else
memo << " " << context[:content] << "\n"
end
memo
end
end
private private

View File

@ -0,0 +1,160 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Dialect
class << self
def can_translate?(_model_name)
raise NotImplemented
end
def dialect_for(model_name)
dialects = [
DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Llama2Classic,
DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::OrcaStyle,
DiscourseAi::Completions::Dialects::Gemini,
]
dialects.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |d|
d.can_translate?(model_name)
end
end
def tokenizer
raise NotImplemented
end
end
def initialize(generic_prompt, model_name, opts: {})
@prompt = generic_prompt
@model_name = model_name
@opts = opts
end
def translate
raise NotImplemented
end
def tools
tools = +""
prompt[:tools].each do |function|
parameters = +""
if function[:parameters].present?
function[:parameters].each do |parameter|
parameters << <<~PARAMETER
<parameter>
<name>#{parameter[:name]}</name>
<type>#{parameter[:type]}</type>
<description>#{parameter[:description]}</description>
<required>#{parameter[:required]}</required>
PARAMETER
if parameter[:enum]
parameters << "<options>#{parameter[:enum].join(",")}</options>\n"
end
parameters << "</parameter>\n"
end
end
tools << <<~TOOLS
<tool_description>
<tool_name>#{function[:name]}</tool_name>
<description>#{function[:description]}</description>
<parameters>
#{parameters}</parameters>
</tool_description>
TOOLS
end
tools
end
def conversation_context
raise NotImplemented
end
def max_prompt_tokens
raise NotImplemented
end
private
attr_reader :prompt, :model_name, :opts
def trim_context(conversation_context)
prompt_limit = max_prompt_tokens
current_token_count = calculate_token_count_without_context
conversation_context.reduce([]) do |memo, context|
break(memo) if current_token_count >= prompt_limit
dupped_context = context.dup
message_tokens = calculate_message_token(dupped_context)
# Trimming content to make sure we respect token limit.
while dupped_context[:content].present? &&
message_tokens + current_token_count + per_message_overhead > prompt_limit
dupped_context[:content] = dupped_context[:content][0..-100] || ""
message_tokens = calculate_message_token(dupped_context)
end
next(memo) if dupped_context[:content].blank?
current_token_count += message_tokens + per_message_overhead
memo << dupped_context
end
end
def calculate_token_count_without_context
tokenizer = self.class.tokenizer
examples_count =
prompt[:examples].to_a.sum do |pair|
tokenizer.size(pair.join) + (per_message_overhead * 2)
end
input_count = tokenizer.size(prompt[:input].to_s) + per_message_overhead
examples_count + input_count +
prompt
.except(:conversation_context, :tools, :examples, :input)
.sum { |_, v| tokenizer.size(v) + per_message_overhead }
end
def per_message_overhead
0
end
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s)
end
def build_tools_prompt
return "" if prompt[:tools].blank?
<<~TEXT
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
#{tools}</tools>
TEXT
end
end
end
end
end

View File

@ -3,36 +3,93 @@
module DiscourseAi module DiscourseAi
module Completions module Completions
module Dialects module Dialects
class Gemini class Gemini < Dialect
def self.can_translate?(model_name) class << self
def can_translate?(model_name)
%w[gemini-pro].include?(model_name) %w[gemini-pro].include?(model_name)
end end
def translate(generic_prompt)
gemini_prompt = [
{
role: "user",
parts: {
text: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
},
},
{ role: "model", parts: { text: "Ok." } },
]
if generic_prompt[:examples]
generic_prompt[:examples].each do |example_pair|
gemini_prompt << { role: "user", parts: { text: example_pair.first } }
gemini_prompt << { role: "model", parts: { text: example_pair.second } }
end
end
gemini_prompt << { role: "user", parts: { text: generic_prompt[:input] } }
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
end end
end end
def translate
gemini_prompt = [
{
role: "user",
parts: {
text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"),
},
},
{ role: "model", parts: { text: "Ok." } },
]
if prompt[:examples]
prompt[:examples].each do |example_pair|
gemini_prompt << { role: "user", parts: { text: example_pair.first } }
gemini_prompt << { role: "model", parts: { text: example_pair.second } }
end
end
gemini_prompt.concat!(conversation_context) if prompt[:conversation_context]
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
end
def tools
return if prompt[:tools].blank?
translated_tools =
prompt[:tools].map do |t|
required_fields = []
tool = t.dup
tool[:parameters] = t[:parameters].map do |p|
required_fields << p[:name] if p[:required]
p.except(:required)
end
tool.merge(required: required_fields)
end
[{ function_declarations: translated_tools }]
end
def conversation_context
return [] if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context.reverse.map do |context|
translated = {}
translated[:role] = (context[:type] == "user" ? "user" : "model")
part = {}
if context[:type] == "tool"
part["functionResponse"] = { name: context[:name], content: context[:content] }
else
part[:text] = context[:content]
end
translated[:parts] = [part]
translated
end
end
def max_prompt_tokens
16_384 # 50% of model tokens
end
protected
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
end
end end
end end
end end

View File

@ -3,29 +3,74 @@
module DiscourseAi module DiscourseAi
module Completions module Completions
module Dialects module Dialects
class Llama2Classic class Llama2Classic < Dialect
def self.can_translate?(model_name) class << self
def can_translate?(model_name)
%w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name) %w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
end end
def translate(generic_prompt)
llama2_prompt =
+"[INST]<<SYS>>#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}<</SYS>>[/INST]\n"
if generic_prompt[:examples]
generic_prompt[:examples].each do |example_pair|
llama2_prompt << "[INST]#{example_pair.first}[/INST]\n"
llama2_prompt << "#{example_pair.second}\n"
end
end
llama2_prompt << "[INST]#{generic_prompt[:input]}[/INST]\n"
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::Llama2Tokenizer DiscourseAi::Tokenizer::Llama2Tokenizer
end end
end end
def translate
llama2_prompt = +<<~TEXT
[INST]
<<SYS>>
#{prompt[:insts]}
#{build_tools_prompt}#{prompt[:post_insts]}
<</SYS>>
[/INST]
TEXT
if prompt[:examples]
prompt[:examples].each do |example_pair|
llama2_prompt << "[INST]#{example_pair.first}[/INST]\n"
llama2_prompt << "#{example_pair.second}\n"
end
end
llama2_prompt << conversation_context if prompt[:conversation_context].present?
llama2_prompt << "[INST]#{prompt[:input]}[/INST]\n"
end
def conversation_context
return "" if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context
.reverse
.reduce(+"") do |memo, context|
if context[:type] == "tool"
memo << <<~TEXT
[INST]
<function_results>
<result>
<tool_name>#{context[:name]}</tool_name>
<json>
#{context[:content]}
</json>
</result>
</function_results>
[/INST]
TEXT
elsif context[:type] == "assistant"
memo << "[INST]" << context[:content] << "[/INST]\n"
else
memo << context[:content] << "\n"
end
memo
end
end
def max_prompt_tokens
SiteSetting.ai_hugging_face_token_limit
end
end
end end
end end
end end

View File

@ -3,31 +3,70 @@
module DiscourseAi module DiscourseAi
module Completions module Completions
module Dialects module Dialects
class OrcaStyle class OrcaStyle < Dialect
def self.can_translate?(model_name) class << self
def can_translate?(model_name)
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name) %w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name)
end end
def translate(generic_prompt)
orca_style_prompt =
+"### System:\n#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}\n"
if generic_prompt[:examples]
generic_prompt[:examples].each do |example_pair|
orca_style_prompt << "### User:\n#{example_pair.first}\n"
orca_style_prompt << "### Assistant:\n#{example_pair.second}\n"
end
end
orca_style_prompt << "### User:\n#{generic_prompt[:input]}\n"
orca_style_prompt << "### Assistant:\n"
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::Llama2Tokenizer DiscourseAi::Tokenizer::Llama2Tokenizer
end end
end end
def translate
orca_style_prompt = +<<~TEXT
### System:
#{prompt[:insts]}
#{build_tools_prompt}#{prompt[:post_insts]}
TEXT
if prompt[:examples]
prompt[:examples].each do |example_pair|
orca_style_prompt << "### User:\n#{example_pair.first}\n"
orca_style_prompt << "### Assistant:\n#{example_pair.second}\n"
end
end
orca_style_prompt << "### User:\n#{prompt[:input]}\n"
orca_style_prompt << "### Assistant:\n"
end
def conversation_context
return "" if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context
.reverse
.reduce(+"") do |memo, context|
memo << (context[:type] == "user" ? "### User:" : "### Assistant:")
if context[:type] == "tool"
memo << <<~TEXT
<function_results>
<result>
<tool_name>#{context[:name]}</tool_name>
<json>
#{context[:content]}
</json>
</result>
</function_results>
TEXT
else
memo << " " << context[:content] << "\n"
end
memo
end
end
def max_prompt_tokens
SiteSetting.ai_hugging_face_token_limit
end
end
end end
end end
end end

View File

@ -22,7 +22,7 @@ module DiscourseAi
@uri ||= URI("https://api.anthropic.com/v1/complete") @uri ||= URI("https://api.anthropic.com/v1/complete")
end end
def prepare_payload(prompt, model_params) def prepare_payload(prompt, model_params, _dialect)
default_options default_options
.merge(model_params) .merge(model_params)
.merge(prompt: prompt) .merge(prompt: prompt)

View File

@ -37,7 +37,7 @@ module DiscourseAi
URI(api_url) URI(api_url)
end end
def prepare_payload(prompt, model_params) def prepare_payload(prompt, model_params, _dialect)
default_options.merge(prompt: prompt).merge(model_params) default_options.merge(prompt: prompt).merge(model_params)
end end

View File

@ -30,9 +30,11 @@ module DiscourseAi
@tokenizer = tokenizer @tokenizer = tokenizer
end end
def perform_completion!(prompt, user, model_params = {}) def perform_completion!(dialect, user, model_params = {})
@streaming_mode = block_given? @streaming_mode = block_given?
prompt = dialect.translate
Net::HTTP.start( Net::HTTP.start(
model_uri.host, model_uri.host,
model_uri.port, model_uri.port,
@ -43,7 +45,10 @@ module DiscourseAi
) do |http| ) do |http|
response_data = +"" response_data = +""
response_raw = +"" response_raw = +""
request_body = prepare_payload(prompt, model_params).to_json
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
partials_raw = +""
request_body = prepare_payload(prompt, model_params, dialect).to_json
request = prepare_request(request_body) request = prepare_request(request_body)
@ -66,6 +71,15 @@ module DiscourseAi
if !@streaming_mode if !@streaming_mode
response_raw = response.read_body response_raw = response.read_body
response_data = extract_completion_from(response_raw) response_data = extract_completion_from(response_raw)
partials_raw = response_data.to_s
if has_tool?("", response_data)
function_buffer = build_buffer # Nokogiri document
function_buffer = add_to_buffer(function_buffer, "", response_data)
response_data = +function_buffer.at("function_calls").to_s
response_data << "\n"
end
return response_data return response_data
end end
@ -75,6 +89,7 @@ module DiscourseAi
cancel = lambda { cancelled = true } cancel = lambda { cancelled = true }
leftover = "" leftover = ""
function_buffer = build_buffer # Nokogiri document
response.read_body do |chunk| response.read_body do |chunk|
if cancelled if cancelled
@ -85,6 +100,12 @@ module DiscourseAi
decoded_chunk = decode(chunk) decoded_chunk = decode(chunk)
response_raw << decoded_chunk response_raw << decoded_chunk
# Buffering for extremely slow streaming.
if (leftover + decoded_chunk).length < "data: [DONE]".length
leftover += decoded_chunk
next
end
partials_from(leftover + decoded_chunk).each do |raw_partial| partials_from(leftover + decoded_chunk).each do |raw_partial|
next if cancelled next if cancelled
next if raw_partial.blank? next if raw_partial.blank?
@ -93,11 +114,27 @@ module DiscourseAi
partial = extract_completion_from(raw_partial) partial = extract_completion_from(raw_partial)
next if partial.nil? next if partial.nil?
leftover = "" leftover = ""
if has_tool?(response_data, partial)
function_buffer = add_to_buffer(function_buffer, response_data, partial)
if buffering_finished?(dialect.tools, function_buffer)
invocation = +function_buffer.at("function_calls").to_s
invocation << "\n"
partials_raw << partial.to_s
response_data << invocation
yield invocation, cancel
end
else
partials_raw << partial
response_data << partial response_data << partial
yield partial, cancel if partial yield partial, cancel if partial
end
rescue JSON::ParserError rescue JSON::ParserError
leftover = raw_partial leftover += decoded_chunk
end end
end end
end end
@ -109,7 +146,7 @@ module DiscourseAi
ensure ensure
if log if log
log.raw_response_payload = response_raw log.raw_response_payload = response_raw
log.response_tokens = tokenizer.size(response_data) log.response_tokens = tokenizer.size(partials_raw)
log.save! log.save!
if Rails.env.development? if Rails.env.development?
@ -165,6 +202,40 @@ module DiscourseAi
def extract_prompt_for_tokenizer(prompt) def extract_prompt_for_tokenizer(prompt)
prompt prompt
end end
def build_buffer
Nokogiri::HTML5.fragment(<<~TEXT)
<function_calls>
<invoke>
<tool_name></tool_name>
<tool_id></tool_id>
<parameters></parameters>
</invoke>
</function_calls>
TEXT
end
def has_tool?(response, partial)
(response + partial).include?("<function_calls>")
end
def add_to_buffer(function_buffer, response_data, partial)
new_buffer = Nokogiri::HTML5.fragment(response_data + partial)
if tool_name = new_buffer.at("tool_name").text
if new_buffer.at("tool_id").nil?
tool_id_node =
Nokogiri::HTML5::DocumentFragment.parse("\n<tool_id>#{tool_name}</tool_id>")
new_buffer.at("invoke").children[1].add_next_sibling(tool_id_node)
end
end
new_buffer
end
def buffering_finished?(_available_functions, buffer)
buffer.to_s.include?("</function_calls>")
end
end end
end end
end end

View File

@ -31,10 +31,15 @@ module DiscourseAi
cancelled = false cancelled = false
cancel_fn = lambda { cancelled = true } cancel_fn = lambda { cancelled = true }
# We buffer and return tool invocations in one go.
if is_tool?(response)
yield(response, cancel_fn)
else
response.each_char do |char| response.each_char do |char|
break if cancelled break if cancelled
yield(char, cancel_fn) yield(char, cancel_fn)
end end
end
else else
response response
end end
@ -43,6 +48,12 @@ module DiscourseAi
def tokenizer def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer DiscourseAi::Tokenizer::OpenAiTokenizer
end end
private
def is_tool?(response)
Nokogiri::HTML5.fragment(response).at("function_calls").present?
end
end end
end end
end end

View File

@ -25,8 +25,11 @@ module DiscourseAi
URI(url) URI(url)
end end
def prepare_payload(prompt, model_params) def prepare_payload(prompt, model_params, dialect)
default_options.merge(model_params).merge(contents: prompt) default_options
.merge(model_params)
.merge(contents: prompt)
.tap { |payload| payload[:tools] = dialect.tools if dialect.tools.present? }
end end
def prepare_request(payload) def prepare_request(payload)
@ -36,25 +39,72 @@ module DiscourseAi
end end
def extract_completion_from(response_raw) def extract_completion_from(response_raw)
if @streaming_mode
parsed = response_raw
else
parsed = JSON.parse(response_raw, symbolize_names: true) parsed = JSON.parse(response_raw, symbolize_names: true)
end
completion = dig_text(parsed).to_s response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
has_function_call = response_h.dig(:functionCall).present?
has_function_call ? response_h[:functionCall] : response_h.dig(:text)
end end
def partials_from(decoded_chunk) def partials_from(decoded_chunk)
JSON.parse(decoded_chunk, symbolize_names: true) decoded_chunk
.split("\n")
.map do |line|
if line == ","
nil
elsif line.starts_with?("[")
line[1..-1]
elsif line.ends_with?("]")
line[0..-1]
else
line
end
end
.compact_blank
end end
def extract_prompt_for_tokenizer(prompt) def extract_prompt_for_tokenizer(prompt)
prompt.to_s prompt.to_s
end end
def dig_text(response) def has_tool?(_response_data, partial)
response.dig(:candidates, 0, :content, :parts, 0, :text) partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
end
def add_to_buffer(function_buffer, _response_data, partial)
if partial[:name].present?
function_buffer.at("tool_name").content = partial[:name]
function_buffer.at("tool_id").content = partial[:name]
end
if partial[:args]
argument_fragments =
partial[:args].reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
end
argument_fragments << "\n"
function_buffer.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
end
function_buffer
end
def buffering_finished?(available_functions, buffer)
tool_name = buffer.at("tool_name")&.text
return false if tool_name.blank?
signature =
available_functions.dig(0, :function_declarations).find { |f| f[:name] == tool_name }
signature[:parameters].reduce(true) do |memo, param|
param_present = buffer.at(param[:name]).present?
next(memo) if param_present || !signature[:required].include?(param[:name])
memo && param_present
end
end end
end end
end end

View File

@ -11,7 +11,7 @@ module DiscourseAi
end end
def default_options def default_options
{ parameters: { repetition_penalty: 1.1, temperature: 0.7 } } { parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } }
end end
def provider_id def provider_id
@ -24,7 +24,7 @@ module DiscourseAi
URI(SiteSetting.ai_hugging_face_api_url) URI(SiteSetting.ai_hugging_face_api_url)
end end
def prepare_payload(prompt, model_params) def prepare_payload(prompt, model_params, _dialect)
default_options default_options
.merge(inputs: prompt) .merge(inputs: prompt)
.tap do |payload| .tap do |payload|
@ -33,7 +33,6 @@ module DiscourseAi
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000 token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt) payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
payload[:parameters][:return_full_text] = false
payload[:stream] = true if @streaming_mode payload[:stream] = true if @streaming_mode
end end

View File

@ -37,11 +37,14 @@ module DiscourseAi
URI(url) URI(url)
end end
def prepare_payload(prompt, model_params) def prepare_payload(prompt, model_params, dialect)
default_options default_options
.merge(model_params) .merge(model_params)
.merge(messages: prompt) .merge(messages: prompt)
.tap { |payload| payload[:stream] = true if @streaming_mode } .tap do |payload|
payload[:stream] = true if @streaming_mode
payload[:tools] = dialect.tools if dialect.tools.present?
end
end end
def prepare_request(payload) def prepare_request(payload)
@ -62,15 +65,12 @@ module DiscourseAi
end end
def extract_completion_from(response_raw) def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true) parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
( response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
if @streaming_mode
parsed.dig(:choices, 0, :delta, :content) has_function_call = response_h.dig(:tool_calls).present?
else has_function_call ? response_h.dig(:tool_calls, 0, :function) : response_h.dig(:content)
parsed.dig(:choices, 0, :message, :content)
end
).to_s
end end
def partials_from(decoded_chunk) def partials_from(decoded_chunk)
@ -86,6 +86,42 @@ module DiscourseAi
def extract_prompt_for_tokenizer(prompt) def extract_prompt_for_tokenizer(prompt)
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end end
def has_tool?(_response_data, partial)
partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
end
def add_to_buffer(function_buffer, _response_data, partial)
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
function_buffer.at("tool_id").content = partial[:id] if partial[:id].present?
if partial[:arguments]
argument_fragments =
partial[:arguments].reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
end
argument_fragments << "\n"
function_buffer.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
end
function_buffer
end
def buffering_finished?(available_functions, buffer)
tool_name = buffer.at("tool_name")&.text
return false if tool_name.blank?
signature = available_functions.find { |f| f.dig(:tool, :name) == tool_name }[:tool]
signature[:parameters].reduce(true) do |memo, param|
param_present = buffer.at(param[:name]).present?
next(memo) if param_present && !param[:required]
memo && param_present
end
end
end end
end end
end end

View File

@ -24,35 +24,26 @@ module DiscourseAi
end end
def self.proxy(model_name) def self.proxy(model_name)
dialects = [ dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Llama2Classic,
DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::OrcaStyle,
DiscourseAi::Completions::Dialects::Gemini,
]
dialect = return new(dialect_klass, @canned_response, model_name) if @canned_response
dialects.detect(-> { raise UNKNOWN_MODEL }) { |d| d.can_translate?(model_name) }.new
return new(dialect, @canned_response, model_name) if @canned_response
gateway = gateway =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new( DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new(
model_name, model_name,
dialect.tokenizer, dialect_klass.tokenizer,
) )
new(dialect, gateway, model_name) new(dialect_klass, gateway, model_name)
end end
def initialize(dialect, gateway, model_name) def initialize(dialect_klass, gateway, model_name)
@dialect = dialect @dialect_klass = dialect_klass
@gateway = gateway @gateway = gateway
@model_name = model_name @model_name = model_name
end end
delegate :tokenizer, to: :dialect delegate :tokenizer, to: :dialect_klass
# @param generic_prompt { Hash } - Prompt using our generic format. # @param generic_prompt { Hash } - Prompt using our generic format.
# We use the following keys from the hash: # We use the following keys from the hash:
@ -60,23 +51,64 @@ module DiscourseAi
# - input: String containing user input # - input: String containing user input
# - examples (optional): Array of arrays with examples of input and responses. Each array is a input/response pair like [[example1, response1], [example2, response2]]. # - examples (optional): Array of arrays with examples of input and responses. Each array is a input/response pair like [[example1, response1], [example2, response2]].
# - post_insts (optional): Additional instructions for the LLM. Some dialects like Claude add these at the end of the prompt. # - post_insts (optional): Additional instructions for the LLM. Some dialects like Claude add these at the end of the prompt.
# - conversation_context (optional): Array of hashes to provide context about an ongoing conversation with the model.
# We translate the array in reverse order, meaning the first element would be the most recent message in the conversation.
# Example:
#
# [
# { type: "user", name: "user1", content: "This is a new message by a user" },
# { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
# { type: "tool", name: "tool_id", content: "I'm a tool result" },
# ]
#
# - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example:
#
# {
# name: "get_weather",
# description: "Get the weather in a city",
# parameters: [
# { name: "location", type: "string", description: "the city name", required: true },
# {
# name: "unit",
# type: "string",
# description: "the unit of measurement celcius c or fahrenheit f",
# enum: %w[c f],
# required: true,
# },
# ],
# }
# #
# @param user { User } - User requesting the summary. # @param user { User } - User requesting the summary.
# #
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function. # @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
# #
# @returns { String } - Completion result. # @returns { String } - Completion result.
#
# When the model invokes a tool, we'll wait until the endpoint finishes replying and feed you a fully-formed tool,
# even if you passed a partial_read_blk block. Invocations are strings that look like this:
#
# <function_calls>
# <invoke>
# <tool_name>get_weather</tool_name>
# <tool_id>get_weather</tool_id>
# <parameters>
# <location>Sydney</location>
# <unit>c</unit>
# </parameters>
# </invoke>
# </function_calls>
#
def completion!(generic_prompt, user, &partial_read_blk) def completion!(generic_prompt, user, &partial_read_blk)
prompt = dialect.translate(generic_prompt)
model_params = generic_prompt.dig(:params, model_name) || {} model_params = generic_prompt.dig(:params, model_name) || {}
gateway.perform_completion!(prompt, user, model_params, &partial_read_blk) dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params)
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
end end
private private
attr_reader :dialect, :gateway, :model_name attr_reader :dialect_klass, :gateway, :model_name
end end
end end
end end

View File

@ -1,7 +1,24 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
subject(:dialect) { described_class.new } subject(:dialect) { described_class.new(prompt, "gpt-4") }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:prompt) do let(:prompt) do
{ {
@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
TEXT TEXT
post_insts: post_insts:
"Please put the translation between <ai></ai> tags and separate each title with a comma.", "Please put the translation between <ai></ai> tags and separate each title with a comma.",
tools: [tool],
} }
end end
@ -35,7 +53,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
{ role: "user", content: prompt[:input] }, { role: "user", content: prompt[:input] },
] ]
translated = dialect.translate(prompt) translated = dialect.translate
expect(translated).to contain_exactly(*open_ai_version) expect(translated).to contain_exactly(*open_ai_version)
end end
@ -55,9 +73,51 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
{ role: "user", content: prompt[:input] }, { role: "user", content: prompt[:input] },
] ]
translated = dialect.translate(prompt) translated = dialect.translate
expect(translated).to contain_exactly(*open_ai_version) expect(translated).to contain_exactly(*open_ai_version)
end end
end end
describe "#conversation_context" do
let(:context) do
[
{ type: "user", name: "user1", content: "This is a new message by a user" },
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
]
end
it "adds conversation in reverse order (first == newer)" do
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context).to eq(
[
{ role: "tool", content: context.last[:content], tool_call_id: context.last[:name] },
{ role: "assistant", content: context.second[:content] },
{ role: "user", content: context.first[:content], name: context.first[:name] },
],
)
end
it "trims content if it's getting too long" do
context.last[:content] = context.last[:content] * 1000
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context.last[:content].length).to be < context.last[:content].length
end
end
describe "#tools" do
it "returns a list of available tools" do
open_ai_tool_f = { type: "function", tool: tool }
expect(subject.tools).to contain_exactly(open_ai_tool_f)
end
end
end end

View File

@ -1,7 +1,24 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Claude do RSpec.describe DiscourseAi::Completions::Dialects::Claude do
subject(:dialect) { described_class.new } subject(:dialect) { described_class.new(prompt, "claude-2") }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:prompt) do let(:prompt) do
{ {
@ -37,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
Assistant: Assistant:
TEXT TEXT
translated = dialect.translate(prompt) translated = dialect.translate
expect(translated).to eq(anthropic_version) expect(translated).to eq(anthropic_version)
end end
@ -60,9 +77,111 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
Assistant: Assistant:
TEXT TEXT
translated = dialect.translate(prompt) translated = dialect.translate
expect(translated).to eq(anthropic_version)
end
it "include tools inside the prompt" do
prompt[:tools] = [tool]
anthropic_version = <<~TEXT
Human: #{prompt[:insts]}
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
#{dialect.tools}</tools>
#{prompt[:input]}
#{prompt[:post_insts]}
Assistant:
TEXT
translated = dialect.translate
expect(translated).to eq(anthropic_version) expect(translated).to eq(anthropic_version)
end end
end end
describe "#conversation_context" do
let(:context) do
[
{ type: "user", name: "user1", content: "This is a new message by a user" },
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
]
end
it "adds conversation in reverse order (first == newer)" do
prompt[:conversation_context] = context
expected = <<~TEXT
Assistant:
<function_results>
<result>
<tool_name>tool_id</tool_name>
<json>
#{context.last[:content]}
</json>
</result>
</function_results>
Assistant: #{context.second[:content]}
Human: #{context.first[:content]}
TEXT
translated_context = dialect.conversation_context
expect(translated_context).to eq(expected)
end
it "trims content if it's getting too long" do
context.last[:content] = context.last[:content] * 10_000
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context.length).to be < context.last[:content].length
end
end
describe "#tools" do
it "translates tools to the tool syntax" do
prompt[:tools] = [tool]
translated_tool = <<~TEXT
<tool_description>
<tool_name>get_weather</tool_name>
<description>Get the weather in a city</description>
<parameters>
<parameter>
<name>location</name>
<type>string</type>
<description>the city name</description>
<required>true</required>
</parameter>
<parameter>
<name>unit</name>
<type>string</type>
<description>the unit of measurement celcius c or fahrenheit f</description>
<required>true</required>
<options>c,f</options>
</parameter>
</parameters>
</tool_description>
TEXT
expect(dialect.tools).to eq(translated_tool)
end
end
end end

View File

@ -1,7 +1,24 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
subject(:dialect) { described_class.new } subject(:dialect) { described_class.new(prompt, "gemini-pro") }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:prompt) do let(:prompt) do
{ {
@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
TEXT TEXT
post_insts: post_insts:
"Please put the translation between <ai></ai> tags and separate each title with a comma.", "Please put the translation between <ai></ai> tags and separate each title with a comma.",
tools: [tool],
} }
end end
@ -36,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
{ role: "user", parts: { text: prompt[:input] } }, { role: "user", parts: { text: prompt[:input] } },
] ]
translated = dialect.translate(prompt) translated = dialect.translate
expect(translated).to eq(gemini_version) expect(translated).to eq(gemini_version)
end end
@ -57,9 +75,79 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
{ role: "user", parts: { text: prompt[:input] } }, { role: "user", parts: { text: prompt[:input] } },
] ]
translated = dialect.translate(prompt) translated = dialect.translate
expect(translated).to contain_exactly(*gemini_version) expect(translated).to contain_exactly(*gemini_version)
end end
end end
describe "#conversation_context" do
let(:context) do
[
{ type: "user", name: "user1", content: "This is a new message by a user" },
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
]
end
it "adds conversation in reverse order (first == newer)" do
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context).to eq(
[
{
role: "model",
parts: [
{
"functionResponse" => {
name: context.last[:name],
content: context.last[:content],
},
},
],
},
{ role: "model", parts: [{ text: context.second[:content] }] },
{ role: "user", parts: [{ text: context.first[:content] }] },
],
)
end
it "trims content if it's getting too long" do
context.last[:content] = context.last[:content] * 1000
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context.last.dig(:parts, 0, :text).length).to be <
context.last[:content].length
end
end
describe "#tools" do
it "returns a list of available tools" do
gemini_tools = {
function_declarations: [
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name" },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
],
required: %w[location unit],
},
],
}
expect(subject.tools).to contain_exactly(gemini_tools)
end
end
end end

View File

@ -1,7 +1,24 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
subject(:dialect) { described_class.new } subject(:dialect) { described_class.new(prompt, "Llama2-chat-hf") }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:prompt) do let(:prompt) do
{ {
@ -31,11 +48,16 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
describe "#translate" do describe "#translate" do
it "translates a prompt written in our generic format to the Llama2 format" do it "translates a prompt written in our generic format to the Llama2 format" do
llama2_classic_version = <<~TEXT llama2_classic_version = <<~TEXT
[INST]<<SYS>>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<</SYS>>[/INST] [INST]
<<SYS>>
#{prompt[:insts]}
#{prompt[:post_insts]}
<</SYS>>
[/INST]
[INST]#{prompt[:input]}[/INST] [INST]#{prompt[:input]}[/INST]
TEXT TEXT
translated = dialect.translate(prompt) translated = dialect.translate
expect(translated).to eq(llama2_classic_version) expect(translated).to eq(llama2_classic_version)
end end
@ -49,15 +71,126 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
] ]
llama2_classic_version = <<~TEXT llama2_classic_version = <<~TEXT
[INST]<<SYS>>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<</SYS>>[/INST] [INST]
<<SYS>>
#{prompt[:insts]}
#{prompt[:post_insts]}
<</SYS>>
[/INST]
[INST]#{prompt[:examples][0][0]}[/INST] [INST]#{prompt[:examples][0][0]}[/INST]
#{prompt[:examples][0][1]} #{prompt[:examples][0][1]}
[INST]#{prompt[:input]}[/INST] [INST]#{prompt[:input]}[/INST]
TEXT TEXT
translated = dialect.translate(prompt) translated = dialect.translate
expect(translated).to eq(llama2_classic_version)
end
it "include tools inside the prompt" do
prompt[:tools] = [tool]
llama2_classic_version = <<~TEXT
[INST]
<<SYS>>
#{prompt[:insts]}
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
#{dialect.tools}</tools>
#{prompt[:post_insts]}
<</SYS>>
[/INST]
[INST]#{prompt[:input]}[/INST]
TEXT
translated = dialect.translate
expect(translated).to eq(llama2_classic_version) expect(translated).to eq(llama2_classic_version)
end end
end end
describe "#conversation_context" do
let(:context) do
[
{ type: "user", name: "user1", content: "This is a new message by a user" },
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
]
end
it "adds conversation in reverse order (first == newer)" do
prompt[:conversation_context] = context
expected = <<~TEXT
[INST]
<function_results>
<result>
<tool_name>tool_id</tool_name>
<json>
#{context.last[:content]}
</json>
</result>
</function_results>
[/INST]
[INST]#{context.second[:content]}[/INST]
#{context.first[:content]}
TEXT
translated_context = dialect.conversation_context
expect(translated_context).to eq(expected)
end
it "trims content if it's getting too long" do
context.last[:content] = context.last[:content] * 1_000
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context.length).to be < context.last[:content].length
end
end
describe "#tools" do
it "translates functions to the tool syntax" do
prompt[:tools] = [tool]
translated_tool = <<~TEXT
<tool_description>
<tool_name>get_weather</tool_name>
<description>Get the weather in a city</description>
<parameters>
<parameter>
<name>location</name>
<type>string</type>
<description>the city name</description>
<required>true</required>
</parameter>
<parameter>
<name>unit</name>
<type>string</type>
<description>the unit of measurement celcius c or fahrenheit f</description>
<required>true</required>
<options>c,f</options>
</parameter>
</parameters>
</tool_description>
TEXT
expect(dialect.tools).to eq(translated_tool)
end
end
end end

View File

@ -1,9 +1,25 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
subject(:dialect) { described_class.new } subject(:dialect) { described_class.new(prompt, "StableBeluga2") }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
describe "#translate" do
let(:prompt) do let(:prompt) do
{ {
insts: <<~TEXT, insts: <<~TEXT,
@ -29,16 +45,18 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
} }
end end
describe "#translate" do
it "translates a prompt written in our generic format to the Open AI format" do it "translates a prompt written in our generic format to the Open AI format" do
orca_style_version = <<~TEXT orca_style_version = <<~TEXT
### System: ### System:
#{[prompt[:insts], prompt[:post_insts]].join("\n")} #{prompt[:insts]}
#{prompt[:post_insts]}
### User: ### User:
#{prompt[:input]} #{prompt[:input]}
### Assistant: ### Assistant:
TEXT TEXT
translated = dialect.translate(prompt) translated = dialect.translate
expect(translated).to eq(orca_style_version) expect(translated).to eq(orca_style_version)
end end
@ -53,7 +71,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
orca_style_version = <<~TEXT orca_style_version = <<~TEXT
### System: ### System:
#{[prompt[:insts], prompt[:post_insts]].join("\n")} #{prompt[:insts]}
#{prompt[:post_insts]}
### User: ### User:
#{prompt[:examples][0][0]} #{prompt[:examples][0][0]}
### Assistant: ### Assistant:
@ -63,9 +82,113 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
### Assistant: ### Assistant:
TEXT TEXT
translated = dialect.translate(prompt) translated = dialect.translate
expect(translated).to eq(orca_style_version)
end
it "include tools inside the prompt" do
prompt[:tools] = [tool]
orca_style_version = <<~TEXT
### System:
#{prompt[:insts]}
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
#{dialect.tools}</tools>
#{prompt[:post_insts]}
### User:
#{prompt[:input]}
### Assistant:
TEXT
translated = dialect.translate
expect(translated).to eq(orca_style_version) expect(translated).to eq(orca_style_version)
end end
end end
describe "#conversation_context" do
let(:context) do
[
{ type: "user", name: "user1", content: "This is a new message by a user" },
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
]
end
it "adds conversation in reverse order (first == newer)" do
prompt[:conversation_context] = context
expected = <<~TEXT
### Assistant:
<function_results>
<result>
<tool_name>tool_id</tool_name>
<json>
#{context.last[:content]}
</json>
</result>
</function_results>
### Assistant: #{context.second[:content]}
### User: #{context.first[:content]}
TEXT
translated_context = dialect.conversation_context
expect(translated_context).to eq(expected)
end
it "trims content if it's getting too long" do
context.last[:content] = context.last[:content] * 1_000
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context.length).to be < context.last[:content].length
end
end
describe "#tools" do
it "translates tools to the tool syntax" do
prompt[:tools] = [tool]
translated_tool = <<~TEXT
<tool_description>
<tool_name>get_weather</tool_name>
<description>Get the weather in a city</description>
<parameters>
<parameter>
<name>location</name>
<type>string</type>
<description>the city name</description>
<required>true</required>
</parameter>
<parameter>
<name>unit</name>
<type>string</type>
<description>the unit of measurement celcius c or fahrenheit f</description>
<required>true</required>
<options>c,f</options>
</parameter>
</parameters>
</tool_description>
TEXT
expect(dialect.tools).to eq(translated_tool)
end
end
end end

View File

@ -6,7 +6,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) } subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
let(:model_name) { "claude-2" } let(:model_name) { "claude-2" }
let(:prompt) { "Human: write 3 words\n\n" } let(:generic_prompt) { { insts: "write 3 words" } }
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json } let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json } let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
@ -23,10 +25,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
} }
end end
def stub_response(prompt, response_text) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete") .stub_request(:post, "https://api.anthropic.com/v1/complete")
.with(body: model.default_options.merge(prompt: prompt).to_json) .with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text))) .to_return(status: 200, body: JSON.dump(response(response_text)))
end end
@ -42,7 +44,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
}.to_json }.to_json
end end
def stub_streamed_response(prompt, deltas) def stub_streamed_response(prompt, deltas, tool_call: false)
chunks = chunks =
deltas.each_with_index.map do |_, index| deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1) if index == (deltas.length - 1)
@ -52,13 +54,27 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end end
end end
chunks = chunks.join("\n\n") chunks = chunks.join("\n\n").split("")
WebMock WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete") .stub_request(:post, "https://api.anthropic.com/v1/complete")
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json) .with(body: stream_request_body)
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
let(:tool_deltas) { ["<function", <<~REPLY] }
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
let(:tool_call) { invocation }
it_behaves_like "an endpoint that can communicate with a completion service" it_behaves_like "an endpoint that can communicate with a completion service"
end end

View File

@ -9,10 +9,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
let(:model_name) { "claude-2" } let(:model_name) { "claude-2" }
let(:bedrock_name) { "claude-v2" } let(:bedrock_name) { "claude-v2" }
let(:prompt) { "Human: write 3 words\n\n" } let(:generic_prompt) { { insts: "write 3 words" } }
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json } let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt).to_json } let(:stream_request_body) { request_body }
before do before do
SiteSetting.ai_bedrock_access_key_id = "123456" SiteSetting.ai_bedrock_access_key_id = "123456"
@ -20,39 +22,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
SiteSetting.ai_bedrock_region = "us-east-1" SiteSetting.ai_bedrock_region = "us-east-1"
end end
# Copied from https://github.com/bblimke/webmock/issues/629
# Workaround for stubbing a streamed response
before do
mocked_http =
Class.new(Net::HTTP) do
def request(*)
super do |response|
response.instance_eval do
def read_body(*, &block)
if block_given?
@body.each(&block)
else
super
end
end
end
yield response if block_given?
response
end
end
end
@original_net_http = Net.send(:remove_const, :HTTP)
Net.send(:const_set, :HTTP, mocked_http)
end
after do
Net.send(:remove_const, :HTTP)
Net.send(:const_set, :HTTP, @original_net_http)
end
def response(content) def response(content)
{ {
completion: content, completion: content,
@ -65,7 +34,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
} }
end end
def stub_response(prompt, response_text) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request( .stub_request(
:post, :post,
@ -102,7 +71,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
encoder.encode(message) encoder.encode(message)
end end
def stub_streamed_response(prompt, deltas) def stub_streamed_response(prompt, deltas, tool_call: false)
chunks = chunks =
deltas.each_with_index.map do |_, index| deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1) if index == (deltas.length - 1)
@ -121,5 +90,19 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
let(:tool_deltas) { ["<function", <<~REPLY] }
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
let(:tool_call) { invocation }
it_behaves_like "an endpoint that can communicate with a completion service" it_behaves_like "an endpoint that can communicate with a completion service"
end end

View File

@ -1,22 +1,86 @@
# frozen_string_literal: true # frozen_string_literal: true
RSpec.shared_examples "an endpoint that can communicate with a completion service" do RSpec.shared_examples "an endpoint that can communicate with a completion service" do
# Copied from https://github.com/bblimke/webmock/issues/629
# Workaround for stubbing a streamed response
before do
mocked_http =
Class.new(Net::HTTP) do
def request(*)
super do |response|
response.instance_eval do
def read_body(*, &block)
if block_given?
@body.each(&block)
else
super
end
end
end
yield response if block_given?
response
end
end
end
@original_net_http = Net.send(:remove_const, :HTTP)
Net.send(:const_set, :HTTP, mocked_http)
end
after do
Net.send(:remove_const, :HTTP)
Net.send(:const_set, :HTTP, @original_net_http)
end
describe "#perform_completion!" do describe "#perform_completion!" do
fab!(:user) { Fabricate(:user) } fab!(:user) { Fabricate(:user) }
let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" } let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:invocation) { <<~TEXT }
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<tool_id>get_weather</tool_id>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
TEXT
context "when using regular mode" do context "when using regular mode" do
context "with simple prompts" do
let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
before { stub_response(prompt, response_text) } before { stub_response(prompt, response_text) }
it "can complete a trivial prompt" do it "can complete a trivial prompt" do
completion_response = model.perform_completion!(prompt, user) completion_response = model.perform_completion!(dialect, user)
expect(completion_response).to eq(response_text) expect(completion_response).to eq(response_text)
end end
it "creates an audit log for the request" do it "creates an audit log for the request" do
model.perform_completion!(prompt, user) model.perform_completion!(dialect, user)
expect(AiApiAuditLog.count).to eq(1) expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first log = AiApiAuditLog.first
@ -32,7 +96,27 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
end end
end end
context "with functions" do
let(:generic_prompt) do
{
insts: "You can tell me the weather",
input: "Return the weather in Sydney",
tools: [tool],
}
end
before { stub_response(prompt, tool_call, tool_call: true) }
it "returns a function invocation" do
completion_response = model.perform_completion!(dialect, user)
expect(completion_response).to eq(invocation)
end
end
end
context "when using stream mode" do context "when using stream mode" do
context "with simple prompts" do
let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] } let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] }
before { stub_streamed_response(prompt, deltas) } before { stub_streamed_response(prompt, deltas) }
@ -40,7 +124,7 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
it "can complete a trivial prompt" do it "can complete a trivial prompt" do
completion_response = +"" completion_response = +""
model.perform_completion!(prompt, user) do |partial, cancel| model.perform_completion!(dialect, user) do |partial, cancel|
completion_response << partial completion_response << partial
cancel.call if completion_response.split(" ").length == 2 cancel.call if completion_response.split(" ").length == 2
end end
@ -51,7 +135,7 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
it "creates an audit log and updates is on each read." do it "creates an audit log and updates is on each read." do
completion_response = +"" completion_response = +""
model.perform_completion!(prompt, user) do |partial, cancel| model.perform_completion!(dialect, user) do |partial, cancel|
completion_response << partial completion_response << partial
cancel.call if completion_response.split(" ").length == 2 cancel.call if completion_response.split(" ").length == 2
end end
@ -67,5 +151,28 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join)) expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join))
end end
end end
context "with functions" do
let(:generic_prompt) do
{
insts: "You can tell me the weather",
input: "Return the weather in Sydney",
tools: [tool],
}
end
before { stub_streamed_response(prompt, tool_deltas, tool_call: true) }
it "waits for the invocation to finish before calling the partial" do
buffered_partial = ""
model.perform_completion!(dialect, user) do |partial, cancel|
buffered_partial = partial if partial.include?("<function_calls>")
end
expect(buffered_partial).to eq(invocation)
end
end
end
end end
end end

View File

@ -6,22 +6,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) } subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
let(:model_name) { "gemini-pro" } let(:model_name) { "gemini-pro" }
let(:prompt) do let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:tool_payload) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name" },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
],
required: %w[location unit],
}
end
let(:request_body) do
model
.default_options
.merge(contents: prompt)
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] }
.to_json
end
let(:stream_request_body) do
model
.default_options
.merge(contents: prompt)
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] }
.to_json
end
let(:tool_deltas) do
[ [
{ role: "system", content: "You are a helpful bot." }, { "functionCall" => { name: "get_weather", args: {} } },
{ role: "user", content: "Write 3 words" }, { "functionCall" => { name: "get_weather", args: { location: "" } } },
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
] ]
end end
let(:request_body) { model.default_options.merge(contents: prompt).to_json } let(:tool_call) do
let(:stream_request_body) { model.default_options.merge(contents: prompt).to_json } { "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
end
def response(content) def response(content, tool_call: false)
{ {
candidates: [ candidates: [
{ {
content: { content: {
parts: [{ text: content }], parts: [(tool_call ? content : { text: content })],
role: "model", role: "model",
}, },
finishReason: "STOP", finishReason: "STOP",
@ -45,22 +83,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
} }
end end
def stub_response(prompt, response_text) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request( .stub_request(
:post, :post,
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}", "https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
) )
.with(body: { contents: prompt }) .with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text))) .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
end end
def stream_line(delta, finish_reason: nil) def stream_line(delta, finish_reason: nil, tool_call: false)
{ {
candidates: [ candidates: [
{ {
content: { content: {
parts: [{ text: delta }], parts: [(tool_call ? delta : { text: delta })],
role: "model", role: "model",
}, },
finishReason: finish_reason, finishReason: finish_reason,
@ -76,24 +114,24 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
}.to_json }.to_json
end end
def stub_streamed_response(prompt, deltas) def stub_streamed_response(prompt, deltas, tool_call: false)
chunks = chunks =
deltas.each_with_index.map do |_, index| deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1) if index == (deltas.length - 1)
stream_line(deltas[index], finish_reason: "STOP") stream_line(deltas[index], finish_reason: "STOP", tool_call: tool_call)
else else
stream_line(deltas[index]) stream_line(deltas[index], tool_call: tool_call)
end end
end end
chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]") chunks = chunks.join("\n,\n").prepend("[").concat("\n]").split("")
WebMock WebMock
.stub_request( .stub_request(
:post, :post,
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}", "https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
) )
.with(body: model.default_options.merge(contents: prompt).to_json) .with(body: stream_request_body)
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end

View File

@ -6,10 +6,11 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) } subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) }
let(:model_name) { "Llama2-*-chat-hf" } let(:model_name) { "Llama2-*-chat-hf" }
let(:prompt) { <<~TEXT } let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
[INST]<<SYS>>You are a helpful bot.<</SYS>>[/INST] let(:dialect) do
[INST]Write 3 words[/INST] DiscourseAi::Completions::Dialects::Llama2Classic.new(generic_prompt, model_name)
TEXT end
let(:prompt) { dialect.translate }
let(:request_body) do let(:request_body) do
model model
@ -18,7 +19,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
.tap do |payload| .tap do |payload|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt) model.prompt_size(prompt)
payload[:parameters][:return_full_text] = false
end end
.to_json .to_json
end end
@ -30,7 +30,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) - payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt) model.prompt_size(prompt)
payload[:stream] = true payload[:stream] = true
payload[:parameters][:return_full_text] = false
end end
.to_json .to_json
end end
@ -41,14 +40,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
[{ generated_text: content }] [{ generated_text: content }]
end end
def stub_response(prompt, response_text) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.with(body: request_body) .with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text))) .to_return(status: 200, body: JSON.dump(response(response_text)))
end end
def stream_line(delta, finish_reason: nil) def stream_line(delta, deltas, finish_reason: nil)
+"data: " << { +"data: " << {
token: { token: {
id: 29_889, id: 29_889,
@ -56,22 +55,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
logprob: -0.08319092, logprob: -0.08319092,
special: !!finish_reason, special: !!finish_reason,
}, },
generated_text: finish_reason ? response_text : nil, generated_text: finish_reason ? deltas.join : nil,
details: nil, details: nil,
}.to_json }.to_json
end end
def stub_streamed_response(prompt, deltas) def stub_streamed_response(prompt, deltas, tool_call: false)
chunks = chunks =
deltas.each_with_index.map do |_, index| deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1) if index == (deltas.length - 1)
stream_line(deltas[index], finish_reason: true) stream_line(deltas[index], deltas, finish_reason: true)
else else
stream_line(deltas[index]) stream_line(deltas[index], deltas)
end end
end end
chunks = chunks.join("\n\n") chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}") .stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
@ -79,5 +78,29 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
let(:tool_call) { invocation }
it_behaves_like "an endpoint that can communicate with a completion service" it_behaves_like "an endpoint that can communicate with a completion service"
end end

View File

@ -6,17 +6,53 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) } subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
let(:model_name) { "gpt-3.5-turbo" } let(:model_name) { "gpt-3.5-turbo" }
let(:prompt) do let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:tool_deltas) do
[ [
{ role: "system", content: "You are a helpful bot." }, { id: "get_weather", name: "get_weather", arguments: {} },
{ role: "user", content: "Write 3 words" }, { id: "get_weather", name: "get_weather", arguments: { location: "" } },
{ id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } },
] ]
end end
let(:request_body) { model.default_options.merge(messages: prompt).to_json } let(:tool_call) do
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json } { id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } }
end
let(:request_body) do
model
.default_options
.merge(messages: prompt)
.tap do |b|
b[:tools] = generic_prompt[:tools].map do |t|
{ type: "function", tool: t }
end if generic_prompt[:tools]
end
.to_json
end
let(:stream_request_body) do
model
.default_options
.merge(messages: prompt, stream: true)
.tap do |b|
b[:tools] = generic_prompt[:tools].map do |t|
{ type: "function", tool: t }
end if generic_prompt[:tools]
end
.to_json
end
def response(content, tool_call: false)
message_content =
if tool_call
{ tool_calls: [{ function: content }] }
else
{ content: content }
end
def response(content)
{ {
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
object: "chat.completion", object: "chat.completion",
@ -28,45 +64,52 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
total_tokens: 499, total_tokens: 499,
}, },
choices: [ choices: [
{ message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 }, { message: { role: "assistant" }.merge(message_content), finish_reason: "stop", index: 0 },
], ],
} }
end end
def stub_response(prompt, response_text) def stub_response(prompt, response_text, tool_call: false)
WebMock WebMock
.stub_request(:post, "https://api.openai.com/v1/chat/completions") .stub_request(:post, "https://api.openai.com/v1/chat/completions")
.with(body: { model: model_name, messages: prompt }) .with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text))) .to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
end
def stream_line(delta, finish_reason: nil, tool_call: false)
message_content =
if tool_call
{ tool_calls: [{ function: delta }] }
else
{ content: delta }
end end
def stream_line(delta, finish_reason: nil)
+"data: " << { +"data: " << {
id: "chatcmpl-#{SecureRandom.hex}", id: "chatcmpl-#{SecureRandom.hex}",
object: "chat.completion.chunk", object: "chat.completion.chunk",
created: 1_681_283_881, created: 1_681_283_881,
model: "gpt-3.5-turbo-0301", model: "gpt-3.5-turbo-0301",
choices: [{ delta: { content: delta } }], choices: [{ delta: message_content }],
finish_reason: finish_reason, finish_reason: finish_reason,
index: 0, index: 0,
}.to_json }.to_json
end end
def stub_streamed_response(prompt, deltas) def stub_streamed_response(prompt, deltas, tool_call: false)
chunks = chunks =
deltas.each_with_index.map do |_, index| deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1) if index == (deltas.length - 1)
stream_line(deltas[index], finish_reason: "stop_sequence") stream_line(deltas[index], finish_reason: "stop_sequence", tool_call: tool_call)
else else
stream_line(deltas[index]) stream_line(deltas[index], tool_call: tool_call)
end end
end end
chunks = chunks.join("\n\n") chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock WebMock
.stub_request(:post, "https://api.openai.com/v1/chat/completions") .stub_request(:post, "https://api.openai.com/v1/chat/completions")
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json) .with(body: stream_request_body)
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end

View File

@ -3,7 +3,7 @@
RSpec.describe DiscourseAi::Completions::Llm do RSpec.describe DiscourseAi::Completions::Llm do
subject(:llm) do subject(:llm) do
described_class.new( described_class.new(
DiscourseAi::Completions::Dialects::OrcaStyle.new, DiscourseAi::Completions::Dialects::OrcaStyle,
canned_response, canned_response,
"Upstage-Llama-2-*-instruct-v2", "Upstage-Llama-2-*-instruct-v2",
) )

View File

@ -103,7 +103,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
expect(response.status).to eq(200) expect(response.status).to eq(200)
expect(response.parsed_body["suggestions"].first).to eq(translated_text) expect(response.parsed_body["suggestions"].first).to eq(translated_text)
expect(response.parsed_body["diff"]).to eq(expected_diff) expect(response.parsed_body["diff"]).to eq(expected_diff)
expect(spy.prompt.last[:content]).to eq(expected_input) expect(spy.prompt.translate.last[:content]).to eq(expected_input)
end end
end end
end end