DEV: Tool support for the LLM service. (#366)
This PR adds tool support to available LLMs. We'll buffer tool invocations and return them instead of making users of this service parse the response. It also adds support for conversation context in the generic prompt. It includes bot messages, user messages, and tool invocations, which we'll trim to make sure it doesn't exceed the prompt limit, then translate them to the correct dialect. Finally, It adds some buffering when reading chunks to handle cases when streaming is extremely slow.:M
This commit is contained in:
parent
203906be65
commit
e0bf6adb5b
|
@ -3,33 +3,101 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Dialects
|
module Dialects
|
||||||
class ChatGpt
|
class ChatGpt < Dialect
|
||||||
def self.can_translate?(model_name)
|
class << self
|
||||||
|
def can_translate?(model_name)
|
||||||
%w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
|
%w[gpt-3.5-turbo gpt-4 gpt-3.5-turbo-16k gpt-4-32k].include?(model_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
def translate(generic_prompt)
|
|
||||||
open_ai_prompt = [
|
|
||||||
{
|
|
||||||
role: "system",
|
|
||||||
content: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
|
|
||||||
},
|
|
||||||
]
|
|
||||||
|
|
||||||
if generic_prompt[:examples]
|
|
||||||
generic_prompt[:examples].each do |example_pair|
|
|
||||||
open_ai_prompt << { role: "user", content: example_pair.first }
|
|
||||||
open_ai_prompt << { role: "assistant", content: example_pair.second }
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
open_ai_prompt << { role: "user", content: generic_prompt[:input] }
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def translate
|
||||||
|
open_ai_prompt = [
|
||||||
|
{ role: "system", content: [prompt[:insts], prompt[:post_insts].to_s].join("\n") },
|
||||||
|
]
|
||||||
|
|
||||||
|
if prompt[:examples]
|
||||||
|
prompt[:examples].each do |example_pair|
|
||||||
|
open_ai_prompt << { role: "user", content: example_pair.first }
|
||||||
|
open_ai_prompt << { role: "assistant", content: example_pair.second }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
open_ai_prompt.concat!(conversation_context) if prompt[:conversation_context]
|
||||||
|
|
||||||
|
open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input]
|
||||||
|
|
||||||
|
open_ai_prompt
|
||||||
|
end
|
||||||
|
|
||||||
|
def tools
|
||||||
|
return if prompt[:tools].blank?
|
||||||
|
|
||||||
|
prompt[:tools].map { |t| { type: "function", tool: t } }
|
||||||
|
end
|
||||||
|
|
||||||
|
def conversation_context
|
||||||
|
return [] if prompt[:conversation_context].blank?
|
||||||
|
|
||||||
|
trimmed_context = trim_context(prompt[:conversation_context])
|
||||||
|
|
||||||
|
trimmed_context.reverse.map do |context|
|
||||||
|
translated = context.slice(:content)
|
||||||
|
translated[:role] = context[:type]
|
||||||
|
|
||||||
|
if context[:name]
|
||||||
|
if translated[:role] == "tool"
|
||||||
|
translated[:tool_call_id] = context[:name]
|
||||||
|
else
|
||||||
|
translated[:name] = context[:name]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
translated
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_prompt_tokens
|
||||||
|
# provide a buffer of 120 tokens - our function counting is not
|
||||||
|
# 100% accurate and getting numbers to align exactly is very hard
|
||||||
|
buffer = (opts[:max_tokens_to_sample] || 2500) + 50
|
||||||
|
|
||||||
|
if tools.present?
|
||||||
|
# note this is about 100 tokens over, OpenAI have a more optimal representation
|
||||||
|
@function_size ||= self.class.tokenizer.size(tools.to_json.to_s)
|
||||||
|
buffer += @function_size
|
||||||
|
end
|
||||||
|
|
||||||
|
model_max_tokens - buffer
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def per_message_overhead
|
||||||
|
# open ai defines about 4 tokens per message of overhead
|
||||||
|
4
|
||||||
|
end
|
||||||
|
|
||||||
|
def calculate_message_token(context)
|
||||||
|
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||||
|
end
|
||||||
|
|
||||||
|
def model_max_tokens
|
||||||
|
case model_name
|
||||||
|
when "gpt-3.5-turbo", "gpt-3.5-turbo-16k"
|
||||||
|
16_384
|
||||||
|
when "gpt-4"
|
||||||
|
8192
|
||||||
|
when "gpt-4-32k"
|
||||||
|
32_768
|
||||||
|
else
|
||||||
|
8192
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,26 +3,66 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Dialects
|
module Dialects
|
||||||
class Claude
|
class Claude < Dialect
|
||||||
def self.can_translate?(model_name)
|
class << self
|
||||||
|
def can_translate?(model_name)
|
||||||
%w[claude-instant-1 claude-2].include?(model_name)
|
%w[claude-instant-1 claude-2].include?(model_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
def translate(generic_prompt)
|
|
||||||
claude_prompt = +"Human: #{generic_prompt[:insts]}\n"
|
|
||||||
|
|
||||||
claude_prompt << build_examples(generic_prompt[:examples]) if generic_prompt[:examples]
|
|
||||||
|
|
||||||
claude_prompt << "#{generic_prompt[:input]}\n"
|
|
||||||
|
|
||||||
claude_prompt << "#{generic_prompt[:post_insts]}\n" if generic_prompt[:post_insts]
|
|
||||||
|
|
||||||
claude_prompt << "Assistant:\n"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::AnthropicTokenizer
|
DiscourseAi::Tokenizer::AnthropicTokenizer
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def translate
|
||||||
|
claude_prompt = +"Human: #{prompt[:insts]}\n"
|
||||||
|
|
||||||
|
claude_prompt << build_tools_prompt if prompt[:tools]
|
||||||
|
|
||||||
|
claude_prompt << build_examples(prompt[:examples]) if prompt[:examples]
|
||||||
|
|
||||||
|
claude_prompt << conversation_context if prompt[:conversation_context]
|
||||||
|
|
||||||
|
claude_prompt << "#{prompt[:input]}\n"
|
||||||
|
|
||||||
|
claude_prompt << "#{prompt[:post_insts]}\n" if prompt[:post_insts]
|
||||||
|
|
||||||
|
claude_prompt << "Assistant:\n"
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_prompt_tokens
|
||||||
|
50_000
|
||||||
|
end
|
||||||
|
|
||||||
|
def conversation_context
|
||||||
|
return "" if prompt[:conversation_context].blank?
|
||||||
|
|
||||||
|
trimmed_context = trim_context(prompt[:conversation_context])
|
||||||
|
|
||||||
|
trimmed_context
|
||||||
|
.reverse
|
||||||
|
.reduce(+"") do |memo, context|
|
||||||
|
memo << (context[:type] == "user" ? "Human:" : "Assistant:")
|
||||||
|
|
||||||
|
if context[:type] == "tool"
|
||||||
|
memo << <<~TEXT
|
||||||
|
|
||||||
|
<function_results>
|
||||||
|
<result>
|
||||||
|
<tool_name>#{context[:name]}</tool_name>
|
||||||
|
<json>
|
||||||
|
#{context[:content]}
|
||||||
|
</json>
|
||||||
|
</result>
|
||||||
|
</function_results>
|
||||||
|
TEXT
|
||||||
|
else
|
||||||
|
memo << " " << context[:content] << "\n"
|
||||||
|
end
|
||||||
|
|
||||||
|
memo
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,160 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
module Dialects
|
||||||
|
class Dialect
|
||||||
|
class << self
|
||||||
|
def can_translate?(_model_name)
|
||||||
|
raise NotImplemented
|
||||||
|
end
|
||||||
|
|
||||||
|
def dialect_for(model_name)
|
||||||
|
dialects = [
|
||||||
|
DiscourseAi::Completions::Dialects::Claude,
|
||||||
|
DiscourseAi::Completions::Dialects::Llama2Classic,
|
||||||
|
DiscourseAi::Completions::Dialects::ChatGpt,
|
||||||
|
DiscourseAi::Completions::Dialects::OrcaStyle,
|
||||||
|
DiscourseAi::Completions::Dialects::Gemini,
|
||||||
|
]
|
||||||
|
dialects.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |d|
|
||||||
|
d.can_translate?(model_name)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
raise NotImplemented
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def initialize(generic_prompt, model_name, opts: {})
|
||||||
|
@prompt = generic_prompt
|
||||||
|
@model_name = model_name
|
||||||
|
@opts = opts
|
||||||
|
end
|
||||||
|
|
||||||
|
def translate
|
||||||
|
raise NotImplemented
|
||||||
|
end
|
||||||
|
|
||||||
|
def tools
|
||||||
|
tools = +""
|
||||||
|
|
||||||
|
prompt[:tools].each do |function|
|
||||||
|
parameters = +""
|
||||||
|
if function[:parameters].present?
|
||||||
|
function[:parameters].each do |parameter|
|
||||||
|
parameters << <<~PARAMETER
|
||||||
|
<parameter>
|
||||||
|
<name>#{parameter[:name]}</name>
|
||||||
|
<type>#{parameter[:type]}</type>
|
||||||
|
<description>#{parameter[:description]}</description>
|
||||||
|
<required>#{parameter[:required]}</required>
|
||||||
|
PARAMETER
|
||||||
|
if parameter[:enum]
|
||||||
|
parameters << "<options>#{parameter[:enum].join(",")}</options>\n"
|
||||||
|
end
|
||||||
|
parameters << "</parameter>\n"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
tools << <<~TOOLS
|
||||||
|
<tool_description>
|
||||||
|
<tool_name>#{function[:name]}</tool_name>
|
||||||
|
<description>#{function[:description]}</description>
|
||||||
|
<parameters>
|
||||||
|
#{parameters}</parameters>
|
||||||
|
</tool_description>
|
||||||
|
TOOLS
|
||||||
|
end
|
||||||
|
|
||||||
|
tools
|
||||||
|
end
|
||||||
|
|
||||||
|
def conversation_context
|
||||||
|
raise NotImplemented
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_prompt_tokens
|
||||||
|
raise NotImplemented
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
attr_reader :prompt, :model_name, :opts
|
||||||
|
|
||||||
|
def trim_context(conversation_context)
|
||||||
|
prompt_limit = max_prompt_tokens
|
||||||
|
current_token_count = calculate_token_count_without_context
|
||||||
|
|
||||||
|
conversation_context.reduce([]) do |memo, context|
|
||||||
|
break(memo) if current_token_count >= prompt_limit
|
||||||
|
|
||||||
|
dupped_context = context.dup
|
||||||
|
|
||||||
|
message_tokens = calculate_message_token(dupped_context)
|
||||||
|
|
||||||
|
# Trimming content to make sure we respect token limit.
|
||||||
|
while dupped_context[:content].present? &&
|
||||||
|
message_tokens + current_token_count + per_message_overhead > prompt_limit
|
||||||
|
dupped_context[:content] = dupped_context[:content][0..-100] || ""
|
||||||
|
message_tokens = calculate_message_token(dupped_context)
|
||||||
|
end
|
||||||
|
|
||||||
|
next(memo) if dupped_context[:content].blank?
|
||||||
|
|
||||||
|
current_token_count += message_tokens + per_message_overhead
|
||||||
|
|
||||||
|
memo << dupped_context
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def calculate_token_count_without_context
|
||||||
|
tokenizer = self.class.tokenizer
|
||||||
|
|
||||||
|
examples_count =
|
||||||
|
prompt[:examples].to_a.sum do |pair|
|
||||||
|
tokenizer.size(pair.join) + (per_message_overhead * 2)
|
||||||
|
end
|
||||||
|
input_count = tokenizer.size(prompt[:input].to_s) + per_message_overhead
|
||||||
|
|
||||||
|
examples_count + input_count +
|
||||||
|
prompt
|
||||||
|
.except(:conversation_context, :tools, :examples, :input)
|
||||||
|
.sum { |_, v| tokenizer.size(v) + per_message_overhead }
|
||||||
|
end
|
||||||
|
|
||||||
|
def per_message_overhead
|
||||||
|
0
|
||||||
|
end
|
||||||
|
|
||||||
|
def calculate_message_token(context)
|
||||||
|
self.class.tokenizer.size(context[:content].to_s)
|
||||||
|
end
|
||||||
|
|
||||||
|
def build_tools_prompt
|
||||||
|
return "" if prompt[:tools].blank?
|
||||||
|
|
||||||
|
<<~TEXT
|
||||||
|
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||||
|
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>$TOOL_NAME</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||||
|
...
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
|
||||||
|
Here are the tools available:
|
||||||
|
|
||||||
|
<tools>
|
||||||
|
#{tools}</tools>
|
||||||
|
TEXT
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -3,36 +3,93 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Dialects
|
module Dialects
|
||||||
class Gemini
|
class Gemini < Dialect
|
||||||
def self.can_translate?(model_name)
|
class << self
|
||||||
|
def can_translate?(model_name)
|
||||||
%w[gemini-pro].include?(model_name)
|
%w[gemini-pro].include?(model_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
def translate(generic_prompt)
|
|
||||||
gemini_prompt = [
|
|
||||||
{
|
|
||||||
role: "user",
|
|
||||||
parts: {
|
|
||||||
text: [generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
{ role: "model", parts: { text: "Ok." } },
|
|
||||||
]
|
|
||||||
|
|
||||||
if generic_prompt[:examples]
|
|
||||||
generic_prompt[:examples].each do |example_pair|
|
|
||||||
gemini_prompt << { role: "user", parts: { text: example_pair.first } }
|
|
||||||
gemini_prompt << { role: "model", parts: { text: example_pair.second } }
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
gemini_prompt << { role: "user", parts: { text: generic_prompt[:input] } }
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def translate
|
||||||
|
gemini_prompt = [
|
||||||
|
{
|
||||||
|
role: "user",
|
||||||
|
parts: {
|
||||||
|
text: [prompt[:insts], prompt[:post_insts].to_s].join("\n"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{ role: "model", parts: { text: "Ok." } },
|
||||||
|
]
|
||||||
|
|
||||||
|
if prompt[:examples]
|
||||||
|
prompt[:examples].each do |example_pair|
|
||||||
|
gemini_prompt << { role: "user", parts: { text: example_pair.first } }
|
||||||
|
gemini_prompt << { role: "model", parts: { text: example_pair.second } }
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
gemini_prompt.concat!(conversation_context) if prompt[:conversation_context]
|
||||||
|
|
||||||
|
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
|
||||||
|
end
|
||||||
|
|
||||||
|
def tools
|
||||||
|
return if prompt[:tools].blank?
|
||||||
|
|
||||||
|
translated_tools =
|
||||||
|
prompt[:tools].map do |t|
|
||||||
|
required_fields = []
|
||||||
|
tool = t.dup
|
||||||
|
|
||||||
|
tool[:parameters] = t[:parameters].map do |p|
|
||||||
|
required_fields << p[:name] if p[:required]
|
||||||
|
|
||||||
|
p.except(:required)
|
||||||
|
end
|
||||||
|
|
||||||
|
tool.merge(required: required_fields)
|
||||||
|
end
|
||||||
|
|
||||||
|
[{ function_declarations: translated_tools }]
|
||||||
|
end
|
||||||
|
|
||||||
|
def conversation_context
|
||||||
|
return [] if prompt[:conversation_context].blank?
|
||||||
|
|
||||||
|
trimmed_context = trim_context(prompt[:conversation_context])
|
||||||
|
|
||||||
|
trimmed_context.reverse.map do |context|
|
||||||
|
translated = {}
|
||||||
|
translated[:role] = (context[:type] == "user" ? "user" : "model")
|
||||||
|
|
||||||
|
part = {}
|
||||||
|
|
||||||
|
if context[:type] == "tool"
|
||||||
|
part["functionResponse"] = { name: context[:name], content: context[:content] }
|
||||||
|
else
|
||||||
|
part[:text] = context[:content]
|
||||||
|
end
|
||||||
|
|
||||||
|
translated[:parts] = [part]
|
||||||
|
|
||||||
|
translated
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_prompt_tokens
|
||||||
|
16_384 # 50% of model tokens
|
||||||
|
end
|
||||||
|
|
||||||
|
protected
|
||||||
|
|
||||||
|
def calculate_message_token(context)
|
||||||
|
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,29 +3,74 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Dialects
|
module Dialects
|
||||||
class Llama2Classic
|
class Llama2Classic < Dialect
|
||||||
def self.can_translate?(model_name)
|
class << self
|
||||||
|
def can_translate?(model_name)
|
||||||
%w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
|
%w[Llama2-*-chat-hf Llama2-chat-hf].include?(model_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
def translate(generic_prompt)
|
|
||||||
llama2_prompt =
|
|
||||||
+"[INST]<<SYS>>#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}<</SYS>>[/INST]\n"
|
|
||||||
|
|
||||||
if generic_prompt[:examples]
|
|
||||||
generic_prompt[:examples].each do |example_pair|
|
|
||||||
llama2_prompt << "[INST]#{example_pair.first}[/INST]\n"
|
|
||||||
llama2_prompt << "#{example_pair.second}\n"
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
llama2_prompt << "[INST]#{generic_prompt[:input]}[/INST]\n"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::Llama2Tokenizer
|
DiscourseAi::Tokenizer::Llama2Tokenizer
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def translate
|
||||||
|
llama2_prompt = +<<~TEXT
|
||||||
|
[INST]
|
||||||
|
<<SYS>>
|
||||||
|
#{prompt[:insts]}
|
||||||
|
#{build_tools_prompt}#{prompt[:post_insts]}
|
||||||
|
<</SYS>>
|
||||||
|
[/INST]
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
if prompt[:examples]
|
||||||
|
prompt[:examples].each do |example_pair|
|
||||||
|
llama2_prompt << "[INST]#{example_pair.first}[/INST]\n"
|
||||||
|
llama2_prompt << "#{example_pair.second}\n"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
llama2_prompt << conversation_context if prompt[:conversation_context].present?
|
||||||
|
|
||||||
|
llama2_prompt << "[INST]#{prompt[:input]}[/INST]\n"
|
||||||
|
end
|
||||||
|
|
||||||
|
def conversation_context
|
||||||
|
return "" if prompt[:conversation_context].blank?
|
||||||
|
|
||||||
|
trimmed_context = trim_context(prompt[:conversation_context])
|
||||||
|
|
||||||
|
trimmed_context
|
||||||
|
.reverse
|
||||||
|
.reduce(+"") do |memo, context|
|
||||||
|
if context[:type] == "tool"
|
||||||
|
memo << <<~TEXT
|
||||||
|
[INST]
|
||||||
|
<function_results>
|
||||||
|
<result>
|
||||||
|
<tool_name>#{context[:name]}</tool_name>
|
||||||
|
<json>
|
||||||
|
#{context[:content]}
|
||||||
|
</json>
|
||||||
|
</result>
|
||||||
|
</function_results>
|
||||||
|
[/INST]
|
||||||
|
TEXT
|
||||||
|
elsif context[:type] == "assistant"
|
||||||
|
memo << "[INST]" << context[:content] << "[/INST]\n"
|
||||||
|
else
|
||||||
|
memo << context[:content] << "\n"
|
||||||
|
end
|
||||||
|
|
||||||
|
memo
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_prompt_tokens
|
||||||
|
SiteSetting.ai_hugging_face_token_limit
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,31 +3,70 @@
|
||||||
module DiscourseAi
|
module DiscourseAi
|
||||||
module Completions
|
module Completions
|
||||||
module Dialects
|
module Dialects
|
||||||
class OrcaStyle
|
class OrcaStyle < Dialect
|
||||||
def self.can_translate?(model_name)
|
class << self
|
||||||
|
def can_translate?(model_name)
|
||||||
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name)
|
%w[StableBeluga2 Upstage-Llama-2-*-instruct-v2].include?(model_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
def translate(generic_prompt)
|
|
||||||
orca_style_prompt =
|
|
||||||
+"### System:\n#{[generic_prompt[:insts], generic_prompt[:post_insts].to_s].join("\n")}\n"
|
|
||||||
|
|
||||||
if generic_prompt[:examples]
|
|
||||||
generic_prompt[:examples].each do |example_pair|
|
|
||||||
orca_style_prompt << "### User:\n#{example_pair.first}\n"
|
|
||||||
orca_style_prompt << "### Assistant:\n#{example_pair.second}\n"
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
orca_style_prompt << "### User:\n#{generic_prompt[:input]}\n"
|
|
||||||
|
|
||||||
orca_style_prompt << "### Assistant:\n"
|
|
||||||
end
|
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::Llama2Tokenizer
|
DiscourseAi::Tokenizer::Llama2Tokenizer
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def translate
|
||||||
|
orca_style_prompt = +<<~TEXT
|
||||||
|
### System:
|
||||||
|
#{prompt[:insts]}
|
||||||
|
#{build_tools_prompt}#{prompt[:post_insts]}
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
if prompt[:examples]
|
||||||
|
prompt[:examples].each do |example_pair|
|
||||||
|
orca_style_prompt << "### User:\n#{example_pair.first}\n"
|
||||||
|
orca_style_prompt << "### Assistant:\n#{example_pair.second}\n"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
orca_style_prompt << "### User:\n#{prompt[:input]}\n"
|
||||||
|
|
||||||
|
orca_style_prompt << "### Assistant:\n"
|
||||||
|
end
|
||||||
|
|
||||||
|
def conversation_context
|
||||||
|
return "" if prompt[:conversation_context].blank?
|
||||||
|
|
||||||
|
trimmed_context = trim_context(prompt[:conversation_context])
|
||||||
|
|
||||||
|
trimmed_context
|
||||||
|
.reverse
|
||||||
|
.reduce(+"") do |memo, context|
|
||||||
|
memo << (context[:type] == "user" ? "### User:" : "### Assistant:")
|
||||||
|
|
||||||
|
if context[:type] == "tool"
|
||||||
|
memo << <<~TEXT
|
||||||
|
|
||||||
|
<function_results>
|
||||||
|
<result>
|
||||||
|
<tool_name>#{context[:name]}</tool_name>
|
||||||
|
<json>
|
||||||
|
#{context[:content]}
|
||||||
|
</json>
|
||||||
|
</result>
|
||||||
|
</function_results>
|
||||||
|
TEXT
|
||||||
|
else
|
||||||
|
memo << " " << context[:content] << "\n"
|
||||||
|
end
|
||||||
|
|
||||||
|
memo
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def max_prompt_tokens
|
||||||
|
SiteSetting.ai_hugging_face_token_limit
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -22,7 +22,7 @@ module DiscourseAi
|
||||||
@uri ||= URI("https://api.anthropic.com/v1/complete")
|
@uri ||= URI("https://api.anthropic.com/v1/complete")
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params)
|
def prepare_payload(prompt, model_params, _dialect)
|
||||||
default_options
|
default_options
|
||||||
.merge(model_params)
|
.merge(model_params)
|
||||||
.merge(prompt: prompt)
|
.merge(prompt: prompt)
|
||||||
|
|
|
@ -37,7 +37,7 @@ module DiscourseAi
|
||||||
URI(api_url)
|
URI(api_url)
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params)
|
def prepare_payload(prompt, model_params, _dialect)
|
||||||
default_options.merge(prompt: prompt).merge(model_params)
|
default_options.merge(prompt: prompt).merge(model_params)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -30,9 +30,11 @@ module DiscourseAi
|
||||||
@tokenizer = tokenizer
|
@tokenizer = tokenizer
|
||||||
end
|
end
|
||||||
|
|
||||||
def perform_completion!(prompt, user, model_params = {})
|
def perform_completion!(dialect, user, model_params = {})
|
||||||
@streaming_mode = block_given?
|
@streaming_mode = block_given?
|
||||||
|
|
||||||
|
prompt = dialect.translate
|
||||||
|
|
||||||
Net::HTTP.start(
|
Net::HTTP.start(
|
||||||
model_uri.host,
|
model_uri.host,
|
||||||
model_uri.port,
|
model_uri.port,
|
||||||
|
@ -43,7 +45,10 @@ module DiscourseAi
|
||||||
) do |http|
|
) do |http|
|
||||||
response_data = +""
|
response_data = +""
|
||||||
response_raw = +""
|
response_raw = +""
|
||||||
request_body = prepare_payload(prompt, model_params).to_json
|
|
||||||
|
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
|
||||||
|
partials_raw = +""
|
||||||
|
request_body = prepare_payload(prompt, model_params, dialect).to_json
|
||||||
|
|
||||||
request = prepare_request(request_body)
|
request = prepare_request(request_body)
|
||||||
|
|
||||||
|
@ -66,6 +71,15 @@ module DiscourseAi
|
||||||
if !@streaming_mode
|
if !@streaming_mode
|
||||||
response_raw = response.read_body
|
response_raw = response.read_body
|
||||||
response_data = extract_completion_from(response_raw)
|
response_data = extract_completion_from(response_raw)
|
||||||
|
partials_raw = response_data.to_s
|
||||||
|
|
||||||
|
if has_tool?("", response_data)
|
||||||
|
function_buffer = build_buffer # Nokogiri document
|
||||||
|
function_buffer = add_to_buffer(function_buffer, "", response_data)
|
||||||
|
|
||||||
|
response_data = +function_buffer.at("function_calls").to_s
|
||||||
|
response_data << "\n"
|
||||||
|
end
|
||||||
|
|
||||||
return response_data
|
return response_data
|
||||||
end
|
end
|
||||||
|
@ -75,6 +89,7 @@ module DiscourseAi
|
||||||
cancel = lambda { cancelled = true }
|
cancel = lambda { cancelled = true }
|
||||||
|
|
||||||
leftover = ""
|
leftover = ""
|
||||||
|
function_buffer = build_buffer # Nokogiri document
|
||||||
|
|
||||||
response.read_body do |chunk|
|
response.read_body do |chunk|
|
||||||
if cancelled
|
if cancelled
|
||||||
|
@ -85,6 +100,12 @@ module DiscourseAi
|
||||||
decoded_chunk = decode(chunk)
|
decoded_chunk = decode(chunk)
|
||||||
response_raw << decoded_chunk
|
response_raw << decoded_chunk
|
||||||
|
|
||||||
|
# Buffering for extremely slow streaming.
|
||||||
|
if (leftover + decoded_chunk).length < "data: [DONE]".length
|
||||||
|
leftover += decoded_chunk
|
||||||
|
next
|
||||||
|
end
|
||||||
|
|
||||||
partials_from(leftover + decoded_chunk).each do |raw_partial|
|
partials_from(leftover + decoded_chunk).each do |raw_partial|
|
||||||
next if cancelled
|
next if cancelled
|
||||||
next if raw_partial.blank?
|
next if raw_partial.blank?
|
||||||
|
@ -93,11 +114,27 @@ module DiscourseAi
|
||||||
partial = extract_completion_from(raw_partial)
|
partial = extract_completion_from(raw_partial)
|
||||||
next if partial.nil?
|
next if partial.nil?
|
||||||
leftover = ""
|
leftover = ""
|
||||||
|
|
||||||
|
if has_tool?(response_data, partial)
|
||||||
|
function_buffer = add_to_buffer(function_buffer, response_data, partial)
|
||||||
|
|
||||||
|
if buffering_finished?(dialect.tools, function_buffer)
|
||||||
|
invocation = +function_buffer.at("function_calls").to_s
|
||||||
|
invocation << "\n"
|
||||||
|
|
||||||
|
partials_raw << partial.to_s
|
||||||
|
response_data << invocation
|
||||||
|
|
||||||
|
yield invocation, cancel
|
||||||
|
end
|
||||||
|
else
|
||||||
|
partials_raw << partial
|
||||||
response_data << partial
|
response_data << partial
|
||||||
|
|
||||||
yield partial, cancel if partial
|
yield partial, cancel if partial
|
||||||
|
end
|
||||||
rescue JSON::ParserError
|
rescue JSON::ParserError
|
||||||
leftover = raw_partial
|
leftover += decoded_chunk
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -109,7 +146,7 @@ module DiscourseAi
|
||||||
ensure
|
ensure
|
||||||
if log
|
if log
|
||||||
log.raw_response_payload = response_raw
|
log.raw_response_payload = response_raw
|
||||||
log.response_tokens = tokenizer.size(response_data)
|
log.response_tokens = tokenizer.size(partials_raw)
|
||||||
log.save!
|
log.save!
|
||||||
|
|
||||||
if Rails.env.development?
|
if Rails.env.development?
|
||||||
|
@ -165,6 +202,40 @@ module DiscourseAi
|
||||||
def extract_prompt_for_tokenizer(prompt)
|
def extract_prompt_for_tokenizer(prompt)
|
||||||
prompt
|
prompt
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def build_buffer
|
||||||
|
Nokogiri::HTML5.fragment(<<~TEXT)
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name></tool_name>
|
||||||
|
<tool_id></tool_id>
|
||||||
|
<parameters></parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
TEXT
|
||||||
|
end
|
||||||
|
|
||||||
|
def has_tool?(response, partial)
|
||||||
|
(response + partial).include?("<function_calls>")
|
||||||
|
end
|
||||||
|
|
||||||
|
def add_to_buffer(function_buffer, response_data, partial)
|
||||||
|
new_buffer = Nokogiri::HTML5.fragment(response_data + partial)
|
||||||
|
if tool_name = new_buffer.at("tool_name").text
|
||||||
|
if new_buffer.at("tool_id").nil?
|
||||||
|
tool_id_node =
|
||||||
|
Nokogiri::HTML5::DocumentFragment.parse("\n<tool_id>#{tool_name}</tool_id>")
|
||||||
|
|
||||||
|
new_buffer.at("invoke").children[1].add_next_sibling(tool_id_node)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
new_buffer
|
||||||
|
end
|
||||||
|
|
||||||
|
def buffering_finished?(_available_functions, buffer)
|
||||||
|
buffer.to_s.include?("</function_calls>")
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -31,10 +31,15 @@ module DiscourseAi
|
||||||
cancelled = false
|
cancelled = false
|
||||||
cancel_fn = lambda { cancelled = true }
|
cancel_fn = lambda { cancelled = true }
|
||||||
|
|
||||||
|
# We buffer and return tool invocations in one go.
|
||||||
|
if is_tool?(response)
|
||||||
|
yield(response, cancel_fn)
|
||||||
|
else
|
||||||
response.each_char do |char|
|
response.each_char do |char|
|
||||||
break if cancelled
|
break if cancelled
|
||||||
yield(char, cancel_fn)
|
yield(char, cancel_fn)
|
||||||
end
|
end
|
||||||
|
end
|
||||||
else
|
else
|
||||||
response
|
response
|
||||||
end
|
end
|
||||||
|
@ -43,6 +48,12 @@ module DiscourseAi
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
end
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def is_tool?(response)
|
||||||
|
Nokogiri::HTML5.fragment(response).at("function_calls").present?
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -25,8 +25,11 @@ module DiscourseAi
|
||||||
URI(url)
|
URI(url)
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
default_options.merge(model_params).merge(contents: prompt)
|
default_options
|
||||||
|
.merge(model_params)
|
||||||
|
.merge(contents: prompt)
|
||||||
|
.tap { |payload| payload[:tools] = dialect.tools if dialect.tools.present? }
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
|
@ -36,25 +39,72 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
def extract_completion_from(response_raw)
|
||||||
if @streaming_mode
|
|
||||||
parsed = response_raw
|
|
||||||
else
|
|
||||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||||
end
|
|
||||||
|
|
||||||
completion = dig_text(parsed).to_s
|
response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
|
||||||
|
|
||||||
|
has_function_call = response_h.dig(:functionCall).present?
|
||||||
|
has_function_call ? response_h[:functionCall] : response_h.dig(:text)
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
def partials_from(decoded_chunk)
|
||||||
JSON.parse(decoded_chunk, symbolize_names: true)
|
decoded_chunk
|
||||||
|
.split("\n")
|
||||||
|
.map do |line|
|
||||||
|
if line == ","
|
||||||
|
nil
|
||||||
|
elsif line.starts_with?("[")
|
||||||
|
line[1..-1]
|
||||||
|
elsif line.ends_with?("]")
|
||||||
|
line[0..-1]
|
||||||
|
else
|
||||||
|
line
|
||||||
|
end
|
||||||
|
end
|
||||||
|
.compact_blank
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_prompt_for_tokenizer(prompt)
|
def extract_prompt_for_tokenizer(prompt)
|
||||||
prompt.to_s
|
prompt.to_s
|
||||||
end
|
end
|
||||||
|
|
||||||
def dig_text(response)
|
def has_tool?(_response_data, partial)
|
||||||
response.dig(:candidates, 0, :content, :parts, 0, :text)
|
partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
|
||||||
|
end
|
||||||
|
|
||||||
|
def add_to_buffer(function_buffer, _response_data, partial)
|
||||||
|
if partial[:name].present?
|
||||||
|
function_buffer.at("tool_name").content = partial[:name]
|
||||||
|
function_buffer.at("tool_id").content = partial[:name]
|
||||||
|
end
|
||||||
|
|
||||||
|
if partial[:args]
|
||||||
|
argument_fragments =
|
||||||
|
partial[:args].reduce(+"") do |memo, (arg_name, value)|
|
||||||
|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
|
||||||
|
end
|
||||||
|
argument_fragments << "\n"
|
||||||
|
|
||||||
|
function_buffer.at("parameters").children =
|
||||||
|
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
||||||
|
end
|
||||||
|
|
||||||
|
function_buffer
|
||||||
|
end
|
||||||
|
|
||||||
|
def buffering_finished?(available_functions, buffer)
|
||||||
|
tool_name = buffer.at("tool_name")&.text
|
||||||
|
return false if tool_name.blank?
|
||||||
|
|
||||||
|
signature =
|
||||||
|
available_functions.dig(0, :function_declarations).find { |f| f[:name] == tool_name }
|
||||||
|
|
||||||
|
signature[:parameters].reduce(true) do |memo, param|
|
||||||
|
param_present = buffer.at(param[:name]).present?
|
||||||
|
next(memo) if param_present || !signature[:required].include?(param[:name])
|
||||||
|
|
||||||
|
memo && param_present
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -11,7 +11,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options
|
def default_options
|
||||||
{ parameters: { repetition_penalty: 1.1, temperature: 0.7 } }
|
{ parameters: { repetition_penalty: 1.1, temperature: 0.7, return_full_text: false } }
|
||||||
end
|
end
|
||||||
|
|
||||||
def provider_id
|
def provider_id
|
||||||
|
@ -24,7 +24,7 @@ module DiscourseAi
|
||||||
URI(SiteSetting.ai_hugging_face_api_url)
|
URI(SiteSetting.ai_hugging_face_api_url)
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params)
|
def prepare_payload(prompt, model_params, _dialect)
|
||||||
default_options
|
default_options
|
||||||
.merge(inputs: prompt)
|
.merge(inputs: prompt)
|
||||||
.tap do |payload|
|
.tap do |payload|
|
||||||
|
@ -33,7 +33,6 @@ module DiscourseAi
|
||||||
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
|
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
|
||||||
|
|
||||||
payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
|
payload[:parameters][:max_new_tokens] = token_limit - prompt_size(prompt)
|
||||||
payload[:parameters][:return_full_text] = false
|
|
||||||
|
|
||||||
payload[:stream] = true if @streaming_mode
|
payload[:stream] = true if @streaming_mode
|
||||||
end
|
end
|
||||||
|
|
|
@ -37,11 +37,14 @@ module DiscourseAi
|
||||||
URI(url)
|
URI(url)
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
default_options
|
default_options
|
||||||
.merge(model_params)
|
.merge(model_params)
|
||||||
.merge(messages: prompt)
|
.merge(messages: prompt)
|
||||||
.tap { |payload| payload[:stream] = true if @streaming_mode }
|
.tap do |payload|
|
||||||
|
payload[:stream] = true if @streaming_mode
|
||||||
|
payload[:tools] = dialect.tools if dialect.tools.present?
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
|
@ -62,15 +65,12 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
def extract_completion_from(response_raw)
|
||||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
|
||||||
|
|
||||||
(
|
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
||||||
if @streaming_mode
|
|
||||||
parsed.dig(:choices, 0, :delta, :content)
|
has_function_call = response_h.dig(:tool_calls).present?
|
||||||
else
|
has_function_call ? response_h.dig(:tool_calls, 0, :function) : response_h.dig(:content)
|
||||||
parsed.dig(:choices, 0, :message, :content)
|
|
||||||
end
|
|
||||||
).to_s
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
def partials_from(decoded_chunk)
|
||||||
|
@ -86,6 +86,42 @@ module DiscourseAi
|
||||||
def extract_prompt_for_tokenizer(prompt)
|
def extract_prompt_for_tokenizer(prompt)
|
||||||
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def has_tool?(_response_data, partial)
|
||||||
|
partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
|
||||||
|
end
|
||||||
|
|
||||||
|
def add_to_buffer(function_buffer, _response_data, partial)
|
||||||
|
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
|
||||||
|
function_buffer.at("tool_id").content = partial[:id] if partial[:id].present?
|
||||||
|
|
||||||
|
if partial[:arguments]
|
||||||
|
argument_fragments =
|
||||||
|
partial[:arguments].reduce(+"") do |memo, (arg_name, value)|
|
||||||
|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
|
||||||
|
end
|
||||||
|
argument_fragments << "\n"
|
||||||
|
|
||||||
|
function_buffer.at("parameters").children =
|
||||||
|
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
||||||
|
end
|
||||||
|
|
||||||
|
function_buffer
|
||||||
|
end
|
||||||
|
|
||||||
|
def buffering_finished?(available_functions, buffer)
|
||||||
|
tool_name = buffer.at("tool_name")&.text
|
||||||
|
return false if tool_name.blank?
|
||||||
|
|
||||||
|
signature = available_functions.find { |f| f.dig(:tool, :name) == tool_name }[:tool]
|
||||||
|
|
||||||
|
signature[:parameters].reduce(true) do |memo, param|
|
||||||
|
param_present = buffer.at(param[:name]).present?
|
||||||
|
next(memo) if param_present && !param[:required]
|
||||||
|
|
||||||
|
memo && param_present
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -24,35 +24,26 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def self.proxy(model_name)
|
def self.proxy(model_name)
|
||||||
dialects = [
|
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
|
||||||
DiscourseAi::Completions::Dialects::Claude,
|
|
||||||
DiscourseAi::Completions::Dialects::Llama2Classic,
|
|
||||||
DiscourseAi::Completions::Dialects::ChatGpt,
|
|
||||||
DiscourseAi::Completions::Dialects::OrcaStyle,
|
|
||||||
DiscourseAi::Completions::Dialects::Gemini,
|
|
||||||
]
|
|
||||||
|
|
||||||
dialect =
|
return new(dialect_klass, @canned_response, model_name) if @canned_response
|
||||||
dialects.detect(-> { raise UNKNOWN_MODEL }) { |d| d.can_translate?(model_name) }.new
|
|
||||||
|
|
||||||
return new(dialect, @canned_response, model_name) if @canned_response
|
|
||||||
|
|
||||||
gateway =
|
gateway =
|
||||||
DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new(
|
DiscourseAi::Completions::Endpoints::Base.endpoint_for(model_name).new(
|
||||||
model_name,
|
model_name,
|
||||||
dialect.tokenizer,
|
dialect_klass.tokenizer,
|
||||||
)
|
)
|
||||||
|
|
||||||
new(dialect, gateway, model_name)
|
new(dialect_klass, gateway, model_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(dialect, gateway, model_name)
|
def initialize(dialect_klass, gateway, model_name)
|
||||||
@dialect = dialect
|
@dialect_klass = dialect_klass
|
||||||
@gateway = gateway
|
@gateway = gateway
|
||||||
@model_name = model_name
|
@model_name = model_name
|
||||||
end
|
end
|
||||||
|
|
||||||
delegate :tokenizer, to: :dialect
|
delegate :tokenizer, to: :dialect_klass
|
||||||
|
|
||||||
# @param generic_prompt { Hash } - Prompt using our generic format.
|
# @param generic_prompt { Hash } - Prompt using our generic format.
|
||||||
# We use the following keys from the hash:
|
# We use the following keys from the hash:
|
||||||
|
@ -60,23 +51,64 @@ module DiscourseAi
|
||||||
# - input: String containing user input
|
# - input: String containing user input
|
||||||
# - examples (optional): Array of arrays with examples of input and responses. Each array is a input/response pair like [[example1, response1], [example2, response2]].
|
# - examples (optional): Array of arrays with examples of input and responses. Each array is a input/response pair like [[example1, response1], [example2, response2]].
|
||||||
# - post_insts (optional): Additional instructions for the LLM. Some dialects like Claude add these at the end of the prompt.
|
# - post_insts (optional): Additional instructions for the LLM. Some dialects like Claude add these at the end of the prompt.
|
||||||
|
# - conversation_context (optional): Array of hashes to provide context about an ongoing conversation with the model.
|
||||||
|
# We translate the array in reverse order, meaning the first element would be the most recent message in the conversation.
|
||||||
|
# Example:
|
||||||
|
#
|
||||||
|
# [
|
||||||
|
# { type: "user", name: "user1", content: "This is a new message by a user" },
|
||||||
|
# { type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||||
|
# { type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||||
|
# ]
|
||||||
|
#
|
||||||
|
# - tools (optional - only functions supported): Array of functions a model can call. Each function is defined as a hash. Example:
|
||||||
|
#
|
||||||
|
# {
|
||||||
|
# name: "get_weather",
|
||||||
|
# description: "Get the weather in a city",
|
||||||
|
# parameters: [
|
||||||
|
# { name: "location", type: "string", description: "the city name", required: true },
|
||||||
|
# {
|
||||||
|
# name: "unit",
|
||||||
|
# type: "string",
|
||||||
|
# description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
# enum: %w[c f],
|
||||||
|
# required: true,
|
||||||
|
# },
|
||||||
|
# ],
|
||||||
|
# }
|
||||||
#
|
#
|
||||||
# @param user { User } - User requesting the summary.
|
# @param user { User } - User requesting the summary.
|
||||||
#
|
#
|
||||||
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
|
# @param &on_partial_blk { Block - Optional } - The passed block will get called with the LLM partial response alongside a cancel function.
|
||||||
#
|
#
|
||||||
# @returns { String } - Completion result.
|
# @returns { String } - Completion result.
|
||||||
|
#
|
||||||
|
# When the model invokes a tool, we'll wait until the endpoint finishes replying and feed you a fully-formed tool,
|
||||||
|
# even if you passed a partial_read_blk block. Invocations are strings that look like this:
|
||||||
|
#
|
||||||
|
# <function_calls>
|
||||||
|
# <invoke>
|
||||||
|
# <tool_name>get_weather</tool_name>
|
||||||
|
# <tool_id>get_weather</tool_id>
|
||||||
|
# <parameters>
|
||||||
|
# <location>Sydney</location>
|
||||||
|
# <unit>c</unit>
|
||||||
|
# </parameters>
|
||||||
|
# </invoke>
|
||||||
|
# </function_calls>
|
||||||
|
#
|
||||||
def completion!(generic_prompt, user, &partial_read_blk)
|
def completion!(generic_prompt, user, &partial_read_blk)
|
||||||
prompt = dialect.translate(generic_prompt)
|
|
||||||
|
|
||||||
model_params = generic_prompt.dig(:params, model_name) || {}
|
model_params = generic_prompt.dig(:params, model_name) || {}
|
||||||
|
|
||||||
gateway.perform_completion!(prompt, user, model_params, &partial_read_blk)
|
dialect = dialect_klass.new(generic_prompt, model_name, opts: model_params)
|
||||||
|
|
||||||
|
gateway.perform_completion!(dialect, user, model_params, &partial_read_blk)
|
||||||
end
|
end
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
attr_reader :dialect, :gateway, :model_name
|
attr_reader :dialect_klass, :gateway, :model_name
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,7 +1,24 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
||||||
subject(:dialect) { described_class.new }
|
subject(:dialect) { described_class.new(prompt, "gpt-4") }
|
||||||
|
|
||||||
|
let(:tool) do
|
||||||
|
{
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: [
|
||||||
|
{ name: "location", type: "string", description: "the city name", required: true },
|
||||||
|
{
|
||||||
|
name: "unit",
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
let(:prompt) do
|
let(:prompt) do
|
||||||
{
|
{
|
||||||
|
@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
||||||
TEXT
|
TEXT
|
||||||
post_insts:
|
post_insts:
|
||||||
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
|
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
|
||||||
|
tools: [tool],
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -35,7 +53,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
||||||
{ role: "user", content: prompt[:input] },
|
{ role: "user", content: prompt[:input] },
|
||||||
]
|
]
|
||||||
|
|
||||||
translated = dialect.translate(prompt)
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated).to contain_exactly(*open_ai_version)
|
expect(translated).to contain_exactly(*open_ai_version)
|
||||||
end
|
end
|
||||||
|
@ -55,9 +73,51 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
||||||
{ role: "user", content: prompt[:input] },
|
{ role: "user", content: prompt[:input] },
|
||||||
]
|
]
|
||||||
|
|
||||||
translated = dialect.translate(prompt)
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated).to contain_exactly(*open_ai_version)
|
expect(translated).to contain_exactly(*open_ai_version)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "#conversation_context" do
|
||||||
|
let(:context) do
|
||||||
|
[
|
||||||
|
{ type: "user", name: "user1", content: "This is a new message by a user" },
|
||||||
|
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||||
|
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
it "adds conversation in reverse order (first == newer)" do
|
||||||
|
prompt[:conversation_context] = context
|
||||||
|
|
||||||
|
translated_context = dialect.conversation_context
|
||||||
|
|
||||||
|
expect(translated_context).to eq(
|
||||||
|
[
|
||||||
|
{ role: "tool", content: context.last[:content], tool_call_id: context.last[:name] },
|
||||||
|
{ role: "assistant", content: context.second[:content] },
|
||||||
|
{ role: "user", content: context.first[:content], name: context.first[:name] },
|
||||||
|
],
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "trims content if it's getting too long" do
|
||||||
|
context.last[:content] = context.last[:content] * 1000
|
||||||
|
|
||||||
|
prompt[:conversation_context] = context
|
||||||
|
|
||||||
|
translated_context = dialect.conversation_context
|
||||||
|
|
||||||
|
expect(translated_context.last[:content].length).to be < context.last[:content].length
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#tools" do
|
||||||
|
it "returns a list of available tools" do
|
||||||
|
open_ai_tool_f = { type: "function", tool: tool }
|
||||||
|
|
||||||
|
expect(subject.tools).to contain_exactly(open_ai_tool_f)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,7 +1,24 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||||
subject(:dialect) { described_class.new }
|
subject(:dialect) { described_class.new(prompt, "claude-2") }
|
||||||
|
|
||||||
|
let(:tool) do
|
||||||
|
{
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: [
|
||||||
|
{ name: "location", type: "string", description: "the city name", required: true },
|
||||||
|
{
|
||||||
|
name: "unit",
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
let(:prompt) do
|
let(:prompt) do
|
||||||
{
|
{
|
||||||
|
@ -37,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||||
Assistant:
|
Assistant:
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
translated = dialect.translate(prompt)
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated).to eq(anthropic_version)
|
expect(translated).to eq(anthropic_version)
|
||||||
end
|
end
|
||||||
|
@ -60,9 +77,111 @@ RSpec.describe DiscourseAi::Completions::Dialects::Claude do
|
||||||
Assistant:
|
Assistant:
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
translated = dialect.translate(prompt)
|
translated = dialect.translate
|
||||||
|
|
||||||
|
expect(translated).to eq(anthropic_version)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "include tools inside the prompt" do
|
||||||
|
prompt[:tools] = [tool]
|
||||||
|
|
||||||
|
anthropic_version = <<~TEXT
|
||||||
|
Human: #{prompt[:insts]}
|
||||||
|
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||||
|
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>$TOOL_NAME</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||||
|
...
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
|
||||||
|
Here are the tools available:
|
||||||
|
|
||||||
|
<tools>
|
||||||
|
#{dialect.tools}</tools>
|
||||||
|
#{prompt[:input]}
|
||||||
|
#{prompt[:post_insts]}
|
||||||
|
Assistant:
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated).to eq(anthropic_version)
|
expect(translated).to eq(anthropic_version)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "#conversation_context" do
|
||||||
|
let(:context) do
|
||||||
|
[
|
||||||
|
{ type: "user", name: "user1", content: "This is a new message by a user" },
|
||||||
|
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||||
|
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
it "adds conversation in reverse order (first == newer)" do
|
||||||
|
prompt[:conversation_context] = context
|
||||||
|
|
||||||
|
expected = <<~TEXT
|
||||||
|
Assistant:
|
||||||
|
<function_results>
|
||||||
|
<result>
|
||||||
|
<tool_name>tool_id</tool_name>
|
||||||
|
<json>
|
||||||
|
#{context.last[:content]}
|
||||||
|
</json>
|
||||||
|
</result>
|
||||||
|
</function_results>
|
||||||
|
Assistant: #{context.second[:content]}
|
||||||
|
Human: #{context.first[:content]}
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
translated_context = dialect.conversation_context
|
||||||
|
|
||||||
|
expect(translated_context).to eq(expected)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "trims content if it's getting too long" do
|
||||||
|
context.last[:content] = context.last[:content] * 10_000
|
||||||
|
prompt[:conversation_context] = context
|
||||||
|
|
||||||
|
translated_context = dialect.conversation_context
|
||||||
|
|
||||||
|
expect(translated_context.length).to be < context.last[:content].length
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#tools" do
|
||||||
|
it "translates tools to the tool syntax" do
|
||||||
|
prompt[:tools] = [tool]
|
||||||
|
|
||||||
|
translated_tool = <<~TEXT
|
||||||
|
<tool_description>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<description>Get the weather in a city</description>
|
||||||
|
<parameters>
|
||||||
|
<parameter>
|
||||||
|
<name>location</name>
|
||||||
|
<type>string</type>
|
||||||
|
<description>the city name</description>
|
||||||
|
<required>true</required>
|
||||||
|
</parameter>
|
||||||
|
<parameter>
|
||||||
|
<name>unit</name>
|
||||||
|
<type>string</type>
|
||||||
|
<description>the unit of measurement celcius c or fahrenheit f</description>
|
||||||
|
<required>true</required>
|
||||||
|
<options>c,f</options>
|
||||||
|
</parameter>
|
||||||
|
</parameters>
|
||||||
|
</tool_description>
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
expect(dialect.tools).to eq(translated_tool)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,7 +1,24 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||||
subject(:dialect) { described_class.new }
|
subject(:dialect) { described_class.new(prompt, "gemini-pro") }
|
||||||
|
|
||||||
|
let(:tool) do
|
||||||
|
{
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: [
|
||||||
|
{ name: "location", type: "string", description: "the city name", required: true },
|
||||||
|
{
|
||||||
|
name: "unit",
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
let(:prompt) do
|
let(:prompt) do
|
||||||
{
|
{
|
||||||
|
@ -25,6 +42,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||||
TEXT
|
TEXT
|
||||||
post_insts:
|
post_insts:
|
||||||
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
|
"Please put the translation between <ai></ai> tags and separate each title with a comma.",
|
||||||
|
tools: [tool],
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -36,7 +54,7 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||||
{ role: "user", parts: { text: prompt[:input] } },
|
{ role: "user", parts: { text: prompt[:input] } },
|
||||||
]
|
]
|
||||||
|
|
||||||
translated = dialect.translate(prompt)
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated).to eq(gemini_version)
|
expect(translated).to eq(gemini_version)
|
||||||
end
|
end
|
||||||
|
@ -57,9 +75,79 @@ RSpec.describe DiscourseAi::Completions::Dialects::Gemini do
|
||||||
{ role: "user", parts: { text: prompt[:input] } },
|
{ role: "user", parts: { text: prompt[:input] } },
|
||||||
]
|
]
|
||||||
|
|
||||||
translated = dialect.translate(prompt)
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated).to contain_exactly(*gemini_version)
|
expect(translated).to contain_exactly(*gemini_version)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "#conversation_context" do
|
||||||
|
let(:context) do
|
||||||
|
[
|
||||||
|
{ type: "user", name: "user1", content: "This is a new message by a user" },
|
||||||
|
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||||
|
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
it "adds conversation in reverse order (first == newer)" do
|
||||||
|
prompt[:conversation_context] = context
|
||||||
|
|
||||||
|
translated_context = dialect.conversation_context
|
||||||
|
|
||||||
|
expect(translated_context).to eq(
|
||||||
|
[
|
||||||
|
{
|
||||||
|
role: "model",
|
||||||
|
parts: [
|
||||||
|
{
|
||||||
|
"functionResponse" => {
|
||||||
|
name: context.last[:name],
|
||||||
|
content: context.last[:content],
|
||||||
|
},
|
||||||
|
},
|
||||||
|
],
|
||||||
|
},
|
||||||
|
{ role: "model", parts: [{ text: context.second[:content] }] },
|
||||||
|
{ role: "user", parts: [{ text: context.first[:content] }] },
|
||||||
|
],
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "trims content if it's getting too long" do
|
||||||
|
context.last[:content] = context.last[:content] * 1000
|
||||||
|
|
||||||
|
prompt[:conversation_context] = context
|
||||||
|
|
||||||
|
translated_context = dialect.conversation_context
|
||||||
|
|
||||||
|
expect(translated_context.last.dig(:parts, 0, :text).length).to be <
|
||||||
|
context.last[:content].length
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#tools" do
|
||||||
|
it "returns a list of available tools" do
|
||||||
|
gemini_tools = {
|
||||||
|
function_declarations: [
|
||||||
|
{
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: [
|
||||||
|
{ name: "location", type: "string", description: "the city name" },
|
||||||
|
{
|
||||||
|
name: "unit",
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
required: %w[location unit],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
expect(subject.tools).to contain_exactly(gemini_tools)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,7 +1,24 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
|
RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
|
||||||
subject(:dialect) { described_class.new }
|
subject(:dialect) { described_class.new(prompt, "Llama2-chat-hf") }
|
||||||
|
|
||||||
|
let(:tool) do
|
||||||
|
{
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: [
|
||||||
|
{ name: "location", type: "string", description: "the city name", required: true },
|
||||||
|
{
|
||||||
|
name: "unit",
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
let(:prompt) do
|
let(:prompt) do
|
||||||
{
|
{
|
||||||
|
@ -31,11 +48,16 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
|
||||||
describe "#translate" do
|
describe "#translate" do
|
||||||
it "translates a prompt written in our generic format to the Llama2 format" do
|
it "translates a prompt written in our generic format to the Llama2 format" do
|
||||||
llama2_classic_version = <<~TEXT
|
llama2_classic_version = <<~TEXT
|
||||||
[INST]<<SYS>>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<</SYS>>[/INST]
|
[INST]
|
||||||
|
<<SYS>>
|
||||||
|
#{prompt[:insts]}
|
||||||
|
#{prompt[:post_insts]}
|
||||||
|
<</SYS>>
|
||||||
|
[/INST]
|
||||||
[INST]#{prompt[:input]}[/INST]
|
[INST]#{prompt[:input]}[/INST]
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
translated = dialect.translate(prompt)
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated).to eq(llama2_classic_version)
|
expect(translated).to eq(llama2_classic_version)
|
||||||
end
|
end
|
||||||
|
@ -49,15 +71,126 @@ RSpec.describe DiscourseAi::Completions::Dialects::Llama2Classic do
|
||||||
]
|
]
|
||||||
|
|
||||||
llama2_classic_version = <<~TEXT
|
llama2_classic_version = <<~TEXT
|
||||||
[INST]<<SYS>>#{[prompt[:insts], prompt[:post_insts]].join("\n")}<</SYS>>[/INST]
|
[INST]
|
||||||
|
<<SYS>>
|
||||||
|
#{prompt[:insts]}
|
||||||
|
#{prompt[:post_insts]}
|
||||||
|
<</SYS>>
|
||||||
|
[/INST]
|
||||||
[INST]#{prompt[:examples][0][0]}[/INST]
|
[INST]#{prompt[:examples][0][0]}[/INST]
|
||||||
#{prompt[:examples][0][1]}
|
#{prompt[:examples][0][1]}
|
||||||
[INST]#{prompt[:input]}[/INST]
|
[INST]#{prompt[:input]}[/INST]
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
translated = dialect.translate(prompt)
|
translated = dialect.translate
|
||||||
|
|
||||||
|
expect(translated).to eq(llama2_classic_version)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "include tools inside the prompt" do
|
||||||
|
prompt[:tools] = [tool]
|
||||||
|
|
||||||
|
llama2_classic_version = <<~TEXT
|
||||||
|
[INST]
|
||||||
|
<<SYS>>
|
||||||
|
#{prompt[:insts]}
|
||||||
|
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||||
|
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>$TOOL_NAME</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||||
|
...
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
|
||||||
|
Here are the tools available:
|
||||||
|
|
||||||
|
<tools>
|
||||||
|
#{dialect.tools}</tools>
|
||||||
|
#{prompt[:post_insts]}
|
||||||
|
<</SYS>>
|
||||||
|
[/INST]
|
||||||
|
[INST]#{prompt[:input]}[/INST]
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated).to eq(llama2_classic_version)
|
expect(translated).to eq(llama2_classic_version)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "#conversation_context" do
|
||||||
|
let(:context) do
|
||||||
|
[
|
||||||
|
{ type: "user", name: "user1", content: "This is a new message by a user" },
|
||||||
|
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||||
|
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
it "adds conversation in reverse order (first == newer)" do
|
||||||
|
prompt[:conversation_context] = context
|
||||||
|
|
||||||
|
expected = <<~TEXT
|
||||||
|
[INST]
|
||||||
|
<function_results>
|
||||||
|
<result>
|
||||||
|
<tool_name>tool_id</tool_name>
|
||||||
|
<json>
|
||||||
|
#{context.last[:content]}
|
||||||
|
</json>
|
||||||
|
</result>
|
||||||
|
</function_results>
|
||||||
|
[/INST]
|
||||||
|
[INST]#{context.second[:content]}[/INST]
|
||||||
|
#{context.first[:content]}
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
translated_context = dialect.conversation_context
|
||||||
|
|
||||||
|
expect(translated_context).to eq(expected)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "trims content if it's getting too long" do
|
||||||
|
context.last[:content] = context.last[:content] * 1_000
|
||||||
|
prompt[:conversation_context] = context
|
||||||
|
|
||||||
|
translated_context = dialect.conversation_context
|
||||||
|
|
||||||
|
expect(translated_context.length).to be < context.last[:content].length
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#tools" do
|
||||||
|
it "translates functions to the tool syntax" do
|
||||||
|
prompt[:tools] = [tool]
|
||||||
|
|
||||||
|
translated_tool = <<~TEXT
|
||||||
|
<tool_description>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<description>Get the weather in a city</description>
|
||||||
|
<parameters>
|
||||||
|
<parameter>
|
||||||
|
<name>location</name>
|
||||||
|
<type>string</type>
|
||||||
|
<description>the city name</description>
|
||||||
|
<required>true</required>
|
||||||
|
</parameter>
|
||||||
|
<parameter>
|
||||||
|
<name>unit</name>
|
||||||
|
<type>string</type>
|
||||||
|
<description>the unit of measurement celcius c or fahrenheit f</description>
|
||||||
|
<required>true</required>
|
||||||
|
<options>c,f</options>
|
||||||
|
</parameter>
|
||||||
|
</parameters>
|
||||||
|
</tool_description>
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
expect(dialect.tools).to eq(translated_tool)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,9 +1,25 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
|
RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
|
||||||
subject(:dialect) { described_class.new }
|
subject(:dialect) { described_class.new(prompt, "StableBeluga2") }
|
||||||
|
|
||||||
|
let(:tool) do
|
||||||
|
{
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: [
|
||||||
|
{ name: "location", type: "string", description: "the city name", required: true },
|
||||||
|
{
|
||||||
|
name: "unit",
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
describe "#translate" do
|
|
||||||
let(:prompt) do
|
let(:prompt) do
|
||||||
{
|
{
|
||||||
insts: <<~TEXT,
|
insts: <<~TEXT,
|
||||||
|
@ -29,16 +45,18 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "#translate" do
|
||||||
it "translates a prompt written in our generic format to the Open AI format" do
|
it "translates a prompt written in our generic format to the Open AI format" do
|
||||||
orca_style_version = <<~TEXT
|
orca_style_version = <<~TEXT
|
||||||
### System:
|
### System:
|
||||||
#{[prompt[:insts], prompt[:post_insts]].join("\n")}
|
#{prompt[:insts]}
|
||||||
|
#{prompt[:post_insts]}
|
||||||
### User:
|
### User:
|
||||||
#{prompt[:input]}
|
#{prompt[:input]}
|
||||||
### Assistant:
|
### Assistant:
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
translated = dialect.translate(prompt)
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated).to eq(orca_style_version)
|
expect(translated).to eq(orca_style_version)
|
||||||
end
|
end
|
||||||
|
@ -53,7 +71,8 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
|
||||||
|
|
||||||
orca_style_version = <<~TEXT
|
orca_style_version = <<~TEXT
|
||||||
### System:
|
### System:
|
||||||
#{[prompt[:insts], prompt[:post_insts]].join("\n")}
|
#{prompt[:insts]}
|
||||||
|
#{prompt[:post_insts]}
|
||||||
### User:
|
### User:
|
||||||
#{prompt[:examples][0][0]}
|
#{prompt[:examples][0][0]}
|
||||||
### Assistant:
|
### Assistant:
|
||||||
|
@ -63,9 +82,113 @@ RSpec.describe DiscourseAi::Completions::Dialects::OrcaStyle do
|
||||||
### Assistant:
|
### Assistant:
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
translated = dialect.translate(prompt)
|
translated = dialect.translate
|
||||||
|
|
||||||
|
expect(translated).to eq(orca_style_version)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "include tools inside the prompt" do
|
||||||
|
prompt[:tools] = [tool]
|
||||||
|
|
||||||
|
orca_style_version = <<~TEXT
|
||||||
|
### System:
|
||||||
|
#{prompt[:insts]}
|
||||||
|
In this environment you have access to a set of tools you can use to answer the user's question.
|
||||||
|
You may call them like this. Only invoke one function at a time and wait for the results before invoking another function:
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>$TOOL_NAME</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<$PARAMETER_NAME>$PARAMETER_VALUE</$PARAMETER_NAME>
|
||||||
|
...
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
|
||||||
|
Here are the tools available:
|
||||||
|
|
||||||
|
<tools>
|
||||||
|
#{dialect.tools}</tools>
|
||||||
|
#{prompt[:post_insts]}
|
||||||
|
### User:
|
||||||
|
#{prompt[:input]}
|
||||||
|
### Assistant:
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
translated = dialect.translate
|
||||||
|
|
||||||
expect(translated).to eq(orca_style_version)
|
expect(translated).to eq(orca_style_version)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "#conversation_context" do
|
||||||
|
let(:context) do
|
||||||
|
[
|
||||||
|
{ type: "user", name: "user1", content: "This is a new message by a user" },
|
||||||
|
{ type: "assistant", content: "I'm a previous bot reply, that's why there's no user" },
|
||||||
|
{ type: "tool", name: "tool_id", content: "I'm a tool result" },
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
|
it "adds conversation in reverse order (first == newer)" do
|
||||||
|
prompt[:conversation_context] = context
|
||||||
|
|
||||||
|
expected = <<~TEXT
|
||||||
|
### Assistant:
|
||||||
|
<function_results>
|
||||||
|
<result>
|
||||||
|
<tool_name>tool_id</tool_name>
|
||||||
|
<json>
|
||||||
|
#{context.last[:content]}
|
||||||
|
</json>
|
||||||
|
</result>
|
||||||
|
</function_results>
|
||||||
|
### Assistant: #{context.second[:content]}
|
||||||
|
### User: #{context.first[:content]}
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
translated_context = dialect.conversation_context
|
||||||
|
|
||||||
|
expect(translated_context).to eq(expected)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "trims content if it's getting too long" do
|
||||||
|
context.last[:content] = context.last[:content] * 1_000
|
||||||
|
prompt[:conversation_context] = context
|
||||||
|
|
||||||
|
translated_context = dialect.conversation_context
|
||||||
|
|
||||||
|
expect(translated_context.length).to be < context.last[:content].length
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
describe "#tools" do
|
||||||
|
it "translates tools to the tool syntax" do
|
||||||
|
prompt[:tools] = [tool]
|
||||||
|
|
||||||
|
translated_tool = <<~TEXT
|
||||||
|
<tool_description>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<description>Get the weather in a city</description>
|
||||||
|
<parameters>
|
||||||
|
<parameter>
|
||||||
|
<name>location</name>
|
||||||
|
<type>string</type>
|
||||||
|
<description>the city name</description>
|
||||||
|
<required>true</required>
|
||||||
|
</parameter>
|
||||||
|
<parameter>
|
||||||
|
<name>unit</name>
|
||||||
|
<type>string</type>
|
||||||
|
<description>the unit of measurement celcius c or fahrenheit f</description>
|
||||||
|
<required>true</required>
|
||||||
|
<options>c,f</options>
|
||||||
|
</parameter>
|
||||||
|
</parameters>
|
||||||
|
</tool_description>
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
expect(dialect.tools).to eq(translated_tool)
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -6,7 +6,9 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::AnthropicTokenizer) }
|
||||||
|
|
||||||
let(:model_name) { "claude-2" }
|
let(:model_name) { "claude-2" }
|
||||||
let(:prompt) { "Human: write 3 words\n\n" }
|
let(:generic_prompt) { { insts: "write 3 words" } }
|
||||||
|
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
|
||||||
|
let(:prompt) { dialect.translate }
|
||||||
|
|
||||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||||
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
|
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
|
||||||
|
@ -23,10 +25,10 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_response(prompt, response_text)
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||||
.with(body: model.default_options.merge(prompt: prompt).to_json)
|
.with(body: request_body)
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -42,7 +44,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
}.to_json
|
}.to_json
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_streamed_response(prompt, deltas)
|
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||||
chunks =
|
chunks =
|
||||||
deltas.each_with_index.map do |_, index|
|
deltas.each_with_index.map do |_, index|
|
||||||
if index == (deltas.length - 1)
|
if index == (deltas.length - 1)
|
||||||
|
@ -52,13 +54,27 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
chunks = chunks.join("\n\n")
|
chunks = chunks.join("\n\n").split("")
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
.stub_request(:post, "https://api.anthropic.com/v1/complete")
|
||||||
.with(body: model.default_options.merge(prompt: prompt, stream: true).to_json)
|
.with(body: stream_request_body)
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
let(:tool_deltas) { ["<function", <<~REPLY] }
|
||||||
|
_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<location>Sydney</location>
|
||||||
|
<unit>c</unit>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
REPLY
|
||||||
|
|
||||||
|
let(:tool_call) { invocation }
|
||||||
|
|
||||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||||
end
|
end
|
||||||
|
|
|
@ -9,10 +9,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
|
|
||||||
let(:model_name) { "claude-2" }
|
let(:model_name) { "claude-2" }
|
||||||
let(:bedrock_name) { "claude-v2" }
|
let(:bedrock_name) { "claude-v2" }
|
||||||
let(:prompt) { "Human: write 3 words\n\n" }
|
let(:generic_prompt) { { insts: "write 3 words" } }
|
||||||
|
let(:dialect) { DiscourseAi::Completions::Dialects::Claude.new(generic_prompt, model_name) }
|
||||||
|
let(:prompt) { dialect.translate }
|
||||||
|
|
||||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||||
let(:stream_request_body) { model.default_options.merge(prompt: prompt).to_json }
|
let(:stream_request_body) { request_body }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_bedrock_access_key_id = "123456"
|
SiteSetting.ai_bedrock_access_key_id = "123456"
|
||||||
|
@ -20,39 +22,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
SiteSetting.ai_bedrock_region = "us-east-1"
|
SiteSetting.ai_bedrock_region = "us-east-1"
|
||||||
end
|
end
|
||||||
|
|
||||||
# Copied from https://github.com/bblimke/webmock/issues/629
|
|
||||||
# Workaround for stubbing a streamed response
|
|
||||||
before do
|
|
||||||
mocked_http =
|
|
||||||
Class.new(Net::HTTP) do
|
|
||||||
def request(*)
|
|
||||||
super do |response|
|
|
||||||
response.instance_eval do
|
|
||||||
def read_body(*, &block)
|
|
||||||
if block_given?
|
|
||||||
@body.each(&block)
|
|
||||||
else
|
|
||||||
super
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
yield response if block_given?
|
|
||||||
|
|
||||||
response
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@original_net_http = Net.send(:remove_const, :HTTP)
|
|
||||||
Net.send(:const_set, :HTTP, mocked_http)
|
|
||||||
end
|
|
||||||
|
|
||||||
after do
|
|
||||||
Net.send(:remove_const, :HTTP)
|
|
||||||
Net.send(:const_set, :HTTP, @original_net_http)
|
|
||||||
end
|
|
||||||
|
|
||||||
def response(content)
|
def response(content)
|
||||||
{
|
{
|
||||||
completion: content,
|
completion: content,
|
||||||
|
@ -65,7 +34,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_response(prompt, response_text)
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(
|
.stub_request(
|
||||||
:post,
|
:post,
|
||||||
|
@ -102,7 +71,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
encoder.encode(message)
|
encoder.encode(message)
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_streamed_response(prompt, deltas)
|
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||||
chunks =
|
chunks =
|
||||||
deltas.each_with_index.map do |_, index|
|
deltas.each_with_index.map do |_, index|
|
||||||
if index == (deltas.length - 1)
|
if index == (deltas.length - 1)
|
||||||
|
@ -121,5 +90,19 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
let(:tool_deltas) { ["<function", <<~REPLY] }
|
||||||
|
_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<location>Sydney</location>
|
||||||
|
<unit>c</unit>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
REPLY
|
||||||
|
|
||||||
|
let(:tool_call) { invocation }
|
||||||
|
|
||||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,22 +1,86 @@
|
||||||
# frozen_string_literal: true
|
# frozen_string_literal: true
|
||||||
|
|
||||||
RSpec.shared_examples "an endpoint that can communicate with a completion service" do
|
RSpec.shared_examples "an endpoint that can communicate with a completion service" do
|
||||||
|
# Copied from https://github.com/bblimke/webmock/issues/629
|
||||||
|
# Workaround for stubbing a streamed response
|
||||||
|
before do
|
||||||
|
mocked_http =
|
||||||
|
Class.new(Net::HTTP) do
|
||||||
|
def request(*)
|
||||||
|
super do |response|
|
||||||
|
response.instance_eval do
|
||||||
|
def read_body(*, &block)
|
||||||
|
if block_given?
|
||||||
|
@body.each(&block)
|
||||||
|
else
|
||||||
|
super
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
yield response if block_given?
|
||||||
|
|
||||||
|
response
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@original_net_http = Net.send(:remove_const, :HTTP)
|
||||||
|
Net.send(:const_set, :HTTP, mocked_http)
|
||||||
|
end
|
||||||
|
|
||||||
|
after do
|
||||||
|
Net.send(:remove_const, :HTTP)
|
||||||
|
Net.send(:const_set, :HTTP, @original_net_http)
|
||||||
|
end
|
||||||
|
|
||||||
describe "#perform_completion!" do
|
describe "#perform_completion!" do
|
||||||
fab!(:user) { Fabricate(:user) }
|
fab!(:user) { Fabricate(:user) }
|
||||||
|
|
||||||
let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
|
let(:tool) do
|
||||||
|
{
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: [
|
||||||
|
{ name: "location", type: "string", description: "the city name", required: true },
|
||||||
|
{
|
||||||
|
name: "unit",
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
required: true,
|
||||||
|
},
|
||||||
|
],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:invocation) { <<~TEXT }
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<tool_id>get_weather</tool_id>
|
||||||
|
<parameters>
|
||||||
|
<location>Sydney</location>
|
||||||
|
<unit>c</unit>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
TEXT
|
||||||
|
|
||||||
context "when using regular mode" do
|
context "when using regular mode" do
|
||||||
|
context "with simple prompts" do
|
||||||
|
let(:response_text) { "1. Serenity\\n2. Laughter\\n3. Adventure" }
|
||||||
|
|
||||||
before { stub_response(prompt, response_text) }
|
before { stub_response(prompt, response_text) }
|
||||||
|
|
||||||
it "can complete a trivial prompt" do
|
it "can complete a trivial prompt" do
|
||||||
completion_response = model.perform_completion!(prompt, user)
|
completion_response = model.perform_completion!(dialect, user)
|
||||||
|
|
||||||
expect(completion_response).to eq(response_text)
|
expect(completion_response).to eq(response_text)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "creates an audit log for the request" do
|
it "creates an audit log for the request" do
|
||||||
model.perform_completion!(prompt, user)
|
model.perform_completion!(dialect, user)
|
||||||
|
|
||||||
expect(AiApiAuditLog.count).to eq(1)
|
expect(AiApiAuditLog.count).to eq(1)
|
||||||
log = AiApiAuditLog.first
|
log = AiApiAuditLog.first
|
||||||
|
@ -32,7 +96,27 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
context "with functions" do
|
||||||
|
let(:generic_prompt) do
|
||||||
|
{
|
||||||
|
insts: "You can tell me the weather",
|
||||||
|
input: "Return the weather in Sydney",
|
||||||
|
tools: [tool],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
before { stub_response(prompt, tool_call, tool_call: true) }
|
||||||
|
|
||||||
|
it "returns a function invocation" do
|
||||||
|
completion_response = model.perform_completion!(dialect, user)
|
||||||
|
|
||||||
|
expect(completion_response).to eq(invocation)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
context "when using stream mode" do
|
context "when using stream mode" do
|
||||||
|
context "with simple prompts" do
|
||||||
let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] }
|
let(:deltas) { ["Mount", "ain", " ", "Tree ", "Frog"] }
|
||||||
|
|
||||||
before { stub_streamed_response(prompt, deltas) }
|
before { stub_streamed_response(prompt, deltas) }
|
||||||
|
@ -40,7 +124,7 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
|
||||||
it "can complete a trivial prompt" do
|
it "can complete a trivial prompt" do
|
||||||
completion_response = +""
|
completion_response = +""
|
||||||
|
|
||||||
model.perform_completion!(prompt, user) do |partial, cancel|
|
model.perform_completion!(dialect, user) do |partial, cancel|
|
||||||
completion_response << partial
|
completion_response << partial
|
||||||
cancel.call if completion_response.split(" ").length == 2
|
cancel.call if completion_response.split(" ").length == 2
|
||||||
end
|
end
|
||||||
|
@ -51,7 +135,7 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
|
||||||
it "creates an audit log and updates is on each read." do
|
it "creates an audit log and updates is on each read." do
|
||||||
completion_response = +""
|
completion_response = +""
|
||||||
|
|
||||||
model.perform_completion!(prompt, user) do |partial, cancel|
|
model.perform_completion!(dialect, user) do |partial, cancel|
|
||||||
completion_response << partial
|
completion_response << partial
|
||||||
cancel.call if completion_response.split(" ").length == 2
|
cancel.call if completion_response.split(" ").length == 2
|
||||||
end
|
end
|
||||||
|
@ -67,5 +151,28 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
|
||||||
expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join))
|
expect(log.response_tokens).to eq(model.tokenizer.size(deltas[0...-1].join))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
context "with functions" do
|
||||||
|
let(:generic_prompt) do
|
||||||
|
{
|
||||||
|
insts: "You can tell me the weather",
|
||||||
|
input: "Return the weather in Sydney",
|
||||||
|
tools: [tool],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
before { stub_streamed_response(prompt, tool_deltas, tool_call: true) }
|
||||||
|
|
||||||
|
it "waits for the invocation to finish before calling the partial" do
|
||||||
|
buffered_partial = ""
|
||||||
|
|
||||||
|
model.perform_completion!(dialect, user) do |partial, cancel|
|
||||||
|
buffered_partial = partial if partial.include?("<function_calls>")
|
||||||
|
end
|
||||||
|
|
||||||
|
expect(buffered_partial).to eq(invocation)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -6,22 +6,60 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
||||||
|
|
||||||
let(:model_name) { "gemini-pro" }
|
let(:model_name) { "gemini-pro" }
|
||||||
let(:prompt) do
|
let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
|
||||||
|
let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
|
||||||
|
let(:prompt) { dialect.translate }
|
||||||
|
|
||||||
|
let(:tool_payload) do
|
||||||
|
{
|
||||||
|
name: "get_weather",
|
||||||
|
description: "Get the weather in a city",
|
||||||
|
parameters: [
|
||||||
|
{ name: "location", type: "string", description: "the city name" },
|
||||||
|
{
|
||||||
|
name: "unit",
|
||||||
|
type: "string",
|
||||||
|
description: "the unit of measurement celcius c or fahrenheit f",
|
||||||
|
enum: %w[c f],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
required: %w[location unit],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:request_body) do
|
||||||
|
model
|
||||||
|
.default_options
|
||||||
|
.merge(contents: prompt)
|
||||||
|
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] }
|
||||||
|
.to_json
|
||||||
|
end
|
||||||
|
let(:stream_request_body) do
|
||||||
|
model
|
||||||
|
.default_options
|
||||||
|
.merge(contents: prompt)
|
||||||
|
.tap { |b| b[:tools] = [{ function_declarations: [tool_payload] }] if generic_prompt[:tools] }
|
||||||
|
.to_json
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:tool_deltas) do
|
||||||
[
|
[
|
||||||
{ role: "system", content: "You are a helpful bot." },
|
{ "functionCall" => { name: "get_weather", args: {} } },
|
||||||
{ role: "user", content: "Write 3 words" },
|
{ "functionCall" => { name: "get_weather", args: { location: "" } } },
|
||||||
|
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } },
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:request_body) { model.default_options.merge(contents: prompt).to_json }
|
let(:tool_call) do
|
||||||
let(:stream_request_body) { model.default_options.merge(contents: prompt).to_json }
|
{ "functionCall" => { name: "get_weather", args: { location: "Sydney", unit: "c" } } }
|
||||||
|
end
|
||||||
|
|
||||||
def response(content)
|
def response(content, tool_call: false)
|
||||||
{
|
{
|
||||||
candidates: [
|
candidates: [
|
||||||
{
|
{
|
||||||
content: {
|
content: {
|
||||||
parts: [{ text: content }],
|
parts: [(tool_call ? content : { text: content })],
|
||||||
role: "model",
|
role: "model",
|
||||||
},
|
},
|
||||||
finishReason: "STOP",
|
finishReason: "STOP",
|
||||||
|
@ -45,22 +83,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_response(prompt, response_text)
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(
|
.stub_request(
|
||||||
:post,
|
:post,
|
||||||
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:generateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||||
)
|
)
|
||||||
.with(body: { contents: prompt })
|
.with(body: request_body)
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
||||||
end
|
end
|
||||||
|
|
||||||
def stream_line(delta, finish_reason: nil)
|
def stream_line(delta, finish_reason: nil, tool_call: false)
|
||||||
{
|
{
|
||||||
candidates: [
|
candidates: [
|
||||||
{
|
{
|
||||||
content: {
|
content: {
|
||||||
parts: [{ text: delta }],
|
parts: [(tool_call ? delta : { text: delta })],
|
||||||
role: "model",
|
role: "model",
|
||||||
},
|
},
|
||||||
finishReason: finish_reason,
|
finishReason: finish_reason,
|
||||||
|
@ -76,24 +114,24 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
}.to_json
|
}.to_json
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_streamed_response(prompt, deltas)
|
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||||
chunks =
|
chunks =
|
||||||
deltas.each_with_index.map do |_, index|
|
deltas.each_with_index.map do |_, index|
|
||||||
if index == (deltas.length - 1)
|
if index == (deltas.length - 1)
|
||||||
stream_line(deltas[index], finish_reason: "STOP")
|
stream_line(deltas[index], finish_reason: "STOP", tool_call: tool_call)
|
||||||
else
|
else
|
||||||
stream_line(deltas[index])
|
stream_line(deltas[index], tool_call: tool_call)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
chunks = chunks.join("\n,\n").prepend("[\n").concat("\n]")
|
chunks = chunks.join("\n,\n").prepend("[").concat("\n]").split("")
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(
|
.stub_request(
|
||||||
:post,
|
:post,
|
||||||
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
"https://generativelanguage.googleapis.com/v1beta/models/#{model_name}:streamGenerateContent?key=#{SiteSetting.ai_gemini_api_key}",
|
||||||
)
|
)
|
||||||
.with(body: model.default_options.merge(contents: prompt).to_json)
|
.with(body: stream_request_body)
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -6,10 +6,11 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) }
|
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::Llama2Tokenizer) }
|
||||||
|
|
||||||
let(:model_name) { "Llama2-*-chat-hf" }
|
let(:model_name) { "Llama2-*-chat-hf" }
|
||||||
let(:prompt) { <<~TEXT }
|
let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
|
||||||
[INST]<<SYS>>You are a helpful bot.<</SYS>>[/INST]
|
let(:dialect) do
|
||||||
[INST]Write 3 words[/INST]
|
DiscourseAi::Completions::Dialects::Llama2Classic.new(generic_prompt, model_name)
|
||||||
TEXT
|
end
|
||||||
|
let(:prompt) { dialect.translate }
|
||||||
|
|
||||||
let(:request_body) do
|
let(:request_body) do
|
||||||
model
|
model
|
||||||
|
@ -18,7 +19,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||||
.tap do |payload|
|
.tap do |payload|
|
||||||
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||||
model.prompt_size(prompt)
|
model.prompt_size(prompt)
|
||||||
payload[:parameters][:return_full_text] = false
|
|
||||||
end
|
end
|
||||||
.to_json
|
.to_json
|
||||||
end
|
end
|
||||||
|
@ -30,7 +30,6 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||||
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
payload[:parameters][:max_new_tokens] = (SiteSetting.ai_hugging_face_token_limit || 4_000) -
|
||||||
model.prompt_size(prompt)
|
model.prompt_size(prompt)
|
||||||
payload[:stream] = true
|
payload[:stream] = true
|
||||||
payload[:parameters][:return_full_text] = false
|
|
||||||
end
|
end
|
||||||
.to_json
|
.to_json
|
||||||
end
|
end
|
||||||
|
@ -41,14 +40,14 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||||
[{ generated_text: content }]
|
[{ generated_text: content }]
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_response(prompt, response_text)
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||||
.with(body: request_body)
|
.with(body: request_body)
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
||||||
end
|
end
|
||||||
|
|
||||||
def stream_line(delta, finish_reason: nil)
|
def stream_line(delta, deltas, finish_reason: nil)
|
||||||
+"data: " << {
|
+"data: " << {
|
||||||
token: {
|
token: {
|
||||||
id: 29_889,
|
id: 29_889,
|
||||||
|
@ -56,22 +55,22 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||||
logprob: -0.08319092,
|
logprob: -0.08319092,
|
||||||
special: !!finish_reason,
|
special: !!finish_reason,
|
||||||
},
|
},
|
||||||
generated_text: finish_reason ? response_text : nil,
|
generated_text: finish_reason ? deltas.join : nil,
|
||||||
details: nil,
|
details: nil,
|
||||||
}.to_json
|
}.to_json
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_streamed_response(prompt, deltas)
|
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||||
chunks =
|
chunks =
|
||||||
deltas.each_with_index.map do |_, index|
|
deltas.each_with_index.map do |_, index|
|
||||||
if index == (deltas.length - 1)
|
if index == (deltas.length - 1)
|
||||||
stream_line(deltas[index], finish_reason: true)
|
stream_line(deltas[index], deltas, finish_reason: true)
|
||||||
else
|
else
|
||||||
stream_line(deltas[index])
|
stream_line(deltas[index], deltas)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
chunks = chunks.join("\n\n")
|
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
.stub_request(:post, "#{SiteSetting.ai_hugging_face_api_url}")
|
||||||
|
@ -79,5 +78,29 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
let(:tool_deltas) { ["<function", <<~REPLY, <<~REPLY] }
|
||||||
|
_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<location>Sydney</location>
|
||||||
|
<unit>c</unit>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
REPLY
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>get_weather</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<location>Sydney</location>
|
||||||
|
<unit>c</unit>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
REPLY
|
||||||
|
|
||||||
|
let(:tool_call) { invocation }
|
||||||
|
|
||||||
it_behaves_like "an endpoint that can communicate with a completion service"
|
it_behaves_like "an endpoint that can communicate with a completion service"
|
||||||
end
|
end
|
||||||
|
|
|
@ -6,17 +6,53 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
subject(:model) { described_class.new(model_name, DiscourseAi::Tokenizer::OpenAiTokenizer) }
|
||||||
|
|
||||||
let(:model_name) { "gpt-3.5-turbo" }
|
let(:model_name) { "gpt-3.5-turbo" }
|
||||||
let(:prompt) do
|
let(:generic_prompt) { { insts: "You are a helpful bot.", input: "write 3 words" } }
|
||||||
|
let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
|
||||||
|
let(:prompt) { dialect.translate }
|
||||||
|
|
||||||
|
let(:tool_deltas) do
|
||||||
[
|
[
|
||||||
{ role: "system", content: "You are a helpful bot." },
|
{ id: "get_weather", name: "get_weather", arguments: {} },
|
||||||
{ role: "user", content: "Write 3 words" },
|
{ id: "get_weather", name: "get_weather", arguments: { location: "" } },
|
||||||
|
{ id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } },
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:request_body) { model.default_options.merge(messages: prompt).to_json }
|
let(:tool_call) do
|
||||||
let(:stream_request_body) { model.default_options.merge(messages: prompt, stream: true).to_json }
|
{ id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } }
|
||||||
|
end
|
||||||
|
|
||||||
|
let(:request_body) do
|
||||||
|
model
|
||||||
|
.default_options
|
||||||
|
.merge(messages: prompt)
|
||||||
|
.tap do |b|
|
||||||
|
b[:tools] = generic_prompt[:tools].map do |t|
|
||||||
|
{ type: "function", tool: t }
|
||||||
|
end if generic_prompt[:tools]
|
||||||
|
end
|
||||||
|
.to_json
|
||||||
|
end
|
||||||
|
let(:stream_request_body) do
|
||||||
|
model
|
||||||
|
.default_options
|
||||||
|
.merge(messages: prompt, stream: true)
|
||||||
|
.tap do |b|
|
||||||
|
b[:tools] = generic_prompt[:tools].map do |t|
|
||||||
|
{ type: "function", tool: t }
|
||||||
|
end if generic_prompt[:tools]
|
||||||
|
end
|
||||||
|
.to_json
|
||||||
|
end
|
||||||
|
|
||||||
|
def response(content, tool_call: false)
|
||||||
|
message_content =
|
||||||
|
if tool_call
|
||||||
|
{ tool_calls: [{ function: content }] }
|
||||||
|
else
|
||||||
|
{ content: content }
|
||||||
|
end
|
||||||
|
|
||||||
def response(content)
|
|
||||||
{
|
{
|
||||||
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
|
id: "chatcmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
|
||||||
object: "chat.completion",
|
object: "chat.completion",
|
||||||
|
@ -28,45 +64,52 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
total_tokens: 499,
|
total_tokens: 499,
|
||||||
},
|
},
|
||||||
choices: [
|
choices: [
|
||||||
{ message: { role: "assistant", content: content }, finish_reason: "stop", index: 0 },
|
{ message: { role: "assistant" }.merge(message_content), finish_reason: "stop", index: 0 },
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_response(prompt, response_text)
|
def stub_response(prompt, response_text, tool_call: false)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
|
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
|
||||||
.with(body: { model: model_name, messages: prompt })
|
.with(body: request_body)
|
||||||
.to_return(status: 200, body: JSON.dump(response(response_text)))
|
.to_return(status: 200, body: JSON.dump(response(response_text, tool_call: tool_call)))
|
||||||
|
end
|
||||||
|
|
||||||
|
def stream_line(delta, finish_reason: nil, tool_call: false)
|
||||||
|
message_content =
|
||||||
|
if tool_call
|
||||||
|
{ tool_calls: [{ function: delta }] }
|
||||||
|
else
|
||||||
|
{ content: delta }
|
||||||
end
|
end
|
||||||
|
|
||||||
def stream_line(delta, finish_reason: nil)
|
|
||||||
+"data: " << {
|
+"data: " << {
|
||||||
id: "chatcmpl-#{SecureRandom.hex}",
|
id: "chatcmpl-#{SecureRandom.hex}",
|
||||||
object: "chat.completion.chunk",
|
object: "chat.completion.chunk",
|
||||||
created: 1_681_283_881,
|
created: 1_681_283_881,
|
||||||
model: "gpt-3.5-turbo-0301",
|
model: "gpt-3.5-turbo-0301",
|
||||||
choices: [{ delta: { content: delta } }],
|
choices: [{ delta: message_content }],
|
||||||
finish_reason: finish_reason,
|
finish_reason: finish_reason,
|
||||||
index: 0,
|
index: 0,
|
||||||
}.to_json
|
}.to_json
|
||||||
end
|
end
|
||||||
|
|
||||||
def stub_streamed_response(prompt, deltas)
|
def stub_streamed_response(prompt, deltas, tool_call: false)
|
||||||
chunks =
|
chunks =
|
||||||
deltas.each_with_index.map do |_, index|
|
deltas.each_with_index.map do |_, index|
|
||||||
if index == (deltas.length - 1)
|
if index == (deltas.length - 1)
|
||||||
stream_line(deltas[index], finish_reason: "stop_sequence")
|
stream_line(deltas[index], finish_reason: "stop_sequence", tool_call: tool_call)
|
||||||
else
|
else
|
||||||
stream_line(deltas[index])
|
stream_line(deltas[index], tool_call: tool_call)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
chunks = chunks.join("\n\n")
|
chunks = (chunks.join("\n\n") << "data: [DONE]").split("")
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
|
.stub_request(:post, "https://api.openai.com/v1/chat/completions")
|
||||||
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
|
.with(body: stream_request_body)
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
RSpec.describe DiscourseAi::Completions::Llm do
|
RSpec.describe DiscourseAi::Completions::Llm do
|
||||||
subject(:llm) do
|
subject(:llm) do
|
||||||
described_class.new(
|
described_class.new(
|
||||||
DiscourseAi::Completions::Dialects::OrcaStyle.new,
|
DiscourseAi::Completions::Dialects::OrcaStyle,
|
||||||
canned_response,
|
canned_response,
|
||||||
"Upstage-Llama-2-*-instruct-v2",
|
"Upstage-Llama-2-*-instruct-v2",
|
||||||
)
|
)
|
||||||
|
|
|
@ -103,7 +103,7 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
|
||||||
expect(response.status).to eq(200)
|
expect(response.status).to eq(200)
|
||||||
expect(response.parsed_body["suggestions"].first).to eq(translated_text)
|
expect(response.parsed_body["suggestions"].first).to eq(translated_text)
|
||||||
expect(response.parsed_body["diff"]).to eq(expected_diff)
|
expect(response.parsed_body["diff"]).to eq(expected_diff)
|
||||||
expect(spy.prompt.last[:content]).to eq(expected_input)
|
expect(spy.prompt.translate.last[:content]).to eq(expected_input)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue