DEV: Tool support for the LLM service. (#366)

This PR adds tool support to available LLMs. We'll buffer tool invocations and return them instead of making users of this service parse the response.

It also adds support for conversation context in the generic prompt. It includes bot messages, user messages, and tool invocations, which we'll trim to make sure it doesn't exceed the prompt limit, then translate them to the correct dialect.

Finally, It adds some buffering when reading chunks to handle cases when streaming is extremely slow.:M
This commit is contained in:
Roman Rizzi 2023-12-18 18:06:01 -03:00 committed by GitHub
parent 203906be65
commit e0bf6adb5b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
27 changed files with 1625 additions and 284 deletions

View File

@ -3,33 +3,101 @@
module DiscourseAi
module Completions
module Dialects
class ChatGpt
def self.can_translate?(model_name)
class ChatGpt < Dialect
class << self
def can_translate?(model_name)
%w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
end
def translate(generic_prompt)
open_ai_prompt = [
{
role: "system",
content: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
},
]
if generic_prompt[:examples]
generic_prompt[:examples].each do |example_pair|
open_ai_prompt << { role: "user", content: example_pair.first }
open_ai_prompt << { role: "assistant", content: example_pair.second }
end
end
open_ai_prompt << { role: "user", content: generic_prompt[:input] }
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
def translate
open_ai_prompt = [
{ role: "system", content: [prompt[:insts], prompt[:post_insts].to_s].join("\n") },
]
if prompt[:examples]
prompt[:examples].each do |example_pair|
open_ai_prompt << { role: "user", content: example_pair.first }
open_ai_prompt << { role: "assistant", content: example_pair.second }
end
end
open_ai_prompt.concat!(conversation_context) if prompt[:conversation_context]
open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input]
open_ai_prompt
end
def tools
return if prompt[:tools].blank?
prompt[:tools].map { |t| { type: "function", tool: t } }
end
def conversation_context
return [] if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context.reverse.map do |context|
translated = context.slice(:content)
translated[:role] = context[:type]
if context[:name]
if translated[:role] == "tool"
translated[:tool_call_id] = context[:name]
else
translated[:name] = context[:name]
end
end
translated
end
end
def max_prompt_tokens
# provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard
buffer = (opts[:max_tokens_to_sample] || 2500) + 50
if tools.present?
# note this is about 100 tokens over, OpenAI have a more optimal representation
@function_size ||= self.class.tokenizer.size(tools.to_json.to_s)
buffer += @function_size
end
model_max_tokens - buffer
end
private
def per_message_overhead
# open ai defines about 4 tokens per message of overhead
4
end
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
def model_max_tokens
case model_name
when "gpt-3.5-turbo", "gpt-3.5-turbo-16k"
16_384
when "gpt-4"
8192
when "gpt-4-32k"
32_768
else
8192
end
end
end
end
end
end

View File

@ -3,26 +3,66 @@
module DiscourseAi
module Completions
module Dialects
class Claude
def self.can_translate?(model_name)
class Claude < Dialect
class << self
def can_translate?(model_name)
%w[claude-instant-1 claude-2].include?(model_name)
end
def translate(generic_prompt)
claude_prompt = +"Human: #{generic_prompt[:insts]}\n"
claude_prompt << build_examples(generic_prompt[:examples]) if generic_prompt[:examples]
claude_prompt << "#{generic_prompt[:input]}\n"
claude_prompt << "#{generic_prompt[:post_insts]}\n" if generic_prompt[:post_insts]
claude_prompt << "Assistant:\n"
end
def tokenizer
DiscourseAi::Tokenizer::AnthropicTokenizer
end
end
def translate
claude_prompt = +"Human: #{prompt[:insts]}\n"
claude_prompt << build_tools_prompt if prompt[:tools]
claude_prompt << build_examples(prompt[:examples]) if prompt[:examples]
claude_prompt << conversation_context if prompt[:conversation_context]
claude_prompt << "#{prompt[:input]}\n"
claude_prompt << "#{prompt[:post_insts]}\n" if prompt[:post_insts]
claude_prompt << "Assistant:\n"
end
def max_prompt_tokens
50_000
end
def conversation_context
return "" if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context
.reverse
.reduce(+"") do |memo, context|
memo << (context[:type] == "user" ? "Human:" : "Assistant:")
if context[:type] == "tool"
memo << <<~TEXT
<function_results>
<result>
<tool_name>#{context[:name]}</tool_name>
<json>
#{context[:content]}
</json>
</result>
</function_results>
TEXT
else
memo << " " << context[:content] << "\n"
end
memo
end
end
private

View File

@ -0,0 +1,160 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
module Dialects
class Dialect
class << self
def can_translate?(_model_name)
raise NotImplemented
end
def dialect_for(model_name)
dialects = [
DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Llama2Classic,
DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::OrcaStyle,
DiscourseAi::Completions::Dialects::Gemini,
]
dialects.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |d|
d.can_translate?(model_name)
end
end
def tokenizer
raise NotImplemented
end
end
def initialize(generic_prompt, model_name, opts: {})
@prompt = generic_prompt
@model_name = model_name
@opts = opts
end
def translate
raise NotImplemented
end
def tools
tools = +""
prompt[:tools].each do |function|
parameters = +""
if function[:parameters].present?
function[:parameters].each do |parameter|
parameters << <<~PARAMETER
<parameter>
<name>#{parameter[:name]}</name>
<type>#{parameter[:type]}</type>
<description>#{parameter[:description]}</description>
<required>#{parameter[:required]}</required>
PARAMETER
if parameter[:enum]
parameters << "<options>#{parameter[:enum].join(",")}</options>\n"
end
parameters << "</parameter>\n"
end
end
tools << <<~TOOLS
<tool_description>
<tool_name>#{function[:name]}</tool_name>
<description>#{function[:description]}</description>
<parameters>
#{parameters}</parameters>
</tool_description>
TOOLS
end
tools
end
def conversation_context
raise NotImplemented
end
def max_prompt_tokens
raise NotImplemented
end
private
attr_reader :prompt, :model_name, :opts
def trim_context(conversation_context)
prompt_limit = max_prompt_tokens
current_token_count = calculate_token_count_without_context
conversation_context.reduce([]) do |memo, context|
break(memo) if current_token_count >= prompt_limit
dupped_context = context.dup
message_tokens = calculate_message_token(dupped_context)
# Trimming content to make sure we respect token limit.
while dupped_context[:content].present? &&
message_tokens + current_token_count + per_message_overhead > prompt_limit
dupped_context[:content] = dupped_context[:content][0..-100] || ""
message_tokens = calculate_message_token(dupped_context)
end
next(memo) if dupped_context[:content].blank?
current_token_count += message_tokens + per_message_overhead
memo << dupped_context
end
end
def calculate_token_count_without_context
tokenizer = self.class.tokenizer
examples_count =
prompt[:examples].to_a.sum do |pair|
tokenizer.size(pair.join) + (per_message_overhead * 2)
end
input_count = tokenizer.size(prompt[:input].to_s) + per_message_overhead
examples_count + input_count +
prompt
.except(:conversation_context, :tools, :examples, :input)
.sum { |_, v| tokenizer.size(v) + per_message_overhead }
end
def per_message_overhead
0
end
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s)
end
def build_tools_prompt
return "" if prompt[:tools].blank?
<<~TEXT
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
#{tools}</tools>
TEXT
end
end
end
end
end

