DEV: Tool support for the LLM service. (#366)
This PR adds tool support to available LLMs. We'll buffer tool invocations and return them instead of making users of this service parse the response. It also adds support for conversation context in the generic prompt. It includes bot messages, user messages, and tool invocations, which we'll trim to make sure it doesn't exceed the prompt limit, then translate them to the correct dialect. Finally, It adds some buffering when reading chunks to handle cases when streaming is extremely slow.:M
This commit is contained in:
parent
203906be65
commit
e0bf6adb5b
|
@ -3,31 +3,99 @@
|
|||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class ChatGpt
|
||||
def self.can_translate?(model_name)
|
||||
%w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
|
||||
class ChatGpt < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
%w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
end
|
||||
|
||||
def translate(generic_prompt)
|
||||
def translate
|
||||
open_ai_prompt = [
|
||||
{
|
||||
role: "system",
|
||||
content: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
|
||||
},
|
||||
{ role: "system", content: [prompt[:insts], prompt[:post_insts].to_s].join("\n") },
|
||||
]
|
||||
|
||||
if generic_prompt[:examples]
|
||||
generic_prompt[:examples].each do |example_pair|
|
||||
if prompt[:examples]
|
||||
prompt[:examples].each do |example_pair|
|
||||
open_ai_prompt << { role: "user", content: example_pair.first }
|
||||
open_ai_prompt << { role: "assistant", content: example_pair.second }
|
||||
end
|
||||
end
|
||||
|
||||
open_ai_prompt << { role: "user", content: generic_prompt[:input] }
|
||||
open_ai_prompt.concat!(conversation_context) if prompt[:conversation_context]
|
||||
|
||||
open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input]
|
||||
|
||||
open_ai_prompt
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
def tools
|
||||
return if prompt[:tools].blank?
|
||||
|
||||
prompt[:tools].map { |t| { type: "function", tool: t } }
|
||||
end
|
||||
|
||||
def conversation_context
|
||||
return [] if prompt[:conversation_context].blank?
|
||||
|
||||
trimmed_context = trim_context(prompt[:conversation_context])
|
||||
|
||||
trimmed_context.reverse.map do |context|
|
||||
translated = context.slice(:content)
|
||||
translated[:role] = context[:type]
|
||||
|
||||
if context[:name]
|
||||
if translated[:role] == "tool"
|
||||
translated[:tool_call_id] = context[:name]
|
||||
else
|
||||
translated[:name] = context[:name]
|
||||
end
|
||||
end
|
||||
|
||||
translated
|
||||
end
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
# provide a buffer of 120 tokens - our function counting is not
|
||||
# 100% accurate and getting numbers to align exactly is very hard
|
||||
buffer = (opts[:max_tokens_to_sample] || 2500) + 50
|
||||
|
||||
if tools.present?
|
||||
# note this is about 100 tokens over, OpenAI have a more optimal representation
|
||||
@function_size ||= self.class.tokenizer.size(tools.to_json.to_s)
|
||||
buffer += @function_size
|
||||
end
|
||||
|
||||
model_max_tokens - buffer
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def per_message_overhead
|
||||
# open ai defines about 4 tokens per message of overhead
|
||||
4
|
||||
end
|
||||
|
||||
def calculate_message_token(context)
|
||||
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
end
|
||||
|
||||
def model_max_tokens
|
||||
case model_name
|
||||
when "gpt-3.5-turbo", "gpt-3.5-turbo-16k"
|
||||
16_384
|
||||
when "gpt-4"
|
||||
8192
|
||||
when "gpt-4-32k"
|
||||
32_768
|
||||
else
|
||||
8192
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -3,25 +3,65 @@
|
|||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class Claude
|
||||
def self.can_translate?(model_name)
|
||||
%w[claude-instant-1 claude-2].include?(model_name)
|
||||
class Claude < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
%w[claude-instant-1 claude-2].include?(model_name)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::AnthropicTokenizer
|
||||
end
|
||||
end
|
||||
|
||||
def translate(generic_prompt)
|
||||
claude_prompt = +"Human: #{generic_prompt[:insts]}\n"
|
||||
def translate
|
||||
claude_prompt = +"Human: #{prompt[:insts]}\n"
|
||||
|
||||
claude_prompt << build_examples(generic_prompt[:examples]) if generic_prompt[:examples]
|
||||
claude_prompt << build_tools_prompt if prompt[:tools]
|
||||
|
||||
claude_prompt << "#{generic_prompt[:input]}\n"
|
||||
claude_prompt << build_examples(prompt[:examples]) if prompt[:examples]
|
||||
|
||||
claude_prompt << "#{generic_prompt[:post_insts]}\n" if generic_prompt[:post_insts]
|
||||
claude_prompt << conversation_context if prompt[:conversation_context]
|
||||
|
||||
claude_prompt << "#{prompt[:input]}\n"
|
||||
|
||||
claude_prompt << "#{prompt[:post_insts]}\n" if prompt[:post_insts]
|
||||
|
||||
claude_prompt << "Assistant:\n"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::AnthropicTokenizer
|
||||
def max_prompt_tokens
|
||||
50_000
|
||||
end
|
||||
|
||||
def conversation_context
|
||||
return "" if prompt[:conversation_context].blank?
|
||||
|
||||
trimmed_context = trim_context(prompt[:conversation_context])
|
||||
|
||||
trimmed_context
|
||||
.reverse
|
||||
.reduce(+"") do |memo, context|
|
||||
memo << (context[:type] == "user" ? "Human:" : "Assistant:")
|
||||
|
||||
if context[:type] == "tool"
|
||||
memo << <<~TEXT
|
||||
|
||||
<function_results>
|
||||
<result>
|
||||
<tool_name>#{context[:name]}</tool_name>
|
||||
<json>
|
||||
#{context[:content]}
|
||||
</json>
|
||||
</result>
|
||||
</function_results>
|
||||
TEXT
|
||||
else
|
||||
memo << " " << context[:content] << "\n"
|
||||
end
|
||||
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
|
|
@ -0,0 +1,160 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class Dialect
|
||||
class << self
|
||||
def can_translate?(_model_name)
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def dialect_for(model_name)
|
||||
dialects = [
|
||||
DiscourseAi::Completions::Dialects::Claude,
|
||||
DiscourseAi::Completions::Dialects::Llama2Classic,
|
||||
DiscourseAi::Completions::Dialects::ChatGpt,
|
||||
DiscourseAi::Completions::Dialects::OrcaStyle,
|
||||
DiscourseAi::Completions::Dialects::Gemini,
|
||||
]
|
||||
dialects.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |d|
|
||||
d.can_translate?(model_name)
|
||||
end
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
raise NotImplemented
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(generic_prompt, model_name, opts: {})
|
||||
@prompt = generic_prompt
|
||||
@model_name = model_name
|
||||
@opts = opts
|
||||
end
|
||||
|
||||
def translate
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def tools
|
||||
tools = +""
|
||||
|
||||
prompt[:tools].each do |function|
|
||||
parameters = +""
|
||||
if function[:parameters].present?
|
||||
function[:parameters].each do |parameter|
|
||||
parameters << <<~PARAMETER
|
||||
<parameter>
|
||||
<name>#{parameter[:name]}</name>
|
||||
<type>#{parameter[:type]}</type>
|
||||
<description>#{parameter[:description]}</description>
|
||||
<required>#{parameter[:required]}</required>
|
||||
PARAMETER
|
||||
if parameter[:enum]
|
||||
parameters << "<options>#{parameter[:enum].join(",")}</options>\n"
|
||||
end
|
||||
parameters << "</parameter>\n"
|
||||
end
|
||||
end
|
||||
|
||||
tools << <<~TOOLS
|
||||
<tool_description>
|
||||
<tool_name>#{function[:name]}</tool_name>
|
||||
<description>#{function[:description]}</description>
|
||||
<parameters>
|
||||
#{parameters}</parameters>
|
||||
</tool_description>
|
||||
TOOLS
|
||||
end
|
||||
|
||||
tools
|
||||
end
|
||||
|
||||
def conversation_context
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
attr_reader :prompt, :model_name, :opts
|
||||
|
||||
def trim_context(conversation_context)
|
||||
prompt_limit = max_prompt_tokens
|
||||
current_token_count = calculate_token_count_without_context
|
||||
|
||||
conversation_context.reduce([]) do |memo, context|
|
||||
break(memo) if current_token_count >= prompt_limit
|
||||
|
||||
dupped_context = context.dup
|
||||
|
||||
message_tokens = calculate_message_token(dupped_context)
|
||||
|
||||
# Trimming content to make sure we respect token limit.
|
||||
while dupped_context[:content].present? &&
|
||||
message_tokens + current_token_count + per_message_overhead > prompt_limit
|
||||
dupped_context[:content] = dupped_context[:content][0..-100] || ""
|
||||
message_tokens = calculate_message_token(dupped_context)
|
||||
end
|
||||
|
||||
next(memo) if dupped_context[:content].blank?
|
||||
|
||||
current_token_count += message_tokens + per_message_overhead
|
||||
|
||||
memo << dupped_context
|
||||
end
|
||||
end
|
||||
|
||||
def calculate_token_count_without_context
|
||||
tokenizer = self.class.tokenizer
|
||||
|
||||
examples_count =
|
||||
prompt[:examples].to_a.sum do |pair|
|
||||
tokenizer.size(pair.join) + (per_message_overhead * 2)
|
||||
end
|
||||
input_count = tokenizer.size(prompt[:input].to_s) + per_message_overhead
|
||||
|
||||
examples_count + input_count +
|
||||
prompt
|
||||
.except(:conversation_context, :tools, :examples, :input)
|
||||
.sum { |_, v| tokenizer.size(v) + per_message_overhead }
|
||||
end
|
||||
|
||||
def per_message_overhead
|
||||
0
|
||||
end
|
||||
|
||||
def calculate_message_token(context)
|
||||
self.class.tokenizer.size(context[:content].to_s)
|
||||
end
|
||||
|
||||
def build_tools_prompt
|
||||
return "" if prompt[:tools].blank?
|
||||
|
||||
<<~TEXT
|
||||
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:
|
||||
|
||||
<tools>
|
||||
#{tools}</tools>
|
||||
TEXT
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -3,34 +3,91 @@
|
|||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class Gemini
|
||||
def self.can_translate?(model_name)
|
||||
%w[gemini-pro].include?(model_name)
|
||||
class Gemini < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
%w[gemini-pro].include?(model_name)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
||||
end
|
||||
end
|
||||
|
||||
def translate(generic_prompt)
|
||||
def translate
|
||||
gemini_prompt = [
|
||||
{
|
||||
role: "user",
|
||||
parts: {
|
||||
text: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
|
||||
text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"),
|
||||
},
|
||||
},
|
||||
{ role: "model", parts: { text: "Ok." } },
|
||||
]
|
||||
|
||||
if generic_prompt[:examples]
|
||||
generic_prompt[:examples].each do |example_pair|
|
||||
if prompt[:examples]
|
||||
prompt[:examples].each do |example_pair|
|
||||
gemini_prompt << { role: "user", parts: { text: example_pair.first } }
|
||||
gemini_prompt << { role: "model", parts: { text: example_pair.second } }
|
||||
end
|
||||
end
|
||||
|
||||
gemini_prompt << { role: "user", parts: { text: generic_prompt[:input] } }
|
||||
gemini_prompt.concat!(conversation_context) if prompt[:conversation_context]
|
||||
|
||||
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
||||
def tools
|
||||
return if prompt[:tools].blank?
|
||||
|
||||
translated_tools =
|
||||
prompt[:tools].map do |t|
|
||||
required_fields = []
|
||||
tool = t.dup
|
||||
|
||||
tool[:parameters] = t[:parameters].map do |p|
|
||||
required_fields << p[:name] if p[:required]
|
||||
|
||||
p.except(:required)
|
||||
end
|
||||
|
||||
tool.merge(required: required_fields)
|
||||
end
|
||||
|
||||
[{ function_declarations: translated_tools }]
|
||||
end
|
||||
|
||||
def conversation_context
|
||||
return [] if prompt[:conversation_context].blank?
|
||||
|
||||
trimmed_context = trim_context(prompt[:conversation_context])
|
||||
|
||||
trimmed_context.reverse.map do |context|
|
||||
translated = {}
|
||||
translated[:role] = (context[:type] == "user" ? "user" : "model")
|
||||
|
||||
part = {}
|
||||
|
||||
if context[:type] == "tool"
|
||||
part["functionResponse"] = { name: context[:name], content: context[:content] }
|
||||
else
|
||||
part[:text] = context[:content]
|
||||
end
|
||||
|
||||
translated[:parts] = [part]
|
||||
|
||||
translated
|
||||
end
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
16_384 # 50% of model tokens
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def calculate_message_token(context)
|
||||
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -3,27 +3,72 @@
|
|||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class Llama2Classic
|
||||
def self.can_translate?(model_name)
|
||||
%w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
|
||||
class Llama2Classic < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
%w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::Llama2Tokenizer
|
||||
end
|
||||
end
|
||||
|
||||
def translate(generic_prompt)
|
||||
llama2_prompt =
|
||||
+"[INST]<<SYS>>#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}<</SYS>>[/INST]\n"
|
||||
def translate
|
||||
llama2_prompt = +<<~TEXT
|
||||
[INST]
|
||||
<<SYS>>
|
||||
#{prompt[:insts]}
|
||||
#{build_tools_prompt}#{prompt[:post_insts]}
|
||||
<</SYS>>
|
||||
[/INST]
|
||||
TEXT
|
||||
|
||||
if generic_prompt[:examples]
|
||||
generic_prompt[:examples].each do |example_pair|
|
||||
if prompt[:examples]
|
||||
prompt[:examples].each do |example_pair|
|
||||
llama2_prompt << "[INST]#{example_pair.first}[/INST]\n"
|
||||
llama2_prompt << "#{example_pair.second}\n"
|
||||
end
|
||||
end
|
||||
|
||||
llama2_prompt << "[INST]#{generic_prompt[:input]}[/INST]\n"
|
||||
llama2_prompt << conversation_context if prompt[:conversation_context].present?
|
||||
|
||||
llama2_prompt << "[INST]#{prompt[:input]}[/INST]\n"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::Llama2Tokenizer
|
||||
def conversation_context
|
||||
return "" if prompt[:conversation_context].blank?
|
||||
|
||||
trimmed_context = trim_context(prompt[:conversation_context])
|
||||
|
||||
trimmed_context
|
||||
.reverse
|
||||
.reduce(+"") do |memo, context|
|
||||
if context[:type] == "tool"
|
||||
memo << <<~TEXT
|
||||
[INST]
|
||||
<function_results>
|
||||
<result>
|
||||
<tool_name>#{context[:name]}</tool_name>
|
||||
<json>
|
||||
#{context[:content]}
|
||||
</json>
|
||||
</result>
|
||||
</function_results>
|
||||
[/INST]
|
||||
TEXT
|
||||
elsif context[:type] == "assistant"
|
||||
memo << "[INST]" << context[:content] << "[/INST]\n"
|
||||
else
|
||||
memo << context[:content] << "\n"
|
||||
end
|
||||
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
SiteSetting.ai_hugging_face_token_limit
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -3,29 +3,68 @@
|
|||
module DiscourseAi
|
||||
module Completions
|
||||
module Dialects
|
||||
class OrcaStyle
|
||||
def self.can_translate?(model_name)
|
||||
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name)
|
||||
class OrcaStyle < Dialect
|
||||
class << self
|
||||
def can_translate?(model_name)
|
||||
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::Llama2Tokenizer
|
||||
end
|
||||
end
|
||||
|
||||
def translate(generic_prompt)
|
||||
orca_style_prompt =
|
||||
+"### System:\n#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}\n"
|
||||
def translate
|
||||
orca_style_prompt = +<<~TEXT
|
||||
### System:
|
||||
#{prompt[:insts]}
|
||||
#{build_tools_prompt}#{prompt[:post_insts]}
|
||||
TEXT
|
||||
|
||||
if generic_prompt[:examples]
|
||||
generic_prompt[:examples].each do |example_pair|
|
||||
if prompt[:examples]
|
||||
prompt[:examples].each do |example_pair|
|
||||
orca_style_prompt << "### User:\n#{example_pair.first}\n"
|
||||
orca_style_prompt << "### Assistant:\n#{example_pair.second}\n"
|
||||
end
|
||||
end
|
||||
|
||||
orca_style_prompt << "### User:\n#{generic_prompt[:input]}\n"
|
||||
orca_style_prompt << "### User:\n#{prompt[:input]}\n"
|
||||
|
||||
orca_style_prompt << "### Assistant:\n"
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::Llama2Tokenizer
|
||||
def conversation_context
|
||||
return "" if prompt[:conversation_context].blank?
|
||||
|
||||
trimmed_context = trim_context(prompt[:conversation_context])
|
||||
|
||||
trimmed_context
|
||||
.reverse
|
||||
.reduce(+"") do |memo, context|
|
||||
memo << (context[:type] == "user" ? "### User:" : "### Assistant:")
|
||||
|
||||
if context[:type] == "tool"
|
||||
memo << <<~TEXT
|
||||
|
||||
<function_results>
|
||||
<result>
|
||||
<tool_name>#{context[:name]}</tool_name>
|
||||
<json>
|
||||
#{context[:content]}
|
||||
</json>
|
||||
</result>
|
||||
</function_results>
|
||||
TEXT
|
||||
else
|
||||
memo << " " << context[:content] << "\n"
|
||||
end
|
||||
|
||||
memo
|
||||
end
|
||||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
SiteSetting.ai_hugging_face_token_limit
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -22,7 +22,7 @@ module DiscourseAi
|
|||
@uri ||= URI("https://api.anthropic.com/v1/complete")
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params)
|
||||
def prepare_payload(prompt, model_params, _dialect)
|
||||
default_options
|
||||
.merge(model_params)
|
||||
.merge(prompt: prompt)
|
||||
|
|
|
@ -37,7 +37,7 @@ module DiscourseAi
|
|||
URI(api_url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params)
|
||||
def prepare_payload(prompt, model_params, _dialect)
|
||||
default_options.merge(prompt: prompt).merge(model_params)
|
||||
end
|
||||
|
||||
|
|
|
@ -30,9 +30,11 @@ module DiscourseAi
|
|||
@tokenizer = tokenizer
|
||||
end
|
||||
|
||||
def perform_completion!(prompt, user, model_params = {})
|
||||
def perform_completion!(dialect, user, model_params = {})
|
||||
@streaming_mode = block_given?
|
||||
|
||||
prompt = dialect.translate
|
||||
|
||||
Net::HTTP.start(
|
||||
model_uri.host,
|
||||
model_uri.port,
|
||||
|
@ -43,7 +45,10 @@ module DiscourseAi
|
|||
) do |http|
|
||||
response_data = +""
|
||||
response_raw = +""
|
||||
request_body = prepare_payload(prompt, model_params).to_json
|
||||
|
||||
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
|
||||
partials_raw = +""
|
||||
request_body = prepare_payload(prompt, model_params, dialect).to_json
|
||||
|
||||
request = prepare_request(request_body)
|
||||
|
||||
|
@ -66,6 +71,15 @@ module DiscourseAi
|
|||
if !@streaming_mode
|
||||
response_raw = response.read_body
|
||||
response_data = extract_completion_from(response_raw)
|
||||
partials_raw = response_data.to_s
|
||||
|
||||
if has_tool?("", response_data)
|
||||
function_buffer = build_buffer # Nokogiri document
|
||||
function_buffer = add_to_buffer(function_buffer, "", response_data)
|
||||
|
||||
response_data = +function_buffer.at("function_calls").to_s
|
||||
response_data << "\n"
|
||||
end
|
||||
|
||||
return response_data
|
||||
end
|
||||
|
@ -75,6 +89,7 @@ module DiscourseAi
|
|||
cancel = lambda { cancelled = true }
|
||||
|
||||
leftover = ""
|
||||
function_buffer = build_buffer # Nokogiri document
|
||||
|
||||
response.read_body do |chunk|
|
||||
if cancelled
|
||||
|
@ -85,6 +100,12 @@ module DiscourseAi
|
|||
decoded_chunk = decode(chunk)
|
||||
response_raw << decoded_chunk
|
||||
|
||||
# Buffering for extremely slow streaming.
|
||||
if (leftover + decoded_chunk).length < "data: [DONE]".length
|
||||
leftover += decoded_chunk
|
||||
next
|
||||
end
|
||||
|
||||
partials_from(leftover + decoded_chunk).each do |raw_partial|
|
||||
next if cancelled
|
||||
next if raw_partial.blank?
|
||||
|
@ -93,11 +114,27 @@ module DiscourseAi
|
|||
partial = extract_completion_from(raw_partial)
|
||||
next if partial.nil?
|
||||
leftover = ""
|
||||
response_data << partial
|
||||
|
||||
yield partial, cancel if partial
|
||||
if has_tool?(response_data, partial)
|
||||
function_buffer = add_to_buffer(function_buffer, response_data, partial)
|
||||
|
||||
if buffering_finished?(dialect.tools, function_buffer)
|
||||
invocation = +function_buffer.at("function_calls").to_s
|
||||
invocation << "\n"
|
||||
|
||||
partials_raw << partial.to_s
|
||||
response_data << invocation
|
||||
|
||||
yield invocation, cancel
|
||||
end
|
||||
else
|
||||
partials_raw << partial
|
||||
response_data << partial
|
||||
|
||||
yield partial, cancel if partial
|
||||
end
|
||||
rescue JSON::ParserError
|
||||
leftover = raw_partial
|
||||
leftover += decoded_chunk
|
||||
end
|
||||
end
|
||||
end
|
||||
|
@ -109,7 +146,7 @@ module DiscourseAi
|
|||
ensure
|
||||
if log
|
||||
log.raw_response_payload = response_raw
|
||||
log.response_tokens = tokenizer.size(response_data)
|
||||
log.response_tokens = tokenizer.size(partials_raw)
|
||||
log.save!
|
||||
|
||||
if Rails.env.development?
|
||||
|
@ -165,6 +202,40 @@ module DiscourseAi
|
|||
def extract_prompt_for_tokenizer(prompt)
|
||||
prompt
|
||||
end
|
||||
|
||||
def build_buffer
|
||||
Nokogiri::HTML5.fragment(<<~TEXT)
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name></tool_name>
|
||||
<tool_id></tool_id>
|
||||
<parameters></parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
end
|
||||
|
||||
def has_tool?(response, partial)
|
||||
(response + partial).include?("<function_calls>")
|
||||
end
|
||||
|
||||
def add_to_buffer(function_buffer, response_data, partial)
|
||||
new_buffer = Nokogiri::HTML5.fragment(response_data + partial)
|
||||
if tool_name = new_buffer.at("tool_name").text
|
||||
if new_buffer.at("tool_id").nil?
|
||||
tool_id_node =
|
||||
Nokogiri::HTML5::DocumentFragment.parse("\n<tool_id>#{tool_name}</tool_id>")
|
||||
|
||||
new_buffer.at("invoke").children[1].add_next_sibling(tool_id_node)
|
||||
end
|
||||
end
|
||||
|
||||
new_buffer
|
||||
end
|
||||
|
||||
def buffering_finished?(_available_functions, buffer)
|
||||
buffer.to_s.include?("</function_calls>")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -31,9 +31,14 @@ module DiscourseAi
|
|||
cancelled = false
|
||||
cancel_fn = lambda { cancelled = true }
|
||||
|
||||
response.each_char do |char|
|
||||
break if cancelled
|
||||
yield(char, cancel_fn)
|
||||
# We buffer and return tool invocations in one go.
|
||||
if is_tool?(response)
|
||||
yield(response, cancel_fn)
|
||||
else
|
||||
response.each_char do |char|
|
||||
break if cancelled
|
||||
yield(char, cancel_fn)
|
||||
end
|
||||
end
|
||||
else
|
||||
response
|
||||
|
@ -43,6 +48,12 @@ module DiscourseAi
|
|||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def is_tool?(response)
|
||||
Nokogiri::HTML5.fragment(response).at("function_calls").present?
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -25,8 +25,11 @@ module DiscourseAi
|
|||
URI(url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params)
|
||||
default_options.merge(model_params).merge(contents: prompt)
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
default_options
|
||||
.merge(model_params)
|
||||
.merge(contents: prompt)
|
||||
.tap { |payload| payload[:tools] = dialect.tools if dialect.tools.present? }
|
||||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
|
@ -36,25 +39,72 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
if @streaming_mode
|
||||
parsed = response_raw
|
||||
else
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
end
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
|
||||
completion = dig_text(parsed).to_s
|
||||
response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
|
||||
|
||||
has_function_call = response_h.dig(:functionCall).present?
|
||||
has_function_call ? response_h[:functionCall] : response_h.dig(:text)
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
JSON.parse(decoded_chunk, symbolize_names: true)
|
||||
decoded_chunk
|
||||
.split("\n")
|
||||
.map do |line|
|
||||
if line == ","
|
||||
nil
|
||||
elsif line.starts_with?("[")
|
||||
line[1..-1]
|
||||
elsif line.ends_with?("]")
|
||||
line[0..-1]
|
||||
else
|
||||
line
|
||||
end
|
||||
end
|
||||
.compact_blank
|
||||
end
|
||||
|
||||
def extract_prompt_for_tokenizer(prompt)
|
||||
prompt.to_s
|
||||
end
|
||||
|
||||
def dig_text(response)
|
||||
response.dig(:candidates, 0, :content, :parts, 0, :text)
|
||||
def has_tool?(_response_data, partial)
|
||||
partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
|
||||
end
|
||||
|
||||
def add_to_buffer(function_buffer, _response_data, partial)
|
||||
if partial[:name].present?
|
||||
function_buffer.at("tool_name").content = partial[:name]
|
||||
function_buffer.at("tool_id").content = partial[:name]
|
||||
end
|
||||
|
||||
if partial[:args]
|
||||
argument_fragments =
|
||||
partial[:args].reduce(+"") do |memo, (arg_name, value)|
|
||||
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
|
||||
end
|
||||
argument_fragments << "\n"
|
||||
|
||||
function_buffer.at("parameters").children =
|
||||
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
||||
end
|
||||
|
||||
function_buffer
|
||||
end
|
||||
|
||||
def buffering_finished?(available_functions, buffer)
|
||||
tool_name = buffer.at("tool_name")&.text
|
||||
return false if tool_name.blank?
|
||||
|
||||
signature =
|
||||
available_functions.dig(0, :function_declarations).find { |f| f[:name] == tool_name }
|
||||
|
||||
signature[:parameters].reduce(true) do |memo, param|
|
||||
param_present = buffer.at(param[:name]).present?
|
||||
next(memo) if param_present || !signature[:required].include?(param[:name])
|
||||
|
||||
memo && param_present
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -11,7 +11,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def default_options
|
||||
{ parameters: { repetition_penalty: 1.1, temperature: 0.7 } }
|
||||
{ parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } }
|
||||
end
|
||||
|
||||
def provider_id
|
||||
|
@ -24,7 +24,7 @@ module DiscourseAi
|
|||
URI(SiteSetting.ai_hugging_face_api_url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params)
|
||||
def prepare_payload(prompt, model_params, _dialect)
|
||||
default_options
|
||||
.merge(inputs: prompt)
|
||||
.tap do |payload|
|
||||
|
@ -33,7 +33,6 @@ module DiscourseAi
|
|||
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
|
||||
|
||||
payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
|
||||
payload[:parameters][:return_full_text] = false
|
||||
|
||||
payload[:stream] = true if @streaming_mode
|
||||
end
|
||||
|
|
|
@ -37,11 +37,14 @@ module DiscourseAi
|
|||
URI(url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params)
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
default_options
|
||||
.merge(model_params)
|
||||
.merge(messages: prompt)
|
||||
.tap { |payload| payload[:stream] = true if @streaming_mode }
|
||||
.tap do |payload|
|
||||
payload[:stream] = true if @streaming_mode
|
||||
payload[:tools] = dialect.tools if dialect.tools.present?
|
||||
end
|
||||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
|
@ -62,15 +65,12 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
|
||||
|
||||
(
|
||||
if @streaming_mode
|
||||
parsed.dig(:choices, 0, :delta, :content)
|
||||
else
|
||||
parsed.dig(:choices, 0, :message, :content)
|
||||
end
|
||||
).to_s
|
||||
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
||||
|
||||
has_function_call = response_h.dig(:tool_calls).present?
|
||||
has_function_call ? response_h.dig(:tool_calls, 0, :function) : response_h.dig(:content)
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
|
@ -86,6 +86,42 @@ module DiscourseAi
|
|||
def extract_prompt_for_tokenizer(prompt)
|
||||
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
||||
end
|
||||
|
||||
def has_tool?(_response_data, partial)
|
||||
partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
|
||||
end
|
||||
|
||||
def add_to_buffer(function_buffer, _response_data, partial)
|
||||
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
|
||||
function_buffer.at("tool_id").content = partial[:id] if partial[:id].present?
|
||||
|
||||
if partial[:arguments]
|
||||
argument_fragments =
|
||||
partial[:arguments].reduce(+"") do |memo, (arg_name, value)|
|
||||
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
|
||||
end
|
||||
argument_fragments << "\n"
|
||||
|
||||
function_buffer.at("parameters").children =
|
||||
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
||||
end
|
||||
|
||||
function_buffer
|
||||
end
|
||||
|
||||
def buffering_finished?(available_functions, buffer)
|
||||
tool_name = buffer.at("tool_name")&.text
|
||||
return false if tool_name.blank?
|
||||
|
||||
signature = available_functions.find { |f| f.dig(:tool, :name) == tool_name }[:tool]
|
||||
|
||||
signature[:parameters].reduce(true) do |memo, param|
|
||||
param_present = buffer.at(param[:name]).present?
|
||||
next(memo) if param_present && !param[:required]
|
||||
|
||||
memo && param_present
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -24,35 +24,26 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def self.proxy(model_name)
|
||||
dialects = [
|
||||
DiscourseAi::Completions::Dialects::Claude,
|
||||
DiscourseAi::Completions::Dialects::Llama2Classic,
|
||||
DiscourseAi::Completions::Dialects::ChatGpt,
|
||||
DiscourseAi::Completions::Dialects::OrcaStyle,
|
||||
DiscourseAi::Completions::Dialects::Gemini,
|
||||
]
|
||||
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
|
||||
|
||||
dialect =
|
||||
dialects.detect(-> { raise UNKNOWN_MODEL }) { |d| d.can_translate?(model_name) }.new
|
||||
|
||||
return new(dialect, @canned_response, model_name) if @canned_response
|
||||
return new(dialect_klass, @canned_response, model_name) if @canned_response
|
||||
|
||||
gateway =
|
||||
DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new(
|
||||
model_name,
|
||||
dialect.tokenizer,
|
||||
dialect_klass.tokenizer,
|
||||
)
|
||||
|
||||
new(dialect, gateway, model_name)
|
||||
new(dialect_klass, gateway, model_name)
|
||||
end
|
||||
|
||||
def initialize(dialect, gateway, model_name)
|
||||
@dialect = dialect
|
||||
def initialize(dialect_klass, gateway, model_name)
|
||||
@dialect_klass = dialect_klass
|
||||
@gateway = gateway
|
||||
@model_name = model_name
|
||||
end
|
||||
|
||||
delegate :tokenizer, to: :dialect
|
||||
delegate :tokenizer, to: :dialect_klass
|
||||
|
||||
# @param generic_prompt { Hash } - Prompt using our generic format.
|
||||
# We use the following keys from the hash:
|
||||
|
@ -60,23 +51,64 @@ module DiscourseAi
|
|||
# - input: String containing user input
|
||||
# - examples (optional): Array of arrays with examples of input and responses. Each array is a input/response pair like [[example1, response1], [example2, response2]].
|
||||
# - post_insts (optional): Additional instructions for the LLM. Some dialects like Claude add these at the end of the prompt.
|
||||
# - conversation_context (optional): Array of hashes to provide context about an ongoing conversation with the model.
|
||||
# We translate the array in reverse order, meaning the first element would be the most recent message in the conversation.
|
||||
# Example:
|
||||
#
|
||||
# [
|
||||
# { type: "user", name: "user1", content: "This is a new message by a user" },
|
||||
# { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||
# { type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||
# ]
|
||||
#
|
||||
# - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example:
|
||||
#
|
||||
# {
|
||||
# name: "get_weather",
|
||||
# description: "Get the weather in a city",
|
||||
# parameters: [
|
||||
# { name: "location", type: "string", description: "the city name", required: true },
|
||||
# {
|
||||
# name: "unit",
|
||||
# type: "string",
|
||||
# description: "the unit of measurement celcius c or fahrenheit f",
|
||||
# enum: %w[c f],
|
||||
# required: true,
|
||||
# },
|
||||
# ],
|
||||
# }
|
||||
#
|
||||
# @param user { User } - User requesting the summary.
|
||||
#
|
||||
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
|
||||
#
|
||||
# @returns { String } - Completion result.
|
||||
#
|
||||
# When the model invokes a tool, we'll wait until the endpoint finishes replying and feed you a fully-formed tool,
|
||||
# even if you passed a partial_read_blk block. Invocations are strings that look like this:
|
||||
#
|
||||
# <function_calls>
|
||||
# <invoke>
|
||||
# <tool_name>get_weather</tool_name>
|
||||
# <tool_id>get_weather</tool_id>
|
||||
# <parameters>
|
||||
# <location>Sydney</location>
|
||||
# <unit>c</unit>
|
||||
# </parameters>
|
||||
# </invoke>
|
||||
# </function_calls>
|
||||
#
|
||||
def completion!(generic_prompt, user, &partial_read_blk)
|
||||
prompt = dialect.translate(generic_prompt)
|
||||
|
||||
model_params = generic_prompt.dig(:params, model_name) || {}
|
||||
|
||||
gateway.perform_completion!(prompt, user, model_params, &partial_read_blk)
|
||||
dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params)
|
||||
|
||||
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
attr_reader :dialect, :gateway, :model_name
|
||||
attr_reader :dialect_klass, :gateway, :model_name
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,7 +1,24 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
||||
subject(:dialect) { described_class.new }
|
||||
subject(:dialect) { described_class.new(prompt, "gpt-4") }
|
||||
|
||||
let(:tool) do
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name", required: true },
|
||||
{
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
let(:prompt) do
|
||||
{
|
||||
|
@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
|||
TEXT
|
||||
post_insts:
|
||||
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
|
||||
tools: [tool],
|
||||
}
|
||||
end
|
||||
|
||||
|
@ -35,7 +53,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
|||
{ role: "user", content: prompt[:input] },
|
||||
]
|
||||
|
||||
translated = dialect.translate(prompt)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to contain_exactly(*open_ai_version)
|
||||
end
|
||||
|
@ -55,9 +73,51 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
|||
{ role: "user", content: prompt[:input] },
|
||||
]
|
||||
|
||||
translated = dialect.translate(prompt)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to contain_exactly(*open_ai_version)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#conversation_context" do
|
||||
let(:context) do
|
||||
[
|
||||
{ type: "user", name: "user1", content: "This is a new message by a user" },
|
||||
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||
]
|
||||
end
|
||||
|
||||
it "adds conversation in reverse order (first == newer)" do
|
||||
prompt[:conversation_context] = context
|
||||
|
||||
translated_context = dialect.conversation_context
|
||||
|
||||
expect(translated_context).to eq(
|
||||
[
|
||||
{ role: "tool", content: context.last[:content], tool_call_id: context.last[:name] },
|
||||
{ role: "assistant", content: context.second[:content] },
|
||||
{ role: "user", content: context.first[:content], name: context.first[:name] },
|
||||
],
|
||||
)
|
||||
end
|
||||
|
||||
it "trims content if it's getting too long" do
|
||||
context.last[:content] = context.last[:content] * 1000
|
||||
|
||||
prompt[:conversation_context] = context
|
||||
|
||||
translated_context = dialect.conversation_context
|
||||
|
||||
expect(translated_context.last[:content].length).to be < context.last[:content].length
|
||||
end
|
||||
end
|
||||
|
||||
describe "#tools" do
|
||||
it "returns a list of available tools" do
|
||||
open_ai_tool_f = { type: "function", tool: tool }
|
||||
|
||||
expect(subject.tools).to contain_exactly(open_ai_tool_f)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,7 +1,24 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||
subject(:dialect) { described_class.new }
|
||||
subject(:dialect) { described_class.new(prompt, "claude-2") }
|
||||
|
||||
let(:tool) do
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name", required: true },
|
||||
{
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
let(:prompt) do
|
||||
{
|
||||
|
@ -37,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
|||
Assistant:
|
||||
TEXT
|
||||
|
||||
translated = dialect.translate(prompt)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to eq(anthropic_version)
|
||||
end
|
||||
|
@ -60,9 +77,111 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
|||
Assistant:
|
||||
TEXT
|
||||
|
||||
translated = dialect.translate(prompt)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to eq(anthropic_version)
|
||||
end
|
||||
|
||||
it "include tools inside the prompt" do
|
||||
prompt[:tools] = [tool]
|
||||
|
||||
anthropic_version = <<~TEXT
|
||||
Human: #{prompt[:insts]}
|
||||
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:
|
||||
|
||||
<tools>
|
||||
#{dialect.tools}</tools>
|
||||
#{prompt[:input]}
|
||||
#{prompt[:post_insts]}
|
||||
Assistant:
|
||||
TEXT
|
||||
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to eq(anthropic_version)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#conversation_context" do
|
||||
let(:context) do
|
||||
[
|
||||
{ type: "user", name: "user1", content: "This is a new message by a user" },
|
||||
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||
]
|
||||
end
|
||||
|
||||
it "adds conversation in reverse order (first == newer)" do
|
||||
prompt[:conversation_context] = context
|
||||
|
||||
expected = <<~TEXT
|
||||
Assistant:
|
||||
<function_results>
|
||||
<result>
|
||||
<tool_name>tool_id</tool_name>
|
||||
<json>
|
||||
#{context.last[:content]}
|
||||
</json>
|
||||
</result>
|
||||
</function_results>
|
||||
Assistant: #{context.second[:content]}
|
||||
Human: #{context.first[:content]}
|
||||
TEXT
|
||||
|
||||
translated_context = dialect.conversation_context
|
||||
|
||||
expect(translated_context).to eq(expected)
|
||||
end
|
||||
|
||||
it "trims content if it's getting too long" do
|
||||
context.last[:content] = context.last[:content] * 10_000
|
||||
prompt[:conversation_context] = context
|
||||
|
||||
translated_context = dialect.conversation_context
|
||||
|
||||
expect(translated_context.length).to be < context.last[:content].length
|
||||
end
|
||||
end
|
||||
|
||||
describe "#tools" do
|
||||
it "translates tools to the tool syntax" do
|
||||
prompt[:tools] = [tool]
|
||||
|
||||
translated_tool = <<~TEXT
|
||||
<tool_description>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<description>Get the weather in a city</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>location</name>
|
||||
<type>string</type>
|
||||
<description>the city name</description>
|
||||
<required>true</required>
|
||||
</parameter>
|
||||
<parameter>
|
||||
<name>unit</name>
|
||||
<type>string</type>
|
||||
<description>the unit of measurement celcius c or fahrenheit f</description>
|
||||
<required>true</required>
|
||||
<options>c,f</options>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</tool_description>
|
||||
TEXT
|
||||
|
||||
expect(dialect.tools).to eq(translated_tool)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,7 +1,24 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||
subject(:dialect) { described_class.new }
|
||||
subject(:dialect) { described_class.new(prompt, "gemini-pro") }
|
||||
|
||||
let(:tool) do
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name", required: true },
|
||||
{
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
let(:prompt) do
|
||||
{
|
||||
|
@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
|||
TEXT
|
||||
post_insts:
|
||||
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
|
||||
tools: [tool],
|
||||
}
|
||||
end
|
||||
|
||||
|
@ -36,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
|||
{ role: "user", parts: { text: prompt[:input] } },
|
||||
]
|
||||
|
||||
translated = dialect.translate(prompt)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to eq(gemini_version)
|
||||
end
|
||||
|
@ -57,9 +75,79 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
|||
{ role: "user", parts: { text: prompt[:input] } },
|
||||
]
|
||||
|
||||
translated = dialect.translate(prompt)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to contain_exactly(*gemini_version)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#conversation_context" do
|
||||
let(:context) do
|
||||
[
|
||||
{ type: "user", name: "user1", content: "This is a new message by a user" },
|
||||
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||
]
|
||||
end
|
||||
|
||||
it "adds conversation in reverse order (first == newer)" do
|
||||
prompt[:conversation_context] = context
|
||||
|
||||
translated_context = dialect.conversation_context
|
||||
|
||||
expect(translated_context).to eq(
|
||||
[
|
||||
{
|
||||
role: "model",
|
||||
parts: [
|
||||
{
|
||||
"functionResponse" => {
|
||||
name: context.last[:name],
|
||||
content: context.last[:content],
|
||||
},
|
||||
},
|
||||
],
|
||||
},
|
||||
{ role: "model", parts: [{ text: context.second[:content] }] },
|
||||
{ role: "user", parts: [{ text: context.first[:content] }] },
|
||||
],
|
||||
)
|
||||
end
|
||||
|
||||
it "trims content if it's getting too long" do
|
||||
context.last[:content] = context.last[:content] * 1000
|
||||
|
||||
prompt[:conversation_context] = context
|
||||
|
||||
translated_context = dialect.conversation_context
|
||||
|
||||
expect(translated_context.last.dig(:parts, 0, :text).length).to be <
|
||||
context.last[:content].length
|
||||
end
|
||||
end
|
||||
|
||||
describe "#tools" do
|
||||
it "returns a list of available tools" do
|
||||
gemini_tools = {
|
||||
function_declarations: [
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name" },
|
||||
{
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
},
|
||||
],
|
||||
required: %w[location unit],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
||||
expect(subject.tools).to contain_exactly(gemini_tools)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,7 +1,24 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
|
||||
subject(:dialect) { described_class.new }
|
||||
subject(:dialect) { described_class.new(prompt, "Llama2-chat-hf") }
|
||||
|
||||
let(:tool) do
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name", required: true },
|
||||
{
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
let(:prompt) do
|
||||
{
|
||||
|
@ -31,11 +48,16 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
|
|||
describe "#translate" do
|
||||
it "translates a prompt written in our generic format to the Llama2 format" do
|
||||
llama2_classic_version = <<~TEXT
|
||||
[INST]<<SYS>>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<</SYS>>[/INST]
|
||||
[INST]
|
||||
<<SYS>>
|
||||
#{prompt[:insts]}
|
||||
#{prompt[:post_insts]}
|
||||
<</SYS>>
|
||||
[/INST]
|
||||
[INST]#{prompt[:input]}[/INST]
|
||||
TEXT
|
||||
|
||||
translated = dialect.translate(prompt)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to eq(llama2_classic_version)
|
||||
end
|
||||
|
@ -49,15 +71,126 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
|
|||
]
|
||||
|
||||
llama2_classic_version = <<~TEXT
|
||||
[INST]<<SYS>>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<</SYS>>[/INST]
|
||||
[INST]
|
||||
<<SYS>>
|
||||
#{prompt[:insts]}
|
||||
#{prompt[:post_insts]}
|
||||
<</SYS>>
|
||||
[/INST]
|
||||
[INST]#{prompt[:examples][0][0]}[/INST]
|
||||
#{prompt[:examples][0][1]}
|
||||
[INST]#{prompt[:input]}[/INST]
|
||||
TEXT
|
||||
|
||||
translated = dialect.translate(prompt)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to eq(llama2_classic_version)
|
||||
end
|
||||
|
||||
it "include tools inside the prompt" do
|
||||
prompt[:tools] = [tool]
|
||||
|
||||
llama2_classic_version = <<~TEXT
|
||||
[INST]
|
||||
<<SYS>>
|
||||
#{prompt[:insts]}
|
||||
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:
|
||||
|
||||
<tools>
|
||||
#{dialect.tools}</tools>
|
||||
#{prompt[:post_insts]}
|
||||
<</SYS>>
|
||||
[/INST]
|
||||
[INST]#{prompt[:input]}[/INST]
|
||||
TEXT
|
||||
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to eq(llama2_classic_version)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#conversation_context" do
|
||||
let(:context) do
|
||||
[
|
||||
{ type: "user", name: "user1", content: "This is a new message by a user" },
|
||||
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||
]
|
||||
end
|
||||
|
||||
it "adds conversation in reverse order (first == newer)" do
|
||||
prompt[:conversation_context] = context
|
||||
|
||||
expected = <<~TEXT
|
||||
[INST]
|
||||
<function_results>
|
||||
<result>
|
||||
<tool_name>tool_id</tool_name>
|
||||
<json>
|
||||
#{context.last[:content]}
|
||||
</json>
|
||||
</result>
|
||||
</function_results>
|
||||
[/INST]
|
||||
[INST]#{context.second[:content]}[/INST]
|
||||
#{context.first[:content]}
|
||||
TEXT
|
||||
|
||||
translated_context = dialect.conversation_context
|
||||
|
||||
expect(translated_context).to eq(expected)
|
||||
end
|
||||
|
||||
it "trims content if it's getting too long" do
|
||||
context.last[:content] = context.last[:content] * 1_000
|
||||
prompt[:conversation_context] = context
|
||||
|
||||
translated_context = dialect.conversation_context
|
||||
|
||||
expect(translated_context.length).to be < context.last[:content].length
|
||||
end
|
||||
end
|
||||
|
||||
describe "#tools" do
|
||||
it "translates functions to the tool syntax" do
|
||||
prompt[:tools] = [tool]
|
||||
|
||||
translated_tool = <<~TEXT
|
||||
<tool_description>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<description>Get the weather in a city</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>location</name>
|
||||
<type>string</type>
|
||||
<description>the city name</description>
|
||||
<required>true</required>
|
||||
</parameter>
|
||||
<parameter>
|
||||
<name>unit</name>
|
||||
<type>string</type>
|
||||
<description>the unit of measurement celcius c or fahrenheit f</description>
|
||||
<required>true</required>
|
||||
<options>c,f</options>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</tool_description>
|
||||
TEXT
|
||||
|
||||
expect(dialect.tools).to eq(translated_tool)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,44 +1,62 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
|
||||
subject(:dialect) { described_class.new }
|
||||
subject(:dialect) { described_class.new(prompt, "StableBeluga2") }
|
||||
|
||||
let(:tool) do
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name", required: true },
|
||||
{
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
let(:prompt) do
|
||||
{
|
||||
insts: <<~TEXT,
|
||||
I want you to act as a title generator for written pieces. I will provide you with a text,
|
||||
and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,
|
||||
and ensure that the meaning is maintained. Replies will utilize the language type of the topic.
|
||||
TEXT
|
||||
input: <<~TEXT,
|
||||
Here is the text, inside <input></input> XML tags:
|
||||
<input>
|
||||
To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
|
||||
discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer
|
||||
defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry.
|
||||
|
||||
Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires,
|
||||
a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and
|
||||
slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he
|
||||
dies so that a scene may be repeated.
|
||||
</input>
|
||||
TEXT
|
||||
post_insts:
|
||||
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
|
||||
}
|
||||
end
|
||||
|
||||
describe "#translate" do
|
||||
let(:prompt) do
|
||||
{
|
||||
insts: <<~TEXT,
|
||||
I want you to act as a title generator for written pieces. I will provide you with a text,
|
||||
and you will generate five attention-grabbing titles. Please keep the title concise and under 20 words,
|
||||
and ensure that the meaning is maintained. Replies will utilize the language type of the topic.
|
||||
TEXT
|
||||
input: <<~TEXT,
|
||||
Here is the text, inside <input></input> XML tags:
|
||||
<input>
|
||||
To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
|
||||
discovers among the faces and blades that of Marcus Brutus, his protege, perhaps his son, and he no longer
|
||||
defends himself, but instead exclaims: 'You too, my son!' Shakespeare and Quevedo capture the pathetic cry.
|
||||
|
||||
Destiny favors repetitions, variants, symmetries; nineteen centuries later, in the southern province of Buenos Aires,
|
||||
a gaucho is attacked by other gauchos and, as he falls, recognizes a godson of his and says with gentle rebuke and
|
||||
slow surprise (these words must be heard, not read): 'But, my friend!' He is killed and does not know that he
|
||||
dies so that a scene may be repeated.
|
||||
</input>
|
||||
TEXT
|
||||
post_insts:
|
||||
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
|
||||
}
|
||||
end
|
||||
|
||||
it "translates a prompt written in our generic format to the Open AI format" do
|
||||
orca_style_version = <<~TEXT
|
||||
### System:
|
||||
#{[prompt[:insts], prompt[:post_insts]].join("\n")}
|
||||
#{prompt[:insts]}
|
||||
#{prompt[:post_insts]}
|
||||
### User:
|
||||
#{prompt[:input]}
|
||||
### Assistant:
|
||||
TEXT
|
||||
|
||||
translated = dialect.translate(prompt)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to eq(orca_style_version)
|
||||
end
|
||||
|
@ -53,7 +71,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
|
|||
|
||||
orca_style_version = <<~TEXT
|
||||
### System:
|
||||
#{[prompt[:insts], prompt[:post_insts]].join("\n")}
|
||||
#{prompt[:insts]}
|
||||
#{prompt[:post_insts]}
|
||||
### User:
|
||||
#{prompt[:examples][0][0]}
|
||||
### Assistant:
|
||||
|
@ -63,9 +82,113 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
|
|||
### Assistant:
|
||||
TEXT
|
||||
|
||||
translated = dialect.translate(prompt)
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to eq(orca_style_version)
|
||||
end
|
||||
|
||||
it "include tools inside the prompt" do
|
||||
prompt[:tools] = [tool]
|
||||
|
||||
orca_style_version = <<~TEXT
|
||||
### System:
|
||||
#{prompt[:insts]}
|
||||
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>$TOOL_NAME</tool_name>
|
||||
<parameters>
|
||||
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||
...
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
|
||||
Here are the tools available:
|
||||
|
||||
<tools>
|
||||
#{dialect.tools}</tools>
|
||||
#{prompt[:post_insts]}
|
||||
### User:
|
||||
#{prompt[:input]}
|
||||
### Assistant:
|
||||
TEXT
|
||||
|
||||
translated = dialect.translate
|
||||
|
||||
expect(translated).to eq(orca_style_version)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#conversation_context" do
|
||||
let(:context) do
|
||||
[
|
||||
{ type: "user", name: "user1", content: "This is a new message by a user" },
|
||||
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||
]
|
||||
end
|
||||
|
||||
it "adds conversation in reverse order (first == newer)" do
|
||||
prompt[:conversation_context] = context
|
||||
|
||||
expected = <<~TEXT
|
||||
### Assistant:
|
||||
<function_results>
|
||||
<result>
|
||||
<tool_name>tool_id</tool_name>
|
||||
<json>
|
||||
#{context.last[:content]}
|
||||
</json>
|
||||
</result>
|
||||
</function_results>
|
||||
### Assistant: #{context.second[:content]}
|
||||
### User: #{context.first[:content]}
|
||||
TEXT
|
||||
|
||||
translated_context = dialect.conversation_context
|
||||
|
||||
expect(translated_context).to eq(expected)
|
||||
end
|
||||
|
||||
it "trims content if it's getting too long" do
|
||||
context.last[:content] = context.last[:content] * 1_000
|
||||
prompt[:conversation_context] = context
|
||||
|
||||
translated_context = dialect.conversation_context
|
||||
|
||||
expect(translated_context.length).to be < context.last[:content].length
|
||||
end
|
||||
end
|
||||
|
||||
describe "#tools" do
|
||||
it "translates tools to the tool syntax" do
|
||||
prompt[:tools] = [tool]
|
||||
|
||||
translated_tool = <<~TEXT
|
||||
<tool_description>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<description>Get the weather in a city</description>
|
||||
<parameters>
|
||||
<parameter>
|
||||
<name>location</name>
|
||||
<type>string</type>
|
||||
<description>the city name</description>
|
||||
<required>true</required>
|
||||
</parameter>
|
||||
<parameter>
|
||||
<name>unit</name>
|
||||
<type>string</type>
|
||||
<description>the unit of measurement celcius c or fahrenheit f</description>
|
||||
<required>true</required>
|
||||
<options>c,f</options>
|
||||
</parameter>
|
||||
</parameters>
|
||||
</tool_description>
|
||||
TEXT
|
||||
|
||||
expect(dialect.tools).to eq(translated_tool)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -6,7 +6,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
||||
|
||||
let(:model_name) { "claude-2" }
|
||||
let(:prompt) { "Human: write 3 words\n\n" }
|
||||
let(:generic_prompt) { { insts: "write 3 words" } }
|
||||
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
|
||||
|
@ -23,10 +25,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
}
|
||||
end
|
||||
|
||||
def stub_response(prompt, response_text)
|
||||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
||||
.with(body: request_body)
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
end
|
||||
|
||||
|
@ -42,7 +44,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
}.to_json
|
||||
end
|
||||
|
||||
def stub_streamed_response(prompt, deltas)
|
||||
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||
chunks =
|
||||
deltas.each_with_index.map do |_, index|
|
||||
if index == (deltas.length - 1)
|
||||
|
@ -52,13 +54,27 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
end
|
||||
end
|
||||
|
||||
chunks = chunks.join("\n\n")
|
||||
chunks = chunks.join("\n\n").split("")
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
|
||||
.with(body: stream_request_body)
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
let(:tool_deltas) { ["<function", <<~REPLY] }
|
||||
_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
REPLY
|
||||
|
||||
let(:tool_call) { invocation }
|
||||
|
||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||
end
|
||||
|
|
|
@ -9,10 +9,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
|
||||
let(:model_name) { "claude-2" }
|
||||
let(:bedrock_name) { "claude-v2" }
|
||||
let(:prompt) { "Human: write 3 words\n\n" }
|
||||
let(:generic_prompt) { { insts: "write 3 words" } }
|
||||
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||
let(:stream_request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||
let(:stream_request_body) { request_body }
|
||||
|
||||
before do
|
||||
SiteSetting.ai_bedrock_access_key_id = "123456"
|
||||
|
@ -20,39 +22,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
SiteSetting.ai_bedrock_region = "us-east-1"
|
||||
end
|
||||
|
||||
# Copied from https://github.com/bblimke/webmock/issues/629
|
||||
# Workaround for stubbing a streamed response
|
||||
before do
|
||||
mocked_http =
|
||||
Class.new(Net::HTTP) do
|
||||
def request(*)
|
||||
super do |response|
|
||||
response.instance_eval do
|
||||
def read_body(*, &block)
|
||||
if block_given?
|
||||
@body.each(&block)
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
yield response if block_given?
|
||||
|
||||
response
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@original_net_http = Net.send(:remove_const, :HTTP)
|
||||
Net.send(:const_set, :HTTP, mocked_http)
|
||||
end
|
||||
|
||||
after do
|
||||
Net.send(:remove_const, :HTTP)
|
||||
Net.send(:const_set, :HTTP, @original_net_http)
|
||||
end
|
||||
|
||||
def response(content)
|
||||
{
|
||||
completion: content,
|
||||
|
@ -65,7 +34,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
}
|
||||
end
|
||||
|
||||
def stub_response(prompt, response_text)
|
||||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
|
@ -102,7 +71,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
encoder.encode(message)
|
||||
end
|
||||
|
||||
def stub_streamed_response(prompt, deltas)
|
||||
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||
chunks =
|
||||
deltas.each_with_index.map do |_, index|
|
||||
if index == (deltas.length - 1)
|
||||
|
@ -121,5 +90,19 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
let(:tool_deltas) { ["<function", <<~REPLY] }
|
||||
_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
REPLY
|
||||
|
||||
let(:tool_call) { invocation }
|
||||
|
||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||
end
|
||||
|
|
|
@ -1,70 +1,177 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.shared_examples "an endpoint that can communicate with a completion service" do
|
||||
# Copied from https://github.com/bblimke/webmock/issues/629
|
||||
# Workaround for stubbing a streamed response
|
||||
before do
|
||||
mocked_http =
|
||||
Class.new(Net::HTTP) do
|
||||
def request(*)
|
||||
super do |response|
|
||||
response.instance_eval do
|
||||
def read_body(*, &block)
|
||||
if block_given?
|
||||
@body.each(&block)
|
||||
else
|
||||
super
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
yield response if block_given?
|
||||
|
||||
response
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@original_net_http = Net.send(:remove_const, :HTTP)
|
||||
Net.send(:const_set, :HTTP, mocked_http)
|
||||
end
|
||||
|
||||
after do
|
||||
Net.send(:remove_const, :HTTP)
|
||||
Net.send(:const_set, :HTTP, @original_net_http)
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
fab!(:user) { Fabricate(:user) }
|
||||
|
||||
let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
|
||||
let(:tool) do
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name", required: true },
|
||||
{
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
required: true,
|
||||
},
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
let(:invocation) { <<~TEXT }
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<tool_id>get_weather</tool_id>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
context "when using regular mode" do
|
||||
before { stub_response(prompt, response_text) }
|
||||
context "with simple prompts" do
|
||||
let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
|
||||
|
||||
it "can complete a trivial prompt" do
|
||||
completion_response = model.perform_completion!(prompt, user)
|
||||
before { stub_response(prompt, response_text) }
|
||||
|
||||
expect(completion_response).to eq(response_text)
|
||||
it "can complete a trivial prompt" do
|
||||
completion_response = model.perform_completion!(dialect, user)
|
||||
|
||||
expect(completion_response).to eq(response_text)
|
||||
end
|
||||
|
||||
it "creates an audit log for the request" do
|
||||
model.perform_completion!(dialect, user)
|
||||
|
||||
expect(AiApiAuditLog.count).to eq(1)
|
||||
log = AiApiAuditLog.first
|
||||
|
||||
response_body = response(response_text).to_json
|
||||
|
||||
expect(log.provider_id).to eq(model.provider_id)
|
||||
expect(log.user_id).to eq(user.id)
|
||||
expect(log.raw_request_payload).to eq(request_body)
|
||||
expect(log.raw_response_payload).to eq(response_body)
|
||||
expect(log.request_tokens).to eq(model.prompt_size(prompt))
|
||||
expect(log.response_tokens).to eq(model.tokenizer.size(response_text))
|
||||
end
|
||||
end
|
||||
|
||||
it "creates an audit log for the request" do
|
||||
model.perform_completion!(prompt, user)
|
||||
context "with functions" do
|
||||
let(:generic_prompt) do
|
||||
{
|
||||
insts: "You can tell me the weather",
|
||||
input: "Return the weather in Sydney",
|
||||
tools: [tool],
|
||||
}
|
||||
end
|
||||
|
||||
expect(AiApiAuditLog.count).to eq(1)
|
||||
log = AiApiAuditLog.first
|
||||
before { stub_response(prompt, tool_call, tool_call: true) }
|
||||
|
||||
response_body = response(response_text).to_json
|
||||
it "returns a function invocation" do
|
||||
completion_response = model.perform_completion!(dialect, user)
|
||||
|
||||
expect(log.provider_id).to eq(model.provider_id)
|
||||
expect(log.user_id).to eq(user.id)
|
||||
expect(log.raw_request_payload).to eq(request_body)
|
||||
expect(log.raw_response_payload).to eq(response_body)
|
||||
expect(log.request_tokens).to eq(model.prompt_size(prompt))
|
||||
expect(log.response_tokens).to eq(model.tokenizer.size(response_text))
|
||||
expect(completion_response).to eq(invocation)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context "when using stream mode" do
|
||||
let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] }
|
||||
context "with simple prompts" do
|
||||
let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] }
|
||||
|
||||
before { stub_streamed_response(prompt, deltas) }
|
||||
before { stub_streamed_response(prompt, deltas) }
|
||||
|
||||
it "can complete a trivial prompt" do
|
||||
completion_response = +""
|
||||
it "can complete a trivial prompt" do
|
||||
completion_response = +""
|
||||
|
||||
model.perform_completion!(prompt, user) do |partial, cancel|
|
||||
completion_response << partial
|
||||
cancel.call if completion_response.split(" ").length == 2
|
||||
model.perform_completion!(dialect, user) do |partial, cancel|
|
||||
completion_response << partial
|
||||
cancel.call if completion_response.split(" ").length == 2
|
||||
end
|
||||
|
||||
expect(completion_response).to eq(deltas[0...-1].join)
|
||||
end
|
||||
|
||||
expect(completion_response).to eq(deltas[0...-1].join)
|
||||
it "creates an audit log and updates is on each read." do
|
||||
completion_response = +""
|
||||
|
||||
model.perform_completion!(dialect, user) do |partial, cancel|
|
||||
completion_response << partial
|
||||
cancel.call if completion_response.split(" ").length == 2
|
||||
end
|
||||
|
||||
expect(AiApiAuditLog.count).to eq(1)
|
||||
log = AiApiAuditLog.first
|
||||
|
||||
expect(log.provider_id).to eq(model.provider_id)
|
||||
expect(log.user_id).to eq(user.id)
|
||||
expect(log.raw_request_payload).to eq(stream_request_body)
|
||||
expect(log.raw_response_payload).to be_present
|
||||
expect(log.request_tokens).to eq(model.prompt_size(prompt))
|
||||
expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join))
|
||||
end
|
||||
end
|
||||
|
||||
it "creates an audit log and updates is on each read." do
|
||||
completion_response = +""
|
||||
|
||||
model.perform_completion!(prompt, user) do |partial, cancel|
|
||||
completion_response << partial
|
||||
cancel.call if completion_response.split(" ").length == 2
|
||||
context "with functions" do
|
||||
let(:generic_prompt) do
|
||||
{
|
||||
insts: "You can tell me the weather",
|
||||
input: "Return the weather in Sydney",
|
||||
tools: [tool],
|
||||
}
|
||||
end
|
||||
|
||||
expect(AiApiAuditLog.count).to eq(1)
|
||||
log = AiApiAuditLog.first
|
||||
before { stub_streamed_response(prompt, tool_deltas, tool_call: true) }
|
||||
|
||||
expect(log.provider_id).to eq(model.provider_id)
|
||||
expect(log.user_id).to eq(user.id)
|
||||
expect(log.raw_request_payload).to eq(stream_request_body)
|
||||
expect(log.raw_response_payload).to be_present
|
||||
expect(log.request_tokens).to eq(model.prompt_size(prompt))
|
||||
expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join))
|
||||
it "waits for the invocation to finish before calling the partial" do
|
||||
buffered_partial = ""
|
||||
|
||||
model.perform_completion!(dialect, user) do |partial, cancel|
|
||||
buffered_partial = partial if partial.include?("<function_calls>")
|
||||
end
|
||||
|
||||
expect(buffered_partial).to eq(invocation)
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -6,22 +6,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
||||
|
||||
let(:model_name) { "gemini-pro" }
|
||||
let(:prompt) do
|
||||
let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
|
||||
let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:tool_payload) do
|
||||
{
|
||||
name: "get_weather",
|
||||
description: "Get the weather in a city",
|
||||
parameters: [
|
||||
{ name: "location", type: "string", description: "the city name" },
|
||||
{
|
||||
name: "unit",
|
||||
type: "string",
|
||||
description: "the unit of measurement celcius c or fahrenheit f",
|
||||
enum: %w[c f],
|
||||
},
|
||||
],
|
||||
required: %w[location unit],
|
||||
}
|
||||
end
|
||||
|
||||
let(:request_body) do
|
||||
model
|
||||
.default_options
|
||||
.merge(contents: prompt)
|
||||
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] }
|
||||
.to_json
|
||||
end
|
||||
let(:stream_request_body) do
|
||||
model
|
||||
.default_options
|
||||
.merge(contents: prompt)
|
||||
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] }
|
||||
.to_json
|
||||
end
|
||||
|
||||
let(:tool_deltas) do
|
||||
[
|
||||
{ role: "system", content: "You are a helpful bot." },
|
||||
{ role: "user", content: "Write 3 words" },
|
||||
{ "functionCall" => { name: "get_weather", args: {} } },
|
||||
{ "functionCall" => { name: "get_weather", args: { location: "" } } },
|
||||
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
|
||||
]
|
||||
end
|
||||
|
||||
let(:request_body) { model.default_options.merge(contents: prompt).to_json }
|
||||
let(:stream_request_body) { model.default_options.merge(contents: prompt).to_json }
|
||||
let(:tool_call) do
|
||||
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
|
||||
end
|
||||
|
||||
def response(content)
|
||||
def response(content, tool_call: false)
|
||||
{
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: content }],
|
||||
parts: [(tool_call ? content : { text: content })],
|
||||
role: "model",
|
||||
},
|
||||
finishReason: "STOP",
|
||||
|
@ -45,22 +83,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
}
|
||||
end
|
||||
|
||||
def stub_response(prompt, response_text)
|
||||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||
)
|
||||
.with(body: { contents: prompt })
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
.with(body: request_body)
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
||||
end
|
||||
|
||||
def stream_line(delta, finish_reason: nil)
|
||||
def stream_line(delta, finish_reason: nil, tool_call: false)
|
||||
{
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ text: delta }],
|
||||
parts: [(tool_call ? delta : { text: delta })],
|
||||
role: "model",
|
||||
},
|
||||
finishReason: finish_reason,
|
||||
|
@ -76,24 +114,24 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
}.to_json
|
||||
end
|
||||
|
||||
def stub_streamed_response(prompt, deltas)
|
||||
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||
chunks =
|
||||
deltas.each_with_index.map do |_, index|
|
||||
if index == (deltas.length - 1)
|
||||
stream_line(deltas[index], finish_reason: "STOP")
|
||||
stream_line(deltas[index], finish_reason: "STOP", tool_call: tool_call)
|
||||
else
|
||||
stream_line(deltas[index])
|
||||
stream_line(deltas[index], tool_call: tool_call)
|
||||
end
|
||||
end
|
||||
|
||||
chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]")
|
||||
chunks = chunks.join("\n,\n").prepend("[").concat("\n]").split("")
|
||||
|
||||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||
)
|
||||
.with(body: model.default_options.merge(contents: prompt).to_json)
|
||||
.with(body: stream_request_body)
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
|
|
|
@ -6,10 +6,11 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) }
|
||||
|
||||
let(:model_name) { "Llama2-*-chat-hf" }
|
||||
let(:prompt) { <<~TEXT }
|
||||
[INST]<<SYS>>You are a helpful bot.<</SYS>>[/INST]
|
||||
[INST]Write 3 words[/INST]
|
||||
TEXT
|
||||
let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
|
||||
let(:dialect) do
|
||||
DiscourseAi::Completions::Dialects::Llama2Classic.new(generic_prompt, model_name)
|
||||
end
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:request_body) do
|
||||
model
|
||||
|
@ -18,7 +19,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|||
.tap do |payload|
|
||||
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||
model.prompt_size(prompt)
|
||||
payload[:parameters][:return_full_text] = false
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
|
@ -30,7 +30,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|||
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||
model.prompt_size(prompt)
|
||||
payload[:stream] = true
|
||||
payload[:parameters][:return_full_text] = false
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
|
@ -41,14 +40,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|||
[{ generated_text: content }]
|
||||
end
|
||||
|
||||
def stub_response(prompt, response_text)
|
||||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||
.with(body: request_body)
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
end
|
||||
|
||||
def stream_line(delta, finish_reason: nil)
|
||||
def stream_line(delta, deltas, finish_reason: nil)
|
||||
+"data: " << {
|
||||
token: {
|
||||
id: 29_889,
|
||||
|
@ -56,22 +55,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|||
logprob: -0.08319092,
|
||||
special: !!finish_reason,
|
||||
},
|
||||
generated_text: finish_reason ? response_text : nil,
|
||||
generated_text: finish_reason ? deltas.join : nil,
|
||||
details: nil,
|
||||
}.to_json
|
||||
end
|
||||
|
||||
def stub_streamed_response(prompt, deltas)
|
||||
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||
chunks =
|
||||
deltas.each_with_index.map do |_, index|
|
||||
if index == (deltas.length - 1)
|
||||
stream_line(deltas[index], finish_reason: true)
|
||||
stream_line(deltas[index], deltas, finish_reason: true)
|
||||
else
|
||||
stream_line(deltas[index])
|
||||
stream_line(deltas[index], deltas)
|
||||
end
|
||||
end
|
||||
|
||||
chunks = chunks.join("\n\n")
|
||||
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||
|
@ -79,5 +78,29 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
|||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
|
||||
_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
REPLY
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
REPLY
|
||||
|
||||
let(:tool_call) { invocation }
|
||||
|
||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||
end
|
||||
|
|
|
@ -6,17 +6,53 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
||||
|
||||
let(:model_name) { "gpt-3.5-turbo" }
|
||||
let(:prompt) do
|
||||
let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
|
||||
let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
|
||||
let(:prompt) { dialect.translate }
|
||||
|
||||
let(:tool_deltas) do
|
||||
[
|
||||
{ role: "system", content: "You are a helpful bot." },
|
||||
{ role: "user", content: "Write 3 words" },
|
||||
{ id: "get_weather", name: "get_weather", arguments: {} },
|
||||
{ id: "get_weather", name: "get_weather", arguments: { location: "" } },
|
||||
{ id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } },
|
||||
]
|
||||
end
|
||||
|
||||
let(:request_body) { model.default_options.merge(messages: prompt).to_json }
|
||||
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
|
||||
let(:tool_call) do
|
||||
{ id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } }
|
||||
end
|
||||
|
||||
let(:request_body) do
|
||||
model
|
||||
.default_options
|
||||
.merge(messages: prompt)
|
||||
.tap do |b|
|
||||
b[:tools] = generic_prompt[:tools].map do |t|
|
||||
{ type: "function", tool: t }
|
||||
end if generic_prompt[:tools]
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
let(:stream_request_body) do
|
||||
model
|
||||
.default_options
|
||||
.merge(messages: prompt, stream: true)
|
||||
.tap do |b|
|
||||
b[:tools] = generic_prompt[:tools].map do |t|
|
||||
{ type: "function", tool: t }
|
||||
end if generic_prompt[:tools]
|
||||
end
|
||||
.to_json
|
||||
end
|
||||
|
||||
def response(content, tool_call: false)
|
||||
message_content =
|
||||
if tool_call
|
||||
{ tool_calls: [{ function: content }] }
|
||||
else
|
||||
{ content: content }
|
||||
end
|
||||
|
||||
def response(content)
|
||||
{
|
||||
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
|
||||
object: "chat.completion",
|
||||
|
@ -28,45 +64,52 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
total_tokens: 499,
|
||||
},
|
||||
choices: [
|
||||
{ message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 },
|
||||
{ message: { role: "assistant" }.merge(message_content), finish_reason: "stop", index: 0 },
|
||||
],
|
||||
}
|
||||
end
|
||||
|
||||
def stub_response(prompt, response_text)
|
||||
def stub_response(prompt, response_text, tool_call: false)
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
|
||||
.with(body: { model: model_name, messages: prompt })
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||
.with(body: request_body)
|
||||
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
||||
end
|
||||
|
||||
def stream_line(delta, finish_reason: nil)
|
||||
def stream_line(delta, finish_reason: nil, tool_call: false)
|
||||
message_content =
|
||||
if tool_call
|
||||
{ tool_calls: [{ function: delta }] }
|
||||
else
|
||||
{ content: delta }
|
||||
end
|
||||
|
||||
+"data: " << {
|
||||
id: "chatcmpl-#{SecureRandom.hex}",
|
||||
object: "chat.completion.chunk",
|
||||
created: 1_681_283_881,
|
||||
model: "gpt-3.5-turbo-0301",
|
||||
choices: [{ delta: { content: delta } }],
|
||||
choices: [{ delta: message_content }],
|
||||
finish_reason: finish_reason,
|
||||
index: 0,
|
||||
}.to_json
|
||||
end
|
||||
|
||||
def stub_streamed_response(prompt, deltas)
|
||||
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||
chunks =
|
||||
deltas.each_with_index.map do |_, index|
|
||||
if index == (deltas.length - 1)
|
||||
stream_line(deltas[index], finish_reason: "stop_sequence")
|
||||
stream_line(deltas[index], finish_reason: "stop_sequence", tool_call: tool_call)
|
||||
else
|
||||
stream_line(deltas[index])
|
||||
stream_line(deltas[index], tool_call: tool_call)
|
||||
end
|
||||
end
|
||||
|
||||
chunks = chunks.join("\n\n")
|
||||
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
|
||||
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
|
||||
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
|
||||
.with(body: stream_request_body)
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
RSpec.describe DiscourseAi::Completions::Llm do
|
||||
subject(:llm) do
|
||||
described_class.new(
|
||||
DiscourseAi::Completions::Dialects::OrcaStyle.new,
|
||||
DiscourseAi::Completions::Dialects::OrcaStyle,
|
||||
canned_response,
|
||||
"Upstage-Llama-2-*-instruct-v2",
|
||||
)
|
||||
|
|
|
@ -103,7 +103,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
|||
expect(response.status).to eq(200)
|
||||
expect(response.parsed_body["suggestions"].first).to eq(translated_text)
|
||||
expect(response.parsed_body["diff"]).to eq(expected_diff)
|
||||
expect(spy.prompt.last[:content]).to eq(expected_input)
|
||||
expect(spy.prompt.translate.last[:content]).to eq(expected_input)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue