FIX: support multiple tool calls (#502)

* FIX: support multiple tool calls

Prior to this change we had a hard limit of 1 tool call per llm
round trip. This meant you could not google multiple things at
once or perform searches across two tools.

Also:

- Hint when Google stops working
- Log topic_id / post_id when performing completions

* Also track id for title
This commit is contained in:
Sam 2024-03-02 07:53:21 +11:00 committed by GitHub
parent b72ee805b6
commit c02794cf2e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
13 changed files with 275 additions and 58 deletions

View File

@ -7,6 +7,7 @@ module DiscourseAi
BOT_NOT_FOUND = Class.new(StandardError)
MAX_COMPLETIONS = 5
MAX_TOOLS = 5
def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new, model: nil)
new(bot_user, persona, model)
@ -21,14 +22,19 @@ module DiscourseAi
attr_reader :bot_user
attr_accessor :persona
def get_updated_title(conversation_context, post_user)
def get_updated_title(conversation_context, post)
system_insts = <<~TEXT.strip
You are titlebot. Given a topic, you will figure out a title.
You will never respond with anything but 7 word topic title.
TEXT
title_prompt =
DiscourseAi::Completions::Prompt.new(system_insts, messages: conversation_context)
DiscourseAi::Completions::Prompt.new(
system_insts,
messages: conversation_context,
topic_id: post.topic_id,
post_id: post.id,
)
title_prompt.push(
type: :user,
@ -38,7 +44,7 @@ module DiscourseAi
DiscourseAi::Completions::Llm
.proxy(model)
.generate(title_prompt, user: post_user)
.generate(title_prompt, user: post.user)
.strip
.split("\n")
.last
@ -64,9 +70,35 @@ module DiscourseAi
result =
llm.generate(prompt, **llm_kwargs) do |partial, cancel|
if (tool = persona.find_tool(partial))
tools = persona.find_tools(partial)
if (tools.present?)
tool_found = true
ongoing_chain = tool.chain_next_response?
tools[0..MAX_TOOLS].each do |tool|
ongoing_chain &&= tool.chain_next_response?
process_tool(tool, raw_context, llm, cancel, update_blk, prompt)
end
else
update_blk.call(partial, cancel, nil)
end
end
if !tool_found
ongoing_chain = false
raw_context << [result, bot_user.username]
end
total_completions += 1
# do not allow tools when we are at the end of a chain (total_completions == MAX_COMPLETIONS)
prompt.tools = [] if total_completions == MAX_COMPLETIONS
end
raw_context
end
private
def process_tool(tool, raw_context, llm, cancel, update_blk, prompt)
tool_call_id = tool.tool_call_id
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
@ -95,25 +127,7 @@ module DiscourseAi
raw_context << [tool_call_message[:content], tool_call_id, "tool_call"]
raw_context << [invocation_result_json, tool_call_id, "tool"]
else
update_blk.call(partial, cancel, nil)
end
end
if !tool_found
ongoing_chain = false
raw_context << [result, bot_user.username]
end
total_completions += 1
# do not allow tools when we are at the end of a chain (total_completions == MAX_COMPLETIONS)
prompt.tools = [] if total_completions == MAX_COMPLETIONS
end
raw_context
end
private
def invoke_tool(tool, llm, cancel, &update_blk)
update_blk.call("", cancel, build_placeholder(tool.summary, ""))

View File

@ -117,6 +117,8 @@ module DiscourseAi
#{available_tools.map(&:custom_system_message).compact_blank.join("\n")}
TEXT
messages: context[:conversation_context].to_a,
topic_id: context[:topic_id],
post_id: context[:post_id],
)
prompt.tools = available_tools.map(&:signature) if available_tools
@ -124,8 +126,16 @@ module DiscourseAi
prompt
end
def find_tool(partial)
def find_tools(partial)
return [] if !partial.include?("</invoke>")
parsed_function = Nokogiri::HTML5.fragment(partial)
parsed_function.css("invoke").map { |fragment| find_tool(fragment) }.compact
end
protected
def find_tool(parsed_function)
function_id = parsed_function.at("tool_id")&.text
function_name = parsed_function.at("tool_name")&.text
return false if function_name.nil?

View File

@ -156,7 +156,7 @@ module DiscourseAi
context = conversation_context(post)
bot
.get_updated_title(context, post.user)
.get_updated_title(context, post)
.tap do |new_title|
PostRevisor.new(post.topic.first_post, post.topic).revise!(
bot.bot_user,
@ -182,6 +182,8 @@ module DiscourseAi
participants: post.topic.allowed_users.map(&:username).join(", "),
conversation_context: conversation_context(post),
user: post.user,
post_id: post.id,
topic_id: post.topic_id,
}
reply_user = bot.bot_user