View File

@ -3,36 +3,93 @@
module DiscourseAi
module Completions
module Dialects
class Gemini
def self.can_translate?(model_name)
class Gemini < Dialect
class << self
def can_translate?(model_name)
%w[gemini-pro].include?(model_name)
end
def translate(generic_prompt)
gemini_prompt = [
{
role: "user",
parts: {
text: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
},
},
{ role: "model", parts: { text: "Ok." } },
]
if generic_prompt[:examples]
generic_prompt[:examples].each do |example_pair|
gemini_prompt << { role: "user", parts: { text: example_pair.first } }
gemini_prompt << { role: "model", parts: { text: example_pair.second } }
end
end
gemini_prompt << { role: "user", parts: { text: generic_prompt[:input] } }
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
end
end
def translate
gemini_prompt = [
{
role: "user",
parts: {
text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"),
},
},
{ role: "model", parts: { text: "Ok." } },
]
if prompt[:examples]
prompt[:examples].each do |example_pair|
gemini_prompt << { role: "user", parts: { text: example_pair.first } }
gemini_prompt << { role: "model", parts: { text: example_pair.second } }
end
end
gemini_prompt.concat!(conversation_context) if prompt[:conversation_context]
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
end
def tools
return if prompt[:tools].blank?
translated_tools =
prompt[:tools].map do |t|
required_fields = []
tool = t.dup
tool[:parameters] = t[:parameters].map do |p|
required_fields << p[:name] if p[:required]
p.except(:required)
end
tool.merge(required: required_fields)
end
[{ function_declarations: translated_tools }]
end
def conversation_context
return [] if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context.reverse.map do |context|
translated = {}
translated[:role] = (context[:type] == "user" ? "user" : "model")
part = {}
if context[:type] == "tool"
part["functionResponse"] = { name: context[:name], content: context[:content] }
else
part[:text] = context[:content]
end
translated[:parts] = [part]
translated
end
end
def max_prompt_tokens
16_384 # 50% of model tokens
end
protected
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
end
end
end
end

View File

@ -3,29 +3,74 @@
module DiscourseAi
module Completions
module Dialects
class Llama2Classic
def self.can_translate?(model_name)
class Llama2Classic < Dialect
class << self
def can_translate?(model_name)
%w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
end
def translate(generic_prompt)
llama2_prompt =
+"[INST]<<SYS>>#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}<</SYS>>[/INST]\n"
if generic_prompt[:examples]
generic_prompt[:examples].each do |example_pair|
llama2_prompt << "[INST]#{example_pair.first}[/INST]\n"
llama2_prompt << "#{example_pair.second}\n"
end
end
llama2_prompt << "[INST]#{generic_prompt[:input]}[/INST]\n"
end
def tokenizer
DiscourseAi::Tokenizer::Llama2Tokenizer
end
end
def translate
llama2_prompt = +<<~TEXT
[INST]
<<SYS>>
#{prompt[:insts]}
#{build_tools_prompt}#{prompt[:post_insts]}
<</SYS>>
[/INST]
TEXT
if prompt[:examples]
prompt[:examples].each do |example_pair|
llama2_prompt << "[INST]#{example_pair.first}[/INST]\n"
llama2_prompt << "#{example_pair.second}\n"
end
end
llama2_prompt << conversation_context if prompt[:conversation_context].present?
llama2_prompt << "[INST]#{prompt[:input]}[/INST]\n"
end
def conversation_context
return "" if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context
.reverse
.reduce(+"") do |memo, context|
if context[:type] == "tool"
memo << <<~TEXT
[INST]
<function_results>
<result>
<tool_name>#{context[:name]}</tool_name>
<json>
#{context[:content]}
</json>
</result>
</function_results>
[/INST]
TEXT
elsif context[:type] == "assistant"
memo << "[INST]" << context[:content] << "[/INST]\n"
else
memo << context[:content] << "\n"
end
memo
end
end
def max_prompt_tokens
SiteSetting.ai_hugging_face_token_limit
end
end
end
end
end

View File

@ -3,31 +3,70 @@
module DiscourseAi
module Completions
module Dialects
class OrcaStyle
def self.can_translate?(model_name)
class OrcaStyle < Dialect
class << self
def can_translate?(model_name)
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name)
end
def translate(generic_prompt)
orca_style_prompt =
+"### System:\n#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}\n"
if generic_prompt[:examples]
generic_prompt[:examples].each do |example_pair|
orca_style_prompt << "### User:\n#{example_pair.first}\n"
orca_style_prompt << "### Assistant:\n#{example_pair.second}\n"
end
end
orca_style_prompt << "### User:\n#{generic_prompt[:input]}\n"
orca_style_prompt << "### Assistant:\n"
end
def tokenizer
DiscourseAi::Tokenizer::Llama2Tokenizer
end
end
def translate
orca_style_prompt = +<<~TEXT
### System:
#{prompt[:insts]}
#{build_tools_prompt}#{prompt[:post_insts]}
TEXT
if prompt[:examples]
prompt[:examples].each do |example_pair|
orca_style_prompt << "### User:\n#{example_pair.first}\n"
orca_style_prompt << "### Assistant:\n#{example_pair.second}\n"
end
end
orca_style_prompt << "### User:\n#{prompt[:input]}\n"
orca_style_prompt << "### Assistant:\n"
end
def conversation_context
return "" if prompt[:conversation_context].blank?
trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context
.reverse
.reduce(+"") do |memo, context|
memo << (context[:type] == "user" ? "### User:" : "### Assistant:")
if context[:type] == "tool"
memo << <<~TEXT
<function_results>
<result>
<tool_name>#{context[:name]}</tool_name>
<json>
#{context[:content]}
</json>
</result>
</function_results>
TEXT
else
memo << " " << context[:content] << "\n"
end
memo
end
end
def max_prompt_tokens
SiteSetting.ai_hugging_face_token_limit
end
end
end
end
end

