FIX: Correctly translate and read tools for Claude and Chat GPT. (#393)

I tested against the live models for the AI bot migration. It ensures Open AI's tool syntax is correct and we can correctly read the replies.
:
This commit is contained in:
Roman Rizzi 2024-01-02 11:21:13 -03:00 committed by GitHub
parent cec9bb8910
commit 4182af230a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 154 additions and 100 deletions

View File

@ -33,7 +33,7 @@ module DiscourseAi
end end
end end
open_ai_prompt.concat!(conversation_context) if prompt[:conversation_context] open_ai_prompt.concat(conversation_context) if prompt[:conversation_context]
open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input] open_ai_prompt << { role: "user", content: prompt[:input] } if prompt[:input]
@ -43,7 +43,25 @@ module DiscourseAi
def tools def tools
return if prompt[:tools].blank? return if prompt[:tools].blank?
prompt[:tools].map { |t| { type: "function", tool: t } } prompt[:tools].map do |t|
tool = t.dup
if tool[:parameters]
tool[:parameters] = t[:parameters].reduce(
{ type: "object", properties: {}, required: [] },
) do |memo, p|
name = p[:name]
memo[:required] << name if p[:required]
memo[:properties][name] = p.except(:name, :required, :item_type)
memo[:properties][name][:items] = { type: p[:item_type] } if p[:item_type]
memo
end
end
{ type: "function", function: tool }
end
end end
def conversation_context def conversation_context
@ -52,6 +70,12 @@ module DiscourseAi
trimmed_context = trim_context(prompt[:conversation_context]) trimmed_context = trim_context(prompt[:conversation_context])
trimmed_context.reverse.map do |context| trimmed_context.reverse.map do |context|
if context[:type] == "tool_call"
{
role: "assistant",
tool_calls: [{ type: "function", function: context[:content], id: context[:name] }],
}
else
translated = context.slice(:content) translated = context.slice(:content)
translated[:role] = context[:type] translated[:role] = context[:type]
@ -66,6 +90,7 @@ module DiscourseAi
translated translated
end end
end end
end
def max_prompt_tokens def max_prompt_tokens
# provide a buffer of 120 tokens - our function counting is not # provide a buffer of 120 tokens - our function counting is not
@ -94,7 +119,7 @@ module DiscourseAi
def model_max_tokens def model_max_tokens
case model_name case model_name
when "gpt-3.5-turbo", "gpt-3.5-turbo-16k" when "gpt-3.5-turbo-16k"
16_384 16_384
when "gpt-4" when "gpt-4"
8192 8192

View File

@ -44,6 +44,7 @@ module DiscourseAi
trimmed_context trimmed_context
.reverse .reverse
.reduce(+"") do |memo, context| .reduce(+"") do |memo, context|
next(memo) if context[:type] == "tool_call"
memo << (context[:type] == "user" ? "Human:" : "Assistant:") memo << (context[:type] == "user" ? "Human:" : "Assistant:")
if context[:type] == "tool" if context[:type] == "tool"

View File

@ -32,7 +32,7 @@ module DiscourseAi
end end
end end
gemini_prompt.concat!(conversation_context) if prompt[:conversation_context] gemini_prompt.concat(conversation_context) if prompt[:conversation_context]
gemini_prompt << { role: "user", parts: { text: prompt[:input] } } gemini_prompt << { role: "user", parts: { text: prompt[:input] } }
end end

View File

@ -44,6 +44,7 @@ module DiscourseAi
trimmed_context trimmed_context
.reverse .reverse
.reduce(+"") do |memo, context| .reduce(+"") do |memo, context|
next(memo) if context[:type] == "tool_call"
if context[:type] == "tool" if context[:type] == "tool"
memo << <<~TEXT memo << <<~TEXT
[INST] [INST]

View File

@ -44,6 +44,7 @@ module DiscourseAi
trimmed_context trimmed_context
.reverse .reverse
.reduce(+"") do |memo, context| .reduce(+"") do |memo, context|
next(memo) if context[:type] == "tool_call"
memo << "[INST] " if context[:type] == "user" memo << "[INST] " if context[:type] == "user"
if context[:type] == "tool" if context[:type] == "tool"

View File

@ -41,6 +41,7 @@ module DiscourseAi
trimmed_context trimmed_context
.reverse .reverse
.reduce(+"") do |memo, context| .reduce(+"") do |memo, context|
next(memo) if context[:type] == "tool_call"
memo << (context[:type] == "user" ? "### User:" : "### Assistant:") memo << (context[:type] == "user" ? "### User:" : "### Assistant:")
if context[:type] == "tool" if context[:type] == "tool"

View File

@ -14,7 +14,7 @@ module DiscourseAi
end end
def default_options def default_options
{ max_tokens_to_sample: 2_000 } { max_tokens_to_sample: 2_000, stop_sequences: ["\n\nHuman:", "</function_calls>"] }
end end
def provider_id def provider_id

View File

@ -74,7 +74,7 @@ module DiscourseAi
response_data = extract_completion_from(response_raw) response_data = extract_completion_from(response_raw)
partials_raw = response_data.to_s partials_raw = response_data.to_s
if has_tool?("", response_data) if has_tool?(response_data)
function_buffer = build_buffer # Nokogiri document function_buffer = build_buffer # Nokogiri document
function_buffer = add_to_buffer(function_buffer, "", response_data) function_buffer = add_to_buffer(function_buffer, "", response_data)
@ -125,26 +125,19 @@ module DiscourseAi
begin begin
partial = extract_completion_from(raw_partial) partial = extract_completion_from(raw_partial)
next if response_data.empty? && partial.blank?
next if partial.nil? next if partial.nil?
if has_tool?(response_data, partial) # Skip yield for tools. We'll buffer and yield later.
function_buffer = add_to_buffer(function_buffer, response_data, partial) if has_tool?(partials_raw)
function_buffer = add_to_buffer(function_buffer, partials_raw, partial)
if buffering_finished?(dialect.tools, function_buffer)
invocation = +function_buffer.at("function_calls").to_s
invocation << "\n"
partials_raw << partial.to_s
response_data << invocation
yield invocation, cancel
end
else else
partials_raw << partial
response_data << partial response_data << partial
yield partial, cancel if partial yield partial, cancel if partial
end end
partials_raw << partial.to_s
rescue JSON::ParserError rescue JSON::ParserError
leftover = redo_chunk leftover = redo_chunk
json_error = true json_error = true
@ -162,6 +155,17 @@ module DiscourseAi
raise if !cancelled raise if !cancelled
end end
# Once we have the full response, try to return the tool as a XML doc.
if has_tool?(partials_raw)
if function_buffer.at("tool_name").text.present?
invocation = +function_buffer.at("function_calls").to_s
invocation << "\n"
response_data << invocation
yield invocation, cancel
end
end
return response_data return response_data
ensure ensure
if log if log
@ -236,12 +240,22 @@ module DiscourseAi
TEXT TEXT
end end
def has_tool?(response, partial) def has_tool?(response)
(response + partial).include?("<function_calls>") response.include?("<function")
end end
def add_to_buffer(function_buffer, response_data, partial) def add_to_buffer(function_buffer, response_data, partial)
read_function = Nokogiri::HTML5.fragment(response_data + partial) raw_data = (response_data + partial)
# recover stop word potentially
raw_data =
raw_data.split("</invoke>").first + "</invoke>\n</function_calls>" if raw_data.split(
"</invoke>",
).length > 1
return function_buffer unless raw_data.include?("</invoke>")
read_function = Nokogiri::HTML5.fragment(raw_data)
if tool_name = read_function.at("tool_name").text if tool_name = read_function.at("tool_name").text
function_buffer.at("tool_name").inner_html = tool_name function_buffer.at("tool_name").inner_html = tool_name
@ -264,10 +278,6 @@ module DiscourseAi
function_buffer function_buffer
end end
def buffering_finished?(_available_functions, buffer)
buffer.to_s.include?("</function_calls>")
end
end end
end end
end end

View File