View File

@ -37,6 +37,7 @@ module DiscourseAi
URI(
"https://www.googleapis.com/customsearch/v1?key=#{api_key}&cx=#{cx}&q=#{escaped_query}&num=10",
)
body = Net::HTTP.get(uri)
parse_search_json(body, escaped_query, llm)
@ -65,6 +66,19 @@ module DiscourseAi
def parse_search_json(json_data, escaped_query, llm)
parsed = JSON.parse(json_data)
error_code = parsed.dig("error", "code")
if error_code == 429
Rails.logger.warn(
"Google Custom Search is Rate Limited, no search can be performed at the moment. #{json_data[0..1000]}",
)
return(
"Google Custom Search is Rate Limited, no search can be performed at the moment. Let the user know there is a problem."
)
elsif error_code
Rails.logger.warn("Google Custom Search returned an error. #{json_data[0..1000]}")
return "Google Custom Search returned an error. Let the user know there is a problem."
end
results = parsed["items"]
@results_count = parsed.dig("searchInformation", "totalResults").to_i

View File

@ -106,9 +106,11 @@ module DiscourseAi
raise NotImplemented
end
attr_reader :prompt
private
attr_reader :prompt, :model_name, :opts
attr_reader :model_name, :opts
def trim_messages(messages)
prompt_limit = max_prompt_tokens

View File

@ -100,6 +100,8 @@ module DiscourseAi
user_id: user&.id,
raw_request_payload: request_body,
request_tokens: prompt_size(prompt),
topic_id: dialect.prompt.topic_id,
post_id: dialect.prompt.post_id,
)
if !@streaming_mode
@ -273,13 +275,19 @@ module DiscourseAi
def build_buffer
Nokogiri::HTML5.fragment(<<~TEXT)
<function_calls>
#{noop_function_call_text}
</function_calls>
TEXT
end
def noop_function_call_text
(<<~TEXT).strip
<invoke>
<tool_name></tool_name>
<tool_id></tool_id>
<parameters>
</parameters>
</invoke>
</function_calls>
TEXT
end

View File

@ -167,8 +167,26 @@ module DiscourseAi
@args_buffer ||= +""
f_name = partial.dig(:function, :name)
function_buffer.at("tool_name").content = f_name if f_name
function_buffer.at("tool_id").content = partial[:id] if partial[:id]
@current_function ||= function_buffer.at("invoke")
if f_name
current_name = function_buffer.at("tool_name").content
if current_name.blank?
# first call
else
# we have a previous function, so we need to add a noop
@args_buffer = +""
@current_function =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
)
end
end
@current_function.at("tool_name").content = f_name if f_name
@current_function.at("tool_id").content = partial[:id] if partial[:id]
args = partial.dig(:function, :arguments)
@ -185,7 +203,7 @@ module DiscourseAi
end
argument_fragments << "\n"
function_buffer.at("parameters").children =
@current_function.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
rescue JSON::ParserError
return function_buffer

View File

@ -6,12 +6,22 @@ module DiscourseAi
INVALID_TURN = Class.new(StandardError)
attr_reader :messages
attr_accessor :tools
attr_accessor :tools, :topic_id, :post_id
def initialize(system_message_text = nil, messages: [], tools: [], skip_validations: false)
def initialize(
system_message_text = nil,
messages: [],
tools: [],
skip_validations: false,
topic_id: nil,
post_id: nil
)
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
@topic_id = topic_id
@post_id = post_id
@messages = []
@skip_validations = skip_validations

View File

@ -178,11 +178,9 @@ class EndpointsCompliance
def regular_mode_tools(mock)
prompt = generic_prompt(tools: [mock.tool])
a_dialect = dialect(prompt: prompt)
mock.stub_tool_call(a_dialect.translate)
completion_response = endpoint.perform_completion!(a_dialect, user)
expect(completion_response).to eq(mock.invocation_response)
end

View File