View File

@ -22,7 +22,7 @@ module DiscourseAi
@uri ||= URI("https://api.anthropic.com/v1/complete")
end
def prepare_payload(prompt, model_params)
def prepare_payload(prompt, model_params, _dialect)
default_options
.merge(model_params)
.merge(prompt: prompt)

View File

@ -37,7 +37,7 @@ module DiscourseAi
URI(api_url)
end
def prepare_payload(prompt, model_params)
def prepare_payload(prompt, model_params, _dialect)
default_options.merge(prompt: prompt).merge(model_params)
end

View File

@ -30,9 +30,11 @@ module DiscourseAi
@tokenizer = tokenizer
end
def perform_completion!(prompt, user, model_params = {})
def perform_completion!(dialect, user, model_params = {})
@streaming_mode = block_given?
prompt = dialect.translate
Net::HTTP.start(
model_uri.host,
model_uri.port,
@ -43,7 +45,10 @@ module DiscourseAi
) do |http|
response_data = +""
response_raw = +""
request_body = prepare_payload(prompt, model_params).to_json
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
partials_raw = +""
request_body = prepare_payload(prompt, model_params, dialect).to_json
request = prepare_request(request_body)
@ -66,6 +71,15 @@ module DiscourseAi
if !@streaming_mode
response_raw = response.read_body
response_data = extract_completion_from(response_raw)
partials_raw = response_data.to_s
if has_tool?("", response_data)
function_buffer = build_buffer # Nokogiri document
function_buffer = add_to_buffer(function_buffer, "", response_data)
response_data = +function_buffer.at("function_calls").to_s
response_data << "\n"
end
return response_data
end
@ -75,6 +89,7 @@ module DiscourseAi
cancel = lambda { cancelled = true }
leftover = ""
function_buffer = build_buffer # Nokogiri document
response.read_body do |chunk|
if cancelled
@ -85,6 +100,12 @@ module DiscourseAi
decoded_chunk = decode(chunk)
response_raw << decoded_chunk
# Buffering for extremely slow streaming.
if (leftover + decoded_chunk).length < "data: [DONE]".length
leftover += decoded_chunk
next
end
partials_from(leftover + decoded_chunk).each do |raw_partial|
next if cancelled
next if raw_partial.blank?
@ -93,11 +114,27 @@ module DiscourseAi
partial = extract_completion_from(raw_partial)
next if partial.nil?
leftover = ""
if has_tool?(response_data, partial)
function_buffer = add_to_buffer(function_buffer, response_data, partial)
if buffering_finished?(dialect.tools, function_buffer)
invocation = +function_buffer.at("function_calls").to_s
invocation << "\n"
partials_raw << partial.to_s
response_data << invocation
yield invocation, cancel
end
else
partials_raw << partial
response_data << partial
yield partial, cancel if partial
end
rescue JSON::ParserError
leftover = raw_partial
leftover += decoded_chunk
end
end
end
@ -109,7 +146,7 @@ module DiscourseAi
ensure
if log
log.raw_response_payload = response_raw
log.response_tokens = tokenizer.size(response_data)
log.response_tokens = tokenizer.size(partials_raw)
log.save!
if Rails.env.development?
@ -165,6 +202,40 @@ module DiscourseAi
def extract_prompt_for_tokenizer(prompt)
prompt
end
def build_buffer
Nokogiri::HTML5.fragment(<<~TEXT)
<function_calls>
<invoke>
<tool_name></tool_name>
<tool_id></tool_id>
<parameters></parameters>
</invoke>
</function_calls>
TEXT
end
def has_tool?(response, partial)
(response + partial).include?("<function_calls>")
end
def add_to_buffer(function_buffer, response_data, partial)
new_buffer = Nokogiri::HTML5.fragment(response_data + partial)
if tool_name = new_buffer.at("tool_name").text
if new_buffer.at("tool_id").nil?
tool_id_node =
Nokogiri::HTML5::DocumentFragment.parse("\n<tool_id>#{tool_name}</tool_id>")
new_buffer.at("invoke").children[1].add_next_sibling(tool_id_node)
end
end
new_buffer
end
def buffering_finished?(_available_functions, buffer)
buffer.to_s.include?("</function_calls>")
end
end
end
end

View File

@ -31,10 +31,15 @@ module DiscourseAi
cancelled = false
cancel_fn = lambda { cancelled = true }
# We buffer and return tool invocations in one go.
if is_tool?(response)
yield(response, cancel_fn)
else
response.each_char do |char|
break if cancelled
yield(char, cancel_fn)
end
end
else
response
end
@ -43,6 +48,12 @@ module DiscourseAi
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
private
def is_tool?(response)
Nokogiri::HTML5.fragment(response).at("function_calls").present?
end
end
end
end

View File

