FIX: Correctly translate and read tools for Claude and Chat GPT. (#393)
I tested against the live models for the AI bot migration. It ensures Open AI's tool syntax is correct and we can correctly read the replies. :
This commit is contained in:
parent
cec9bb8910
commit
4182af230a
|
@ -33,7 +33,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
open_ai_prompt.concat!(conversation_context) if prompt[:conversation_context]
|
open_ai_prompt.concat(conversation_context) if prompt[:conversation_context]
|
||||||
|
|
||||||
open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input]
|
open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input]
|
||||||
|
|
||||||
|
@ -43,7 +43,25 @@ module DiscourseAi
|
||||||
def tools
|
def tools
|
||||||
return if prompt[:tools].blank?
|
return if prompt[:tools].blank?
|
||||||
|
|
||||||
prompt[:tools].map { |t| { type: "function", tool: t } }
|
prompt[:tools].map do |t|
|
||||||
|
tool = t.dup
|
||||||
|
|
||||||
|
if tool[:parameters]
|
||||||
|
tool[:parameters] = t[:parameters].reduce(
|
||||||
|
{ type: "object", properties: {}, required: [] },
|
||||||
|
) do |memo, p|
|
||||||
|
name = p[:name]
|
||||||
|
memo[:required] << name if p[:required]
|
||||||
|
|
||||||
|
memo[:properties][name] = p.except(:name, :required, :item_type)
|
||||||
|
|
||||||
|
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
|
||||||
|
memo
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
{ type: "function", function: tool }
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def conversation_context
|
def conversation_context
|
||||||
|
@ -52,18 +70,25 @@ module DiscourseAi
|
||||||
trimmed_context = trim_context(prompt[:conversation_context])
|
trimmed_context = trim_context(prompt[:conversation_context])
|
||||||
|
|
||||||
trimmed_context.reverse.map do |context|
|
trimmed_context.reverse.map do |context|
|
||||||
translated = context.slice(:content)
|
if context[:type] == "tool_call"
|
||||||
translated[:role] = context[:type]
|
{
|
||||||
|
role: "assistant",
|
||||||
|
tool_calls: [{ type: "function", function: context[:content], id: context[:name] }],
|
||||||
|
}
|
||||||
|
else
|
||||||
|
translated = context.slice(:content)
|
||||||
|
translated[:role] = context[:type]
|
||||||
|
|
||||||
if context[:name]
|
if context[:name]
|
||||||
if translated[:role] == "tool"
|
if translated[:role] == "tool"
|
||||||
translated[:tool_call_id] = context[:name]
|
translated[:tool_call_id] = context[:name]
|
||||||
else
|
else
|
||||||
translated[:name] = context[:name]
|
translated[:name] = context[:name]
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
|
||||||
|
|
||||||
translated
|
translated
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -94,7 +119,7 @@ module DiscourseAi
|
||||||
|
|
||||||
def model_max_tokens
|
def model_max_tokens
|
||||||
case model_name
|
case model_name
|
||||||
when "gpt-3.5-turbo", "gpt-3.5-turbo-16k"
|
when "gpt-3.5-turbo-16k"
|
||||||
16_384
|
16_384
|
||||||
when "gpt-4"
|
when "gpt-4"
|
||||||
8192
|
8192
|
||||||
|
|
|
@ -44,6 +44,7 @@ module DiscourseAi
|
||||||
trimmed_context
|
trimmed_context
|
||||||
.reverse
|
.reverse
|
||||||
.reduce(+"") do |memo, context|
|
.reduce(+"") do |memo, context|
|
||||||
|
next(memo) if context[:type] == "tool_call"
|
||||||
memo << (context[:type] == "user" ? "Human:" : "Assistant:")
|
memo << (context[:type] == "user" ? "Human:" : "Assistant:")
|
||||||
|
|
||||||
if context[:type] == "tool"
|
if context[:type] == "tool"
|
||||||
|
|
|
@ -32,7 +32,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
gemini_prompt.concat!(conversation_context) if prompt[:conversation_context]
|
gemini_prompt.concat(conversation_context) if prompt[:conversation_context]
|
||||||
|
|
||||||
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
|
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
|
||||||
end
|
end
|
||||||
|
|
|
@ -44,6 +44,7 @@ module DiscourseAi
|
||||||
trimmed_context
|
trimmed_context
|
||||||
.reverse
|
.reverse
|
||||||
.reduce(+"") do |memo, context|
|
.reduce(+"") do |memo, context|
|
||||||
|
next(memo) if context[:type] == "tool_call"
|
||||||
if context[:type] == "tool"
|
if context[:type] == "tool"
|
||||||
memo << <<~TEXT
|
memo << <<~TEXT
|
||||||
[INST]
|
[INST]
|
||||||
|
|
|
@ -44,6 +44,7 @@ module DiscourseAi
|
||||||
trimmed_context
|
trimmed_context
|
||||||
.reverse
|
.reverse
|
||||||
.reduce(+"") do |memo, context|
|
.reduce(+"") do |memo, context|
|
||||||
|
next(memo) if context[:type] == "tool_call"
|
||||||
memo << "[INST] " if context[:type] == "user"
|
memo << "[INST] " if context[:type] == "user"
|
||||||
|
|
||||||
if context[:type] == "tool"
|
if context[:type] == "tool"
|
||||||
|
|
|
@ -41,6 +41,7 @@ module DiscourseAi
|
||||||
trimmed_context
|
trimmed_context
|
||||||
.reverse
|
.reverse
|
||||||
.reduce(+"") do |memo, context|
|
.reduce(+"") do |memo, context|
|
||||||
|
next(memo) if context[:type] == "tool_call"
|
||||||
memo << (context[:type] == "user" ? "### User:" : "### Assistant:")
|
memo << (context[:type] == "user" ? "### User:" : "### Assistant:")
|
||||||
|
|
||||||
if context[:type] == "tool"
|
if context[:type] == "tool"
|
||||||
|
|
|
@ -14,7 +14,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def default_options
|
def default_options
|
||||||
{ max_tokens_to_sample: 2_000 }
|
{ max_tokens_to_sample: 2_000, stop_sequences: ["\n\nHuman:", "</function_calls>"] }
|
||||||
end
|
end
|
||||||
|
|
||||||
def provider_id
|
def provider_id
|
||||||
|
|
|
@ -74,7 +74,7 @@ module DiscourseAi
|
||||||
response_data = extract_completion_from(response_raw)
|
response_data = extract_completion_from(response_raw)
|
||||||
partials_raw = response_data.to_s
|
partials_raw = response_data.to_s
|
||||||
|
|
||||||
if has_tool?("", response_data)
|
if has_tool?(response_data)
|
||||||
function_buffer = build_buffer # Nokogiri document
|
function_buffer = build_buffer # Nokogiri document
|
||||||
function_buffer = add_to_buffer(function_buffer, "", response_data)
|
function_buffer = add_to_buffer(function_buffer, "", response_data)
|
||||||
|
|
||||||
|
@ -125,26 +125,19 @@ module DiscourseAi
|
||||||
|
|
||||||
begin
|
begin
|
||||||
partial = extract_completion_from(raw_partial)
|
partial = extract_completion_from(raw_partial)
|
||||||
|
next if response_data.empty? && partial.blank?
|
||||||
next if partial.nil?
|
next if partial.nil?
|
||||||
|
|
||||||
if has_tool?(response_data, partial)
|
# Skip yield for tools. We'll buffer and yield later.
|
||||||
function_buffer = add_to_buffer(function_buffer, response_data, partial)
|
if has_tool?(partials_raw)
|
||||||
|
function_buffer = add_to_buffer(function_buffer, partials_raw, partial)
|
||||||
if buffering_finished?(dialect.tools, function_buffer)
|
|
||||||
invocation = +function_buffer.at("function_calls").to_s
|
|
||||||
invocation << "\n"
|
|
||||||
|
|
||||||
partials_raw << partial.to_s
|
|
||||||
response_data << invocation
|
|
||||||
|
|
||||||
yield invocation, cancel
|
|
||||||
end
|
|
||||||
else
|
else
|
||||||
partials_raw << partial
|
|
||||||
response_data << partial
|
response_data << partial
|
||||||
|
|
||||||
yield partial, cancel if partial
|
yield partial, cancel if partial
|
||||||
end
|
end
|
||||||
|
|
||||||
|
partials_raw << partial.to_s
|
||||||
rescue JSON::ParserError
|
rescue JSON::ParserError
|
||||||
leftover = redo_chunk
|
leftover = redo_chunk
|
||||||
json_error = true
|
json_error = true
|
||||||
|
@ -162,6 +155,17 @@ module DiscourseAi
|
||||||
raise if !cancelled
|
raise if !cancelled
|
||||||
end
|
end
|
||||||
|
|
||||||
|
# Once we have the full response, try to return the tool as a XML doc.
|
||||||
|
if has_tool?(partials_raw)
|
||||||
|
if function_buffer.at("tool_name").text.present?
|
||||||
|
invocation = +function_buffer.at("function_calls").to_s
|
||||||
|
invocation << "\n"
|
||||||
|
|
||||||
|
response_data << invocation
|
||||||
|
yield invocation, cancel
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
return response_data
|
return response_data
|
||||||
ensure
|
ensure
|
||||||
if log
|
if log
|
||||||
|
@ -236,12 +240,22 @@ module DiscourseAi
|
||||||
TEXT
|
TEXT
|
||||||
end
|
end
|
||||||
|
|
||||||
def has_tool?(response, partial)
|
def has_tool?(response)
|
||||||
(response + partial).include?("<function_calls>")
|
response.include?("<function")
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_to_buffer(function_buffer, response_data, partial)
|
def add_to_buffer(function_buffer, response_data, partial)
|
||||||
read_function = Nokogiri::HTML5.fragment(response_data + partial)
|
raw_data = (response_data + partial)
|
||||||
|
|
||||||
|
# recover stop word potentially
|
||||||
|
raw_data =
|
||||||
|
raw_data.split("</invoke>").first + "</invoke>\n</function_calls>" if raw_data.split(
|
||||||
|
"</invoke>",
|
||||||
|
).length > 1
|
||||||
|
|
||||||
|
return function_buffer unless raw_data.include?("</invoke>")
|
||||||
|
|
||||||
|
read_function = Nokogiri::HTML5.fragment(raw_data)
|
||||||
|
|
||||||
if tool_name = read_function.at("tool_name").text
|
if tool_name = read_function.at("tool_name").text
|
||||||
function_buffer.at("tool_name").inner_html = tool_name
|
function_buffer.at("tool_name").inner_html = tool_name
|
||||||
|
@ -264,10 +278,6 @@ module DiscourseAi
|
||||||
|
|
||||||
function_buffer
|
function_buffer
|
||||||
end
|
end
|
||||||
|
|
||||||
def buffering_finished?(_available_functions, buffer)
|
|
||||||
buffer.to_s.include?("</function_calls>")
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -43,8 +43,8 @@ module DiscourseAi
|
||||||
|
|
||||||
response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
|
response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
|
||||||
|
|
||||||
has_function_call = response_h.dig(:functionCall).present?
|
@has_function_call ||= response_h.dig(:functionCall).present?
|
||||||
has_function_call ? response_h[:functionCall] : response_h.dig(:text)
|
@has_function_call ? response_h[:functionCall] : response_h.dig(:text)
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
def partials_from(decoded_chunk)
|
||||||
|
@ -68,8 +68,8 @@ module DiscourseAi
|
||||||
prompt.to_s
|
prompt.to_s
|
||||||
end
|
end
|
||||||
|
|
||||||
def has_tool?(_response_data, partial)
|
def has_tool?(_response_data)
|
||||||
partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
|
@has_function_call
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_to_buffer(function_buffer, _response_data, partial)
|
def add_to_buffer(function_buffer, _response_data, partial)
|
||||||
|
@ -91,21 +91,6 @@ module DiscourseAi
|
||||||
|
|
||||||
function_buffer
|
function_buffer
|
||||||
end
|
end
|
||||||
|
|
||||||
def buffering_finished?(available_functions, buffer)
|
|
||||||
tool_name = buffer.at("tool_name")&.text
|
|
||||||
return false if tool_name.blank?
|
|
||||||
|
|
||||||
signature =
|
|
||||||
available_functions.dig(0, :function_declarations).find { |f| f[:name] == tool_name }
|
|
||||||
|
|
||||||
signature[:parameters].reduce(true) do |memo, param|
|
|
||||||
param_present = buffer.at(param[:name]).present?
|
|
||||||
next(memo) if param_present || !signature[:required].include?(param[:name])
|
|
||||||
|
|
||||||
memo && param_present
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -83,8 +83,9 @@ module DiscourseAi
|
||||||
|
|
||||||
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
||||||
|
|
||||||
has_function_call = response_h.dig(:tool_calls).present?
|
@has_function_call ||= response_h.dig(:tool_calls).present?
|
||||||
has_function_call ? response_h.dig(:tool_calls, 0, :function) : response_h.dig(:content)
|
|
||||||
|
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
def partials_from(decoded_chunk)
|
||||||
|
@ -101,41 +102,38 @@ module DiscourseAi
|
||||||
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
||||||
end
|
end
|
||||||
|
|
||||||
def has_tool?(_response_data, partial)
|
def has_tool?(_response_data)
|
||||||
partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
|
@has_function_call
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_to_buffer(function_buffer, _response_data, partial)
|
def add_to_buffer(function_buffer, _response_data, partial)
|
||||||
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
|
@args_buffer ||= +""
|
||||||
function_buffer.at("tool_id").content = partial[:id] if partial[:id].present?
|
|
||||||
|
|
||||||
if partial[:arguments]
|
f_name = partial.dig(:function, :name)
|
||||||
argument_fragments =
|
function_buffer.at("tool_name").content = f_name if f_name
|
||||||
partial[:arguments].reduce(+"") do |memo, (arg_name, value)|
|
function_buffer.at("tool_id").content = partial[:id] if partial[:id]
|
||||||
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
|
|
||||||
end
|
|
||||||
argument_fragments << "\n"
|
|
||||||
|
|
||||||
function_buffer.at("parameters").children =
|
if partial.dig(:function, :arguments).present?
|
||||||
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
@args_buffer << partial.dig(:function, :arguments)
|
||||||
|
|
||||||
|
begin
|
||||||
|
json_args = JSON.parse(@args_buffer, symbolize_names: true)
|
||||||
|
|
||||||
|
argument_fragments =
|
||||||
|
json_args.reduce(+"") do |memo, (arg_name, value)|
|
||||||
|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
|
||||||
|
end
|
||||||
|
argument_fragments << "\n"
|
||||||
|
|
||||||
|
function_buffer.at("parameters").children =
|
||||||
|
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
||||||
|
rescue JSON::ParserError
|
||||||
|
return function_buffer
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
function_buffer
|
function_buffer
|
||||||
end
|
end
|
||||||
|
|
||||||
def buffering_finished?(available_functions, buffer)
|
|
||||||
tool_name = buffer.at("tool_name")&.text
|
|
||||||
return false if tool_name.blank?
|
|
||||||
|
|
||||||
signature = available_functions.find { |f| f.dig(:tool, :name) == tool_name }[:tool]
|
|
||||||
|
|
||||||
signature[:parameters].reduce(true) do |memo, param|
|
|
||||||
param_present = buffer.at(param[:name]).present?
|
|
||||||
next(memo) if param_present && !param[:required]
|
|
||||||
|
|
||||||
memo && param_present
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -115,7 +115,25 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
|
||||||
|
|
||||||
describe "#tools" do
|
describe "#tools" do
|
||||||
it "returns a list of available tools" do
|
it "returns a list of available tools" do
|
||||||
open_ai_tool_f = { type: "function", tool: tool }
|
open_ai_tool_f = {
|
||||||
|
function: {
|
||||||
|
description: tool[:description],
|
||||||
|
name: tool[:name],
|
||||||
|
parameters: {
|
||||||
|
properties:
|
||||||
|
tool[:parameters].reduce({}) do |memo, p|
|
||||||
|
memo[p[:name]] = { description: p[:description], type: p[:type] }
|
||||||
|
|
||||||
|
memo[p[:name]][:enum] = p[:enum] if p[:enum]
|
||||||
|
|
||||||
|
memo
|
||||||
|
end,
|
||||||
|
required: %w[location unit],
|
||||||
|
type: "object",
|
||||||
|
},
|
||||||
|
},
|
||||||
|
type: "function",
|
||||||
|
}
|
||||||
|
|
||||||
expect(subject.tools).to contain_exactly(open_ai_tool_f)
|
expect(subject.tools).to contain_exactly(open_ai_tool_f)
|
||||||
end
|
end
|
||||||
|
|
|
@ -13,6 +13,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||||
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
|
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
|
||||||
|
|
||||||
|
let(:tool_id) { "get_weather" }
|
||||||
|
|
||||||
def response(content)
|
def response(content)
|
||||||
{
|
{
|
||||||
completion: content,
|
completion: content,
|
||||||
|
|
|
@ -16,6 +16,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
|
||||||
let(:stream_request_body) { request_body }
|
let(:stream_request_body) { request_body }
|
||||||
|
|
||||||
|
let(:tool_id) { "get_weather" }
|
||||||
|
|
||||||
before do
|
before do
|
||||||
SiteSetting.ai_bedrock_access_key_id = "123456"
|
SiteSetting.ai_bedrock_access_key_id = "123456"
|
||||||
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
|
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
|
||||||
|
|
|
@ -58,7 +58,7 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
|
||||||
<function_calls>
|
<function_calls>
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name>get_weather</tool_name>
|
<tool_name>get_weather</tool_name>
|
||||||
<tool_id>get_weather</tool_id>
|
<tool_id>#{tool_id || "get_weather"}</tool_id>
|
||||||
<parameters>
|
<parameters>
|
||||||
<location>Sydney</location>
|
<location>Sydney</location>
|
||||||
<unit>c</unit>
|
<unit>c</unit>
|
||||||
|
|
|
@ -10,6 +10,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
|
let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
|
||||||
let(:prompt) { dialect.translate }
|
let(:prompt) { dialect.translate }
|
||||||
|
|
||||||
|
let(:tool_id) { "get_weather" }
|
||||||
|
|
||||||
let(:tool_payload) do
|
let(:tool_payload) do
|
||||||
{
|
{
|
||||||
name: "get_weather",
|
name: "get_weather",
|
||||||
|
|
|
@ -12,6 +12,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
|
||||||
end
|
end
|
||||||
let(:prompt) { dialect.translate }
|
let(:prompt) { dialect.translate }
|
||||||
|
|
||||||
|
let(:tool_id) { "get_weather" }
|
||||||
|
|
||||||
let(:request_body) do
|
let(:request_body) do
|
||||||
model
|
model
|
||||||
.default_options
|
.default_options
|
||||||
|
|
|
@ -10,45 +10,49 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
|
let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
|
||||||
let(:prompt) { dialect.translate }
|
let(:prompt) { dialect.translate }
|
||||||
|
|
||||||
|
let(:tool_id) { "eujbuebfe" }
|
||||||
|
|
||||||
let(:tool_deltas) do
|
let(:tool_deltas) do
|
||||||
[
|
[
|
||||||
{ id: "get_weather", name: "get_weather", arguments: {} },
|
{ id: tool_id, function: {} },
|
||||||
{ id: "get_weather", name: "get_weather", arguments: { location: "" } },
|
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
||||||
{ id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } },
|
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
||||||
|
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
|
||||||
|
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
|
||||||
|
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:tool_call) do
|
let(:tool_call) do
|
||||||
{ id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } }
|
{
|
||||||
|
id: tool_id,
|
||||||
|
function: {
|
||||||
|
name: "get_weather",
|
||||||
|
arguments: { location: "Sydney", unit: "c" }.to_json,
|
||||||
|
},
|
||||||
|
}
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:request_body) do
|
let(:request_body) do
|
||||||
model
|
model
|
||||||
.default_options
|
.default_options
|
||||||
.merge(messages: prompt)
|
.merge(messages: prompt)
|
||||||
.tap do |b|
|
.tap { |b| b[:tools] = dialect.tools if generic_prompt[:tools] }
|
||||||
b[:tools] = generic_prompt[:tools].map do |t|
|
|
||||||
{ type: "function", tool: t }
|
|
||||||
end if generic_prompt[:tools]
|
|
||||||
end
|
|
||||||
.to_json
|
.to_json
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:stream_request_body) do
|
let(:stream_request_body) do
|
||||||
model
|
model
|
||||||
.default_options
|
.default_options
|
||||||
.merge(messages: prompt, stream: true)
|
.merge(messages: prompt, stream: true)
|
||||||
.tap do |b|
|
.tap { |b| b[:tools] = dialect.tools if generic_prompt[:tools] }
|
||||||
b[:tools] = generic_prompt[:tools].map do |t|
|
|
||||||
{ type: "function", tool: t }
|
|
||||||
end if generic_prompt[:tools]
|
|
||||||
end
|
|
||||||
.to_json
|
.to_json
|
||||||
end
|
end
|
||||||
|
|
||||||
def response(content, tool_call: false)
|
def response(content, tool_call: false)
|
||||||
message_content =
|
message_content =
|
||||||
if tool_call
|
if tool_call
|
||||||
{ tool_calls: [{ function: content }] }
|
{ tool_calls: [content] }
|
||||||
else
|
else
|
||||||
{ content: content }
|
{ content: content }
|
||||||
end
|
end
|
||||||
|
@ -79,7 +83,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
def stream_line(delta, finish_reason: nil, tool_call: false)
|
def stream_line(delta, finish_reason: nil, tool_call: false)
|
||||||
message_content =
|
message_content =
|
||||||
if tool_call
|
if tool_call
|
||||||
{ tool_calls: [{ function: delta }] }
|
{ tool_calls: [delta] }
|
||||||
else
|
else
|
||||||
{ content: delta }
|
{ content: delta }
|
||||||
end
|
end
|
||||||
|
|
|
@ -15,6 +15,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
||||||
|
|
||||||
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
|
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
|
||||||
|
|
||||||
|
let(:tool_id) { "get_weather" }
|
||||||
|
|
||||||
def response(content)
|
def response(content)
|
||||||
{
|
{
|
||||||
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
|
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",
|
||||||
|
|
Loading…
Reference in New Issue