FIX: support multiple tool calls (#502)
* FIX: support multiple tool calls Prior to this change we had a hard limit of 1 tool call per llm round trip. This meant you could not google multiple things at once or perform searches across two tools. Also: - Hint when Google stops working - Log topic_id / post_id when performing completions * Also track id for title
This commit is contained in:
parent
b72ee805b6
commit
c02794cf2e
|
@ -7,6 +7,7 @@ module DiscourseAi
|
||||||
|
|
||||||
BOT_NOT_FOUND = Class.new(StandardError)
|
BOT_NOT_FOUND = Class.new(StandardError)
|
||||||
MAX_COMPLETIONS = 5
|
MAX_COMPLETIONS = 5
|
||||||
|
MAX_TOOLS = 5
|
||||||
|
|
||||||
def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new, model: nil)
|
def self.as(bot_user, persona: DiscourseAi::AiBot::Personas::General.new, model: nil)
|
||||||
new(bot_user, persona, model)
|
new(bot_user, persona, model)
|
||||||
|
@ -21,14 +22,19 @@ module DiscourseAi
|
||||||
attr_reader :bot_user
|
attr_reader :bot_user
|
||||||
attr_accessor :persona
|
attr_accessor :persona
|
||||||
|
|
||||||
def get_updated_title(conversation_context, post_user)
|
def get_updated_title(conversation_context, post)
|
||||||
system_insts = <<~TEXT.strip
|
system_insts = <<~TEXT.strip
|
||||||
You are titlebot. Given a topic, you will figure out a title.
|
You are titlebot. Given a topic, you will figure out a title.
|
||||||
You will never respond with anything but 7 word topic title.
|
You will never respond with anything but 7 word topic title.
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
title_prompt =
|
title_prompt =
|
||||||
DiscourseAi::Completions::Prompt.new(system_insts, messages: conversation_context)
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
system_insts,
|
||||||
|
messages: conversation_context,
|
||||||
|
topic_id: post.topic_id,
|
||||||
|
post_id: post.id,
|
||||||
|
)
|
||||||
|
|
||||||
title_prompt.push(
|
title_prompt.push(
|
||||||
type: :user,
|
type: :user,
|
||||||
|
@ -38,7 +44,7 @@ module DiscourseAi
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm
|
DiscourseAi::Completions::Llm
|
||||||
.proxy(model)
|
.proxy(model)
|
||||||
.generate(title_prompt, user: post_user)
|
.generate(title_prompt, user: post.user)
|
||||||
.strip
|
.strip
|
||||||
.split("\n")
|
.split("\n")
|
||||||
.last
|
.last
|
||||||
|
@ -64,9 +70,35 @@ module DiscourseAi
|
||||||
|
|
||||||
result =
|
result =
|
||||||
llm.generate(prompt, **llm_kwargs) do |partial, cancel|
|
llm.generate(prompt, **llm_kwargs) do |partial, cancel|
|
||||||
if (tool = persona.find_tool(partial))
|
tools = persona.find_tools(partial)
|
||||||
|
|
||||||
|
if (tools.present?)
|
||||||
tool_found = true
|
tool_found = true
|
||||||
ongoing_chain = tool.chain_next_response?
|
tools[0..MAX_TOOLS].each do |tool|
|
||||||
|
ongoing_chain &&= tool.chain_next_response?
|
||||||
|
process_tool(tool, raw_context, llm, cancel, update_blk, prompt)
|
||||||
|
end
|
||||||
|
else
|
||||||
|
update_blk.call(partial, cancel, nil)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
if !tool_found
|
||||||
|
ongoing_chain = false
|
||||||
|
raw_context << [result, bot_user.username]
|
||||||
|
end
|
||||||
|
total_completions += 1
|
||||||
|
|
||||||
|
# do not allow tools when we are at the end of a chain (total_completions == MAX_COMPLETIONS)
|
||||||
|
prompt.tools = [] if total_completions == MAX_COMPLETIONS
|
||||||
|
end
|
||||||
|
|
||||||
|
raw_context
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def process_tool(tool, raw_context, llm, cancel, update_blk, prompt)
|
||||||
tool_call_id = tool.tool_call_id
|
tool_call_id = tool.tool_call_id
|
||||||
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
|
invocation_result_json = invoke_tool(tool, llm, cancel, &update_blk).to_json
|
||||||
|
|
||||||
|
@ -95,25 +127,7 @@ module DiscourseAi
|
||||||
|
|
||||||
raw_context << [tool_call_message[:content], tool_call_id, "tool_call"]
|
raw_context << [tool_call_message[:content], tool_call_id, "tool_call"]
|
||||||
raw_context << [invocation_result_json, tool_call_id, "tool"]
|
raw_context << [invocation_result_json, tool_call_id, "tool"]
|
||||||
else
|
|
||||||
update_blk.call(partial, cancel, nil)
|
|
||||||
end
|
end
|
||||||
end
|
|
||||||
|
|
||||||
if !tool_found
|
|
||||||
ongoing_chain = false
|
|
||||||
raw_context << [result, bot_user.username]
|
|
||||||
end
|
|
||||||
total_completions += 1
|
|
||||||
|
|
||||||
# do not allow tools when we are at the end of a chain (total_completions == MAX_COMPLETIONS)
|
|
||||||
prompt.tools = [] if total_completions == MAX_COMPLETIONS
|
|
||||||
end
|
|
||||||
|
|
||||||
raw_context
|
|
||||||
end
|
|
||||||
|
|
||||||
private
|
|
||||||
|
|
||||||
def invoke_tool(tool, llm, cancel, &update_blk)
|
def invoke_tool(tool, llm, cancel, &update_blk)
|
||||||
update_blk.call("", cancel, build_placeholder(tool.summary, ""))
|
update_blk.call("", cancel, build_placeholder(tool.summary, ""))
|
||||||
|
|
|
@ -117,6 +117,8 @@ module DiscourseAi
|
||||||
#{available_tools.map(&:custom_system_message).compact_blank.join("\n")}
|
#{available_tools.map(&:custom_system_message).compact_blank.join("\n")}
|
||||||
TEXT
|
TEXT
|
||||||
messages: context[:conversation_context].to_a,
|
messages: context[:conversation_context].to_a,
|
||||||
|
topic_id: context[:topic_id],
|
||||||
|
post_id: context[:post_id],
|
||||||
)
|
)
|
||||||
|
|
||||||
prompt.tools = available_tools.map(&:signature) if available_tools
|
prompt.tools = available_tools.map(&:signature) if available_tools
|
||||||
|
@ -124,8 +126,16 @@ module DiscourseAi
|
||||||
prompt
|
prompt
|
||||||
end
|
end
|
||||||
|
|
||||||
def find_tool(partial)
|
def find_tools(partial)
|
||||||
|
return [] if !partial.include?("</invoke>")
|
||||||
|
|
||||||
parsed_function = Nokogiri::HTML5.fragment(partial)
|
parsed_function = Nokogiri::HTML5.fragment(partial)
|
||||||
|
parsed_function.css("invoke").map { |fragment| find_tool(fragment) }.compact
|
||||||
|
end
|
||||||
|
|
||||||
|
protected
|
||||||
|
|
||||||
|
def find_tool(parsed_function)
|
||||||
function_id = parsed_function.at("tool_id")&.text
|
function_id = parsed_function.at("tool_id")&.text
|
||||||
function_name = parsed_function.at("tool_name")&.text
|
function_name = parsed_function.at("tool_name")&.text
|
||||||
return false if function_name.nil?
|
return false if function_name.nil?
|
||||||
|
|
|
@ -156,7 +156,7 @@ module DiscourseAi
|
||||||
context = conversation_context(post)
|
context = conversation_context(post)
|
||||||
|
|
||||||
bot
|
bot
|
||||||
.get_updated_title(context, post.user)
|
.get_updated_title(context, post)
|
||||||
.tap do |new_title|
|
.tap do |new_title|
|
||||||
PostRevisor.new(post.topic.first_post, post.topic).revise!(
|
PostRevisor.new(post.topic.first_post, post.topic).revise!(
|
||||||
bot.bot_user,
|
bot.bot_user,
|
||||||
|
@ -182,6 +182,8 @@ module DiscourseAi
|
||||||
participants: post.topic.allowed_users.map(&:username).join(", "),
|
participants: post.topic.allowed_users.map(&:username).join(", "),
|
||||||
conversation_context: conversation_context(post),
|
conversation_context: conversation_context(post),
|
||||||
user: post.user,
|
user: post.user,
|
||||||
|
post_id: post.id,
|
||||||
|
topic_id: post.topic_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
reply_user = bot.bot_user
|
reply_user = bot.bot_user
|
||||||
|
|
|
@ -37,6 +37,7 @@ module DiscourseAi
|
||||||
URI(
|
URI(
|
||||||
"https://www.googleapis.com/customsearch/v1?key=#{api_key}&cx=#{cx}&q=#{escaped_query}&num=10",
|
"https://www.googleapis.com/customsearch/v1?key=#{api_key}&cx=#{cx}&q=#{escaped_query}&num=10",
|
||||||
)
|
)
|
||||||
|
|
||||||
body = Net::HTTP.get(uri)
|
body = Net::HTTP.get(uri)
|
||||||
|
|
||||||
parse_search_json(body, escaped_query, llm)
|
parse_search_json(body, escaped_query, llm)
|
||||||
|
@ -65,6 +66,19 @@ module DiscourseAi
|
||||||
|
|
||||||
def parse_search_json(json_data, escaped_query, llm)
|
def parse_search_json(json_data, escaped_query, llm)
|
||||||
parsed = JSON.parse(json_data)
|
parsed = JSON.parse(json_data)
|
||||||
|
error_code = parsed.dig("error", "code")
|
||||||
|
if error_code == 429
|
||||||
|
Rails.logger.warn(
|
||||||
|
"Google Custom Search is Rate Limited, no search can be performed at the moment. #{json_data[0..1000]}",
|
||||||
|
)
|
||||||
|
return(
|
||||||
|
"Google Custom Search is Rate Limited, no search can be performed at the moment. Let the user know there is a problem."
|
||||||
|
)
|
||||||
|
elsif error_code
|
||||||
|
Rails.logger.warn("Google Custom Search returned an error. #{json_data[0..1000]}")
|
||||||
|
return "Google Custom Search returned an error. Let the user know there is a problem."
|
||||||
|
end
|
||||||
|
|
||||||
results = parsed["items"]
|
results = parsed["items"]
|
||||||
|
|
||||||
@results_count = parsed.dig("searchInformation", "totalResults").to_i
|
@results_count = parsed.dig("searchInformation", "totalResults").to_i
|
||||||
|
|
|
@ -106,9 +106,11 @@ module DiscourseAi
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
end
|
end
|
||||||
|
|
||||||
|
attr_reader :prompt
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
attr_reader :prompt, :model_name, :opts
|
attr_reader :model_name, :opts
|
||||||
|
|
||||||
def trim_messages(messages)
|
def trim_messages(messages)
|
||||||
prompt_limit = max_prompt_tokens
|
prompt_limit = max_prompt_tokens
|
||||||
|
|
|
@ -100,6 +100,8 @@ module DiscourseAi
|
||||||
user_id: user&.id,
|
user_id: user&.id,
|
||||||
raw_request_payload: request_body,
|
raw_request_payload: request_body,
|
||||||
request_tokens: prompt_size(prompt),
|
request_tokens: prompt_size(prompt),
|
||||||
|
topic_id: dialect.prompt.topic_id,
|
||||||
|
post_id: dialect.prompt.post_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
if !@streaming_mode
|
if !@streaming_mode
|
||||||
|
@ -273,13 +275,19 @@ module DiscourseAi
|
||||||
def build_buffer
|
def build_buffer
|
||||||
Nokogiri::HTML5.fragment(<<~TEXT)
|
Nokogiri::HTML5.fragment(<<~TEXT)
|
||||||
<function_calls>
|
<function_calls>
|
||||||
|
#{noop_function_call_text}
|
||||||
|
</function_calls>
|
||||||
|
TEXT
|
||||||
|
end
|
||||||
|
|
||||||
|
def noop_function_call_text
|
||||||
|
(<<~TEXT).strip
|
||||||
<invoke>
|
<invoke>
|
||||||
<tool_name></tool_name>
|
<tool_name></tool_name>
|
||||||
<tool_id></tool_id>
|
<tool_id></tool_id>
|
||||||
<parameters>
|
<parameters>
|
||||||
</parameters>
|
</parameters>
|
||||||
</invoke>
|
</invoke>
|
||||||
</function_calls>
|
|
||||||
TEXT
|
TEXT
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -167,8 +167,26 @@ module DiscourseAi
|
||||||
@args_buffer ||= +""
|
@args_buffer ||= +""
|
||||||
|
|
||||||
f_name = partial.dig(:function, :name)
|
f_name = partial.dig(:function, :name)
|
||||||
function_buffer.at("tool_name").content = f_name if f_name
|
|
||||||
function_buffer.at("tool_id").content = partial[:id] if partial[:id]
|
@current_function ||= function_buffer.at("invoke")
|
||||||
|
|
||||||
|
if f_name
|
||||||
|
current_name = function_buffer.at("tool_name").content
|
||||||
|
|
||||||
|
if current_name.blank?
|
||||||
|
# first call
|
||||||
|
else
|
||||||
|
# we have a previous function, so we need to add a noop
|
||||||
|
@args_buffer = +""
|
||||||
|
@current_function =
|
||||||
|
function_buffer.at("function_calls").add_child(
|
||||||
|
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
@current_function.at("tool_name").content = f_name if f_name
|
||||||
|
@current_function.at("tool_id").content = partial[:id] if partial[:id]
|
||||||
|
|
||||||
args = partial.dig(:function, :arguments)
|
args = partial.dig(:function, :arguments)
|
||||||
|
|
||||||
|
@ -185,7 +203,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
argument_fragments << "\n"
|
argument_fragments << "\n"
|
||||||
|
|
||||||
function_buffer.at("parameters").children =
|
@current_function.at("parameters").children =
|
||||||
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
||||||
rescue JSON::ParserError
|
rescue JSON::ParserError
|
||||||
return function_buffer
|
return function_buffer
|
||||||
|
|
|
@ -6,12 +6,22 @@ module DiscourseAi
|
||||||
INVALID_TURN = Class.new(StandardError)
|
INVALID_TURN = Class.new(StandardError)
|
||||||
|
|
||||||
attr_reader :messages
|
attr_reader :messages
|
||||||
attr_accessor :tools
|
attr_accessor :tools, :topic_id, :post_id
|
||||||
|
|
||||||
def initialize(system_message_text = nil, messages: [], tools: [], skip_validations: false)
|
def initialize(
|
||||||
|
system_message_text = nil,
|
||||||
|
messages: [],
|
||||||
|
tools: [],
|
||||||
|
skip_validations: false,
|
||||||
|
topic_id: nil,
|
||||||
|
post_id: nil
|
||||||
|
)
|
||||||
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
raise ArgumentError, "messages must be an array" if !messages.is_a?(Array)
|
||||||
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
raise ArgumentError, "tools must be an array" if !tools.is_a?(Array)
|
||||||
|
|
||||||
|
@topic_id = topic_id
|
||||||
|
@post_id = post_id
|
||||||
|
|
||||||
@messages = []
|
@messages = []
|
||||||
@skip_validations = skip_validations
|
@skip_validations = skip_validations
|
||||||
|
|
||||||
|
|
|
@ -178,11 +178,9 @@ class EndpointsCompliance
|
||||||
def regular_mode_tools(mock)
|
def regular_mode_tools(mock)
|
||||||
prompt = generic_prompt(tools: [mock.tool])
|
prompt = generic_prompt(tools: [mock.tool])
|
||||||
a_dialect = dialect(prompt: prompt)
|
a_dialect = dialect(prompt: prompt)
|
||||||
|
|
||||||
mock.stub_tool_call(a_dialect.translate)
|
mock.stub_tool_call(a_dialect.translate)
|
||||||
|
|
||||||
completion_response = endpoint.perform_completion!(a_dialect, user)
|
completion_response = endpoint.perform_completion!(a_dialect, user)
|
||||||
|
|
||||||
expect(completion_response).to eq(mock.invocation_response)
|
expect(completion_response).to eq(mock.invocation_response)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -84,10 +84,10 @@ class OpenAiMock < EndpointMock
|
||||||
[
|
[
|
||||||
{ id: tool_id, function: {} },
|
{ id: tool_id, function: {} },
|
||||||
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
||||||
{ id: tool_id, function: { name: "get_weather", arguments: "" } },
|
{ id: tool_id, function: { arguments: "" } },
|
||||||
{ id: tool_id, function: { name: "get_weather", arguments: "{" } },
|
{ id: tool_id, function: { arguments: "{" } },
|
||||||
{ id: tool_id, function: { name: "get_weather", arguments: " \"location\": \"Sydney\"" } },
|
{ id: tool_id, function: { arguments: " \"location\": \"Sydney\"" } },
|
||||||
{ id: tool_id, function: { name: "get_weather", arguments: " ,\"unit\": \"c\" }" } },
|
{ id: tool_id, function: { arguments: " ,\"unit\": \"c\" }" } },
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -216,9 +216,77 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
compliance.streaming_mode_tools(open_ai_mock)
|
compliance.streaming_mode_tools(open_ai_mock)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "properly handles multiple tool calls" do
|
||||||
|
raw_data = <<~TEXT.strip
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"role":"assistant","content":null},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"id":"call_3Gyr3HylFJwfrtKrL6NaIit1","type":"function","function":{"name":"search","arguments":""}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"{\\"se"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"arch_"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"query\\""}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":": \\"D"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"iscou"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"rse AI"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":" bot"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":0,"function":{"arguments":"\\"}"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"id":"call_H7YkbgYurHpyJqzwUN4bghwN","type":"function","function":{"name":"search","arguments":""}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"{\\"qu"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"ery\\":"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":" \\"Disc"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"ours"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot\\"}"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
open_ai_mock.stub_raw(raw_data)
|
||||||
|
content = +""
|
||||||
|
|
||||||
|
endpoint.perform_completion!(compliance.dialect, user) { |partial| content << partial }
|
||||||
|
|
||||||
|
expected = <<~TEXT
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>search</tool_name>
|
||||||
|
<tool_id>call_3Gyr3HylFJwfrtKrL6NaIit1</tool_id>
|
||||||
|
<parameters>
|
||||||
|
<search_query>Discourse AI bot</search_query>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>search</tool_name>
|
||||||
|
<tool_id>call_H7YkbgYurHpyJqzwUN4bghwN</tool_id>
|
||||||
|
<parameters>
|
||||||
|
<query>Discourse AI bot</query>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
expect(content).to eq(expected)
|
||||||
|
end
|
||||||
|
|
||||||
it "properly handles spaces in tools payload" do
|
it "properly handles spaces in tools payload" do
|
||||||
raw_data = <<~TEXT.strip
|
raw_data = <<~TEXT.strip
|
||||||
data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"google","arguments":""}}]}}]}
|
data: {"choices":[{"index":0,"delta":{"role":"assistant","content":null,"tool_calls":[{"index":0,"id":"func_id","type":"function","function":{"name":"go|ogle","arg|uments":""}}]}}]}
|
||||||
|
|
||||||
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "{\\""}}]}}]}
|
data: {"choices": [{"index": 0, "delta": {"tool_calls": [{"index": 0, "function": {"arguments": "{\\""}}]}}]}
|
||||||
|
|
||||||
|
@ -253,9 +321,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
open_ai_mock.stub_raw(chunks)
|
open_ai_mock.stub_raw(chunks)
|
||||||
partials = []
|
partials = []
|
||||||
|
|
||||||
endpoint.perform_completion!(compliance.dialect, user) do |partial, x, y|
|
endpoint.perform_completion!(compliance.dialect, user) { |partial| partials << partial }
|
||||||
partials << partial
|
|
||||||
end
|
|
||||||
|
|
||||||
expect(partials.length).to eq(1)
|
expect(partials.length).to eq(1)
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,40 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
describe "AiApiAuditLog" do
|
||||||
|
it "is able to keep track of post and topic id" do
|
||||||
|
prompt =
|
||||||
|
DiscourseAi::Completions::Prompt.new(
|
||||||
|
"You are fake",
|
||||||
|
messages: [{ type: :user, content: "fake orders" }],
|
||||||
|
topic_id: 123,
|
||||||
|
post_id: 1,
|
||||||
|
)
|
||||||
|
|
||||||
|
result = <<~TEXT
|
||||||
|
data: {"id":"chatcmpl-8xoPOYRmiuBANTmGqdCGVk4ZA3Orz","object":"chat.completion.chunk","created":1709265814,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: {"id":"chatcmpl-8xoPOYRmiuBANTmGqdCGVk4ZA3Orz","object":"chat.completion.chunk","created":1709265814,"model":"gpt-4-0125-preview","system_fingerprint":"fp_70b2088885","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
WebMock.stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||||
|
status: 200,
|
||||||
|
body: result,
|
||||||
|
)
|
||||||
|
result = +""
|
||||||
|
described_class
|
||||||
|
.proxy("open_ai:gpt-3.5-turbo")
|
||||||
|
.generate(prompt, user: user) { |partial| result << partial }
|
||||||
|
|
||||||
|
expect(result).to eq("Hello")
|
||||||
|
log = AiApiAuditLog.order("id desc").first
|
||||||
|
expect(log.topic_id).to eq(123)
|
||||||
|
expect(log.post_id).to eq(1)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
describe "#generate with fake model" do
|
describe "#generate with fake model" do
|
||||||
before do
|
before do
|
||||||
DiscourseAi::Completions::Endpoints::Fake.delays = []
|
DiscourseAi::Completions::Endpoints::Fake.delays = []
|
||||||
|
|
|
@ -82,10 +82,18 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
<prompts>["cat oil painting", "big car"]</prompts>
|
<prompts>["cat oil painting", "big car"]</prompts>
|
||||||
</parameters>
|
</parameters>
|
||||||
</invoke>
|
</invoke>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>dall_e</tool_name>
|
||||||
|
<tool_id>abc</tool_id>
|
||||||
|
<parameters>
|
||||||
|
<prompts>["pic3"]</prompts>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
</function_calls>
|
</function_calls>
|
||||||
XML
|
XML
|
||||||
dall_e = DiscourseAi::AiBot::Personas::DallE3.new.find_tool(xml)
|
dall_e1, dall_e2 = DiscourseAi::AiBot::Personas::DallE3.new.find_tools(xml)
|
||||||
expect(dall_e.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
||||||
|
expect(dall_e2.parameters[:prompts]).to eq(["pic3"])
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "custom personas" do
|
describe "custom personas" do
|
||||||
|
|
|
@ -212,6 +212,39 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "supports multiple function calls" do
|
||||||
|
response1 = (<<~TXT).strip
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>search</tool_name>
|
||||||
|
<tool_id>search</tool_id>
|
||||||
|
<parameters>
|
||||||
|
<search_query>testing various things</search_query>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>search</tool_name>
|
||||||
|
<tool_id>search</tool_id>
|
||||||
|
<parameters>
|
||||||
|
<search_query>another search</search_query>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
TXT
|
||||||
|
|
||||||
|
response2 = "I found stuff"
|
||||||
|
|
||||||
|
DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do
|
||||||
|
playground.reply_to(third_post)
|
||||||
|
end
|
||||||
|
|
||||||
|
last_post = third_post.topic.reload.posts.order(:post_number).last
|
||||||
|
|
||||||
|
expect(last_post.raw).to include("testing various things")
|
||||||
|
expect(last_post.raw).to include("another search")
|
||||||
|
expect(last_post.raw).to include("I found stuff")
|
||||||
|
end
|
||||||
|
|
||||||
it "does not include placeholders in conversation context but includes all completions" do
|
it "does not include placeholders in conversation context but includes all completions" do
|
||||||
response1 = (<<~TXT).strip
|
response1 = (<<~TXT).strip
|
||||||
<function_calls>
|
<function_calls>
|
||||||
|
|
Loading…
Reference in New Issue