@ -25,8 +25,11 @@ module DiscourseAi
URI(url)
end
def prepare_payload(prompt, model_params)
default_options.merge(model_params).merge(contents: prompt)
def prepare_payload(prompt, model_params, dialect)
default_options
.merge(model_params)
.merge(contents: prompt)
.tap { |payload| payload[:tools] = dialect.tools if dialect.tools.present? }
end
def prepare_request(payload)
@ -36,25 +39,72 @@ module DiscourseAi
end
def extract_completion_from(response_raw)
if @streaming_mode
parsed = response_raw
else
parsed = JSON.parse(response_raw, symbolize_names: true)
end
completion = dig_text(parsed).to_s
response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
has_function_call = response_h.dig(:functionCall).present?
has_function_call ? response_h[:functionCall] : response_h.dig(:text)
end
def partials_from(decoded_chunk)
JSON.parse(decoded_chunk, symbolize_names: true)
decoded_chunk
.split("\n")
.map do |line|
if line == ","
nil
elsif line.starts_with?("[")
line[1..-1]
elsif line.ends_with?("]")
line[0..-1]
else
line
end
end
.compact_blank
end
def extract_prompt_for_tokenizer(prompt)
prompt.to_s
end
def dig_text(response)
response.dig(:candidates, 0, :content, :parts, 0, :text)
def has_tool?(_response_data, partial)
partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
end
def add_to_buffer(function_buffer, _response_data, partial)
if partial[:name].present?
function_buffer.at("tool_name").content = partial[:name]
function_buffer.at("tool_id").content = partial[:name]
end
if partial[:args]
argument_fragments =
partial[:args].reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
end
argument_fragments << "\n"
function_buffer.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
end
function_buffer
end
def buffering_finished?(available_functions, buffer)
tool_name = buffer.at("tool_name")&.text
return false if tool_name.blank?
signature =
available_functions.dig(0, :function_declarations).find { |f| f[:name] == tool_name }
signature[:parameters].reduce(true) do |memo, param|
param_present = buffer.at(param[:name]).present?
next(memo) if param_present || !signature[:required].include?(param[:name])
memo && param_present
end
end
end
end

View File

@ -11,7 +11,7 @@ module DiscourseAi
end
def default_options
{ parameters: { repetition_penalty: 1.1, temperature: 0.7 } }
{ parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } }
end
def provider_id
@ -24,7 +24,7 @@ module DiscourseAi
URI(SiteSetting.ai_hugging_face_api_url)
end
def prepare_payload(prompt, model_params)
def prepare_payload(prompt, model_params, _dialect)
default_options
.merge(inputs: prompt)
.tap do |payload|
@ -33,7 +33,6 @@ module DiscourseAi
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
payload[:parameters][:return_full_text] = false
payload[:stream] = true if @streaming_mode
end

View File

@ -37,11 +37,14 @@ module DiscourseAi
URI(url)
end
def prepare_payload(prompt, model_params)
def prepare_payload(prompt, model_params, dialect)
default_options
.merge(model_params)
.merge(messages: prompt)
.tap { |payload| payload[:stream] = true if @streaming_mode }
.tap do |payload|
payload[:stream] = true if @streaming_mode
payload[:tools] = dialect.tools if dialect.tools.present?
end
end
def prepare_request(payload)
@ -62,15 +65,12 @@ module DiscourseAi
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
(
if @streaming_mode
parsed.dig(:choices, 0, :delta, :content)
else
parsed.dig(:choices, 0, :message, :content)
end
).to_s
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
has_function_call = response_h.dig(:tool_calls).present?
has_function_call ? response_h.dig(:tool_calls, 0, :function) : response_h.dig(:content)
end
def partials_from(decoded_chunk)
@ -86,6 +86,42 @@ module DiscourseAi
def extract_prompt_for_tokenizer(prompt)
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
def has_tool?(_response_data, partial)
partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
end
def add_to_buffer(function_buffer, _response_data, partial)
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
function_buffer.at("tool_id").content = partial[:id] if partial[:id].present?
if partial[:arguments]
argument_fragments =
partial[:arguments].reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
end
argument_fragments << "\n"
function_buffer.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
end
function_buffer
end
def buffering_finished?(available_functions, buffer)
tool_name = buffer.at("tool_name")&.text
return false if tool_name.blank?
signature = available_functions.find { |f| f.dig(:tool, :name) == tool_name }[:tool]
signature[:parameters].reduce(true) do |memo, param|
param_present = buffer.at(param[:name]).present?
next(memo) if param_present && !param[:required]
memo && param_present
end
end
end
end
end

View File

@ -24,35 +24,26 @@ module DiscourseAi
end
def self.proxy(model_name)
dialects = [
DiscourseAi::Completions::Dialects::Claude,
DiscourseAi::Completions::Dialects::Llama2Classic,
DiscourseAi::Completions::Dialects::ChatGpt,
DiscourseAi::Completions::Dialects::OrcaStyle,
DiscourseAi::Completions::Dialects::Gemini,
]
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
dialect =
dialects.detect(-> { raise UNKNOWN_MODEL }) { |d| d.can_translate?(model_name) }.new
return new(dialect, @canned_response, model_name) if @canned_response
return new(dialect_klass, @canned_response, model_name) if @canned_response
gateway =
DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new(
model_name,
dialect.tokenizer,
dialect_klass.tokenizer,
)
new(dialect, gateway, model_name)
new(dialect_klass, gateway, model_name)
end
def initialize(dialect, gateway, model_name)
@dialect = dialect
def initialize(dialect_klass, gateway, model_name)
@dialect_klass = dialect_klass
@gateway = gateway
@model_name = model_name
end
delegate :tokenizer, to: :dialect
delegate :tokenizer, to: :dialect_klass
# @param generic_prompt { Hash } - Prompt using our generic format.
# We use the following keys from the hash:
@ -60,23 +51,64 @@ module DiscourseAi
# - input: String containing user input
# - examples (optional): Array of arrays with examples of input and responses. Each array is a input/response pair like [[example1, response1], [example2, response2]].
# - post_insts (optional): Additional instructions for the LLM. Some dialects like Claude add these at the end of the prompt.
# - conversation_context (optional): Array of hashes to provide context about an ongoing conversation with the model.
# We translate the array in reverse order, meaning the first element would be the most recent message in the conversation.
# Example:
#
# [
# { type: "user", name: "user1", content: "This is a new message by a user" },
# { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
# { type: "tool", name: "tool_id", content: "I'm a tool result" },
# ]
#
# - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example:
#
# {
# name: "get_weather",
# description: "Get the weather in a city",
# parameters: [
# { name: "location", type: "string", description: "the city name", required: true },
# {
# name: "unit",
# type: "string",
# description: "the unit of measurement celcius c or fahrenheit f",
# enum: %w[c f],
# required: true,
# },
# ],
# }
#
# @param user { User } - User requesting the summary.
#
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
#
# @returns { String } - Completion result.
#
# When the model invokes a tool, we'll wait until the endpoint finishes replying and feed you a fully-formed tool,
# even if you passed a partial_read_blk block. Invocations are strings that look like this:
#
# <function_calls>
# <invoke>
# <tool_name>get_weather</tool_name>
# <tool_id>get_weather</tool_id>
# <parameters>
# <location>Sydney</location>
# <unit>c</unit>
# </parameters>
# </invoke>
# </function_calls>
#
def completion!(generic_prompt, user, &partial_read_blk)
prompt = dialect.translate(generic_prompt)
model_params = generic_prompt.dig(:params, model_name) || {}
gateway.perform_completion!(prompt, user, model_params, &partial_read_blk)
dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params)
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
end
private
attr_reader :dialect, :gateway, :model_name
attr_reader :dialect_klass, :gateway, :model_name
end
end
end

