diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb
index 777f4184..0c9676a8 100644
--- a/lib/completions/dialects/chat_gpt.rb
+++ b/lib/completions/dialects/chat_gpt.rb
@@ -33,7 +33,7 @@ module DiscourseAi
end
end
- open_ai_prompt.concat!(conversation_context) if prompt[:conversation_context]
+ open_ai_prompt.concat(conversation_context) if prompt[:conversation_context]
open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input]
@@ -43,7 +43,25 @@ module DiscourseAi
def tools
return if prompt[:tools].blank?
- prompt[:tools].map { |t| { type: "function", tool: t } }
+ prompt[:tools].map do |t|
+ tool = t.dup
+
+ if tool[:parameters]
+ tool[:parameters] = t[:parameters].reduce(
+ { type: "object", properties: {}, required: [] },
+ ) do |memo, p|
+ name = p[:name]
+ memo[:required] << name if p[:required]
+
+ memo[:properties][name] = p.except(:name, :required, :item_type)
+
+ memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
+ memo
+ end
+ end
+
+ { type: "function", function: tool }
+ end
end
def conversation_context
@@ -52,18 +70,25 @@ module DiscourseAi
trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context.reverse.map do |context|
- translated = context.slice(:content)
- translated[:role] = context[:type]
+ if context[:type] == "tool_call"
+ {
+ role: "assistant",
+ tool_calls: [{ type: "function", function: context[:content], id: context[:name] }],
+ }
+ else
+ translated = context.slice(:content)
+ translated[:role] = context[:type]
- if context[:name]
- if translated[:role] == "tool"
- translated[:tool_call_id] = context[:name]
- else
- translated[:name] = context[:name]
+ if context[:name]
+ if translated[:role] == "tool"
+ translated[:tool_call_id] = context[:name]
+ else
+ translated[:name] = context[:name]
+ end
end
- end
- translated
+ translated
+ end
end
end
@@ -94,7 +119,7 @@ module DiscourseAi
def model_max_tokens
case model_name
- when "gpt-3.5-turbo", "gpt-3.5-turbo-16k"
+ when "gpt-3.5-turbo-16k"
16_384
when "gpt-4"
8192
diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb
index 28fff8ee..73f9d231 100644
--- a/lib/completions/dialects/claude.rb
+++ b/lib/completions/dialects/claude.rb
@@ -44,6 +44,7 @@ module DiscourseAi
trimmed_context
.reverse
.reduce(+"") do |memo, context|
+ next(memo) if context[:type] == "tool_call"
memo << (context[:type] == "user" ? "Human:" : "Assistant:")
if context[:type] == "tool"
diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb
index b6b938bf..72f3a8ff 100644
--- a/lib/completions/dialects/gemini.rb
+++ b/lib/completions/dialects/gemini.rb
@@ -32,7 +32,7 @@ module DiscourseAi
end
end
- gemini_prompt.concat!(conversation_context) if prompt[:conversation_context]
+ gemini_prompt.concat(conversation_context) if prompt[:conversation_context]
gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
end
diff --git a/lib/completions/dialects/llama2_classic.rb b/lib/completions/dialects/llama2_classic.rb
index 0470a61c..26e541a3 100644
--- a/lib/completions/dialects/llama2_classic.rb
+++ b/lib/completions/dialects/llama2_classic.rb
@@ -44,6 +44,7 @@ module DiscourseAi
trimmed_context
.reverse
.reduce(+"") do |memo, context|
+ next(memo) if context[:type] == "tool_call"
if context[:type] == "tool"
memo << <<~TEXT
[INST]
diff --git a/lib/completions/dialects/mixtral.rb b/lib/completions/dialects/mixtral.rb
index 6fb93d04..75e0f954 100644
--- a/lib/completions/dialects/mixtral.rb
+++ b/lib/completions/dialects/mixtral.rb
@@ -44,6 +44,7 @@ module DiscourseAi
trimmed_context
.reverse
.reduce(+"") do |memo, context|
+ next(memo) if context[:type] == "tool_call"
memo << "[INST] " if context[:type] == "user"
if context[:type] == "tool"
diff --git a/lib/completions/dialects/orca_style.rb b/lib/completions/dialects/orca_style.rb
index fd76f3b5..b89dca01 100644
--- a/lib/completions/dialects/orca_style.rb
+++ b/lib/completions/dialects/orca_style.rb
@@ -41,6 +41,7 @@ module DiscourseAi
trimmed_context
.reverse
.reduce(+"") do |memo, context|
+ next(memo) if context[:type] == "tool_call"
memo << (context[:type] == "user" ? "### User:" : "### Assistant:")
if context[:type] == "tool"
diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb
index 902a375c..98f29634 100644
--- a/lib/completions/endpoints/aws_bedrock.rb
+++ b/lib/completions/endpoints/aws_bedrock.rb
@@ -14,7 +14,7 @@ module DiscourseAi
end
def default_options
- { max_tokens_to_sample: 2_000 }
+ { max_tokens_to_sample: 2_000, stop_sequences: ["\n\nHuman:", ""] }
end
def provider_id
diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index a3cb6d5b..da433c18 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -74,7 +74,7 @@ module DiscourseAi
response_data = extract_completion_from(response_raw)
partials_raw = response_data.to_s
- if has_tool?("", response_data)
+ if has_tool?(response_data)
function_buffer = build_buffer # Nokogiri document
function_buffer = add_to_buffer(function_buffer, "", response_data)
@@ -125,26 +125,19 @@ module DiscourseAi
begin
partial = extract_completion_from(raw_partial)
+ next if response_data.empty? && partial.blank?
next if partial.nil?
- if has_tool?(response_data, partial)
- function_buffer = add_to_buffer(function_buffer, response_data, partial)
-
- if buffering_finished?(dialect.tools, function_buffer)
- invocation = +function_buffer.at("function_calls").to_s
- invocation << "\n"
-
- partials_raw << partial.to_s
- response_data << invocation
-
- yield invocation, cancel
- end
+ # Skip yield for tools. We'll buffer and yield later.
+ if has_tool?(partials_raw)
+ function_buffer = add_to_buffer(function_buffer, partials_raw, partial)
else
- partials_raw << partial
response_data << partial
yield partial, cancel if partial
end
+
+ partials_raw << partial.to_s
rescue JSON::ParserError
leftover = redo_chunk
json_error = true
@@ -162,6 +155,17 @@ module DiscourseAi
raise if !cancelled
end
+ # Once we have the full response, try to return the tool as a XML doc.
+ if has_tool?(partials_raw)
+ if function_buffer.at("tool_name").text.present?
+ invocation = +function_buffer.at("function_calls").to_s
+ invocation << "\n"
+
+ response_data << invocation
+ yield invocation, cancel
+ end
+ end
+
return response_data
ensure
if log
@@ -236,12 +240,22 @@ module DiscourseAi
TEXT
end
- def has_tool?(response, partial)
- (response + partial).include?("")
+ def has_tool?(response)
+ response.include?("").first + "\n" if raw_data.split(
+ "",
+ ).length > 1
+
+ return function_buffer unless raw_data.include?("")
+
+ read_function = Nokogiri::HTML5.fragment(raw_data)
if tool_name = read_function.at("tool_name").text
function_buffer.at("tool_name").inner_html = tool_name
@@ -264,10 +278,6 @@ module DiscourseAi
function_buffer
end
-
- def buffering_finished?(_available_functions, buffer)
- buffer.to_s.include?("")
- end
end
end
end
diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb
index 8fa8a2d3..231309b2 100644
--- a/lib/completions/endpoints/gemini.rb
+++ b/lib/completions/endpoints/gemini.rb
@@ -43,8 +43,8 @@ module DiscourseAi
response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
- has_function_call = response_h.dig(:functionCall).present?
- has_function_call ? response_h[:functionCall] : response_h.dig(:text)
+ @has_function_call ||= response_h.dig(:functionCall).present?
+ @has_function_call ? response_h[:functionCall] : response_h.dig(:text)
end
def partials_from(decoded_chunk)
@@ -68,8 +68,8 @@ module DiscourseAi
prompt.to_s
end
- def has_tool?(_response_data, partial)
- partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
+ def has_tool?(_response_data)
+ @has_function_call
end
def add_to_buffer(function_buffer, _response_data, partial)
@@ -91,21 +91,6 @@ module DiscourseAi
function_buffer
end
-
- def buffering_finished?(available_functions, buffer)
- tool_name = buffer.at("tool_name")&.text
- return false if tool_name.blank?
-
- signature =
- available_functions.dig(0, :function_declarations).find { |f| f[:name] == tool_name }
-
- signature[:parameters].reduce(true) do |memo, param|
- param_present = buffer.at(param[:name]).present?
- next(memo) if param_present || !signature[:required].include?(param[:name])
-
- memo && param_present
- end
- end
end
end
end
diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb
index bb51090d..2a1d29cb 100644
--- a/lib/completions/endpoints/open_ai.rb
+++ b/lib/completions/endpoints/open_ai.rb
@@ -83,8 +83,9 @@ module DiscourseAi
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
- has_function_call = response_h.dig(:tool_calls).present?
- has_function_call ? response_h.dig(:tool_calls, 0, :function) : response_h.dig(:content)
+ @has_function_call ||= response_h.dig(:tool_calls).present?
+
+ @has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
end
def partials_from(decoded_chunk)
@@ -101,41 +102,38 @@ module DiscourseAi
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
- def has_tool?(_response_data, partial)
- partial.is_a?(Hash) && partial.has_key?(:name) # Has function name
+ def has_tool?(_response_data)
+ @has_function_call
end
def add_to_buffer(function_buffer, _response_data, partial)
- function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
- function_buffer.at("tool_id").content = partial[:id] if partial[:id].present?
+ @args_buffer ||= +""
- if partial[:arguments]
- argument_fragments =
- partial[:arguments].reduce(+"") do |memo, (arg_name, value)|
- memo << "\n<#{arg_name}>#{value}#{arg_name}>"
- end
- argument_fragments << "\n"
+ f_name = partial.dig(:function, :name)
+ function_buffer.at("tool_name").content = f_name if f_name
+ function_buffer.at("tool_id").content = partial[:id] if partial[:id]
- function_buffer.at("parameters").children =
- Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
+ if partial.dig(:function, :arguments).present?
+ @args_buffer << partial.dig(:function, :arguments)
+
+ begin
+ json_args = JSON.parse(@args_buffer, symbolize_names: true)
+
+ argument_fragments =
+ json_args.reduce(+"") do |memo, (arg_name, value)|
+ memo << "\n<#{arg_name}>#{value}#{arg_name}>"
+ end
+ argument_fragments << "\n"
+
+ function_buffer.at("parameters").children =
+ Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
+ rescue JSON::ParserError
+ return function_buffer
+ end
end
function_buffer
end
-
- def buffering_finished?(available_functions, buffer)
- tool_name = buffer.at("tool_name")&.text
- return false if tool_name.blank?
-
- signature = available_functions.find { |f| f.dig(:tool, :name) == tool_name }[:tool]
-
- signature[:parameters].reduce(true) do |memo, param|
- param_present = buffer.at(param[:name]).present?
- next(memo) if param_present && !param[:required]
-
- memo && param_present
- end
- end
end
end
end
diff --git a/spec/lib/completions/dialects/chat_gpt_spec.rb b/spec/lib/completions/dialects/chat_gpt_spec.rb
index 2792dbf6..84e348bc 100644
--- a/spec/lib/completions/dialects/chat_gpt_spec.rb
+++ b/spec/lib/completions/dialects/chat_gpt_spec.rb
@@ -115,7 +115,25 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
describe "#tools" do
it "returns a list of available tools" do
- open_ai_tool_f = { type: "function", tool: tool }
+ open_ai_tool_f = {
+ function: {
+ description: tool[:description],
+ name: tool[:name],
+ parameters: {
+ properties:
+ tool[:parameters].reduce({}) do |memo, p|
+ memo[p[:name]] = { description: p[:description], type: p[:type] }
+
+ memo[p[:name]][:enum] = p[:enum] if p[:enum]
+
+ memo
+ end,
+ required: %w[location unit],
+ type: "object",
+ },
+ },
+ type: "function",
+ }
expect(subject.tools).to contain_exactly(open_ai_tool_f)
end
diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb
index 0a57ad29..4c39c2e7 100644
--- a/spec/lib/completions/endpoints/anthropic_spec.rb
+++ b/spec/lib/completions/endpoints/anthropic_spec.rb
@@ -13,6 +13,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
+ let(:tool_id) { "get_weather" }
+
def response(content)
{
completion: content,
diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb
index 2c866898..65999393 100644
--- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb
+++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb
@@ -16,6 +16,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { request_body }
+ let(:tool_id) { "get_weather" }
+
before do
SiteSetting.ai_bedrock_access_key_id = "123456"
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"
diff --git a/spec/lib/completions/endpoints/endpoint_examples.rb b/spec/lib/completions/endpoints/endpoint_examples.rb
index 0de7ef73..0448fcbf 100644
--- a/spec/lib/completions/endpoints/endpoint_examples.rb
+++ b/spec/lib/completions/endpoints/endpoint_examples.rb
@@ -58,7 +58,7 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
get_weather
- get_weather
+ #{tool_id || "get_weather"}
Sydney
c
diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb
index df855e73..195351e3 100644
--- a/spec/lib/completions/endpoints/gemini_spec.rb
+++ b/spec/lib/completions/endpoints/gemini_spec.rb
@@ -10,6 +10,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
+ let(:tool_id) { "get_weather" }
+
let(:tool_payload) do
{
name: "get_weather",
diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb
index 087ca1fc..de69f8ed 100644
--- a/spec/lib/completions/endpoints/hugging_face_spec.rb
+++ b/spec/lib/completions/endpoints/hugging_face_spec.rb
@@ -12,6 +12,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
end
let(:prompt) { dialect.translate }
+ let(:tool_id) { "get_weather" }
+
let(:request_body) do
model
.default_options
diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb
index 2b72164f..7caf1a1f 100644
--- a/spec/lib/completions/endpoints/open_ai_spec.rb
+++ b/spec/lib/completions/endpoints/open_ai_spec.rb
@@ -10,45 +10,49 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate }
+ let(:tool_id) { "eujbuebfe" }
+
let(:tool_deltas) do
[
- { id: "get_weather", name: "get_weather", arguments: {} },
- { id: "get_weather", name: "get_weather", arguments: { location: "" } },
- { id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } },
+ { id: tool_id, function: {} },
+ { id: tool_id, function: { name: "get_weather", arguments: "" } },
+ { id: tool_id, function: { name: "get_weather", arguments: "" } },
+ { id: tool_id, function: { name: "get_weather", arguments: "{" } },
+ { id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
+ { id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
]
end
let(:tool_call) do
- { id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } }
+ {
+ id: tool_id,
+ function: {
+ name: "get_weather",
+ arguments: { location: "Sydney", unit: "c" }.to_json,
+ },
+ }
end
let(:request_body) do
model
.default_options
.merge(messages: prompt)
- .tap do |b|
- b[:tools] = generic_prompt[:tools].map do |t|
- { type: "function", tool: t }
- end if generic_prompt[:tools]
- end
+ .tap { |b| b[:tools] = dialect.tools if generic_prompt[:tools] }
.to_json
end
+
let(:stream_request_body) do
model
.default_options
.merge(messages: prompt, stream: true)
- .tap do |b|
- b[:tools] = generic_prompt[:tools].map do |t|
- { type: "function", tool: t }
- end if generic_prompt[:tools]
- end
+ .tap { |b| b[:tools] = dialect.tools if generic_prompt[:tools] }
.to_json
end
def response(content, tool_call: false)
message_content =
if tool_call
- { tool_calls: [{ function: content }] }
+ { tool_calls: [content] }
else
{ content: content }
end
@@ -79,7 +83,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
def stream_line(delta, finish_reason: nil, tool_call: false)
message_content =
if tool_call
- { tool_calls: [{ function: delta }] }
+ { tool_calls: [delta] }
else
{ content: delta }
end
diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb
index 99bb0151..54d9955d 100644
--- a/spec/lib/completions/endpoints/vllm_spec.rb
+++ b/spec/lib/completions/endpoints/vllm_spec.rb
@@ -15,6 +15,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
+ let(:tool_id) { "get_weather" }
+
def response(content)
{
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",