@ -43,8 +43,8 @@ module DiscourseAi
response_h = parsed.dig(:candidates, 0, :content, :parts, 0) response_h = parsed.dig(:candidates, 0, :content, :parts, 0)
has_function_call = response_h.dig(:functionCall).present? @has_function_call ||= response_h.dig(:functionCall).present?
has_function_call ? response_h[:functionCall] : response_h.dig(:text) @has_function_call ? response_h[:functionCall] : response_h.dig(:text)
end end
def partials_from(decoded_chunk) def partials_from(decoded_chunk)
@ -68,8 +68,8 @@ module DiscourseAi
prompt.to_s prompt.to_s
end end
def has_tool?(_response_data, partial) def has_tool?(_response_data)
partial.is_a?(Hash) && partial.has_key?(:name) # Has function name @has_function_call
end end
def add_to_buffer(function_buffer, _response_data, partial) def add_to_buffer(function_buffer, _response_data, partial)
@ -91,21 +91,6 @@ module DiscourseAi
function_buffer function_buffer
end end
def buffering_finished?(available_functions, buffer)
tool_name = buffer.at("tool_name")&.text
return false if tool_name.blank?
signature =
available_functions.dig(0, :function_declarations).find { |f| f[:name] == tool_name }
signature[:parameters].reduce(true) do |memo, param|
param_present = buffer.at(param[:name]).present?
next(memo) if param_present || !signature[:required].include?(param[:name])
memo && param_present
end
end
end end
end end
end end

View File

@ -83,8 +83,9 @@ module DiscourseAi
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
has_function_call = response_h.dig(:tool_calls).present? @has_function_call ||= response_h.dig(:tool_calls).present?
has_function_call ? response_h.dig(:tool_calls, 0, :function) : response_h.dig(:content)
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
end end
def partials_from(decoded_chunk) def partials_from(decoded_chunk)
@ -101,41 +102,38 @@ module DiscourseAi
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end end
def has_tool?(_response_data, partial) def has_tool?(_response_data)
partial.is_a?(Hash) && partial.has_key?(:name) # Has function name @has_function_call
end end
def add_to_buffer(function_buffer, _response_data, partial) def add_to_buffer(function_buffer, _response_data, partial)
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present? @args_buffer ||= +""
function_buffer.at("tool_id").content = partial[:id] if partial[:id].present?
f_name = partial.dig(:function, :name)
function_buffer.at("tool_name").content = f_name if f_name
function_buffer.at("tool_id").content = partial[:id] if partial[:id]
if partial.dig(:function, :arguments).present?
@args_buffer << partial.dig(:function, :arguments)
begin
json_args = JSON.parse(@args_buffer, symbolize_names: true)
if partial[:arguments]
argument_fragments = argument_fragments =
partial[:arguments].reduce(+"") do |memo, (arg_name, value)| json_args.reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>" memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
end end
argument_fragments << "\n" argument_fragments << "\n"
function_buffer.at("parameters").children = function_buffer.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
rescue JSON::ParserError
return function_buffer
end
end end
function_buffer function_buffer
end end
def buffering_finished?(available_functions, buffer)
tool_name = buffer.at("tool_name")&.text
return false if tool_name.blank?
signature = available_functions.find { |f| f.dig(:tool, :name) == tool_name }[:tool]
signature[:parameters].reduce(true) do |memo, param|
param_present = buffer.at(param[:name]).present?
next(memo) if param_present && !param[:required]
memo && param_present
end
end
end end
end end
end end

View File

@ -115,7 +115,25 @@ RSpec.describe DiscourseAi::Completions::Dialects::ChatGpt do
describe "#tools" do describe "#tools" do
it "returns a list of available tools" do it "returns a list of available tools" do
open_ai_tool_f = { type: "function", tool: tool } open_ai_tool_f = {
function: {
description: tool[:description],
name: tool[:name],
parameters: {
properties:
tool[:parameters].reduce({}) do |memo, p|
memo[p[:name]] = { description: p[:description], type: p[:type] }
memo[p[:name]][:enum] = p[:enum] if p[:enum]
memo
end,
required: %w[location unit],
type: "object",
},
},
type: "function",
}
expect(subject.tools).to contain_exactly(open_ai_tool_f) expect(subject.tools).to contain_exactly(open_ai_tool_f)
end end

View File