View File

@ -1,7 +1,24 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
subject(:dialect) { described_class.new }
subject(:dialect) { described_class.new(prompt, "gpt-4") }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:prompt) do
{
@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
TEXT
post_insts:
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
tools: [tool],
}
end
@ -35,7 +53,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
{ role: "user", content: prompt[:input] },
]
translated = dialect.translate(prompt)
translated = dialect.translate
expect(translated).to contain_exactly(*open_ai_version)
end
@ -55,9 +73,51 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
{ role: "user", content: prompt[:input] },
]
translated = dialect.translate(prompt)
translated = dialect.translate
expect(translated).to contain_exactly(*open_ai_version)
end
end
describe "#conversation_context" do
let(:context) do
[
{ type: "user", name: "user1", content: "This is a new message by a user" },
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
]
end
it "adds conversation in reverse order (first == newer)" do
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context).to eq(
[
{ role: "tool", content: context.last[:content], tool_call_id: context.last[:name] },
{ role: "assistant", content: context.second[:content] },
{ role: "user", content: context.first[:content], name: context.first[:name] },
],
)
end
it "trims content if it's getting too long" do
context.last[:content] = context.last[:content] * 1000
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context.last[:content].length).to be < context.last[:content].length
end
end
describe "#tools" do
it "returns a list of available tools" do
open_ai_tool_f = { type: "function", tool: tool }
expect(subject.tools).to contain_exactly(open_ai_tool_f)
end
end
end

View File

@ -1,7 +1,24 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
subject(:dialect) { described_class.new }
subject(:dialect) { described_class.new(prompt, "claude-2") }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:prompt) do
{
@ -37,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
Assistant:
TEXT
translated = dialect.translate(prompt)
translated = dialect.translate
expect(translated).to eq(anthropic_version)
end
@ -60,9 +77,111 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
Assistant:
TEXT
translated = dialect.translate(prompt)
translated = dialect.translate
expect(translated).to eq(anthropic_version)
end
it "include tools inside the prompt" do
prompt[:tools] = [tool]
anthropic_version = <<~TEXT
Human: #{prompt[:insts]}
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
#{dialect.tools}</tools>
#{prompt[:input]}
#{prompt[:post_insts]}
Assistant:
TEXT
translated = dialect.translate
expect(translated).to eq(anthropic_version)
end
end
describe "#conversation_context" do
let(:context) do
[
{ type: "user", name: "user1", content: "This is a new message by a user" },
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
]
end
it "adds conversation in reverse order (first == newer)" do
prompt[:conversation_context] = context
expected = <<~TEXT
Assistant:
<function_results>
<result>
<tool_name>tool_id</tool_name>
<json>
#{context.last[:content]}
</json>
</result>
</function_results>
Assistant: #{context.second[:content]}
Human: #{context.first[:content]}
TEXT
translated_context = dialect.conversation_context
expect(translated_context).to eq(expected)
end
it "trims content if it's getting too long" do
context.last[:content] = context.last[:content] * 10_000
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context.length).to be < context.last[:content].length
end
end
describe "#tools" do
it "translates tools to the tool syntax" do
prompt[:tools] = [tool]
translated_tool = <<~TEXT
<tool_description>
<tool_name>get_weather</tool_name>
<description>Get the weather in a city</description>
<parameters>
<parameter>
<name>location</name>
<type>string</type>
<description>the city name</description>
<required>true</required>
</parameter>
<parameter>
<name>unit</name>
<type>string</type>
<description>the unit of measurement celcius c or fahrenheit f</description>
<required>true</required>
<options>c,f</options>
</parameter>
</parameters>
</tool_description>
TEXT
expect(dialect.tools).to eq(translated_tool)
end
end
end

View File

@ -1,7 +1,24 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
subject(:dialect) { described_class.new }
subject(:dialect) { described_class.new(prompt, "gemini-pro") }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:prompt) do
{
@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
TEXT
post_insts:
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
tools: [tool],
}
end
@ -36,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
{ role: "user", parts: { text: prompt[:input] } },
]
translated = dialect.translate(prompt)
translated = dialect.translate
expect(translated).to eq(gemini_version)
end
@ -57,9 +75,79 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
{ role: "user", parts: { text: prompt[:input] } },
]
translated = dialect.translate(prompt)
translated = dialect.translate
expect(translated).to contain_exactly(*gemini_version)
end
end
describe "#conversation_context" do
let(:context) do
[
{ type: "user", name: "user1", content: "This is a new message by a user" },
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
]
end
it "adds conversation in reverse order (first == newer)" do
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context).to eq(
[
{
role: "model",
parts: [
{
"functionResponse" => {
name: context.last[:name],
content: context.last[:content],
},
},
],
},
{ role: "model", parts: [{ text: context.second[:content] }] },
{ role: "user", parts: [{ text: context.first[:content] }] },
],
)
end
it "trims content if it's getting too long" do
context.last[:content] = context.last[:content] * 1000
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context.last.dig(:parts, 0, :text).length).to be <
context.last[:content].length
end
end
describe "#tools" do
it "returns a list of available tools" do
gemini_tools = {
function_declarations: [
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name" },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
],
required: %w[location unit],
},
],
}
expect(subject.tools).to contain_exactly(gemini_tools)
end
end
end

View File