@ -84,10 +84,10 @@ class OpenAiMock < EndpointMock
[
{ id: tool_id, function: {} },
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
{ id: tool_id, function: { arguments: "" } },
{ id: tool_id, function: { arguments: "{" } },
{ id: tool_id, function: { arguments: " \"location\": \"Sydney\"" } },
{ id: tool_id, function: { arguments: " ,\"unit\": \"c\" }" } },
]
end
@ -216,9 +216,77 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
compliance.streaming_mode_tools(open_ai_mock)
end
it "properly handles multiple tool calls" do
raw_data = <<~TEXT.strip
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"role":"assistant","content":null},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_3Gyr3HylFJwfrtKrL6NaIit1","type":"function","function":{"name":"search","arguments":""}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"se"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"arch_"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"query\\""}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":": \\"D"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"iscou"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"rse AI"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" bot"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\\"}"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"id":"call_H7YkbgYurHpyJqzwUN4bghwN","type":"function","function":{"name":"search","arguments":""}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{\\"qu"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"ery\\":"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":" \\"Disc"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"ours"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot\\"}"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
data: [DONE]
TEXT
open_ai_mock.stub_raw(raw_data)
content = +""
endpoint.perform_completion!(compliance.dialect, user) { |partial| content << partial }
expected = <<~TEXT
<function_calls>
<invoke>
<tool_name>search</tool_name>
<tool_id>call_3Gyr3HylFJwfrtKrL6NaIit1</tool_id>
<parameters>
<search_query>Discourse AI bot</search_query>
</parameters>
</invoke>
<invoke>
<tool_name>search</tool_name>
<tool_id>call_H7YkbgYurHpyJqzwUN4bghwN</tool_id>
<parameters>
<query>Discourse AI bot</query>
</parameters>
</invoke>
</function_calls>
TEXT
expect(content).to eq(expected)
end
it "properly handles spaces in tools payload" do
raw_data = <<~TEXT.strip
data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"google","arguments":""}}]}}]}
data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"go|ogle","arg|uments":""}}]}}]}
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "{\\""}}]}}]}
@ -253,9 +321,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
open_ai_mock.stub_raw(chunks)
partials = []
endpoint.perform_completion!(compliance.dialect, user) do |partial, x, y|
partials << partial
end
endpoint.perform_completion!(compliance.dialect, user) { |partial| partials << partial }
expect(partials.length).to eq(1)

View File

@ -21,6 +21,40 @@ RSpec.describe DiscourseAi::Completions::Llm do
end
end
describe "AiApiAuditLog" do
it "is able to keep track of post and topic id" do
prompt =
DiscourseAi::Completions::Prompt.new(
"You are fake",
messages: [{ type: :user, content: "fake orders" }],
topic_id: 123,
post_id: 1,
)
result = <<~TEXT
data: {"id":"chatcmpl-8xoPOYRmiuBANTmGqdCGVk4ZA3Orz","object":"chat.completion.chunk","created":1709265814,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xoPOYRmiuBANTmGqdCGVk4ZA3Orz","object":"chat.completion.chunk","created":1709265814,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
data: [DONE]
TEXT
WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
status: 200,
body: result,
)
result = +""
described_class
.proxy("open_ai:gpt-3.5-turbo")
.generate(prompt, user: user) { |partial| result << partial }
expect(result).to eq("Hello")
log = AiApiAuditLog.order("id desc").first
expect(log.topic_id).to eq(123)
expect(log.post_id).to eq(1)
end
end
describe "#generate with fake model" do
before do
DiscourseAi::Completions::Endpoints::Fake.delays = []

View File

@ -82,10 +82,18 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
<prompts>["cat oil painting", "big car"]</prompts>
</parameters>
</invoke>
<invoke>
<tool_name>dall_e</tool_name>
<tool_id>abc</tool_id>
<parameters>
<prompts>["pic3"]</prompts>
</parameters>
</invoke>
</function_calls>
XML
dall_e = DiscourseAi::AiBot::Personas::DallE3.new.find_tool(xml)
expect(dall_e.parameters[:prompts]).to eq(["cat oil painting", "big car"])
dall_e1, dall_e2 = DiscourseAi::AiBot::Personas::DallE3.new.find_tools(xml)
expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
expect(dall_e2.parameters[:prompts]).to eq(["pic3"])
end
describe "custom personas" do

View File

@ -212,6 +212,39 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end
end
it "supports multiple function calls" do
response1 = (<<~TXT).strip
<function_calls>
<invoke>
<tool_name>search</tool_name>
<tool_id>search</tool_id>
<parameters>
<search_query>testing various things</search_query>
</parameters>
</invoke>
<invoke>
<tool_name>search</tool_name>
<tool_id>search</tool_id>
<parameters>
<search_query>another search</search_query>
</parameters>
</invoke>
</function_calls>
TXT
response2 = "I found stuff"
DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do
playground.reply_to(third_post)
end
last_post = third_post.topic.reload.posts.order(:post_number).last
expect(last_post.raw).to include("testing various things")
expect(last_post.raw).to include("another search")
expect(last_post.raw).to include("I found stuff")
end
it "does not include placeholders in conversation context but includes all completions" do
response1 = (<<~TXT).strip
<function_calls>