FIX: support multiple tool calls (#502)
* FIX: support multiple tool calls Prior to this change we had a hard limit of 1 tool call per llm round trip. This meant you could not google multiple things at once or perform searches across two tools. Also: - Hint when Google stops working - Log topic_id / post_id when performing completions * Also track id for title
This commit is contained in:
parent
b72ee805b6
commit
c02794cf2e
|
@ -7,6 +7,7 @@ module DiscourseAi
|
|||
|
||||
BOT_NOT_FOUND = Class.new(StandardError)
|
||||
MAX_COMPLETIONS = 5
|
||||
MAX_TOOLS = 5
|
||||
|
||||
def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new, model: nil)
|
||||
new(bot_user, persona, model)
|
||||
|
@ -21,14 +22,19 @@ module DiscourseAi
|
|||
attr_reader :bot_user
|
||||
attr_accessor :persona
|
||||
|
||||
def get_updated_title(conversation_context, post_user)
|
||||
def get_updated_title(conversation_context, post)
|
||||
system_insts = <<~TEXT.strip
|
||||
You are titlebot. Given a topic, you will figure out a title.
|
||||
You will never respond with anything but 7 word topic title.
|
||||
TEXT
|
||||
|
||||
title_prompt =
|
||||
DiscourseAi::Completions::Prompt.new(system_insts, messages: conversation_context)
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
system_insts,
|
||||
messages: conversation_context,
|
||||
topic_id: post.topic_id,
|
||||
post_id: post.id,
|
||||
)
|
||||
|
||||
title_prompt.push(
|
||||
type: :user,
|
||||
|
@ -38,7 +44,7 @@ module DiscourseAi
|
|||
|
||||
DiscourseAi::Completions::Llm
|
||||
.proxy(model)
|
||||
.generate(title_prompt, user: post_user)
|
||||
.generate(title_prompt, user: post.user)
|
||||
.strip
|
||||
.split("\n")
|
||||
.last
|
||||
|
@ -64,37 +70,14 @@ module DiscourseAi
|
|||
|
||||
result =
|
||||
llm.generate(prompt, **llm_kwargs) do |partial, cancel|
|
||||
if (tool = persona.find_tool(partial))
|
||||
tools = persona.find_tools(partial)
|
||||
|
||||
if (tools.present?)
|
||||
tool_found = true
|
||||
ongoing_chain = tool.chain_next_response?
|
||||
tool_call_id = tool.tool_call_id
|
||||
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
|
||||
|
||||
tool_call_message = {
|
||||
type: :tool_call,
|
||||
id: tool_call_id,
|
||||
content: { name: tool.name, arguments: tool.parameters }.to_json,
|
||||
}
|
||||
|
||||
tool_message = { type: :tool, id: tool_call_id, content: invocation_result_json }
|
||||
|
||||
if tool.standalone?
|
||||
standalone_context =
|
||||
context.dup.merge(
|
||||
conversation_context: [
|
||||
context[:conversation_context].last,
|
||||
tool_call_message,
|
||||
tool_message,
|
||||
],
|
||||
)
|
||||
prompt = persona.craft_prompt(standalone_context)
|
||||
else
|
||||
prompt.push(**tool_call_message)
|
||||
prompt.push(**tool_message)
|
||||
tools[0..MAX_TOOLS].each do |tool|
|
||||
ongoing_chain &&= tool.chain_next_response?
|
||||
process_tool(tool, raw_context, llm, cancel, update_blk, prompt)
|
||||
end
|
||||
|
||||
raw_context << [tool_call_message[:content], tool_call_id, "tool_call"]
|
||||
raw_context << [invocation_result_json, tool_call_id, "tool"]
|
||||
else
|
||||
update_blk.call(partial, cancel, nil)
|
||||
end
|
||||
|
@ -115,6 +98,37 @@ module DiscourseAi
|
|||
|
||||
private
|
||||
|
||||
def process_tool(tool, raw_context, llm, cancel, update_blk, prompt)
|
||||
tool_call_id = tool.tool_call_id
|
||||
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
|
||||
|
||||
tool_call_message = {
|
||||
type: :tool_call,
|
||||
id: tool_call_id,
|
||||
content: { name: tool.name, arguments: tool.parameters }.to_json,
|
||||
}
|
||||
|
||||
tool_message = { type: :tool, id: tool_call_id, content: invocation_result_json }
|
||||
|
||||
if tool.standalone?
|
||||
standalone_context =
|
||||
context.dup.merge(
|
||||
conversation_context: [
|
||||
context[:conversation_context].last,
|
||||
tool_call_message,
|
||||
tool_message,
|
||||
],
|
||||
)
|
||||
prompt = persona.craft_prompt(standalone_context)
|
||||
else
|
||||
prompt.push(**tool_call_message)
|
||||
prompt.push(**tool_message)
|
||||
end
|
||||
|
||||
raw_context << [tool_call_message[:content], tool_call_id, "tool_call"]
|
||||
raw_context << [invocation_result_json, tool_call_id, "tool"]
|
||||
end
|
||||
|
||||
def invoke_tool(tool, llm, cancel, &update_blk)
|
||||
update_blk.call("", cancel, build_placeholder(tool.summary, ""))
|
||||
|
||||
|
|
|
@ -117,6 +117,8 @@ module DiscourseAi
|
|||
#{available_tools.map(&:custom_system_message).compact_blank.join("\n")}
|
||||
TEXT
|
||||
messages: context[:conversation_context].to_a,
|
||||
topic_id: context[:topic_id],
|
||||
post_id: context[:post_id],
|
||||
)
|
||||
|
||||
prompt.tools = available_tools.map(&:signature) if available_tools
|
||||
|
@ -124,8 +126,16 @@ module DiscourseAi
|
|||
prompt
|
||||
end
|
||||
|
||||
def find_tool(partial)
|
||||
def find_tools(partial)
|
||||
return [] if !partial.include?("</invoke>")
|
||||
|
||||
parsed_function = Nokogiri::HTML5.fragment(partial)
|
||||
parsed_function.css("invoke").map { |fragment| find_tool(fragment) }.compact
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def find_tool(parsed_function)
|
||||
function_id = parsed_function.at("tool_id")&.text
|
||||
function_name = parsed_function.at("tool_name")&.text
|
||||
return false if function_name.nil?
|
||||
|
|
|
@ -156,7 +156,7 @@ module DiscourseAi
|
|||
context = conversation_context(post)
|
||||
|
||||
bot
|
||||
.get_updated_title(context, post.user)
|
||||
.get_updated_title(context, post)
|
||||
.tap do |new_title|
|
||||
PostRevisor.new(post.topic.first_post, post.topic).revise!(
|
||||
bot.bot_user,
|
||||
|
@ -182,6 +182,8 @@ module DiscourseAi
|
|||
participants: post.topic.allowed_users.map(&:username).join(", "),
|
||||
conversation_context: conversation_context(post),
|
||||
user: post.user,
|
||||
post_id: post.id,
|
||||
topic_id: post.topic_id,
|
||||
}
|
||||
|
||||
reply_user = bot.bot_user
|
||||
|
|
|
@ -37,6 +37,7 @@ module DiscourseAi
|
|||
URI(
|
||||
"https://www.googleapis.com/customsearch/v1?key=#{api_key}&cx=#{cx}&q=#{escaped_query}&num=10",
|
||||
)
|
||||
|
||||
body = Net::HTTP.get(uri)
|
||||
|
||||
parse_search_json(body, escaped_query, llm)
|
||||
|
@ -65,6 +66,19 @@ module DiscourseAi
|
|||
|
||||
def parse_search_json(json_data, escaped_query, llm)
|
||||
parsed = JSON.parse(json_data)
|
||||
error_code = parsed.dig("error", "code")
|
||||
if error_code == 429
|
||||
Rails.logger.warn(
|
||||
"Google Custom Search is Rate Limited, no search can be performed at the moment. #{json_data[0..1000]}",
|
||||
)
|
||||
return(
|
||||
"Google Custom Search is Rate Limited, no search can be performed at the moment. Let the user know there is a problem."
|
||||
)
|
||||
elsif error_code
|
||||
Rails.logger.warn("Google Custom Search returned an error. #{json_data[0..1000]}")
|
||||
return "Google Custom Search returned an error. Let the user know there is a problem."
|
||||
end
|
||||
|
||||
results = parsed["items"]
|
||||
|
||||
@results_count = parsed.dig("searchInformation", "totalResults").to_i
|
||||
|
|
|
@ -106,9 +106,11 @@ module DiscourseAi
|
|||
raise NotImplemented
|
||||
end
|
||||
|
||||
attr_reader :prompt
|
||||
|
||||
private
|
||||
|
||||
attr_reader :prompt, :model_name, :opts
|
||||
attr_reader :model_name, :opts
|
||||
|
||||
def trim_messages(messages)
|
||||
prompt_limit = max_prompt_tokens
|
||||
|
|
|
@ -100,6 +100,8 @@ module DiscourseAi
|
|||
user_id: user&.id,
|
||||
raw_request_payload: request_body,
|
||||
request_tokens: prompt_size(prompt),
|
||||
topic_id: dialect.prompt.topic_id,
|
||||
post_id: dialect.prompt.post_id,
|
||||
)
|
||||
|
||||
if !@streaming_mode
|
||||
|
@ -273,16 +275,22 @@ module DiscourseAi
|
|||
def build_buffer
|
||||
Nokogiri::HTML5.fragment(<<~TEXT)
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name></tool_name>
|
||||
<tool_id></tool_id>
|
||||
<parameters>
|
||||
</parameters>
|
||||
</invoke>
|
||||
#{noop_function_call_text}
|
||||
</function_calls>
|
||||
TEXT
|
||||
end
|
||||
|
||||
def noop_function_call_text
|
||||
(<<~TEXT).strip
|
||||
<invoke>
|
||||
<tool_name></tool_name>
|
||||
<tool_id></tool_id>
|
||||
<parameters>
|
||||
</parameters>
|
||||
</invoke>
|
||||
TEXT
|
||||
end
|
||||
|
||||
def has_tool?(response)
|
||||
response.include?("<function")
|
||||
end
|
||||
|
|
|
@ -167,8 +167,26 @@ module DiscourseAi
|
|||
@args_buffer ||= +""
|
||||
|
||||
f_name = partial.dig(:function, :name)
|
||||
function_buffer.at("tool_name").content = f_name if f_name
|
||||
function_buffer.at("tool_id").content = partial[:id] if partial[:id]
|
||||
|
||||
@current_function ||= function_buffer.at("invoke")
|
||||
|
||||
if f_name
|
||||
current_name = function_buffer.at("tool_name").content
|
||||
|
||||
if current_name.blank?
|
||||
# first call
|
||||
else
|
||||
# we have a previous function, so we need to add a noop
|
||||
@args_buffer = +""
|
||||
@current_function =
|
||||
function_buffer.at("function_calls").add_child(
|
||||
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
@current_function.at("tool_name").content = f_name if f_name
|
||||
@current_function.at("tool_id").content = partial[:id] if partial[:id]
|
||||
|
||||
args = partial.dig(:function, :arguments)
|
||||
|
||||
|
@ -185,7 +203,7 @@ module DiscourseAi
|
|||
end
|
||||
argument_fragments << "\n"
|
||||
|
||||
function_buffer.at("parameters").children =
|
||||
@current_function.at("parameters").children =
|
||||
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
||||
rescue JSON::ParserError
|
||||
return function_buffer
|
||||
|
|
|
@ -6,12 +6,22 @@ module DiscourseAi
|
|||
INVALID_TURN = Class.new(StandardError)
|
||||
|
||||
attr_reader :messages
|
||||
attr_accessor :tools
|
||||
attr_accessor :tools, :topic_id, :post_id
|
||||
|
||||
def initialize(system_message_text = nil, messages: [], tools: [], skip_validations: false)
|
||||
def initialize(
|
||||
system_message_text = nil,
|
||||
messages: [],
|
||||
tools: [],
|
||||
skip_validations: false,
|
||||
topic_id: nil,
|
||||
post_id: nil
|
||||
)
|
||||
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
||||
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
||||
|
||||
@topic_id = topic_id
|
||||
@post_id = post_id
|
||||
|
||||
@messages = []
|
||||
@skip_validations = skip_validations
|
||||
|
||||
|
|
|
@ -178,11 +178,9 @@ class EndpointsCompliance
|
|||
def regular_mode_tools(mock)
|
||||
prompt = generic_prompt(tools: [mock.tool])
|
||||
a_dialect = dialect(prompt: prompt)
|
||||
|
||||
mock.stub_tool_call(a_dialect.translate)
|
||||
|
||||
completion_response = endpoint.perform_completion!(a_dialect, user)
|
||||
|
||||
expect(completion_response).to eq(mock.invocation_response)
|
||||
end
|
||||
|
||||
|
|
|
@ -84,10 +84,10 @@ class OpenAiMock < EndpointMock
|
|||
[
|
||||
{ id: tool_id, function: {} },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
|
||||
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
|
||||
{ id: tool_id, function: { arguments: "" } },
|
||||
{ id: tool_id, function: { arguments: "{" } },
|
||||
{ id: tool_id, function: { arguments: " \"location\": \"Sydney\"" } },
|
||||
{ id: tool_id, function: { arguments: " ,\"unit\": \"c\" }" } },
|
||||
]
|
||||
end
|
||||
|
||||
|
@ -216,9 +216,77 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
compliance.streaming_mode_tools(open_ai_mock)
|
||||
end
|
||||
|
||||
it "properly handles multiple tool calls" do
|
||||
raw_data = <<~TEXT.strip
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"role":"assistant","content":null},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_3Gyr3HylFJwfrtKrL6NaIit1","type":"function","function":{"name":"search","arguments":""}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"se"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"arch_"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"query\\""}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":": \\"D"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"iscou"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"rse AI"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" bot"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\\"}"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"id":"call_H7YkbgYurHpyJqzwUN4bghwN","type":"function","function":{"name":"search","arguments":""}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{\\"qu"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"ery\\":"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":" \\"Disc"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"ours"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot\\"}"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
|
||||
|
||||
data: [DONE]
|
||||
TEXT
|
||||
|
||||
open_ai_mock.stub_raw(raw_data)
|
||||
content = +""
|
||||
|
||||
endpoint.perform_completion!(compliance.dialect, user) { |partial| content << partial }
|
||||
|
||||
expected = <<~TEXT
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<tool_id>call_3Gyr3HylFJwfrtKrL6NaIit1</tool_id>
|
||||
<parameters>
|
||||
<search_query>Discourse AI bot</search_query>
|
||||
</parameters>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<tool_id>call_H7YkbgYurHpyJqzwUN4bghwN</tool_id>
|
||||
<parameters>
|
||||
<query>Discourse AI bot</query>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
expect(content).to eq(expected)
|
||||
end
|
||||
|
||||
it "properly handles spaces in tools payload" do
|
||||
raw_data = <<~TEXT.strip
|
||||
data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"google","arguments":""}}]}}]}
|
||||
data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"go|ogle","arg|uments":""}}]}}]}
|
||||
|
||||
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "{\\""}}]}}]}
|
||||
|
||||
|
@ -253,9 +321,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
open_ai_mock.stub_raw(chunks)
|
||||
partials = []
|
||||
|
||||
endpoint.perform_completion!(compliance.dialect, user) do |partial, x, y|
|
||||
partials << partial
|
||||
end
|
||||
endpoint.perform_completion!(compliance.dialect, user) { |partial| partials << partial }
|
||||
|
||||
expect(partials.length).to eq(1)
|
||||
|
||||
|
|
|
@ -21,6 +21,40 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
|||
end
|
||||
end
|
||||
|
||||
describe "AiApiAuditLog" do
|
||||
it "is able to keep track of post and topic id" do
|
||||
prompt =
|
||||
DiscourseAi::Completions::Prompt.new(
|
||||
"You are fake",
|
||||
messages: [{ type: :user, content: "fake orders" }],
|
||||
topic_id: 123,
|
||||
post_id: 1,
|
||||
)
|
||||
|
||||
result = <<~TEXT
|
||||
data: {"id":"chatcmpl-8xoPOYRmiuBANTmGqdCGVk4ZA3Orz","object":"chat.completion.chunk","created":1709265814,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xoPOYRmiuBANTmGqdCGVk4ZA3Orz","object":"chat.completion.chunk","created":1709265814,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: [DONE]
|
||||
TEXT
|
||||
|
||||
WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
status: 200,
|
||||
body: result,
|
||||
)
|
||||
result = +""
|
||||
described_class
|
||||
.proxy("open_ai:gpt-3.5-turbo")
|
||||
.generate(prompt, user: user) { |partial| result << partial }
|
||||
|
||||
expect(result).to eq("Hello")
|
||||
log = AiApiAuditLog.order("id desc").first
|
||||
expect(log.topic_id).to eq(123)
|
||||
expect(log.post_id).to eq(1)
|
||||
end
|
||||
end
|
||||
|
||||
describe "#generate with fake model" do
|
||||
before do
|
||||
DiscourseAi::Completions::Endpoints::Fake.delays = []
|
||||
|
|
|
@ -82,10 +82,18 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
<prompts>["cat oil painting", "big car"]</prompts>
|
||||
</parameters>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>dall_e</tool_name>
|
||||
<tool_id>abc</tool_id>
|
||||
<parameters>
|
||||
<prompts>["pic3"]</prompts>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
dall_e = DiscourseAi::AiBot::Personas::DallE3.new.find_tool(xml)
|
||||
expect(dall_e.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
||||
dall_e1, dall_e2 = DiscourseAi::AiBot::Personas::DallE3.new.find_tools(xml)
|
||||
expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
||||
expect(dall_e2.parameters[:prompts]).to eq(["pic3"])
|
||||
end
|
||||
|
||||
describe "custom personas" do
|
||||
|
|
|
@ -212,6 +212,39 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
end
|
||||
end
|
||||
|
||||
it "supports multiple function calls" do
|
||||
response1 = (<<~TXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<tool_id>search</tool_id>
|
||||
<parameters>
|
||||
<search_query>testing various things</search_query>
|
||||
</parameters>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<tool_id>search</tool_id>
|
||||
<parameters>
|
||||
<search_query>another search</search_query>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TXT
|
||||
|
||||
response2 = "I found stuff"
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do
|
||||
playground.reply_to(third_post)
|
||||
end
|
||||
|
||||
last_post = third_post.topic.reload.posts.order(:post_number).last
|
||||
|
||||
expect(last_post.raw).to include("testing various things")
|
||||
expect(last_post.raw).to include("another search")
|
||||
expect(last_post.raw).to include("I found stuff")
|
||||
end
|
||||
|
||||
it "does not include placeholders in conversation context but includes all completions" do
|
||||
response1 = (<<~TXT).strip
|
||||
<function_calls>
|
||||
|
|
Loading…
Reference in New Issue