@ -1,7 +1,24 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
subject(:dialect) { described_class.new }
subject(:dialect) { described_class.new(prompt, "Llama2-chat-hf") }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:prompt) do
{
@ -31,11 +48,16 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
describe "#translate" do
it "translates a prompt written in our generic format to the Llama2 format" do
llama2_classic_version = <<~TEXT
[INST]<<SYS>>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<</SYS>>[/INST]
[INST]
<<SYS>>
#{prompt[:insts]}
#{prompt[:post_insts]}
<</SYS>>
[/INST]
[INST]#{prompt[:input]}[/INST]
TEXT
translated = dialect.translate(prompt)
translated = dialect.translate
expect(translated).to eq(llama2_classic_version)
end
@ -49,15 +71,126 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
]
llama2_classic_version = <<~TEXT
[INST]<<SYS>>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<</SYS>>[/INST]
[INST]
<<SYS>>
#{prompt[:insts]}
#{prompt[:post_insts]}
<</SYS>>
[/INST]
[INST]#{prompt[:examples][0][0]}[/INST]
#{prompt[:examples][0][1]}
[INST]#{prompt[:input]}[/INST]
TEXT
translated = dialect.translate(prompt)
translated = dialect.translate
expect(translated).to eq(llama2_classic_version)
end
it "include tools inside the prompt" do
prompt[:tools] = [tool]
llama2_classic_version = <<~TEXT
[INST]
<<SYS>>
#{prompt[:insts]}
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
#{dialect.tools}</tools>
#{prompt[:post_insts]}
<</SYS>>
[/INST]
[INST]#{prompt[:input]}[/INST]
TEXT
translated = dialect.translate
expect(translated).to eq(llama2_classic_version)
end
end
describe "#conversation_context" do
let(:context) do
[
{ type: "user", name: "user1", content: "This is a new message by a user" },
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
]
end
it "adds conversation in reverse order (first == newer)" do
prompt[:conversation_context] = context
expected = <<~TEXT
[INST]
<function_results>
<result>
<tool_name>tool_id</tool_name>
<json>
#{context.last[:content]}
</json>
</result>
</function_results>
[/INST]
[INST]#{context.second[:content]}[/INST]
#{context.first[:content]}
TEXT
translated_context = dialect.conversation_context
expect(translated_context).to eq(expected)
end
it "trims content if it's getting too long" do
context.last[:content] = context.last[:content] * 1_000
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context.length).to be < context.last[:content].length
end
end
describe "#tools" do
it "translates functions to the tool syntax" do
prompt[:tools] = [tool]
translated_tool = <<~TEXT
<tool_description>
<tool_name>get_weather</tool_name>
<description>Get the weather in a city</description>
<parameters>
<parameter>
<name>location</name>
<type>string</type>
<description>the city name</description>
<required>true</required>
</parameter>
<parameter>
<name>unit</name>
<type>string</type>
<description>the unit of measurement celcius c or fahrenheit f</description>
<required>true</required>
<options>c,f</options>
</parameter>
</parameters>
</tool_description>
TEXT
expect(dialect.tools).to eq(translated_tool)
end
end
end

View File

@ -1,9 +1,25 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
subject(:dialect) { described_class.new }
subject(:dialect) { described_class.new(prompt, "StableBeluga2") }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
describe "#translate" do
let(:prompt) do
{
insts: <<~TEXT,
@ -29,16 +45,18 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
}
end
describe "#translate" do
it "translates a prompt written in our generic format to the Open AI format" do
orca_style_version = <<~TEXT
### System:
#{[prompt[:insts], prompt[:post_insts]].join("\n")}
#{prompt[:insts]}
#{prompt[:post_insts]}
### User:
#{prompt[:input]}
### Assistant:
TEXT
translated = dialect.translate(prompt)
translated = dialect.translate
expect(translated).to eq(orca_style_version)
end
@ -53,7 +71,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
orca_style_version = <<~TEXT
### System:
#{[prompt[:insts], prompt[:post_insts]].join("\n")}
#{prompt[:insts]}
#{prompt[:post_insts]}
### User:
#{prompt[:examples][0][0]}
### Assistant:
@ -63,9 +82,113 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
### Assistant:
TEXT
translated = dialect.translate(prompt)
translated = dialect.translate
expect(translated).to eq(orca_style_version)
end
it "include tools inside the prompt" do
prompt[:tools] = [tool]
orca_style_version = <<~TEXT
### System:
#{prompt[:insts]}
In this environment you have access to a set of tools you can use to answer the user's question.
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
<function_calls>
<invoke>
<tool_name>$TOOL_NAME</tool_name>
<parameters>
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
...
</parameters>
</invoke>
</function_calls>
Here are the tools available:
<tools>
#{dialect.tools}</tools>
#{prompt[:post_insts]}
### User:
#{prompt[:input]}
### Assistant:
TEXT
translated = dialect.translate
expect(translated).to eq(orca_style_version)
end
end
describe "#conversation_context" do
let(:context) do
[
{ type: "user", name: "user1", content: "This is a new message by a user" },
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
]
end
it "adds conversation in reverse order (first == newer)" do
prompt[:conversation_context] = context
expected = <<~TEXT
### Assistant:
<function_results>
<result>
<tool_name>tool_id</tool_name>
<json>
#{context.last[:content]}
</json>
</result>
</function_results>
### Assistant: #{context.second[:content]}
### User: #{context.first[:content]}
TEXT
translated_context = dialect.conversation_context
expect(translated_context).to eq(expected)
end
it "trims content if it's getting too long" do
context.last[:content] = context.last[:content] * 1_000
prompt[:conversation_context] = context
translated_context = dialect.conversation_context
expect(translated_context.length).to be < context.last[:content].length
end
end
describe "#tools" do
it "translates tools to the tool syntax" do
prompt[:tools] = [tool]
translated_tool = <<~TEXT
<tool_description>
<tool_name>get_weather</tool_name>
<description>Get the weather in a city</description>
<parameters>
<parameter>
<name>location</name>
<type>string</type>
<description>the city name</description>
<required>true</required>
</parameter>
<parameter>
<name>unit</name>
<type>string</type>
<description>the unit of measurement celcius c or fahrenheit f</description>
<required>true</required>
<options>c,f</options>
</parameter>
</parameters>
</tool_description>
TEXT
expect(dialect.tools).to eq(translated_tool)
end
end
end

View File

@ -6,7 +6,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
let(:model_name) { "claude-2" }
let(:prompt) { "Human: write 3 words\n\n" }
let(:generic_prompt) { { insts: "write 3 words" } }
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
@ -23,10 +25,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
}
end
def stub_response(prompt, response_text)
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete")
.with(body: model.default_options.merge(prompt: prompt).to_json)
.with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
@ -42,7 +44,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
}.to_json
end
def stub_streamed_response(prompt, deltas)
def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
@ -52,13 +54,27 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end
end
chunks = chunks.join("\n\n")
chunks = chunks.join("\n\n").split("")
WebMock
.stub_request(:post, "https://api.anthropic.com/v1/complete")
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
.with(body: stream_request_body)
.to_return(status: 200, body: chunks)
end
let(:tool_deltas) { ["<function", <<~REPLY] }
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
let(:tool_call) { invocation }
it_behaves_like "an endpoint that can communicate with a completion service"
end