@ -13,6 +13,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
let(:request_body) { model.default_options.merge(prompt: prompt).to_json } let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json } let(:stream_request_body) { model.default_options.merge(prompt: prompt, stream: true).to_json }
let(:tool_id) { "get_weather" }
def response(content) def response(content)
{ {
completion: content, completion: content,

View File

@ -16,6 +16,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
let(:request_body) { model.default_options.merge(prompt: prompt).to_json } let(:request_body) { model.default_options.merge(prompt: prompt).to_json }
let(:stream_request_body) { request_body } let(:stream_request_body) { request_body }
let(:tool_id) { "get_weather" }
before do before do
SiteSetting.ai_bedrock_access_key_id = "123456" SiteSetting.ai_bedrock_access_key_id = "123456"
SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd" SiteSetting.ai_bedrock_secret_access_key = "asd-asd-asd"

View File

@ -58,7 +58,7 @@ RSpec.shared_examples "an endpoint that can communicate with a completion servic
<function_calls> <function_calls>
<invoke> <invoke>
<tool_name>get_weather</tool_name> <tool_name>get_weather</tool_name>
<tool_id>get_weather</tool_id> <tool_id>#{tool_id || "get_weather"}</tool_id>
<parameters> <parameters>
<location>Sydney</location> <location>Sydney</location>
<unit>c</unit> <unit>c</unit>

View File

@ -10,6 +10,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) } let(:dialect) { DiscourseAi::Completions::Dialects::Gemini.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate } let(:prompt) { dialect.translate }
let(:tool_id) { "get_weather" }
let(:tool_payload) do let(:tool_payload) do
{ {
name: "get_weather", name: "get_weather",

View File

@ -12,6 +12,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do
end end
let(:prompt) { dialect.translate } let(:prompt) { dialect.translate }
let(:tool_id) { "get_weather" }
let(:request_body) do let(:request_body) do
model model
.default_options .default_options

View File

@ -10,45 +10,49 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) } let(:dialect) { DiscourseAi::Completions::Dialects::ChatGpt.new(generic_prompt, model_name) }
let(:prompt) { dialect.translate } let(:prompt) { dialect.translate }
let(:tool_id) { "eujbuebfe" }
let(:tool_deltas) do let(:tool_deltas) do
[ [
{ id: "get_weather", name: "get_weather", arguments: {} }, { id: tool_id, function: {} },
{ id: "get_weather", name: "get_weather", arguments: { location: "" } }, { id: tool_id, function: { name: "get_weather", arguments: "" } },
{ id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } }, { id: tool_id, function: { name: "get_weather", arguments: "" } },
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
] ]
end end
let(:tool_call) do let(:tool_call) do
{ id: "get_weather", name: "get_weather", arguments: { location: "Sydney", unit: "c" } } {
id: tool_id,
function: {
name: "get_weather",
arguments: { location: "Sydney", unit: "c" }.to_json,
},
}
end end
let(:request_body) do let(:request_body) do
model model
.default_options .default_options
.merge(messages: prompt) .merge(messages: prompt)
.tap do |b| .tap { |b| b[:tools] = dialect.tools if generic_prompt[:tools] }
b[:tools] = generic_prompt[:tools].map do |t|
{ type: "function", tool: t }
end if generic_prompt[:tools]
end
.to_json .to_json
end end
let(:stream_request_body) do let(:stream_request_body) do
model model
.default_options .default_options
.merge(messages: prompt, stream: true) .merge(messages: prompt, stream: true)
.tap do |b| .tap { |b| b[:tools] = dialect.tools if generic_prompt[:tools] }
b[:tools] = generic_prompt[:tools].map do |t|
{ type: "function", tool: t }
end if generic_prompt[:tools]
end
.to_json .to_json
end end
def response(content, tool_call: false) def response(content, tool_call: false)
message_content = message_content =
if tool_call if tool_call
{ tool_calls: [{ function: content }] } { tool_calls: [content] }
else else
{ content: content } { content: content }
end end
@ -79,7 +83,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
def stream_line(delta, finish_reason: nil, tool_call: false) def stream_line(delta, finish_reason: nil, tool_call: false)
message_content = message_content =
if tool_call if tool_call
{ tool_calls: [{ function: delta }] } { tool_calls: [delta] }
else else
{ content: delta } { content: delta }
end end

View File

@ -15,6 +15,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
before { SiteSetting.ai_vllm_endpoint = "https://test.dev" } before { SiteSetting.ai_vllm_endpoint = "https://test.dev" }
let(:tool_id) { "get_weather" }
def response(content) def response(content)
{ {
id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S", id: "cmpl-6sZfAb30Rnv9Q7ufzFwvQsMpjZh8S",