View File

@ -9,10 +9,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
let(:model_name) { "claude-2" }
let(:bedrock_name) { "claude-v2" }
let(:prompt) { "Human: write 3 words\n\n" }
let(:generic_prompt) { { insts: "write 3 words" } }
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { request_body }
before do
SiteSetting.ai_bedrock_access_key_id = "123456"
@ -20,39 +22,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
SiteSetting.ai_bedrock_region = "us-east-1"
end
# Copied from https://github.com/bblimke/webmock/issues/629
# Workaround for stubbing a streamed response
before do
mocked_http =
Class.new(Net::HTTP) do
def request(*)
super do |response|
response.instance_eval do
def read_body(*, &block)
if block_given?
@body.each(&block)
else
super
end
end
end
yield response if block_given?
response
end
end
end
@original_net_http = Net.send(:remove_const, :HTTP)
Net.send(:const_set, :HTTP, mocked_http)
end
after do
Net.send(:remove_const, :HTTP)
Net.send(:const_set, :HTTP, @original_net_http)
end
def response(content)
{
completion: content,
@ -65,7 +34,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
}
end
def stub_response(prompt, response_text)
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(
:post,
@ -102,7 +71,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
encoder.encode(message)
end
def stub_streamed_response(prompt, deltas)
def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
@ -121,5 +90,19 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
.to_return(status: 200, body: chunks)
end
let(:tool_deltas) { ["<function", <<~REPLY] }
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
let(:tool_call) { invocation }
it_behaves_like "an endpoint that can communicate with a completion service"
end

View File

@ -1,22 +1,86 @@
# frozen_string_literal: true
RSpec.shared_examples "an endpoint that can communicate with a completion service" do
# Copied from https://github.com/bblimke/webmock/issues/629
# Workaround for stubbing a streamed response
before do
mocked_http =
Class.new(Net::HTTP) do
def request(*)
super do |response|
response.instance_eval do
def read_body(*, &block)
if block_given?
@body.each(&block)
else
super
end
end
end
yield response if block_given?
response
end
end
end
@original_net_http = Net.send(:remove_const, :HTTP)
Net.send(:const_set, :HTTP, mocked_http)
end
after do
Net.send(:remove_const, :HTTP)
Net.send(:const_set, :HTTP, @original_net_http)
end
describe "#perform_completion!" do
fab!(:user) { Fabricate(:user) }
let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
let(:tool) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name", required: true },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
required: true,
},
],
}
end
let(:invocation) { <<~TEXT }
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<tool_id>get_weather</tool_id>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
TEXT
context "when using regular mode" do
context "with simple prompts" do
let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
before { stub_response(prompt, response_text) }
it "can complete a trivial prompt" do
completion_response = model.perform_completion!(prompt, user)
completion_response = model.perform_completion!(dialect, user)
expect(completion_response).to eq(response_text)
end
it "creates an audit log for the request" do
model.perform_completion!(prompt, user)
model.perform_completion!(dialect, user)
expect(AiApiAuditLog.count).to eq(1)
log = AiApiAuditLog.first
@ -32,7 +96,27 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
end
end
context "with functions" do
let(:generic_prompt) do
{
insts: "You can tell me the weather",
input: "Return the weather in Sydney",
tools: [tool],
}
end
before { stub_response(prompt, tool_call, tool_call: true) }
it "returns a function invocation" do
completion_response = model.perform_completion!(dialect, user)
expect(completion_response).to eq(invocation)
end
end
end
context "when using stream mode" do
context "with simple prompts" do
let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] }
before { stub_streamed_response(prompt, deltas) }
@ -40,7 +124,7 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
it "can complete a trivial prompt" do
completion_response = +""
model.perform_completion!(prompt, user) do |partial, cancel|
model.perform_completion!(dialect, user) do |partial, cancel|
completion_response << partial
cancel.call if completion_response.split(" ").length == 2
end
@ -51,7 +135,7 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
it "creates an audit log and updates is on each read." do
completion_response = +""
model.perform_completion!(prompt, user) do |partial, cancel|
model.perform_completion!(dialect, user) do |partial, cancel|
completion_response << partial
cancel.call if completion_response.split(" ").length == 2
end
@ -67,5 +151,28 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join))
end
end
context "with functions" do
let(:generic_prompt) do
{
insts: "You can tell me the weather",
input: "Return the weather in Sydney",
tools: [tool],
}
end
before { stub_streamed_response(prompt, tool_deltas, tool_call: true) }
it "waits for the invocation to finish before calling the partial" do
buffered_partial = ""
model.perform_completion!(dialect, user) do |partial, cancel|
buffered_partial = partial if partial.include?("<function_calls>")
end
expect(buffered_partial).to eq(invocation)
end
end
end
end
end

View File

@ -6,22 +6,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
let(:model_name) { "gemini-pro" }
let(:prompt) do
let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:tool_payload) do
{
name: "get_weather",
description: "Get the weather in a city",
parameters: [
{ name: "location", type: "string", description: "the city name" },
{
name: "unit",
type: "string",
description: "the unit of measurement celcius c or fahrenheit f",
enum: %w[c f],
},
],
required: %w[location unit],
}
end
let(:request_body) do
model
.default_options
.merge(contents: prompt)
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] }
.to_json
end
let(:stream_request_body) do
model
.default_options
.merge(contents: prompt)
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] }
.to_json
end
let(:tool_deltas) do
[
{ role: "system", content: "You are a helpful bot." },
{ role: "user", content: "Write 3 words" },
{ "functionCall" => { name: "get_weather", args: {} } },
{ "functionCall" => { name: "get_weather", args: { location: "" } } },
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
]
end
let(:request_body) { model.default_options.merge(contents: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(contents: prompt).to_json }
let(:tool_call) do
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
end
def response(content)
def response(content, tool_call: false)
{
candidates: [
{
content: {
parts: [{ text: content }],
parts: [(tool_call ? content : { text: content })],
role: "model",
},
finishReason: "STOP",
@ -45,22 +83,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
}
end
def stub_response(prompt, response_text)
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(
:post,
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
)
.with(body: { contents: prompt })
.to_return(status: 200, body: JSON.dump(response(response_text)))
.with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
end
def stream_line(delta, finish_reason: nil)
def stream_line(delta, finish_reason: nil, tool_call: false)
{
candidates: [
{
content: {
parts: [{ text: delta }],
parts: [(tool_call ? delta : { text: delta })],
role: "model",
},
finishReason: finish_reason,
@ -76,24 +114,24 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
}.to_json
end
def stub_streamed_response(prompt, deltas)
def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
stream_line(deltas[index], finish_reason: "STOP")
stream_line(deltas[index], finish_reason: "STOP", tool_call: tool_call)
else
stream_line(deltas[index])
stream_line(deltas[index], tool_call: tool_call)
end
end
chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]")
chunks = chunks.join("\n,\n").prepend("[").concat("\n]").split("")
WebMock
.stub_request(
:post,
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
)
.with(body: model.default_options.merge(contents: prompt).to_json)
.with(body: stream_request_body)
.to_return(status: 200, body: chunks)
end

View File

@ -6,10 +6,11 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) }
let(:model_name) { "Llama2-*-chat-hf" }
let(:prompt) { <<~TEXT }
[INST]<<SYS>>You are a helpful bot.<</SYS>>[/INST]
[INST]Write 3 words[/INST]
TEXT
let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
let(:dialect) do
DiscourseAi::Completions::Dialects::Llama2Classic.new(generic_prompt, model_name)
end
let(:prompt) { dialect.translate }
let(:request_body) do
model
@ -18,7 +19,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
.tap do |payload|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
payload[:parameters][:return_full_text] = false
end
.to_json
end
@ -30,7 +30,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
model.prompt_size(prompt)
payload[:stream] = true
payload[:parameters][:return_full_text] = false
end
.to_json
end
@ -41,14 +40,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
[{ generated_text: content }]
end
def stub_response(prompt, response_text)
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
.with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text)))
end
def stream_line(delta, finish_reason: nil)
def stream_line(delta, deltas, finish_reason: nil)
+"data: " << {
token: {
id: 29_889,
@ -56,22 +55,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
logprob: -0.08319092,
special: !!finish_reason,
},
generated_text: finish_reason ? response_text : nil,
generated_text: finish_reason ? deltas.join : nil,
details: nil,
}.to_json
end
def stub_streamed_response(prompt, deltas)
def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
stream_line(deltas[index], finish_reason: true)
stream_line(deltas[index], deltas, finish_reason: true)
else
stream_line(deltas[index])
stream_line(deltas[index], deltas)
end
end
chunks = chunks.join("\n\n")
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
@ -79,5 +78,29 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
.to_return(status: 200, body: chunks)
end
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
</invoke>
</function_calls>
REPLY
let(:tool_call) { invocation }
it_behaves_like "an endpoint that can communicate with a completion service"
end

View File

@ -6,17 +6,53 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
let(:model_name) { "gpt-3.5-turbo" }
let(:prompt) do
let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
let(:tool_deltas) do
[
{ role: "system", content: "You are a helpful bot." },
{ role: "user", content: "Write 3 words" },
{ id: "get_weather", name: "get_weather", arguments: {} },
{ id: "get_weather", name: "get_weather", arguments: { location: "" } },
{ id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } },
]
end
let(:request_body) { model.default_options.merge(messages: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
let(:tool_call) do
{ id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } }
end
let(:request_body) do
model
.default_options
.merge(messages: prompt)
.tap do |b|
b[:tools] = generic_prompt[:tools].map do |t|
{ type: "function", tool: t }
end if generic_prompt[:tools]
end
.to_json
end
let(:stream_request_body) do
model
.default_options
.merge(messages: prompt, stream: true)
.tap do |b|
b[:tools] = generic_prompt[:tools].map do |t|
{ type: "function", tool: t }
end if generic_prompt[:tools]
end
.to_json
end
def response(content, tool_call: false)
message_content =
if tool_call
{ tool_calls: [{ function: content }] }
else
{ content: content }
end
def response(content)
{
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
object: "chat.completion",
@ -28,45 +64,52 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
total_tokens: 499,
},
choices: [
{ message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 },
{ message: { role: "assistant" }.merge(message_content), finish_reason: "stop", index: 0 },
],
}
end
def stub_response(prompt, response_text)
def stub_response(prompt, response_text, tool_call: false)
WebMock
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
.with(body: { model: model_name, messages: prompt })
.to_return(status: 200, body: JSON.dump(response(response_text)))
.with(body: request_body)
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
end
def stream_line(delta, finish_reason: nil, tool_call: false)
message_content =
if tool_call
{ tool_calls: [{ function: delta }] }
else
{ content: delta }
end
def stream_line(delta, finish_reason: nil)
+"data: " << {
id: "chatcmpl-#{SecureRandom.hex}",
object: "chat.completion.chunk",
created: 1_681_283_881,
model: "gpt-3.5-turbo-0301",
choices: [{ delta: { content: delta } }],
choices: [{ delta: message_content }],
finish_reason: finish_reason,
index: 0,
}.to_json
end
def stub_streamed_response(prompt, deltas)
def stub_streamed_response(prompt, deltas, tool_call: false)
chunks =
deltas.each_with_index.map do |_, index|
if index == (deltas.length - 1)
stream_line(deltas[index], finish_reason: "stop_sequence")
stream_line(deltas[index], finish_reason: "stop_sequence", tool_call: tool_call)
else
stream_line(deltas[index])
stream_line(deltas[index], tool_call: tool_call)
end
end
chunks = chunks.join("\n\n")
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
WebMock
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
.with(body: stream_request_body)
.to_return(status: 200, body: chunks)
end

View File

@ -3,7 +3,7 @@
RSpec.describe DiscourseAi::Completions::Llm do
subject(:llm) do
described_class.new(
DiscourseAi::Completions::Dialects::OrcaStyle.new,
DiscourseAi::Completions::Dialects::OrcaStyle,
canned_response,
"Upstage-Llama-2-*-instruct-v2",
)

View File

@ -103,7 +103,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
expect(response.status).to eq(200)
expect(response.parsed_body["suggestions"].first).to eq(translated_text)
expect(response.parsed_body["diff"]).to eq(expected_diff)
expect(spy.prompt.last[:content]).to eq(expected_input)
expect(spy.prompt.translate.last[:content]).to eq(expected_input)
end
end
end