FEATURE: improve tool support (#904)
This re-implements tool support in DiscourseAi::Completions::Llm #generate Previously tool support was always returned via XML and it would be the responsibility of the caller to parse XML New implementation has the endpoints return ToolCall objects. Additionally this simplifies the Llm endpoint interface and gives it more clarity. Llms must implement decode, decode_chunk (for streaming) It is the implementers responsibility to figure out how to decode chunks, base no longer implements. To make this easy we ship a flexible json decoder which is easy to wire up. Also (new) Better debugging for PMs, we now have a next / previous button to see all the Llm messages associated with a PM Token accounting is fixed for vllm (we were not correctly counting tokens)
This commit is contained in:
parent
644141ff08
commit
e817b7dc11
|
@ -6,6 +6,14 @@ module DiscourseAi
|
|||
requires_plugin ::DiscourseAi::PLUGIN_NAME
|
||||
requires_login
|
||||
|
||||
def show_debug_info_by_id
|
||||
log = AiApiAuditLog.find(params[:id])
|
||||
raise Discourse::NotFound if !log.topic
|
||||
|
||||
guardian.ensure_can_debug_ai_bot_conversation!(log.topic)
|
||||
render json: AiApiAuditLogSerializer.new(log, root: false), status: 200
|
||||
end
|
||||
|
||||
def show_debug_info
|
||||
post = Post.find(params[:post_id])
|
||||
guardian.ensure_can_debug_ai_bot_conversation!(post)
|
||||
|
|
|
@ -14,6 +14,14 @@ class AiApiAuditLog < ActiveRecord::Base
|
|||
Ollama = 7
|
||||
SambaNova = 8
|
||||
end
|
||||
|
||||
def next_log_id
|
||||
self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first
|
||||
end
|
||||
|
||||
def prev_log_id
|
||||
self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first
|
||||
end
|
||||
end
|
||||
|
||||
# == Schema Information
|
||||
|
|
|
@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer
|
|||
:post_id,
|
||||
:feature_name,
|
||||
:language_model,
|
||||
:created_at
|
||||
:created_at,
|
||||
:prev_log_id,
|
||||
:next_log_id
|
||||
end
|
||||
|
|
|
@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template";
|
|||
import DButton from "discourse/components/d-button";
|
||||
import DModal from "discourse/components/d-modal";
|
||||
import { ajax } from "discourse/lib/ajax";
|
||||
import { popupAjaxError } from "discourse/lib/ajax-error";
|
||||
import { clipboardCopy, escapeExpression } from "discourse/lib/utilities";
|
||||
import i18n from "discourse-common/helpers/i18n";
|
||||
import discourseLater from "discourse-common/lib/later";
|
||||
|
@ -63,6 +64,28 @@ export default class DebugAiModal extends Component {
|
|||
this.copy(this.info.raw_response_payload);
|
||||
}
|
||||
|
||||
async loadLog(logId) {
|
||||
try {
|
||||
await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then(
|
||||
(result) => {
|
||||
this.info = result;
|
||||
}
|
||||
);
|
||||
} catch (e) {
|
||||
popupAjaxError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@action
|
||||
prevLog() {
|
||||
this.loadLog(this.info.prev_log_id);
|
||||
}
|
||||
|
||||
@action
|
||||
nextLog() {
|
||||
this.loadLog(this.info.next_log_id);
|
||||
}
|
||||
|
||||
copy(text) {
|
||||
clipboardCopy(text);
|
||||
this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared");
|
||||
|
@ -73,11 +96,13 @@ export default class DebugAiModal extends Component {
|
|||
}
|
||||
|
||||
loadApiRequestInfo() {
|
||||
ajax(
|
||||
`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`
|
||||
).then((result) => {
|
||||
this.info = result;
|
||||
});
|
||||
ajax(`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`)
|
||||
.then((result) => {
|
||||
this.info = result;
|
||||
})
|
||||
.catch((e) => {
|
||||
popupAjaxError(e);
|
||||
});
|
||||
}
|
||||
|
||||
get requestActive() {
|
||||
|
@ -147,6 +172,22 @@ export default class DebugAiModal extends Component {
|
|||
@action={{this.copyResponse}}
|
||||
@label="discourse_ai.ai_bot.debug_ai_modal.copy_response"
|
||||
/>
|
||||
{{#if this.info.prev_log_id}}
|
||||
<DButton
|
||||
class="btn"
|
||||
@icon="angles-left"
|
||||
@action={{this.prevLog}}
|
||||
@label="discourse_ai.ai_bot.debug_ai_modal.previous_log"
|
||||
/>
|
||||
{{/if}}
|
||||
{{#if this.info.next_log_id}}
|
||||
<DButton
|
||||
class="btn"
|
||||
@icon="angles-right"
|
||||
@action={{this.nextLog}}
|
||||
@label="discourse_ai.ai_bot.debug_ai_modal.next_log"
|
||||
/>
|
||||
{{/if}}
|
||||
<span class="ai-debut-modal__just-copied">{{this.justCopiedText}}</span>
|
||||
</:footer>
|
||||
</DModal>
|
||||
|
|
|
@ -415,6 +415,8 @@ en:
|
|||
response_tokens: "Response tokens:"
|
||||
request: "Request"
|
||||
response: "Response"
|
||||
next_log: "Next"
|
||||
previous_log: "Previous"
|
||||
|
||||
share_full_topic_modal:
|
||||
title: "Share Conversation Publicly"
|
||||
|
|
|
@ -22,6 +22,7 @@ DiscourseAi::Engine.routes.draw do
|
|||
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
|
||||
get "bot-username" => "bot#show_bot_username"
|
||||
get "post/:post_id/show-debug-info" => "bot#show_debug_info"
|
||||
get "show-debug-info/:id" => "bot#show_debug_info_by_id"
|
||||
post "post/:post_id/stop-streaming" => "bot#stop_streaming_response"
|
||||
end
|
||||
|
||||
|
|
|
@ -100,6 +100,7 @@ module DiscourseAi
|
|||
llm_kwargs[:top_p] = persona.top_p if persona.top_p
|
||||
|
||||
needs_newlines = false
|
||||
tools_ran = 0
|
||||
|
||||
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
||||
tool_found = false
|
||||
|
@ -107,9 +108,10 @@ module DiscourseAi
|
|||
|
||||
result =
|
||||
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
|
||||
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
|
||||
tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context)
|
||||
tool = nil if tools_ran >= MAX_TOOLS
|
||||
|
||||
if (tools.present?)
|
||||
if tool.present?
|
||||
tool_found = true
|
||||
# a bit hacky, but extra newlines do no harm
|
||||
if needs_newlines
|
||||
|
@ -117,13 +119,16 @@ module DiscourseAi
|
|||
needs_newlines = false
|
||||
end
|
||||
|
||||
tools[0..MAX_TOOLS].each do |tool|
|
||||
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
|
||||
ongoing_chain &&= tool.chain_next_response?
|
||||
end
|
||||
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
|
||||
tools_ran += 1
|
||||
ongoing_chain &&= tool.chain_next_response?
|
||||
else
|
||||
needs_newlines = true
|
||||
update_blk.call(partial, cancel)
|
||||
if partial.is_a?(DiscourseAi::Completions::ToolCall)
|
||||
Rails.logger.warn("DiscourseAi: Tool not found: #{partial.name}")
|
||||
else
|
||||
update_blk.call(partial, cancel)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -199,23 +199,16 @@ module DiscourseAi
|
|||
prompt
|
||||
end
|
||||
|
||||
def find_tools(partial, bot_user:, llm:, context:)
|
||||
return [] if !partial.include?("</invoke>")
|
||||
|
||||
parsed_function = Nokogiri::HTML5.fragment(partial)
|
||||
parsed_function
|
||||
.css("invoke")
|
||||
.map do |fragment|
|
||||
tool_instance(fragment, bot_user: bot_user, llm: llm, context: context)
|
||||
end
|
||||
.compact
|
||||
def find_tool(partial, bot_user:, llm:, context:)
|
||||
return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall)
|
||||
tool_instance(partial, bot_user: bot_user, llm: llm, context: context)
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def tool_instance(parsed_function, bot_user:, llm:, context:)
|
||||
function_id = parsed_function.at("tool_id")&.text
|
||||
function_name = parsed_function.at("tool_name")&.text
|
||||
def tool_instance(tool_call, bot_user:, llm:, context:)
|
||||
function_id = tool_call.id
|
||||
function_name = tool_call.name
|
||||
return nil if function_name.nil?
|
||||
|
||||
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
|
||||
|
@ -224,7 +217,7 @@ module DiscourseAi
|
|||
arguments = {}
|
||||
tool_klass.signature[:parameters].to_a.each do |param|
|
||||
name = param[:name]
|
||||
value = parsed_function.at(name)&.text
|
||||
value = tool_call.parameters[name.to_sym]
|
||||
|
||||
if param[:type] == "array" && value
|
||||
value =
|
||||
|
|
|
@ -13,6 +13,11 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
|||
def append(json)
|
||||
@raw_json << json
|
||||
end
|
||||
|
||||
def to_tool_call
|
||||
parameters = JSON.parse(raw_json, symbolize_names: true)
|
||||
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
|
||||
end
|
||||
end
|
||||
|
||||
attr_reader :tool_calls, :input_tokens, :output_tokens
|
||||
|
@ -20,80 +25,69 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
|||
def initialize(streaming_mode:)
|
||||
@streaming_mode = streaming_mode
|
||||
@tool_calls = []
|
||||
@current_tool_call = nil
|
||||
end
|
||||
|
||||
def to_xml_tool_calls(function_buffer)
|
||||
return function_buffer if @tool_calls.blank?
|
||||
def to_tool_calls
|
||||
@tool_calls.map { |tool_call| tool_call.to_tool_call }
|
||||
end
|
||||
|
||||
function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
|
||||
<function_calls>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
@tool_calls.each do |tool_call|
|
||||
node =
|
||||
function_buffer.at("function_calls").add_child(
|
||||
Nokogiri::HTML5::DocumentFragment.parse(
|
||||
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
|
||||
),
|
||||
)
|
||||
|
||||
params = JSON.parse(tool_call.raw_json, symbolize_names: true)
|
||||
xml =
|
||||
params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}</#{name}>" }.join("\n")
|
||||
|
||||
node.at("tool_name").content = tool_call.name
|
||||
node.at("tool_id").content = tool_call.id
|
||||
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
|
||||
def process_streamed_message(parsed)
|
||||
result = nil
|
||||
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
|
||||
tool_name = parsed.dig(:content_block, :name)
|
||||
tool_id = parsed.dig(:content_block, :id)
|
||||
result = @current_tool_call.to_tool_call if @current_tool_call
|
||||
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
|
||||
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||
if @current_tool_call
|
||||
tool_delta = parsed.dig(:delta, :partial_json).to_s
|
||||
@current_tool_call.append(tool_delta)
|
||||
else
|
||||
result = parsed.dig(:delta, :text).to_s
|
||||
end
|
||||
elsif parsed[:type] == "content_block_stop"
|
||||
if @current_tool_call
|
||||
result = @current_tool_call.to_tool_call
|
||||
@current_tool_call = nil
|
||||
end
|
||||
elsif parsed[:type] == "message_start"
|
||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
||||
elsif parsed[:type] == "message_delta"
|
||||
@output_tokens =
|
||||
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
|
||||
elsif parsed[:type] == "message_stop"
|
||||
# bedrock has this ...
|
||||
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
|
||||
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
|
||||
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
|
||||
end
|
||||
end
|
||||
|
||||
function_buffer
|
||||
result
|
||||
end
|
||||
|
||||
def process_message(payload)
|
||||
result = ""
|
||||
parsed = JSON.parse(payload, symbolize_names: true)
|
||||
parsed = payload
|
||||
parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String)
|
||||
|
||||
if @streaming_mode
|
||||
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
|
||||
tool_name = parsed.dig(:content_block, :name)
|
||||
tool_id = parsed.dig(:content_block, :id)
|
||||
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
|
||||
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||
if @tool_calls.present?
|
||||
result = parsed.dig(:delta, :partial_json).to_s
|
||||
@tool_calls.last.append(result)
|
||||
else
|
||||
result = parsed.dig(:delta, :text).to_s
|
||||
content = parsed.dig(:content)
|
||||
if content.is_a?(Array)
|
||||
result =
|
||||
content.map do |data|
|
||||
if data[:type] == "tool_use"
|
||||
call = AnthropicToolCall.new(data[:name], data[:id])
|
||||
call.append(data[:input].to_json)
|
||||
call.to_tool_call
|
||||
else
|
||||
data[:text]
|
||||
end
|
||||
end
|
||||
elsif parsed[:type] == "message_start"
|
||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
||||
elsif parsed[:type] == "message_delta"
|
||||
@output_tokens =
|
||||
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
|
||||
elsif parsed[:type] == "message_stop"
|
||||
# bedrock has this ...
|
||||
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
|
||||
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
|
||||
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
|
||||
end
|
||||
end
|
||||
else
|
||||
content = parsed.dig(:content)
|
||||
if content.is_a?(Array)
|
||||
tool_call = content.find { |c| c[:type] == "tool_use" }
|
||||
if tool_call
|
||||
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
|
||||
@tool_calls.last.append(tool_call[:input].to_json)
|
||||
else
|
||||
result = parsed.dig(:content, 0, :text).to_s
|
||||
end
|
||||
end
|
||||
|
||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
||||
end
|
||||
|
||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
||||
|
||||
result
|
||||
end
|
||||
end
|
||||
|
|
|
@ -63,8 +63,23 @@ module DiscourseAi
|
|||
def user_msg(msg)
|
||||
user_message = { role: "user", content: msg[:content] }
|
||||
|
||||
# TODO: Add support for user messages with empbeded user ids
|
||||
# TODO: Add support for user messages with attachments
|
||||
encoded_uploads = prompt.encoded_uploads(msg)
|
||||
if encoded_uploads.present?
|
||||
images =
|
||||
encoded_uploads
|
||||
.map do |upload|
|
||||
if upload[:mime_type].start_with?("image/")
|
||||
upload[:base64]
|
||||
else
|
||||
nil
|
||||
end
|
||||
end
|
||||
.compact
|
||||
|
||||
user_message[:images] = images if images.present?
|
||||
end
|
||||
|
||||
# TODO: Add support for user messages with embedded user ids
|
||||
|
||||
user_message
|
||||
end
|
||||
|
|
|
@ -63,6 +63,10 @@ module DiscourseAi
|
|||
URI(llm_model.url)
|
||||
end
|
||||
|
||||
def xml_tools_enabled?
|
||||
!@native_tool_support
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
@native_tool_support = dialect.native_tool_support?
|
||||
|
||||
|
@ -90,35 +94,34 @@ module DiscourseAi
|
|||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
def decode_chunk(partial_data)
|
||||
@decoder ||= JsonStreamDecoder.new
|
||||
(@decoder << partial_data)
|
||||
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
|
||||
.compact
|
||||
end
|
||||
|
||||
def decode(response_data)
|
||||
processor.process_message(response_data)
|
||||
end
|
||||
|
||||
def processor
|
||||
@processor ||=
|
||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
processor.to_xml_tool_calls(function_buffer) if !partial
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
processor.process_message(response_raw)
|
||||
end
|
||||
|
||||
def has_tool?(_response_data)
|
||||
processor.tool_calls.present?
|
||||
end
|
||||
|
||||
def tool_calls
|
||||
processor.to_tool_calls
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = processor.input_tokens if processor.input_tokens
|
||||
log.response_tokens = processor.output_tokens if processor.output_tokens
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
@native_tool_support
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
decoded_chunk.split("\n").map { |line| line.split("data: ", 2)[1] }.compact
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -117,7 +117,24 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def decode(chunk)
|
||||
def decode_chunk(partial_data)
|
||||
bedrock_decode(partial_data)
|
||||
.map do |decoded_partial_data|
|
||||
@raw_response ||= +""
|
||||
@raw_response << decoded_partial_data
|
||||
@raw_response << "\n"
|
||||
|
||||
parsed_json = JSON.parse(decoded_partial_data, symbolize_names: true)
|
||||
processor.process_streamed_message(parsed_json)
|
||||
end
|
||||
.compact
|
||||
end
|
||||
|
||||
def decode(response_data)
|
||||
processor.process_message(response_data)
|
||||
end
|
||||
|
||||
def bedrock_decode(chunk)
|
||||
@decoder ||= Aws::EventStream::Decoder.new
|
||||
|
||||
decoded, _done = @decoder.decode_chunk(chunk)
|
||||
|
@ -147,12 +164,13 @@ module DiscourseAi
|
|||
Aws::EventStream::Errors::MessageChecksumError,
|
||||
Aws::EventStream::Errors::PreludeChecksumError => e
|
||||
Rails.logger.error("#{self.class.name}: #{e.message}")
|
||||
nil
|
||||
[]
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = processor.input_tokens if processor.input_tokens
|
||||
log.response_tokens = processor.output_tokens if processor.output_tokens
|
||||
log.raw_response_payload = @raw_response
|
||||
end
|
||||
|
||||
def processor
|
||||
|
@ -160,30 +178,8 @@ module DiscourseAi
|
|||
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
processor.to_xml_tool_calls(function_buffer) if !partial
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
processor.process_message(response_raw)
|
||||
end
|
||||
|
||||
def has_tool?(_response_data)
|
||||
processor.tool_calls.present?
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunks)
|
||||
decoded_chunks
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
@native_tool_support
|
||||
end
|
||||
|
||||
def chunk_to_string(chunk)
|
||||
joined = +chunk.join("\n")
|
||||
joined << "\n" if joined.length > 0
|
||||
joined
|
||||
def xml_tools_enabled?
|
||||
!@native_tool_support
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -40,10 +40,6 @@ module DiscourseAi
|
|||
@llm_model = llm_model
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
false
|
||||
end
|
||||
|
||||
def use_ssl?
|
||||
if model_uri&.scheme.present?
|
||||
model_uri.scheme == "https"
|
||||
|
@ -64,22 +60,10 @@ module DiscourseAi
|
|||
feature_context: nil,
|
||||
&blk
|
||||
)
|
||||
allow_tools = dialect.prompt.has_tools?
|
||||
model_params = normalize_model_params(model_params)
|
||||
orig_blk = blk
|
||||
|
||||
@streaming_mode = block_given?
|
||||
to_strip = xml_tags_to_strip(dialect)
|
||||
@xml_stripper =
|
||||
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
|
||||
|
||||
if @streaming_mode && @xml_stripper
|
||||
blk =
|
||||
lambda do |partial, cancel|
|
||||
partial = @xml_stripper << partial
|
||||
orig_blk.call(partial, cancel) if partial
|
||||
end
|
||||
end
|
||||
|
||||
prompt = dialect.translate
|
||||
|
||||
|
@ -108,177 +92,91 @@ module DiscourseAi
|
|||
raise CompletionFailed, response.body
|
||||
end
|
||||
|
||||
xml_tool_processor = XmlToolProcessor.new if xml_tools_enabled? &&
|
||||
dialect.prompt.has_tools?
|
||||
|
||||
to_strip = xml_tags_to_strip(dialect)
|
||||
xml_stripper =
|
||||
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
|
||||
|
||||
if @streaming_mode && xml_stripper
|
||||
blk =
|
||||
lambda do |partial, cancel|
|
||||
partial = xml_stripper << partial if partial.is_a?(String)
|
||||
orig_blk.call(partial, cancel) if partial
|
||||
end
|
||||
end
|
||||
|
||||
log =
|
||||
AiApiAuditLog.new(
|
||||
start_log(
|
||||
provider_id: provider_id,
|
||||
user_id: user&.id,
|
||||
raw_request_payload: request_body,
|
||||
request_tokens: prompt_size(prompt),
|
||||
topic_id: dialect.prompt.topic_id,
|
||||
post_id: dialect.prompt.post_id,
|
||||
request_body: request_body,
|
||||
dialect: dialect,
|
||||
prompt: prompt,
|
||||
user: user,
|
||||
feature_name: feature_name,
|
||||
language_model: llm_model.name,
|
||||
feature_context: feature_context.present? ? feature_context.as_json : nil,
|
||||
feature_context: feature_context,
|
||||
)
|
||||
|
||||
if !@streaming_mode
|
||||
response_raw = response.read_body
|
||||
response_data = extract_completion_from(response_raw)
|
||||
partials_raw = response_data.to_s
|
||||
|
||||
if native_tool_support?
|
||||
if allow_tools && has_tool?(response_data)
|
||||
function_buffer = build_buffer # Nokogiri document
|
||||
function_buffer =
|
||||
add_to_function_buffer(function_buffer, payload: response_data)
|
||||
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
|
||||
|
||||
response_data = +function_buffer.at("function_calls").to_s
|
||||
response_data << "\n"
|
||||
end
|
||||
else
|
||||
if allow_tools
|
||||
response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
|
||||
response_data = function_calls if function_calls.present?
|
||||
end
|
||||
end
|
||||
|
||||
return response_data
|
||||
return(
|
||||
non_streaming_response(
|
||||
response: response,
|
||||
xml_tool_processor: xml_tool_processor,
|
||||
xml_stripper: xml_stripper,
|
||||
partials_raw: partials_raw,
|
||||
response_raw: response_raw,
|
||||
)
|
||||
)
|
||||
end
|
||||
|
||||
has_tool = false
|
||||
|
||||
begin
|
||||
cancelled = false
|
||||
cancel = -> { cancelled = true }
|
||||
|
||||
wrapped_blk = ->(partial, inner_cancel) do
|
||||
response_data << partial
|
||||
blk.call(partial, inner_cancel)
|
||||
if cancelled
|
||||
http.finish
|
||||
break
|
||||
end
|
||||
|
||||
normalizer = FunctionCallNormalizer.new(wrapped_blk, cancel)
|
||||
|
||||
leftover = ""
|
||||
function_buffer = build_buffer # Nokogiri document
|
||||
prev_processed_partials = 0
|
||||
|
||||
response.read_body do |chunk|
|
||||
if cancelled
|
||||
http.finish
|
||||
break
|
||||
end
|
||||
|
||||
decoded_chunk = decode(chunk)
|
||||
if decoded_chunk.nil?
|
||||
raise CompletionFailed, "#{self.class.name}: Failed to decode LLM completion"
|
||||
end
|
||||
response_raw << chunk_to_string(decoded_chunk)
|
||||
|
||||
if decoded_chunk.is_a?(String)
|
||||
redo_chunk = leftover + decoded_chunk
|
||||
else
|
||||
# custom implementation for endpoint
|
||||
# no implicit leftover support
|
||||
redo_chunk = decoded_chunk
|
||||
end
|
||||
|
||||
raw_partials = partials_from(redo_chunk)
|
||||
|
||||
raw_partials =
|
||||
raw_partials[prev_processed_partials..-1] if prev_processed_partials > 0
|
||||
|
||||
if raw_partials.blank? || (raw_partials.size == 1 && raw_partials.first.blank?)
|
||||
leftover = redo_chunk
|
||||
next
|
||||
end
|
||||
|
||||
json_error = false
|
||||
|
||||
raw_partials.each do |raw_partial|
|
||||
json_error = false
|
||||
prev_processed_partials += 1
|
||||
|
||||
next if cancelled
|
||||
next if raw_partial.blank?
|
||||
|
||||
begin
|
||||
partial = extract_completion_from(raw_partial)
|
||||
next if partial.nil?
|
||||
# empty vs blank... we still accept " "
|
||||
next if response_data.empty? && partial.empty?
|
||||
partials_raw << partial.to_s
|
||||
|
||||
if native_tool_support?
|
||||
# Stop streaming the response as soon as you find a tool.
|
||||
# We'll buffer and yield it later.
|
||||
has_tool = true if allow_tools && has_tool?(partials_raw)
|
||||
|
||||
if has_tool
|
||||
function_buffer =
|
||||
add_to_function_buffer(function_buffer, partial: partial)
|
||||
else
|
||||
response_data << partial
|
||||
blk.call(partial, cancel) if partial
|
||||
end
|
||||
else
|
||||
if allow_tools
|
||||
normalizer << partial
|
||||
else
|
||||
response_data << partial
|
||||
blk.call(partial, cancel) if partial
|
||||
end
|
||||
response_raw << chunk
|
||||
decode_chunk(chunk).each do |partial|
|
||||
partials_raw << partial.to_s
|
||||
response_data << partial if partial.is_a?(String)
|
||||
partials = [partial]
|
||||
if xml_tool_processor && partial.is_a?(String)
|
||||
partials = (xml_tool_processor << partial)
|
||||
if xml_tool_processor.should_cancel?
|
||||
cancel.call
|
||||
break
|
||||
end
|
||||
rescue JSON::ParserError
|
||||
leftover = redo_chunk
|
||||
json_error = true
|
||||
end
|
||||
partials.each { |inner_partial| blk.call(inner_partial, cancel) }
|
||||
end
|
||||
|
||||
if json_error
|
||||
prev_processed_partials -= 1
|
||||
else
|
||||
leftover = ""
|
||||
end
|
||||
|
||||
prev_processed_partials = 0 if leftover.blank?
|
||||
end
|
||||
rescue IOError, StandardError
|
||||
raise if !cancelled
|
||||
end
|
||||
|
||||
has_tool ||= has_tool?(partials_raw)
|
||||
# Once we have the full response, try to return the tool as a XML doc.
|
||||
if has_tool && native_tool_support?
|
||||
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
|
||||
|
||||
if function_buffer.at("tool_name").text.present?
|
||||
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
|
||||
|
||||
invocation = +function_buffer.at("function_calls").to_s
|
||||
invocation << "\n"
|
||||
|
||||
response_data << invocation
|
||||
blk.call(invocation, cancel)
|
||||
if xml_stripper
|
||||
stripped = xml_stripper.finish
|
||||
if stripped.present?
|
||||
response_data << stripped
|
||||
result = []
|
||||
result = (xml_tool_processor << stripped) if xml_tool_processor
|
||||
result.each { |partial| blk.call(partial, cancel) }
|
||||
end
|
||||
end
|
||||
|
||||
if !native_tool_support? && function_calls = normalizer.function_calls
|
||||
response_data << function_calls
|
||||
blk.call(function_calls, cancel)
|
||||
if xml_tool_processor
|
||||
xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) }
|
||||
end
|
||||
|
||||
if @xml_stripper
|
||||
leftover = @xml_stripper.finish
|
||||
orig_blk.call(leftover, cancel) if leftover.present?
|
||||
end
|
||||
|
||||
decode_chunk_finish.each { |partial| blk.call(partial, cancel) }
|
||||
return response_data
|
||||
ensure
|
||||
if log
|
||||
log.raw_response_payload = response_raw
|
||||
log.response_tokens = tokenizer.size(partials_raw)
|
||||
final_log_update(log)
|
||||
|
||||
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
|
||||
log.save!
|
||||
|
||||
if Rails.env.development?
|
||||
|
@ -330,15 +228,15 @@ module DiscourseAi
|
|||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def extract_completion_from(_response_raw)
|
||||
def decode(_response_raw)
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def decode(chunk)
|
||||
chunk
|
||||
def decode_chunk_finish
|
||||
[]
|
||||
end
|
||||
|
||||
def partials_from(_decoded_chunk)
|
||||
def decode_chunk(_chunk)
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
|
@ -346,49 +244,73 @@ module DiscourseAi
|
|||
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
||||
end
|
||||
|
||||
def build_buffer
|
||||
Nokogiri::HTML5.fragment(<<~TEXT)
|
||||
<function_calls>
|
||||
#{noop_function_call_text}
|
||||
</function_calls>
|
||||
TEXT
|
||||
def xml_tools_enabled?
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
||||
def self.noop_function_call_text
|
||||
(<<~TEXT).strip
|
||||
<invoke>
|
||||
<tool_name></tool_name>
|
||||
<parameters>
|
||||
</parameters>
|
||||
<tool_id></tool_id>
|
||||
</invoke>
|
||||
TEXT
|
||||
private
|
||||
|
||||
def start_log(
|
||||
provider_id:,
|
||||
request_body:,
|
||||
dialect:,
|
||||
prompt:,
|
||||
user:,
|
||||
feature_name:,
|
||||
feature_context:
|
||||
)
|
||||
AiApiAuditLog.new(
|
||||
provider_id: provider_id,
|
||||
user_id: user&.id,
|
||||
raw_request_payload: request_body,
|
||||
request_tokens: prompt_size(prompt),
|
||||
topic_id: dialect.prompt.topic_id,
|
||||
post_id: dialect.prompt.post_id,
|
||||
feature_name: feature_name,
|
||||
language_model: llm_model.name,
|
||||
feature_context: feature_context.present? ? feature_context.as_json : nil,
|
||||
)
|
||||
end
|
||||
|
||||
def noop_function_call_text
|
||||
self.class.noop_function_call_text
|
||||
end
|
||||
def non_streaming_response(
|
||||
response:,
|
||||
xml_tool_processor:,
|
||||
xml_stripper:,
|
||||
partials_raw:,
|
||||
response_raw:
|
||||
)
|
||||
response_raw << response.read_body
|
||||
response_data = decode(response_raw)
|
||||
|
||||
def has_tool?(response)
|
||||
response.include?("<function_calls>")
|
||||
end
|
||||
response_data.each { |partial| partials_raw << partial.to_s }
|
||||
|
||||
def chunk_to_string(chunk)
|
||||
if chunk.is_a?(String)
|
||||
chunk
|
||||
else
|
||||
chunk.to_s
|
||||
end
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
if payload&.include?("</invoke>")
|
||||
matches = payload.match(%r{<function_calls>.*</invoke>}m)
|
||||
function_buffer =
|
||||
Nokogiri::HTML5.fragment(matches[0] + "\n</function_calls>") if matches
|
||||
if xml_tool_processor
|
||||
response_data.each do |partial|
|
||||
processed = (xml_tool_processor << partial)
|
||||
processed << xml_tool_processor.finish
|
||||
response_data = []
|
||||
processed.flatten.compact.each { |inner| response_data << inner }
|
||||
end
|
||||
end
|
||||
|
||||
function_buffer
|
||||
if xml_stripper
|
||||
response_data.map! do |partial|
|
||||
stripped = (xml_stripper << partial) if partial.is_a?(String)
|
||||
if stripped.present?
|
||||
stripped
|
||||
else
|
||||
partial
|
||||
end
|
||||
end
|
||||
response_data << xml_stripper.finish
|
||||
end
|
||||
|
||||
response_data.reject!(&:blank?)
|
||||
|
||||
# this is to keep stuff backwards compatible
|
||||
response_data = response_data.first if response_data.length == 1
|
||||
|
||||
response_data
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -45,17 +45,21 @@ module DiscourseAi
|
|||
cancel_fn = lambda { cancelled = true }
|
||||
|
||||
# We buffer and return tool invocations in one go.
|
||||
if is_tool?(response)
|
||||
yield(response, cancel_fn)
|
||||
else
|
||||
response.each_char do |char|
|
||||
break if cancelled
|
||||
yield(char, cancel_fn)
|
||||
as_array = response.is_a?(Array) ? response : [response]
|
||||
as_array.each do |response|
|
||||
if is_tool?(response)
|
||||
yield(response, cancel_fn)
|
||||
else
|
||||
response.each_char do |char|
|
||||
break if cancelled
|
||||
yield(char, cancel_fn)
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
response
|
||||
end
|
||||
|
||||
response = response.first if response.is_a?(Array) && response.length == 1
|
||||
response
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
|
@ -65,7 +69,7 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def is_tool?(response)
|
||||
Nokogiri::HTML5.fragment(response).at("function_calls").present?
|
||||
response.is_a?(DiscourseAi::Completions::ToolCall)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -49,6 +49,47 @@ module DiscourseAi
|
|||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
def decode(response_raw)
|
||||
rval = []
|
||||
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
|
||||
text = parsed[:text]
|
||||
rval << parsed[:text] if !text.to_s.empty? # also allow " "
|
||||
|
||||
# TODO tool calls
|
||||
|
||||
update_usage(parsed)
|
||||
|
||||
rval
|
||||
end
|
||||
|
||||
def decode_chunk(chunk)
|
||||
@tool_idx ||= -1
|
||||
@json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/)
|
||||
(@json_decoder << chunk)
|
||||
.map do |parsed|
|
||||
update_usage(parsed)
|
||||
rval = []
|
||||
|
||||
rval << parsed[:text] if !parsed[:text].to_s.empty?
|
||||
|
||||
if tool_calls = parsed[:tool_calls]
|
||||
tool_calls&.each do |tool_call|
|
||||
@tool_idx += 1
|
||||
tool_name = tool_call[:name]
|
||||
tool_params = tool_call[:parameters]
|
||||
tool_id = "tool_#{@tool_idx}"
|
||||
rval << ToolCall.new(id: tool_id, name: tool_name, parameters: tool_params)
|
||||
end
|
||||
end
|
||||
|
||||
rval
|
||||
end
|
||||
.flatten
|
||||
.compact
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
|
||||
|
@ -77,36 +118,8 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def has_tool?(_ignored)
|
||||
@has_tool
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
true
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
if partial
|
||||
tools = JSON.parse(partial)
|
||||
tools.each do |tool|
|
||||
name = tool["name"]
|
||||
parameters = tool["parameters"]
|
||||
xml_params = parameters.map { |k, v| "<#{k}>#{v}</#{k}>\n" }.join
|
||||
|
||||
current_function = function_buffer.at("invoke")
|
||||
if current_function.nil? || current_function.at("tool_name").content.present?
|
||||
current_function =
|
||||
function_buffer.at("function_calls").add_child(
|
||||
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
||||
)
|
||||
end
|
||||
|
||||
current_function.at("tool_name").content = name == "search_local" ? "search" : name
|
||||
current_function.at("parameters").children =
|
||||
Nokogiri::HTML5::DocumentFragment.parse(xml_params)
|
||||
end
|
||||
end
|
||||
function_buffer
|
||||
def xml_tools_enabled?
|
||||
false
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
|
@ -114,10 +127,6 @@ module DiscourseAi
|
|||
log.response_tokens = @output_tokens if @output_tokens
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
decoded_chunk.split("\n").compact
|
||||
end
|
||||
|
||||
def extract_prompt_for_tokenizer(prompt)
|
||||
text = +""
|
||||
if prompt[:chat_history]
|
||||
|
@ -131,6 +140,18 @@ module DiscourseAi
|
|||
|
||||
text
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def update_usage(parsed)
|
||||
input_tokens = parsed.dig(:meta, :billed_units, :input_tokens)
|
||||
input_tokens ||= parsed.dig(:response, :meta, :billed_units, :input_tokens)
|
||||
@input_tokens = input_tokens if input_tokens.present?
|
||||
|
||||
output_tokens = parsed.dig(:meta, :billed_units, :output_tokens)
|
||||
output_tokens ||= parsed.dig(:response, :meta, :billed_units, :output_tokens)
|
||||
@output_tokens = output_tokens if output_tokens.present?
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -133,31 +133,35 @@ module DiscourseAi
|
|||
content = content.shift if content.is_a?(Array)
|
||||
|
||||
if block_given?
|
||||
split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
|
||||
indexes = [0, *split_indices, content.length]
|
||||
if content.is_a?(DiscourseAi::Completions::ToolCall)
|
||||
yield(content, -> {})
|
||||
else
|
||||
split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
|
||||
indexes = [0, *split_indices, content.length]
|
||||
|
||||
original_content = content
|
||||
content = +""
|
||||
original_content = content
|
||||
content = +""
|
||||
|
||||
cancel = false
|
||||
cancel_proc = -> { cancel = true }
|
||||
cancel = false
|
||||
cancel_proc = -> { cancel = true }
|
||||
|
||||
i = 0
|
||||
indexes
|
||||
.each_cons(2)
|
||||
.map { |start, finish| original_content[start...finish] }
|
||||
.each do |chunk|
|
||||
break if cancel
|
||||
if self.class.delays.present? &&
|
||||
(delay = self.class.delays[i % self.class.delays.length])
|
||||
sleep(delay)
|
||||
i += 1
|
||||
i = 0
|
||||
indexes
|
||||
.each_cons(2)
|
||||
.map { |start, finish| original_content[start...finish] }
|
||||
.each do |chunk|
|
||||
break if cancel
|
||||
if self.class.delays.present? &&
|
||||
(delay = self.class.delays[i % self.class.delays.length])
|
||||
sleep(delay)
|
||||
i += 1
|
||||
end
|
||||
break if cancel
|
||||
|
||||
content << chunk
|
||||
yield(chunk, cancel_proc)
|
||||
end
|
||||
break if cancel
|
||||
|
||||
content << chunk
|
||||
yield(chunk, cancel_proc)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
content
|
||||
|
|
|
@ -103,15 +103,7 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
decoded_chunk
|
||||
end
|
||||
|
||||
def chunk_to_string(chunk)
|
||||
chunk.to_s
|
||||
end
|
||||
|
||||
class Decoder
|
||||
class GeminiStreamingDecoder
|
||||
def initialize
|
||||
@buffer = +""
|
||||
end
|
||||
|
@ -151,43 +143,87 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def decode(chunk)
|
||||
@decoder ||= Decoder.new
|
||||
@decoder.decode(chunk)
|
||||
json = JSON.parse(chunk, symbolize_names: true)
|
||||
idx = -1
|
||||
json
|
||||
.dig(:candidates, 0, :content, :parts)
|
||||
.map do |part|
|
||||
if part[:functionCall]
|
||||
idx += 1
|
||||
ToolCall.new(
|
||||
id: "tool_#{idx}",
|
||||
name: part[:functionCall][:name],
|
||||
parameters: part[:functionCall][:args],
|
||||
)
|
||||
else
|
||||
part = part[:text]
|
||||
if part != ""
|
||||
part
|
||||
else
|
||||
nil
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def decode_chunk(chunk)
|
||||
@tool_index ||= -1
|
||||
|
||||
streaming_decoder
|
||||
.decode(chunk)
|
||||
.map do |parsed|
|
||||
update_usage(parsed)
|
||||
parsed
|
||||
.dig(:candidates, 0, :content, :parts)
|
||||
.map do |part|
|
||||
if part[:text]
|
||||
part = part[:text]
|
||||
if part != ""
|
||||
part
|
||||
else
|
||||
nil
|
||||
end
|
||||
elsif part[:functionCall]
|
||||
@tool_index += 1
|
||||
ToolCall.new(
|
||||
id: "tool_#{@tool_index}",
|
||||
name: part[:functionCall][:name],
|
||||
parameters: part[:functionCall][:args],
|
||||
)
|
||||
end
|
||||
end
|
||||
end
|
||||
.flatten
|
||||
.compact
|
||||
end
|
||||
|
||||
def update_usage(parsed)
|
||||
usage = parsed.dig(:usageMetadata)
|
||||
if usage
|
||||
if prompt_token_count = usage[:promptTokenCount]
|
||||
@prompt_token_count = prompt_token_count
|
||||
end
|
||||
if candidate_token_count = usage[:candidatesTokenCount]
|
||||
@candidate_token_count = candidate_token_count
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = @prompt_token_count if @prompt_token_count
|
||||
log.response_tokens = @candidate_token_count if @candidate_token_count
|
||||
end
|
||||
|
||||
def streaming_decoder
|
||||
@decoder ||= GeminiStreamingDecoder.new
|
||||
end
|
||||
|
||||
def extract_prompt_for_tokenizer(prompt)
|
||||
prompt.to_s
|
||||
end
|
||||
|
||||
def has_tool?(_response_data)
|
||||
@has_function_call
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
true
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
|
||||
if @streaming_mode
|
||||
return function_buffer if !partial
|
||||
else
|
||||
partial = payload
|
||||
end
|
||||
|
||||
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
|
||||
|
||||
if partial[:args]
|
||||
argument_fragments =
|
||||
partial[:args].reduce(+"") do |memo, (arg_name, value)|
|
||||
memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}</#{arg_name}>"
|
||||
end
|
||||
argument_fragments << "\n"
|
||||
|
||||
function_buffer.at("parameters").children =
|
||||
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
||||
end
|
||||
|
||||
function_buffer
|
||||
def xml_tools_enabled?
|
||||
false
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -59,22 +59,30 @@ module DiscourseAi
|
|||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
|
||||
# half a line sent here
|
||||
return if !parsed
|
||||
|
||||
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
||||
|
||||
response_h.dig(:content)
|
||||
def xml_tools_enabled?
|
||||
true
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
decoded_chunk
|
||||
.split("\n")
|
||||
.map do |line|
|
||||
data = line.split("data:", 2)[1]
|
||||
data&.squish == "[DONE]" ? nil : data
|
||||
def decode(response_raw)
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
text = parsed.dig(:choices, 0, :message, :content)
|
||||
if text.to_s.empty?
|
||||
[""]
|
||||
else
|
||||
[text]
|
||||
end
|
||||
end
|
||||
|
||||
def decode_chunk(chunk)
|
||||
@json_decoder ||= JsonStreamDecoder.new
|
||||
(@json_decoder << chunk)
|
||||
.map do |parsed|
|
||||
text = parsed.dig(:choices, 0, :delta, :content)
|
||||
if text.to_s.empty?
|
||||
nil
|
||||
else
|
||||
text
|
||||
end
|
||||
end
|
||||
.compact
|
||||
end
|
||||
|
|
|
@ -37,12 +37,8 @@ module DiscourseAi
|
|||
URI(llm_model.url)
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
@native_tool_support
|
||||
end
|
||||
|
||||
def has_tool?(_response_data)
|
||||
@has_function_call
|
||||
def xml_tools_enabled?
|
||||
!@native_tool_support
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
|
@ -67,74 +63,30 @@ module DiscourseAi
|
|||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
decoded_chunk.split("\n").compact
|
||||
def decode_chunk(chunk)
|
||||
# Native tool calls are not working right in streaming mode, use XML
|
||||
@json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/)
|
||||
(@json_decoder << chunk).map { |parsed| parsed.dig(:message, :content) }.compact
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
def decode(response_raw)
|
||||
rval = []
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||
return if !parsed
|
||||
content = parsed.dig(:message, :content)
|
||||
rval << content if !content.to_s.empty?
|
||||
|
||||
response_h = parsed.dig(:message)
|
||||
|
||||
@has_function_call ||= response_h.dig(:tool_calls).present?
|
||||
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
|
||||
@args_buffer ||= +""
|
||||
|
||||
if @streaming_mode
|
||||
return function_buffer if !partial
|
||||
else
|
||||
partial = payload
|
||||
end
|
||||
|
||||
f_name = partial.dig(:function, :name)
|
||||
|
||||
@current_function ||= function_buffer.at("invoke")
|
||||
|
||||
if f_name
|
||||
current_name = function_buffer.at("tool_name").content
|
||||
|
||||
if current_name.blank?
|
||||
# first call
|
||||
else
|
||||
# we have a previous function, so we need to add a noop
|
||||
@args_buffer = +""
|
||||
@current_function =
|
||||
function_buffer.at("function_calls").add_child(
|
||||
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
||||
)
|
||||
idx = -1
|
||||
parsed
|
||||
.dig(:message, :tool_calls)
|
||||
&.each do |tool_call|
|
||||
idx += 1
|
||||
id = "tool_#{idx}"
|
||||
name = tool_call.dig(:function, :name)
|
||||
args = tool_call.dig(:function, :arguments)
|
||||
rval << ToolCall.new(id: id, name: name, parameters: args)
|
||||
end
|
||||
end
|
||||
|
||||
@current_function.at("tool_name").content = f_name if f_name
|
||||
@current_function.at("tool_id").content = partial[:id] if partial[:id]
|
||||
|
||||
args = partial.dig(:function, :arguments)
|
||||
|
||||
# allow for SPACE within arguments
|
||||
if args && args != ""
|
||||
@args_buffer << args.to_json
|
||||
|
||||
begin
|
||||
json_args = JSON.parse(@args_buffer, symbolize_names: true)
|
||||
|
||||
argument_fragments =
|
||||
json_args.reduce(+"") do |memo, (arg_name, value)|
|
||||
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
|
||||
end
|
||||
argument_fragments << "\n"
|
||||
|
||||
@current_function.at("parameters").children =
|
||||
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
||||
rescue JSON::ParserError
|
||||
return function_buffer
|
||||
end
|
||||
end
|
||||
|
||||
function_buffer
|
||||
rval
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -93,98 +93,34 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = @prompt_tokens if @prompt_tokens
|
||||
log.response_tokens = @completion_tokens if @completion_tokens
|
||||
log.request_tokens = processor.prompt_tokens if processor.prompt_tokens
|
||||
log.response_tokens = processor.completion_tokens if processor.completion_tokens
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
json = JSON.parse(response_raw, symbolize_names: true)
|
||||
|
||||
if @streaming_mode
|
||||
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
||||
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
||||
end
|
||||
|
||||
parsed = json.dig(:choices, 0)
|
||||
return if !parsed
|
||||
|
||||
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
||||
@has_function_call ||= response_h.dig(:tool_calls).present?
|
||||
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
|
||||
def decode(response_raw)
|
||||
processor.process_message(JSON.parse(response_raw, symbolize_names: true))
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
decoded_chunk
|
||||
.split("\n")
|
||||
.map do |line|
|
||||
data = line.split("data: ", 2)[1]
|
||||
data == "[DONE]" ? nil : data
|
||||
end
|
||||
def decode_chunk(chunk)
|
||||
@decoder ||= JsonStreamDecoder.new
|
||||
(@decoder << chunk)
|
||||
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
|
||||
.flatten
|
||||
.compact
|
||||
end
|
||||
|
||||
def has_tool?(_response_data)
|
||||
@has_function_call
|
||||
def decode_chunk_finish
|
||||
@processor.finish
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
true
|
||||
def xml_tools_enabled?
|
||||
false
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
if @streaming_mode
|
||||
return function_buffer if !partial
|
||||
else
|
||||
partial = payload
|
||||
end
|
||||
private
|
||||
|
||||
@args_buffer ||= +""
|
||||
|
||||
f_name = partial.dig(:function, :name)
|
||||
|
||||
@current_function ||= function_buffer.at("invoke")
|
||||
|
||||
if f_name
|
||||
current_name = function_buffer.at("tool_name").content
|
||||
|
||||
if current_name.blank?
|
||||
# first call
|
||||
else
|
||||
# we have a previous function, so we need to add a noop
|
||||
@args_buffer = +""
|
||||
@current_function =
|
||||
function_buffer.at("function_calls").add_child(
|
||||
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
||||
)
|
||||
end
|
||||
end
|
||||
|
||||
@current_function.at("tool_name").content = f_name if f_name
|
||||
@current_function.at("tool_id").content = partial[:id] if partial[:id]
|
||||
|
||||
args = partial.dig(:function, :arguments)
|
||||
|
||||
# allow for SPACE within arguments
|
||||
if args && args != ""
|
||||
@args_buffer << args
|
||||
|
||||
begin
|
||||
json_args = JSON.parse(@args_buffer, symbolize_names: true)
|
||||
|
||||
argument_fragments =
|
||||
json_args.reduce(+"") do |memo, (arg_name, value)|
|
||||
memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}</#{arg_name}>"
|
||||
end
|
||||
argument_fragments << "\n"
|
||||
|
||||
@current_function.at("parameters").children =
|
||||
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
||||
rescue JSON::ParserError
|
||||
return function_buffer
|
||||
end
|
||||
end
|
||||
|
||||
function_buffer
|
||||
def processor
|
||||
@processor ||= OpenAiMessageProcessor.new
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -55,27 +55,31 @@ module DiscourseAi
|
|||
log.response_tokens = @completion_tokens if @completion_tokens
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
json = JSON.parse(response_raw, symbolize_names: true)
|
||||
|
||||
if @streaming_mode
|
||||
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
||||
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
||||
end
|
||||
|
||||
parsed = json.dig(:choices, 0)
|
||||
return if !parsed
|
||||
|
||||
@streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content)
|
||||
def xml_tools_enabled?
|
||||
true
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
decoded_chunk
|
||||
.split("\n")
|
||||
.map do |line|
|
||||
data = line.split("data: ", 2)[1]
|
||||
data == "[DONE]" ? nil : data
|
||||
def decode(response_raw)
|
||||
json = JSON.parse(response_raw, symbolize_names: true)
|
||||
[json.dig(:choices, 0, :message, :content)]
|
||||
end
|
||||
|
||||
def decode_chunk(chunk)
|
||||
@json_decoder ||= JsonStreamDecoder.new
|
||||
(@json_decoder << chunk)
|
||||
.map do |json|
|
||||
text = json.dig(:choices, 0, :delta, :content)
|
||||
|
||||
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
||||
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
||||
|
||||
if !text.to_s.empty?
|
||||
text
|
||||
else
|
||||
nil
|
||||
end
|
||||
end
|
||||
.flatten
|
||||
.compact
|
||||
end
|
||||
end
|
||||
|
|
|
@ -42,7 +42,10 @@ module DiscourseAi
|
|||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
payload = default_options.merge(model_params).merge(messages: prompt)
|
||||
payload[:stream] = true if @streaming_mode
|
||||
if @streaming_mode
|
||||
payload[:stream] = true if @streaming_mode
|
||||
payload[:stream_options] = { include_usage: true }
|
||||
end
|
||||
|
||||
payload
|
||||
end
|
||||
|
@ -56,24 +59,42 @@ module DiscourseAi
|
|||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
def partials_from(decoded_chunk)
|
||||
decoded_chunk
|
||||
.split("\n")
|
||||
.map do |line|
|
||||
data = line.split("data: ", 2)[1]
|
||||
data == "[DONE]" ? nil : data
|
||||
end
|
||||
.compact
|
||||
def xml_tools_enabled?
|
||||
true
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
|
||||
# half a line sent here
|
||||
return if !parsed
|
||||
def final_log_update(log)
|
||||
log.request_tokens = @prompt_tokens if @prompt_tokens
|
||||
log.response_tokens = @completion_tokens if @completion_tokens
|
||||
end
|
||||
|
||||
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
||||
def decode(response_raw)
|
||||
json = JSON.parse(response_raw, symbolize_names: true)
|
||||
@prompt_tokens = json.dig(:usage, :prompt_tokens)
|
||||
@completion_tokens = json.dig(:usage, :completion_tokens)
|
||||
[json.dig(:choices, 0, :message, :content)]
|
||||
end
|
||||
|
||||
response_h.dig(:content)
|
||||
def decode_chunk(chunk)
|
||||
@json_decoder ||= JsonStreamDecoder.new
|
||||
(@json_decoder << chunk)
|
||||
.map do |parsed|
|
||||
# vLLM keeps sending usage over and over again
|
||||
prompt_tokens = parsed.dig(:usage, :prompt_tokens)
|
||||
completion_tokens = parsed.dig(:usage, :completion_tokens)
|
||||
|
||||
@prompt_tokens = prompt_tokens if prompt_tokens
|
||||
|
||||
@completion_tokens = completion_tokens if completion_tokens
|
||||
|
||||
text = parsed.dig(:choices, 0, :delta, :content)
|
||||
if text.to_s.empty?
|
||||
nil
|
||||
else
|
||||
text
|
||||
end
|
||||
end
|
||||
.compact
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -1,113 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class DiscourseAi::Completions::FunctionCallNormalizer
|
||||
attr_reader :done
|
||||
|
||||
# blk is the block to call with filtered data
|
||||
def initialize(blk, cancel)
|
||||
@blk = blk
|
||||
@cancel = cancel
|
||||
@done = false
|
||||
|
||||
@in_tool = false
|
||||
|
||||
@buffer = +""
|
||||
@function_buffer = +""
|
||||
end
|
||||
|
||||
def self.normalize(data)
|
||||
text = +""
|
||||
cancel = -> {}
|
||||
blk = ->(partial, _) { text << partial }
|
||||
|
||||
normalizer = self.new(blk, cancel)
|
||||
normalizer << data
|
||||
|
||||
[text, normalizer.function_calls]
|
||||
end
|
||||
|
||||
def function_calls
|
||||
return nil if @function_buffer.blank?
|
||||
|
||||
xml = Nokogiri::HTML5.fragment(@function_buffer)
|
||||
self.class.normalize_function_ids!(xml)
|
||||
last_invoke = xml.at("invoke:last")
|
||||
if last_invoke
|
||||
last_invoke.next_sibling.remove while last_invoke.next_sibling
|
||||
xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
|
||||
end
|
||||
xml.at("function_calls").to_s.dup.force_encoding("UTF-8")
|
||||
end
|
||||
|
||||
def <<(text)
|
||||
@buffer << text
|
||||
|
||||
if !@in_tool
|
||||
# double check if we are clearly in a tool
|
||||
search_length = text.length + 20
|
||||
search_string = @buffer[-search_length..-1] || @buffer
|
||||
|
||||
index = search_string.rindex("<function_calls>")
|
||||
@in_tool = !!index
|
||||
if @in_tool
|
||||
@function_buffer = @buffer[index..-1]
|
||||
text_index = text.rindex("<function_calls>")
|
||||
@blk.call(text[0..text_index - 1].strip, @cancel) if text_index && text_index > 0
|
||||
end
|
||||
else
|
||||
@function_buffer << text
|
||||
end
|
||||
|
||||
if !@in_tool
|
||||
if maybe_has_tool?(@buffer)
|
||||
split_index = text.rindex("<").to_i - 1
|
||||
if split_index >= 0
|
||||
@function_buffer = text[split_index + 1..-1] || ""
|
||||
text = text[0..split_index] || ""
|
||||
else
|
||||
@function_buffer << text
|
||||
text = ""
|
||||
end
|
||||
else
|
||||
if @function_buffer.length > 0
|
||||
@blk.call(@function_buffer, @cancel)
|
||||
@function_buffer = +""
|
||||
end
|
||||
end
|
||||
|
||||
@blk.call(text, @cancel) if text.length > 0
|
||||
else
|
||||
if text.include?("</function_calls>")
|
||||
@done = true
|
||||
@cancel.call
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def self.normalize_function_ids!(function_buffer)
|
||||
function_buffer
|
||||
.css("invoke")
|
||||
.each_with_index do |invoke, index|
|
||||
if invoke.at("tool_id")
|
||||
invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
|
||||
else
|
||||
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def maybe_has_tool?(text)
|
||||
# 16 is the length of function calls
|
||||
substring = text[-16..-1] || text
|
||||
split = substring.split("<")
|
||||
|
||||
if split.length > 1
|
||||
match = "<" + split.last
|
||||
"<function_calls>".start_with?(match)
|
||||
else
|
||||
substring.ends_with?("<")
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,48 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
# will work for anthropic and open ai compatible
|
||||
class JsonStreamDecoder
|
||||
attr_reader :buffer
|
||||
|
||||
LINE_REGEX = /data: ({.*})\s*$/
|
||||
|
||||
def initialize(symbolize_keys: true, line_regex: LINE_REGEX)
|
||||
@symbolize_keys = symbolize_keys
|
||||
@buffer = +""
|
||||
@line_regex = line_regex
|
||||
end
|
||||
|
||||
def <<(raw)
|
||||
@buffer << raw.to_s
|
||||
rval = []
|
||||
|
||||
split = @buffer.scan(/.*\n?/)
|
||||
split.pop if split.last.blank?
|
||||
|
||||
@buffer = +(split.pop.to_s)
|
||||
|
||||
split.each do |line|
|
||||
matches = line.match(@line_regex)
|
||||
next if !matches
|
||||
rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
|
||||
end
|
||||
|
||||
if @buffer.present?
|
||||
matches = @buffer.match(@line_regex)
|
||||
if matches
|
||||
begin
|
||||
rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
|
||||
@buffer = +""
|
||||
rescue JSON::ParserError
|
||||
# maybe it is a partial line
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
rval
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,103 @@
|
|||
# frozen_string_literal: true
|
||||
module DiscourseAi::Completions
|
||||
class OpenAiMessageProcessor
|
||||
attr_reader :prompt_tokens, :completion_tokens
|
||||
|
||||
def initialize
|
||||
@tool = nil
|
||||
@tool_arguments = +""
|
||||
@prompt_tokens = nil
|
||||
@completion_tokens = nil
|
||||
end
|
||||
|
||||
def process_message(json)
|
||||
result = []
|
||||
tool_calls = json.dig(:choices, 0, :message, :tool_calls)
|
||||
|
||||
message = json.dig(:choices, 0, :message, :content)
|
||||
result << message if message.present?
|
||||
|
||||
if tool_calls.present?
|
||||
tool_calls.each do |tool_call|
|
||||
id = tool_call.dig(:id)
|
||||
name = tool_call.dig(:function, :name)
|
||||
arguments = tool_call.dig(:function, :arguments)
|
||||
parameters = arguments.present? ? JSON.parse(arguments, symbolize_names: true) : {}
|
||||
result << ToolCall.new(id: id, name: name, parameters: parameters)
|
||||
end
|
||||
end
|
||||
|
||||
update_usage(json)
|
||||
|
||||
result
|
||||
end
|
||||
|
||||
def process_streamed_message(json)
|
||||
rval = nil
|
||||
|
||||
tool_calls = json.dig(:choices, 0, :delta, :tool_calls)
|
||||
content = json.dig(:choices, 0, :delta, :content)
|
||||
|
||||
finished_tools = json.dig(:choices, 0, :finish_reason) || tool_calls == []
|
||||
|
||||
if tool_calls.present?
|
||||
id = tool_calls.dig(0, :id)
|
||||
name = tool_calls.dig(0, :function, :name)
|
||||
arguments = tool_calls.dig(0, :function, :arguments)
|
||||
|
||||
# TODO: multiple tool support may require index
|
||||
#index = tool_calls[0].dig(:index)
|
||||
|
||||
if id.present? && @tool && @tool.id != id
|
||||
process_arguments
|
||||
rval = @tool
|
||||
@tool = nil
|
||||
end
|
||||
|
||||
if id.present? && name.present?
|
||||
@tool_arguments = +""
|
||||
@tool = ToolCall.new(id: id, name: name)
|
||||
end
|
||||
|
||||
@tool_arguments << arguments.to_s
|
||||
elsif finished_tools && @tool
|
||||
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
|
||||
@tool.parameters = parsed_args
|
||||
rval = @tool
|
||||
@tool = nil
|
||||
elsif content.present?
|
||||
rval = content
|
||||
end
|
||||
|
||||
update_usage(json)
|
||||
|
||||
rval
|
||||
end
|
||||
|
||||
def finish
|
||||
rval = []
|
||||
if @tool
|
||||
process_arguments
|
||||
rval << @tool
|
||||
@tool = nil
|
||||
end
|
||||
|
||||
rval
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def process_arguments
|
||||
if @tool_arguments.present?
|
||||
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
|
||||
@tool.parameters = parsed_args
|
||||
@tool_arguments = nil
|
||||
end
|
||||
end
|
||||
|
||||
def update_usage(json)
|
||||
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
||||
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,29 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
class ToolCall
|
||||
attr_reader :id, :name, :parameters
|
||||
|
||||
def initialize(id:, name:, parameters: nil)
|
||||
@id = id
|
||||
@name = name
|
||||
self.parameters = parameters if parameters
|
||||
@parameters ||= {}
|
||||
end
|
||||
|
||||
def parameters=(parameters)
|
||||
raise ArgumentError, "parameters must be a hash" unless parameters.is_a?(Hash)
|
||||
@parameters = parameters.symbolize_keys
|
||||
end
|
||||
|
||||
def ==(other)
|
||||
id == other.id && name == other.name && parameters == other.parameters
|
||||
end
|
||||
|
||||
def to_s
|
||||
"#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,124 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
# This class can be used to process a stream of text that may contain XML tool
|
||||
# calls.
|
||||
# It will return either text or ToolCall objects.
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
class XmlToolProcessor
|
||||
def initialize
|
||||
@buffer = +""
|
||||
@function_buffer = +""
|
||||
@should_cancel = false
|
||||
@in_tool = false
|
||||
end
|
||||
|
||||
def <<(text)
|
||||
@buffer << text
|
||||
result = []
|
||||
|
||||
if !@in_tool
|
||||
# double check if we are clearly in a tool
|
||||
search_length = text.length + 20
|
||||
search_string = @buffer[-search_length..-1] || @buffer
|
||||
|
||||
index = search_string.rindex("<function_calls>")
|
||||
@in_tool = !!index
|
||||
if @in_tool
|
||||
@function_buffer = @buffer[index..-1]
|
||||
text_index = text.rindex("<function_calls>")
|
||||
result << text[0..text_index - 1].strip if text_index && text_index > 0
|
||||
end
|
||||
else
|
||||
@function_buffer << text
|
||||
end
|
||||
|
||||
if !@in_tool
|
||||
if maybe_has_tool?(@buffer)
|
||||
split_index = text.rindex("<").to_i - 1
|
||||
if split_index >= 0
|
||||
@function_buffer = text[split_index + 1..-1] || ""
|
||||
text = text[0..split_index] || ""
|
||||
else
|
||||
@function_buffer << text
|
||||
text = ""
|
||||
end
|
||||
else
|
||||
if @function_buffer.length > 0
|
||||
result << @function_buffer
|
||||
@function_buffer = +""
|
||||
end
|
||||
end
|
||||
|
||||
result << text if text.length > 0
|
||||
else
|
||||
@should_cancel = true if text.include?("</function_calls>")
|
||||
end
|
||||
|
||||
result
|
||||
end
|
||||
|
||||
def finish
|
||||
return [] if @function_buffer.blank?
|
||||
|
||||
xml = Nokogiri::HTML5.fragment(@function_buffer)
|
||||
normalize_function_ids!(xml)
|
||||
last_invoke = xml.at("invoke:last")
|
||||
if last_invoke
|
||||
last_invoke.next_sibling.remove while last_invoke.next_sibling
|
||||
xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
|
||||
end
|
||||
|
||||
xml
|
||||
.css("invoke")
|
||||
.map do |invoke|
|
||||
tool_name = invoke.at("tool_name").content.force_encoding("UTF-8")
|
||||
tool_id = invoke.at("tool_id").content.force_encoding("UTF-8")
|
||||
parameters = {}
|
||||
invoke
|
||||
.at("parameters")
|
||||
&.children
|
||||
&.each do |node|
|
||||
next if node.text?
|
||||
name = node.name
|
||||
value = node.content.to_s
|
||||
parameters[name.to_sym] = value.to_s.force_encoding("UTF-8")
|
||||
end
|
||||
ToolCall.new(id: tool_id, name: tool_name, parameters: parameters)
|
||||
end
|
||||
end
|
||||
|
||||
def should_cancel?
|
||||
@should_cancel
|
||||
end
|
||||
|
||||
private
|
||||
|
||||
def normalize_function_ids!(function_buffer)
|
||||
function_buffer
|
||||
.css("invoke")
|
||||
.each_with_index do |invoke, index|
|
||||
if invoke.at("tool_id")
|
||||
invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
|
||||
else
|
||||
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def maybe_has_tool?(text)
|
||||
# 16 is the length of function calls
|
||||
substring = text[-16..-1] || text
|
||||
split = substring.split("<")
|
||||
|
||||
if split.length > 1
|
||||
match = "<" + split.last
|
||||
"<function_calls>".start_with?(match)
|
||||
else
|
||||
substring.ends_with?("<")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -104,7 +104,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
data: {"type":"message_stop"}
|
||||
STRING
|
||||
|
||||
result = +""
|
||||
result = []
|
||||
body = body.scan(/.*\n/)
|
||||
EndpointMock.with_chunk_array_support do
|
||||
stub_request(:post, url).to_return(status: 200, body: body)
|
||||
|
@ -114,18 +114,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
end
|
||||
end
|
||||
|
||||
expected = (<<~TEXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<parameters><search_query>s<a>m sam</search_query>
|
||||
<category>general</category></parameters>
|
||||
<tool_id>toolu_01DjrShFRRHp9SnHYRFRc53F</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "search",
|
||||
id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
|
||||
parameters: {
|
||||
search_query: "s<a>m sam",
|
||||
category: "general",
|
||||
},
|
||||
)
|
||||
|
||||
expect(result.strip).to eq(expected)
|
||||
expect(result).to eq([tool_call])
|
||||
end
|
||||
|
||||
it "can stream a response" do
|
||||
|
@ -191,6 +190,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
expect(log.feature_name).to eq("testing")
|
||||
expect(log.response_tokens).to eq(15)
|
||||
expect(log.request_tokens).to eq(25)
|
||||
expect(log.raw_request_payload).to eq(expected_body.to_json)
|
||||
expect(log.raw_response_payload.strip).to eq(body.strip)
|
||||
end
|
||||
|
||||
it "supports non streaming tool calls" do
|
||||
|
@ -242,17 +243,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
|
||||
result = llm.generate(prompt, user: Discourse.system_user)
|
||||
|
||||
expected = <<~TEXT.strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>calculate</tool_name>
|
||||
<parameters><expression>2758975 + 21.11</expression></parameters>
|
||||
<tool_id>toolu_012kBdhG4eHaV68W56p4N94h</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "calculate",
|
||||
id: "toolu_012kBdhG4eHaV68W56p4N94h",
|
||||
parameters: {
|
||||
expression: "2758975 + 21.11",
|
||||
},
|
||||
)
|
||||
|
||||
expect(result.strip).to eq(expected)
|
||||
expect(result).to eq(["Here is the calculation:", tool_call])
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.request_tokens).to eq(345)
|
||||
expect(log.response_tokens).to eq(65)
|
||||
end
|
||||
|
||||
it "can send images via a completion prompt" do
|
||||
|
|
|
@ -79,7 +79,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
}
|
||||
|
||||
prompt.tools = [tool]
|
||||
response = +""
|
||||
response = []
|
||||
proxy.generate(prompt, user: user) { |partial| response << partial }
|
||||
|
||||
expect(request.headers["Authorization"]).to be_present
|
||||
|
@ -90,21 +90,18 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
expect(parsed_body["tools"]).to eq(nil)
|
||||
expect(parsed_body["stop_sequences"]).to eq(["</function_calls>"])
|
||||
|
||||
# note we now have a tool_id cause we were normalized
|
||||
function_call = <<~XML.strip
|
||||
hello
|
||||
expected = [
|
||||
"hello\n",
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "tool_0",
|
||||
name: "google",
|
||||
parameters: {
|
||||
query: "sydney weather today",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>google</tool_name>
|
||||
<parameters><query>sydney weather today</query></parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
expect(response.strip).to eq(function_call)
|
||||
expect(response).to eq(expected)
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -230,23 +227,23 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
|||
}
|
||||
|
||||
prompt.tools = [tool]
|
||||
response = +""
|
||||
response = []
|
||||
proxy.generate(prompt, user: user) { |partial| response << partial }
|
||||
|
||||
expect(request.headers["Authorization"]).to be_present
|
||||
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
||||
|
||||
expected_response = (<<~RESPONSE).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>google</tool_name>
|
||||
<parameters><query>sydney weather today</query></parameters>
|
||||
<tool_id>toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
RESPONSE
|
||||
expected_response = [
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
|
||||
name: "google",
|
||||
parameters: {
|
||||
query: "sydney weather today",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
expect(response.strip).to eq(expected_response)
|
||||
expect(response).to eq(expected_response)
|
||||
|
||||
expected = {
|
||||
"max_tokens" => 3000,
|
||||
|
|
|
@ -66,7 +66,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
|||
TEXT
|
||||
|
||||
parsed_body = nil
|
||||
result = +""
|
||||
result = []
|
||||
|
||||
sig = {
|
||||
name: "google",
|
||||
|
@ -91,21 +91,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
|||
},
|
||||
).to_return(status: 200, body: body.split("|"))
|
||||
|
||||
result = llm.generate(prompt, user: user) { |partial, cancel| result << partial }
|
||||
llm.generate(prompt, user: user) { |partial, cancel| result << partial }
|
||||
end
|
||||
|
||||
expected = <<~TEXT
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>google</tool_name>
|
||||
<parameters><query>who is sam saffron</query>
|
||||
</parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
text = "I will search for 'who is sam saffron' and relay the information to the user."
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "tool_0",
|
||||
name: "google",
|
||||
parameters: {
|
||||
query: "who is sam saffron",
|
||||
},
|
||||
)
|
||||
|
||||
expect(result.strip).to eq(expected.strip)
|
||||
expect(result).to eq([text, tool_call])
|
||||
|
||||
expected = {
|
||||
model: "command-r-plus",
|
||||
|
|
|
@ -62,18 +62,14 @@ class EndpointMock
|
|||
end
|
||||
|
||||
def invocation_response
|
||||
<<~TEXT
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>get_weather</tool_name>
|
||||
<parameters>
|
||||
<location>Sydney</location>
|
||||
<unit>c</unit>
|
||||
</parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "tool_0",
|
||||
name: "get_weather",
|
||||
parameters: {
|
||||
location: "Sydney",
|
||||
unit: "c",
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
def tool_id
|
||||
|
@ -185,7 +181,7 @@ class EndpointsCompliance
|
|||
mock.stub_tool_call(a_dialect.translate)
|
||||
|
||||
completion_response = endpoint.perform_completion!(a_dialect, user)
|
||||
expect(completion_response.strip).to eq(mock.invocation_response.strip)
|
||||
expect(completion_response).to eq(mock.invocation_response)
|
||||
end
|
||||
|
||||
def streaming_mode_simple_prompt(mock)
|
||||
|
@ -205,6 +201,7 @@ class EndpointsCompliance
|
|||
expect(log.raw_request_payload).to be_present
|
||||
expect(log.raw_response_payload).to be_present
|
||||
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
||||
|
||||
expect(log.response_tokens).to eq(
|
||||
endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join),
|
||||
)
|
||||
|
@ -216,14 +213,14 @@ class EndpointsCompliance
|
|||
a_dialect = dialect(prompt: prompt)
|
||||
|
||||
mock.stub_streamed_tool_call(a_dialect.translate) do
|
||||
buffered_partial = +""
|
||||
buffered_partial = []
|
||||
|
||||
endpoint.perform_completion!(a_dialect, user) do |partial, cancel|
|
||||
buffered_partial << partial
|
||||
cancel.call if buffered_partial.include?("<function_calls>")
|
||||
cancel.call if partial.is_a?(DiscourseAi::Completions::ToolCall)
|
||||
end
|
||||
|
||||
expect(buffered_partial.strip).to eq(mock.invocation_response.strip)
|
||||
expect(buffered_partial).to eq([mock.invocation_response])
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -195,19 +195,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
|
||||
response = llm.generate(prompt, user: user)
|
||||
|
||||
expected = (<<~XML).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text><S>ydney</text>
|
||||
</parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
tool =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "tool_0",
|
||||
name: "echo",
|
||||
parameters: {
|
||||
text: "<S>ydney",
|
||||
},
|
||||
)
|
||||
|
||||
expect(response.strip).to eq(expected)
|
||||
expect(response).to eq(tool)
|
||||
end
|
||||
|
||||
it "Supports Vision API" do
|
||||
|
@ -265,6 +262,68 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
expect(JSON.parse(req_body)).to eq(expected_prompt)
|
||||
end
|
||||
|
||||
it "Can stream tool calls correctly" do
|
||||
rows = [
|
||||
{
|
||||
candidates: [
|
||||
{
|
||||
content: {
|
||||
parts: [{ functionCall: { name: "echo", args: { text: "sam<>wh!s" } } }],
|
||||
role: "model",
|
||||
},
|
||||
safetyRatings: [
|
||||
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
|
||||
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
|
||||
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
|
||||
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
|
||||
],
|
||||
},
|
||||
],
|
||||
usageMetadata: {
|
||||
promptTokenCount: 625,
|
||||
totalTokenCount: 625,
|
||||
},
|
||||
modelVersion: "gemini-1.5-pro-002",
|
||||
},
|
||||
{
|
||||
candidates: [{ content: { parts: [{ text: "" }], role: "model" }, finishReason: "STOP" }],
|
||||
usageMetadata: {
|
||||
promptTokenCount: 625,
|
||||
candidatesTokenCount: 4,
|
||||
totalTokenCount: 629,
|
||||
},
|
||||
modelVersion: "gemini-1.5-pro-002",
|
||||
},
|
||||
]
|
||||
|
||||
payload = rows.map { |r| "data: #{r.to_json}\n\n" }.join
|
||||
|
||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
|
||||
|
||||
prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool])
|
||||
|
||||
output = []
|
||||
|
||||
stub_request(:post, url).to_return(status: 200, body: payload)
|
||||
llm.generate(prompt, user: user) { |partial| output << partial }
|
||||
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "tool_0",
|
||||
name: "echo",
|
||||
parameters: {
|
||||
text: "sam<>wh!s",
|
||||
},
|
||||
)
|
||||
|
||||
expect(output).to eq([tool_call])
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.request_tokens).to eq(625)
|
||||
expect(log.response_tokens).to eq(4)
|
||||
end
|
||||
|
||||
it "Can correctly handle streamed responses even if they are chunked badly" do
|
||||
data = +""
|
||||
data << "da|ta: |"
|
||||
|
@ -279,12 +338,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
|||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
|
||||
|
||||
output = +""
|
||||
output = []
|
||||
gemini_mock.with_chunk_array_support do
|
||||
stub_request(:post, url).to_return(status: 200, body: split)
|
||||
llm.generate("Hello", user: user) { |partial| output << partial }
|
||||
end
|
||||
|
||||
expect(output).to eq("Hello World Sam")
|
||||
expect(output.join).to eq("Hello World Sam")
|
||||
end
|
||||
end
|
||||
|
|
|
@ -150,7 +150,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do
|
|||
end
|
||||
|
||||
describe "when using streaming mode" do
|
||||
context "with simpel prompts" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.streaming_mode_simple_prompt(ollama_mock)
|
||||
end
|
||||
|
|
|
@ -17,8 +17,8 @@ class OpenAiMock < EndpointMock
|
|||
created: 1_678_464_820,
|
||||
model: "gpt-3.5-turbo-0301",
|
||||
usage: {
|
||||
prompt_tokens: 337,
|
||||
completion_tokens: 162,
|
||||
prompt_tokens: 8,
|
||||
completion_tokens: 13,
|
||||
total_tokens: 499,
|
||||
},
|
||||
choices: [
|
||||
|
@ -231,19 +231,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
|
||||
result = llm.generate(prompt, user: user)
|
||||
|
||||
expected = (<<~TXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>hello</text>
|
||||
</parameters>
|
||||
<tool_id>call_I8LKnoijVuhKOM85nnEQgWwd</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TXT
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "call_I8LKnoijVuhKOM85nnEQgWwd",
|
||||
name: "echo",
|
||||
parameters: {
|
||||
text: "hello",
|
||||
},
|
||||
)
|
||||
|
||||
expect(result.strip).to eq(expected)
|
||||
expect(result).to eq(tool_call)
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
body: { choices: [message: { content: "OK" }] }.to_json,
|
||||
|
@ -320,19 +317,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
|
||||
expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } })
|
||||
|
||||
expected = (<<~TXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>echo</tool_name>
|
||||
<parameters>
|
||||
<text>h<e>llo</text>
|
||||
</parameters>
|
||||
<tool_id>call_I8LKnoijVuhKOM85nnEQgWwd</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TXT
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.request_tokens).to eq(55)
|
||||
expect(log.response_tokens).to eq(13)
|
||||
|
||||
expect(result.strip).to eq(expected)
|
||||
expected =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "call_I8LKnoijVuhKOM85nnEQgWwd",
|
||||
name: "echo",
|
||||
parameters: {
|
||||
text: "h<e>llo",
|
||||
},
|
||||
)
|
||||
|
||||
expect(result).to eq(expected)
|
||||
|
||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||
body: { choices: [message: { content: "OK" }] }.to_json,
|
||||
|
@ -487,7 +485,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot\\"}"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot2\\"}"}}]},"logprobs":null,"finish_reason":null}]}
|
||||
|
||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
|
||||
|
||||
|
@ -495,32 +493,30 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
|||
TEXT
|
||||
|
||||
open_ai_mock.stub_raw(raw_data)
|
||||
content = +""
|
||||
response = []
|
||||
|
||||
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
|
||||
|
||||
endpoint.perform_completion!(dialect, user) { |partial| content << partial }
|
||||
endpoint.perform_completion!(dialect, user) { |partial| response << partial }
|
||||
|
||||
expected = <<~TEXT
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<parameters>
|
||||
<search_query>Discourse AI bot</search_query>
|
||||
</parameters>
|
||||
<tool_id>call_3Gyr3HylFJwfrtKrL6NaIit1</tool_id>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<parameters>
|
||||
<query>Discourse AI bot</query>
|
||||
</parameters>
|
||||
<tool_id>call_H7YkbgYurHpyJqzwUN4bghwN</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
tool_calls = [
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "search",
|
||||
id: "call_3Gyr3HylFJwfrtKrL6NaIit1",
|
||||
parameters: {
|
||||
search_query: "Discourse AI bot",
|
||||
},
|
||||
),
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "search",
|
||||
id: "call_H7YkbgYurHpyJqzwUN4bghwN",
|
||||
parameters: {
|
||||
query: "Discourse AI bot2",
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
expect(content).to eq(expected)
|
||||
expect(response).to eq(tool_calls)
|
||||
end
|
||||
|
||||
it "uses proper token accounting" do
|
||||
|
@ -593,21 +589,16 @@ TEXT
|
|||
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
|
||||
endpoint.perform_completion!(dialect, user) { |partial| partials << partial }
|
||||
|
||||
expect(partials.length).to eq(1)
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "func_id",
|
||||
name: "google",
|
||||
parameters: {
|
||||
query: "Adabas 9.1",
|
||||
},
|
||||
)
|
||||
|
||||
function_call = (<<~TXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>google</tool_name>
|
||||
<parameters>
|
||||
<query>Adabas 9.1</query>
|
||||
</parameters>
|
||||
<tool_id>func_id</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TXT
|
||||
|
||||
expect(partials[0].strip).to eq(function_call)
|
||||
expect(partials).to eq([tool_call])
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -22,10 +22,15 @@ data: [DONE]
|
|||
},
|
||||
).to_return(status: 200, body: body, headers: {})
|
||||
|
||||
response = +""
|
||||
response = []
|
||||
llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial }
|
||||
|
||||
expect(response).to eq("I am a bot")
|
||||
expect(response).to eq(["I am a bot"])
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
|
||||
expect(log.request_tokens).to eq(21)
|
||||
expect(log.response_tokens).to eq(41)
|
||||
end
|
||||
|
||||
it "can perform regular completions" do
|
||||
|
|
|
@ -51,7 +51,13 @@ class VllmMock < EndpointMock
|
|||
|
||||
WebMock
|
||||
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
||||
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
|
||||
.with(
|
||||
body:
|
||||
model
|
||||
.default_options
|
||||
.merge(messages: prompt, stream: true, stream_options: { include_usage: true })
|
||||
.to_json,
|
||||
)
|
||||
.to_return(status: 200, body: chunks)
|
||||
end
|
||||
end
|
||||
|
@ -136,29 +142,115 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
|||
|
||||
result = llm.generate(prompt, user: Discourse.system_user)
|
||||
|
||||
expected = <<~TEXT
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>calculate</tool_name>
|
||||
<parameters>
|
||||
<expression>1+1</expression></parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
expected =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "calculate",
|
||||
id: "tool_0",
|
||||
parameters: {
|
||||
expression: "1+1",
|
||||
},
|
||||
)
|
||||
|
||||
expect(result.strip).to eq(expected.strip)
|
||||
expect(result).to eq(expected)
|
||||
end
|
||||
end
|
||||
|
||||
it "correctly accounts for tokens in non streaming mode" do
|
||||
body = (<<~TEXT).strip
|
||||
{"id":"chat-c580e4a9ebaa44a0becc802ed5dc213a","object":"chat.completion","created":1731294404,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Random Number Generator Produces Smallest Possible Result","tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":146,"total_tokens":156,"completion_tokens":10},"prompt_logprobs":null}
|
||||
TEXT
|
||||
|
||||
stub_request(:post, "https://test.dev/v1/chat/completions").to_return(status: 200, body: body)
|
||||
|
||||
result = llm.generate("generate a title", user: Discourse.system_user)
|
||||
|
||||
expect(result).to eq("Random Number Generator Produces Smallest Possible Result")
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.request_tokens).to eq(146)
|
||||
expect(log.response_tokens).to eq(10)
|
||||
end
|
||||
|
||||
it "can properly include usage in streaming mode" do
|
||||
payload = <<~TEXT.strip
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":46,"completion_tokens":0}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":47,"completion_tokens":1}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Sam"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":48,"completion_tokens":2}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":49,"completion_tokens":3}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" It"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":50,"completion_tokens":4}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"'s"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":51,"completion_tokens":5}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" nice"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":52,"completion_tokens":6}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":53,"completion_tokens":7}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" meet"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":54,"completion_tokens":8}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":55,"completion_tokens":9}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":56,"completion_tokens":10}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Is"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":57,"completion_tokens":11}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" there"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":58,"completion_tokens":12}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" something"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":59,"completion_tokens":13}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":60,"completion_tokens":14}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":61,"completion_tokens":15}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" help"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":62,"completion_tokens":16}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":63,"completion_tokens":17}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" with"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":64,"completion_tokens":18}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" or"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":65,"completion_tokens":19}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" would"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":66,"completion_tokens":20}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":67,"completion_tokens":21}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" like"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":68,"completion_tokens":22}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":69,"completion_tokens":23}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" chat"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":70,"completion_tokens":24}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":71,"completion_tokens":25}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":""},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}}
|
||||
|
||||
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}}
|
||||
|
||||
data: [DONE]
|
||||
TEXT
|
||||
|
||||
stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
|
||||
status: 200,
|
||||
body: payload,
|
||||
)
|
||||
|
||||
response = []
|
||||
llm.generate("say hello", user: Discourse.system_user) { |partial| response << partial }
|
||||
|
||||
expect(response.join).to eq(
|
||||
"Hello Sam. It's nice to meet you. Is there something I can help you with or would you like to chat?",
|
||||
)
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.request_tokens).to eq(46)
|
||||
expect(log.response_tokens).to eq(26)
|
||||
end
|
||||
|
||||
describe "#perform_completion!" do
|
||||
context "when using regular mode" do
|
||||
context "with simple prompts" do
|
||||
it "completes a trivial prompt and logs the response" do
|
||||
compliance.regular_mode_simple_prompt(vllm_mock)
|
||||
end
|
||||
end
|
||||
|
||||
context "with tools" do
|
||||
it "returns a function invocation" do
|
||||
compliance.regular_mode_tools(vllm_mock)
|
||||
|
|
|
@ -1,182 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::FunctionCallNormalizer do
|
||||
let(:buffer) { +"" }
|
||||
|
||||
let(:normalizer) do
|
||||
blk = ->(data, cancel) { buffer << data }
|
||||
cancel = -> { @done = true }
|
||||
DiscourseAi::Completions::FunctionCallNormalizer.new(blk, cancel)
|
||||
end
|
||||
|
||||
def pass_through!(data)
|
||||
normalizer << data
|
||||
expect(buffer[-data.length..-1]).to eq(data)
|
||||
end
|
||||
|
||||
it "is usable in non streaming mode" do
|
||||
xml = (<<~XML).strip
|
||||
hello
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
</invoke>
|
||||
XML
|
||||
|
||||
text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
|
||||
|
||||
expect(text).to eq("hello")
|
||||
|
||||
expected_function_calls = (<<~XML).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
expect(function_calls).to eq(expected_function_calls)
|
||||
end
|
||||
|
||||
it "strips junk from end of function calls" do
|
||||
xml = (<<~XML).strip
|
||||
hello
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
</invoke>
|
||||
junk
|
||||
XML
|
||||
|
||||
_text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
|
||||
|
||||
expected_function_calls = (<<~XML).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
expect(function_calls).to eq(expected_function_calls)
|
||||
end
|
||||
|
||||
it "returns nil for function calls if there are none" do
|
||||
input = "hello world\n"
|
||||
text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(input)
|
||||
|
||||
expect(text).to eq(input)
|
||||
expect(function_calls).to eq(nil)
|
||||
end
|
||||
|
||||
it "passes through data if there are no function calls detected" do
|
||||
pass_through!("hello")
|
||||
pass_through!("<tool_name>hello</tool_name>")
|
||||
pass_through!("<parameters><hello>world</hello></parameters>")
|
||||
pass_through!("<function_call>")
|
||||
end
|
||||
|
||||
it "properly handles non English tools" do
|
||||
normalizer << "hello<function"
|
||||
expect(buffer).to eq("hello")
|
||||
|
||||
normalizer << "_calls>\n"
|
||||
|
||||
normalizer << (<<~XML).strip
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<parameters>
|
||||
<hello>世界</hello>
|
||||
</parameters>
|
||||
</invoke>
|
||||
XML
|
||||
|
||||
expected = (<<~XML).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<parameters>
|
||||
<hello>世界</hello>
|
||||
</parameters>
|
||||
<tool_id>tool_0</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
function_calls = normalizer.function_calls
|
||||
expect(function_calls).to eq(expected)
|
||||
end
|
||||
|
||||
it "works correctly even if you only give it 1 letter at a time" do
|
||||
xml = (<<~XML).strip
|
||||
abc
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<parameters>
|
||||
<hello>world</hello>
|
||||
</parameters>
|
||||
<tool_id>abc</tool_id>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>hello2</tool_name>
|
||||
<parameters>
|
||||
<hello>world</hello>
|
||||
</parameters>
|
||||
<tool_id>aba</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
xml.each_char { |char| normalizer << char }
|
||||
|
||||
expect(buffer + normalizer.function_calls).to eq(xml)
|
||||
end
|
||||
|
||||
it "supports multiple invokes" do
|
||||
xml = (<<~XML).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<parameters>
|
||||
<hello>world</hello>
|
||||
</parameters>
|
||||
<tool_id>abc</tool_id>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>hello2</tool_name>
|
||||
<parameters>
|
||||
<hello>world</hello>
|
||||
</parameters>
|
||||
<tool_id>aba</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
normalizer << xml
|
||||
|
||||
expect(normalizer.function_calls).to eq(xml)
|
||||
end
|
||||
|
||||
it "can will cancel if it encounteres </function_calls>" do
|
||||
normalizer << "<function_calls>"
|
||||
expect(normalizer.done).to eq(false)
|
||||
normalizer << "</function_calls>"
|
||||
expect(normalizer.done).to eq(true)
|
||||
expect(@done).to eq(true)
|
||||
|
||||
expect(normalizer.function_calls).to eq("<function_calls></function_calls>")
|
||||
end
|
||||
|
||||
it "pauses on function call and starts buffering" do
|
||||
normalizer << "hello<function_call"
|
||||
expect(buffer).to eq("hello")
|
||||
expect(normalizer.done).to eq(false)
|
||||
|
||||
normalizer << ">"
|
||||
expect(buffer).to eq("hello<function_call>")
|
||||
expect(normalizer.done).to eq(false)
|
||||
end
|
||||
end
|
|
@ -0,0 +1,47 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
describe DiscourseAi::Completions::JsonStreamDecoder do
|
||||
let(:decoder) { DiscourseAi::Completions::JsonStreamDecoder.new }
|
||||
|
||||
it "should be able to parse simple messages" do
|
||||
result = decoder << "data: #{{ hello: "world" }.to_json}"
|
||||
expect(result).to eq([{ hello: "world" }])
|
||||
end
|
||||
|
||||
it "should handle anthropic mixed stlye streams" do
|
||||
stream = (<<~TEXT).split("|")
|
||||
event: |message_start|
|
||||
data: |{"hel|lo": "world"}|
|
||||
|
||||
event: |message_start
|
||||
data: {"foo": "bar"}
|
||||
|
||||
event: |message_start
|
||||
data: {"ba|z": "qux"|}
|
||||
|
||||
[DONE]
|
||||
TEXT
|
||||
|
||||
results = []
|
||||
stream.each { |chunk| results << (decoder << chunk) }
|
||||
|
||||
expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }])
|
||||
end
|
||||
|
||||
it "should be able to handle complex overlaps" do
|
||||
stream = (<<~TEXT).split("|")
|
||||
data: |{"hel|lo": "world"}
|
||||
|
||||
data: {"foo": "bar"}
|
||||
|
||||
data: {"ba|z": "qux"|}
|
||||
|
||||
[DONE]
|
||||
TEXT
|
||||
|
||||
results = []
|
||||
stream.each { |chunk| results << (decoder << chunk) }
|
||||
|
||||
expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }])
|
||||
end
|
||||
end
|
|
@ -0,0 +1,188 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
RSpec.describe DiscourseAi::Completions::XmlToolProcessor do
|
||||
let(:processor) { DiscourseAi::Completions::XmlToolProcessor.new }
|
||||
|
||||
it "can process simple text" do
|
||||
result = []
|
||||
result << (processor << "hello")
|
||||
result << (processor << " world ")
|
||||
expect(result).to eq([["hello"], [" world "]])
|
||||
expect(processor.finish).to eq([])
|
||||
expect(processor.should_cancel?).to eq(false)
|
||||
end
|
||||
|
||||
it "is usable for simple single message mode" do
|
||||
xml = (<<~XML).strip
|
||||
hello
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>hello</tool_name>
|
||||
<parameters>
|
||||
<hello>world</hello>
|
||||
<test>value</test>
|
||||
</parameters>
|
||||
</invoke>
|
||||
XML
|
||||
|
||||
result = []
|
||||
result << (processor << xml)
|
||||
result << (processor.finish)
|
||||
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "tool_0",
|
||||
name: "hello",
|
||||
parameters: {
|
||||
hello: "world",
|
||||
test: "value",
|
||||
},
|
||||
)
|
||||
expect(result).to eq([["hello"], [tool_call]])
|
||||
expect(processor.should_cancel?).to eq(false)
|
||||
end
|
||||
|
||||
it "handles multiple tool calls in sequence" do
|
||||
xml = (<<~XML).strip
|
||||
start
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>first_tool</tool_name>
|
||||
<parameters>
|
||||
<param1>value1</param1>
|
||||
</parameters>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>second_tool</tool_name>
|
||||
<parameters>
|
||||
<param2>value2</param2>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
end
|
||||
XML
|
||||
|
||||
result = []
|
||||
result << (processor << xml)
|
||||
result << (processor.finish)
|
||||
|
||||
first_tool =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "tool_0",
|
||||
name: "first_tool",
|
||||
parameters: {
|
||||
param1: "value1",
|
||||
},
|
||||
)
|
||||
|
||||
second_tool =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "tool_1",
|
||||
name: "second_tool",
|
||||
parameters: {
|
||||
param2: "value2",
|
||||
},
|
||||
)
|
||||
|
||||
expect(result).to eq([["start"], [first_tool, second_tool]])
|
||||
expect(processor.should_cancel?).to eq(true)
|
||||
end
|
||||
|
||||
it "handles non-English parameters correctly" do
|
||||
xml = (<<~XML).strip
|
||||
こんにちは
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>translator</tool_name>
|
||||
<parameters>
|
||||
<text>世界</text>
|
||||
</parameters>
|
||||
</invoke>
|
||||
XML
|
||||
|
||||
result = []
|
||||
result << (processor << xml)
|
||||
result << (processor.finish)
|
||||
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
id: "tool_0",
|
||||
name: "translator",
|
||||
parameters: {
|
||||
text: "世界",
|
||||
},
|
||||
)
|
||||
|
||||
expect(result).to eq([["こんにちは"], [tool_call]])
|
||||
end
|
||||
|
||||
it "processes input character by character" do
|
||||
xml =
|
||||
"hi<function_calls><invoke><tool_name>test</tool_name><parameters><p>v</p></parameters></invoke>"
|
||||
|
||||
result = []
|
||||
xml.each_char { |char| result << (processor << char) }
|
||||
result << processor.finish
|
||||
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { p: "v" })
|
||||
|
||||
filtered_result = result.reject(&:empty?)
|
||||
expect(filtered_result).to eq([["h"], ["i"], [tool_call]])
|
||||
end
|
||||
|
||||
it "handles malformed XML gracefully" do
|
||||
xml = (<<~XML).strip
|
||||
text
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>test</tool_name>
|
||||
<parameters>
|
||||
<param>value
|
||||
</parameters>
|
||||
</invoke>
|
||||
malformed
|
||||
XML
|
||||
|
||||
result = []
|
||||
result << (processor << xml)
|
||||
result << (processor.finish)
|
||||
|
||||
# Should just do its best to parse the XML
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { param: "" })
|
||||
expect(result).to eq([["text"], [tool_call]])
|
||||
end
|
||||
|
||||
it "correctly processes empty parameter sets" do
|
||||
xml = (<<~XML).strip
|
||||
hello
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>no_params</tool_name>
|
||||
<parameters>
|
||||
</parameters>
|
||||
</invoke>
|
||||
XML
|
||||
|
||||
result = []
|
||||
result << (processor << xml)
|
||||
result << (processor.finish)
|
||||
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "no_params", parameters: {})
|
||||
|
||||
expect(result).to eq([["hello"], [tool_call]])
|
||||
end
|
||||
|
||||
it "properly handles cancelled processing" do
|
||||
xml = "start<function_calls></function_calls>"
|
||||
result = []
|
||||
result << (processor << xml)
|
||||
result << (processor << "more text")
|
||||
result << processor.finish
|
||||
|
||||
expect(result).to eq([["start"], [], []])
|
||||
expect(processor.should_cancel?).to eq(true)
|
||||
end
|
||||
end
|
|
@ -72,40 +72,27 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
|
||||
it "can parse string that are wrapped in quotes" do
|
||||
SiteSetting.ai_stability_api_key = "123"
|
||||
xml = <<~XML
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>image</tool_name>
|
||||
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
|
||||
<parameters>
|
||||
<prompts>["cat oil painting", "big car"]</prompts>
|
||||
<aspect_ratio>"16:9"</aspect_ratio>
|
||||
</parameters>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>image</tool_name>
|
||||
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
|
||||
<parameters>
|
||||
<prompts>["cat oil painting", "big car"]</prompts>
|
||||
<aspect_ratio>'16:9'</aspect_ratio>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
|
||||
image1, image2 =
|
||||
tools =
|
||||
DiscourseAi::AiBot::Personas::Artist.new.find_tools(
|
||||
xml,
|
||||
bot_user: nil,
|
||||
llm: nil,
|
||||
context: nil,
|
||||
)
|
||||
expect(image1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
||||
expect(image1.parameters[:aspect_ratio]).to eq("16:9")
|
||||
expect(image2.parameters[:aspect_ratio]).to eq("16:9")
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "image",
|
||||
id: "call_JtYQMful5QKqw97XFsHzPweB",
|
||||
parameters: {
|
||||
prompts: ["cat oil painting", "big car"],
|
||||
aspect_ratio: "16:9",
|
||||
},
|
||||
)
|
||||
|
||||
expect(tools.length).to eq(2)
|
||||
tool_instance =
|
||||
DiscourseAi::AiBot::Personas::Artist.new.find_tool(
|
||||
tool_call,
|
||||
bot_user: nil,
|
||||
llm: nil,
|
||||
context: nil,
|
||||
)
|
||||
|
||||
expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
||||
expect(tool_instance.parameters[:aspect_ratio]).to eq("16:9")
|
||||
end
|
||||
|
||||
it "enforces enums" do
|
||||
|
@ -132,42 +119,68 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
</function_calls>
|
||||
XML
|
||||
|
||||
search1, search2 =
|
||||
tools =
|
||||
DiscourseAi::AiBot::Personas::General.new.find_tools(
|
||||
xml,
|
||||
bot_user: nil,
|
||||
llm: nil,
|
||||
context: nil,
|
||||
)
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "search",
|
||||
id: "call_JtYQMful5QKqw97XFsHzPweB",
|
||||
parameters: {
|
||||
max_posts: "3.2",
|
||||
status: "cow",
|
||||
foo: "bar",
|
||||
},
|
||||
)
|
||||
|
||||
expect(search1.parameters.key?(:status)).to eq(false)
|
||||
expect(search2.parameters[:status]).to eq("open")
|
||||
tool_instance =
|
||||
DiscourseAi::AiBot::Personas::General.new.find_tool(
|
||||
tool_call,
|
||||
bot_user: nil,
|
||||
llm: nil,
|
||||
context: nil,
|
||||
)
|
||||
|
||||
expect(tool_instance.parameters.key?(:status)).to eq(false)
|
||||
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "search",
|
||||
id: "call_JtYQMful5QKqw97XFsHzPweB",
|
||||
parameters: {
|
||||
max_posts: "3.2",
|
||||
status: "open",
|
||||
foo: "bar",
|
||||
},
|
||||
)
|
||||
|
||||
tool_instance =
|
||||
DiscourseAi::AiBot::Personas::General.new.find_tool(
|
||||
tool_call,
|
||||
bot_user: nil,
|
||||
llm: nil,
|
||||
context: nil,
|
||||
)
|
||||
|
||||
expect(tool_instance.parameters[:status]).to eq("open")
|
||||
end
|
||||
|
||||
it "can coerce integers" do
|
||||
xml = <<~XML
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
|
||||
<parameters>
|
||||
<max_posts>"3.2"</max_posts>
|
||||
<search_query>hello world</search_query>
|
||||
<foo>bar</foo>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "search",
|
||||
id: "call_JtYQMful5QKqw97XFsHzPweB",
|
||||
parameters: {
|
||||
max_posts: "3.2",
|
||||
search_query: "hello world",
|
||||
foo: "bar",
|
||||
},
|
||||
)
|
||||
|
||||
search, =
|
||||
tools =
|
||||
DiscourseAi::AiBot::Personas::General.new.find_tools(
|
||||
xml,
|
||||
bot_user: nil,
|
||||
llm: nil,
|
||||
context: nil,
|
||||
)
|
||||
search =
|
||||
DiscourseAi::AiBot::Personas::General.new.find_tool(
|
||||
tool_call,
|
||||
bot_user: nil,
|
||||
llm: nil,
|
||||
context: nil,
|
||||
)
|
||||
|
||||
expect(search.parameters[:max_posts]).to eq(3)
|
||||
expect(search.parameters[:search_query]).to eq("hello world")
|
||||
|
@ -177,43 +190,23 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
|||
it "can correctly parse arrays in tools" do
|
||||
SiteSetting.ai_openai_api_key = "123"
|
||||
|
||||
# Dall E tool uses an array for params
|
||||
xml = <<~XML
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>dall_e</tool_name>
|
||||
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
|
||||
<parameters>
|
||||
<prompts>["cat oil painting", "big car"]</prompts>
|
||||
</parameters>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>dall_e</tool_name>
|
||||
<tool_id>abc</tool_id>
|
||||
<parameters>
|
||||
<prompts>["pic3"]</prompts>
|
||||
</parameters>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>unknown</tool_name>
|
||||
<tool_id>abc</tool_id>
|
||||
<parameters>
|
||||
<prompts>["pic3"]</prompts>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
dall_e1, dall_e2 =
|
||||
tools =
|
||||
DiscourseAi::AiBot::Personas::DallE3.new.find_tools(
|
||||
xml,
|
||||
bot_user: nil,
|
||||
llm: nil,
|
||||
context: nil,
|
||||
)
|
||||
expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
||||
expect(dall_e2.parameters[:prompts]).to eq(["pic3"])
|
||||
expect(tools.length).to eq(2)
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "dall_e",
|
||||
id: "call_JtYQMful5QKqw97XFsHzPweB",
|
||||
parameters: {
|
||||
prompts: ["cat oil painting", "big car"],
|
||||
},
|
||||
)
|
||||
|
||||
tool_instance =
|
||||
DiscourseAi::AiBot::Personas::DallE3.new.find_tool(
|
||||
tool_call,
|
||||
bot_user: nil,
|
||||
llm: nil,
|
||||
context: nil,
|
||||
)
|
||||
expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
||||
end
|
||||
|
||||
describe "custom personas" do
|
||||
|
|
|
@ -55,6 +55,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
)
|
||||
end
|
||||
|
||||
before { SiteSetting.ai_embeddings_enabled = false }
|
||||
|
||||
after do
|
||||
# we must reset cache on persona cause data can be rolled back
|
||||
AiPersona.persona_cache.flush!
|
||||
|
@ -83,17 +85,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
end
|
||||
|
||||
let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) }
|
||||
let(:function_call) { (<<~XML).strip }
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<tool_id>666</tool_id>
|
||||
<parameters>
|
||||
<query>Can you use the custom tool</query>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>",
|
||||
XML
|
||||
let(:tool_call) do
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "search",
|
||||
id: "666",
|
||||
parameters: {
|
||||
query: "Can you use the custom tool",
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) }
|
||||
|
||||
|
@ -115,7 +115,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
reply_post = nil
|
||||
prompts = nil
|
||||
|
||||
responses = [function_call]
|
||||
responses = [tool_call]
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
|
||||
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
|
||||
reply_post = playground.reply_to(new_post)
|
||||
|
@ -133,7 +133,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
it "can force usage of a tool" do
|
||||
tool_name = "custom-#{custom_tool.id}"
|
||||
ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1)
|
||||
responses = [function_call, "custom tool did stuff (maybe)"]
|
||||
responses = [tool_call, "custom tool did stuff (maybe)"]
|
||||
|
||||
prompts = nil
|
||||
reply_post = nil
|
||||
|
@ -166,7 +166,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
|
||||
playground = DiscourseAi::AiBot::Playground.new(bot)
|
||||
|
||||
responses = [function_call, "custom tool did stuff (maybe)"]
|
||||
responses = [tool_call, "custom tool did stuff (maybe)"]
|
||||
|
||||
reply_post = nil
|
||||
|
||||
|
@ -206,13 +206,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
|
||||
playground = DiscourseAi::AiBot::Playground.new(bot)
|
||||
|
||||
responses = ["custom tool did stuff (maybe)", tool_call]
|
||||
|
||||
# lets ensure tool does not run...
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt|
|
||||
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
|
||||
reply_post = playground.reply_to(new_post)
|
||||
end
|
||||
|
||||
expect(reply_post.raw.strip).to eq(function_call)
|
||||
expect(reply_post.raw.strip).to eq("custom tool did stuff (maybe)")
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -452,10 +454,25 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
it "can run tools" do
|
||||
persona.update!(tools: ["Time"])
|
||||
|
||||
responses = [
|
||||
"<function_calls><invoke><tool_name>time</tool_name><tool_id>time</tool_id><parameters><timezone>Buenos Aires</timezone></parameters></invoke></function_calls>",
|
||||
"The time is 2023-12-14 17:24:00 -0300",
|
||||
]
|
||||
tool_call1 =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "time",
|
||||
id: "time",
|
||||
parameters: {
|
||||
timezone: "Buenos Aires",
|
||||
},
|
||||
)
|
||||
|
||||
tool_call2 =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "time",
|
||||
id: "time",
|
||||
parameters: {
|
||||
timezone: "Sydney",
|
||||
},
|
||||
)
|
||||
|
||||
responses = [[tool_call1, tool_call2], "The time is 2023-12-14 17:24:00 -0300"]
|
||||
|
||||
message =
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do
|
||||
|
@ -470,7 +487,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
|
||||
# it also needs to have tool details now set on message
|
||||
prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id)
|
||||
expect(prompt.custom_prompt.length).to eq(3)
|
||||
|
||||
expect(prompt.custom_prompt.length).to eq(5)
|
||||
|
||||
# TODO in chat I am mixed on including this in the context, but I guess maybe?
|
||||
# thinking about this
|
||||
|
@ -782,30 +800,29 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
end
|
||||
|
||||
it "supports multiple function calls" do
|
||||
response1 = (<<~TXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<tool_id>search</tool_id>
|
||||
<parameters>
|
||||
<search_query>testing various things</search_query>
|
||||
</parameters>
|
||||
</invoke>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<tool_id>search</tool_id>
|
||||
<parameters>
|
||||
<search_query>another search</search_query>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TXT
|
||||
tool_call1 =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "search",
|
||||
id: "search",
|
||||
parameters: {
|
||||
search_query: "testing various things",
|
||||
},
|
||||
)
|
||||
|
||||
tool_call2 =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "search",
|
||||
id: "search",
|
||||
parameters: {
|
||||
search_query: "another search",
|
||||
},
|
||||
)
|
||||
|
||||
response2 = "I found stuff"
|
||||
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do
|
||||
playground.reply_to(third_post)
|
||||
end
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(
|
||||
[[tool_call1, tool_call2], response2],
|
||||
) { playground.reply_to(third_post) }
|
||||
|
||||
last_post = third_post.topic.reload.posts.order(:post_number).last
|
||||
|
||||
|
@ -819,17 +836,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona.class_instance.new)
|
||||
playground = described_class.new(bot)
|
||||
|
||||
response1 = (<<~TXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<tool_id>search</tool_id>
|
||||
<parameters>
|
||||
<search_query>testing various things</search_query>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TXT
|
||||
response1 =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "search",
|
||||
id: "search",
|
||||
parameters: {
|
||||
search_query: "testing various things",
|
||||
},
|
||||
)
|
||||
|
||||
response2 = "I found stuff"
|
||||
|
||||
|
@ -843,17 +857,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
end
|
||||
|
||||
it "does not include placeholders in conversation context but includes all completions" do
|
||||
response1 = (<<~TXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<tool_id>search</tool_id>
|
||||
<parameters>
|
||||
<search_query>testing various things</search_query>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TXT
|
||||
response1 =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "search",
|
||||
id: "search",
|
||||
parameters: {
|
||||
search_query: "testing various things",
|
||||
},
|
||||
)
|
||||
|
||||
response2 = "I found some really amazing stuff!"
|
||||
|
||||
|
@ -889,17 +900,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
[{ b64_json: image, revised_prompt: "a pink cow 1" }]
|
||||
end
|
||||
|
||||
let(:response) { (<<~TXT).strip }
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>dall_e</tool_name>
|
||||
<tool_id>dall_e</tool_id>
|
||||
<parameters>
|
||||
<prompts>["a pink cow"]</prompts>
|
||||
</parameters>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TXT
|
||||
let(:response) do
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "dall_e",
|
||||
id: "dall_e",
|
||||
parameters: {
|
||||
prompts: ["a pink cow"],
|
||||
},
|
||||
)
|
||||
end
|
||||
|
||||
it "properly returns an image when skipping tool details" do
|
||||
persona.update!(tool_details: false)
|
||||
|
|
|
@ -541,16 +541,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
|||
expect(topic.title).to eq("An amazing title")
|
||||
expect(topic.posts.count).to eq(2)
|
||||
|
||||
# now let's try to make a reply with a tool call
|
||||
function_call = <<~XML
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>categories</tool_name>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
XML
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(name: "categories", parameters: {}, id: "tool_1")
|
||||
|
||||
fake_endpoint.fake_content = [function_call, "this is the response after the tool"]
|
||||
fake_endpoint.fake_content = [tool_call, "this is the response after the tool"]
|
||||
# this simplifies function calls
|
||||
fake_endpoint.chunk_count = 1
|
||||
|
||||
|
|
|
@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiBot::BotController do
|
|||
fab!(:user)
|
||||
fab!(:pm_topic) { Fabricate(:private_message_topic) }
|
||||
fab!(:pm_post) { Fabricate(:post, topic: pm_topic) }
|
||||
fab!(:pm_post2) { Fabricate(:post, topic: pm_topic) }
|
||||
fab!(:pm_post3) { Fabricate(:post, topic: pm_topic) }
|
||||
|
||||
before { sign_in(user) }
|
||||
|
||||
|
@ -22,15 +24,37 @@ RSpec.describe DiscourseAi::AiBot::BotController do
|
|||
user = pm_topic.topic_allowed_users.first.user
|
||||
sign_in(user)
|
||||
|
||||
AiApiAuditLog.create!(
|
||||
post_id: pm_post.id,
|
||||
provider_id: 1,
|
||||
topic_id: pm_topic.id,
|
||||
raw_request_payload: "request",
|
||||
raw_response_payload: "response",
|
||||
request_tokens: 1,
|
||||
response_tokens: 2,
|
||||
)
|
||||
log1 =
|
||||
AiApiAuditLog.create!(
|
||||
provider_id: 1,
|
||||
topic_id: pm_topic.id,
|
||||
raw_request_payload: "request",
|
||||
raw_response_payload: "response",
|
||||
request_tokens: 1,
|
||||
response_tokens: 2,
|
||||
)
|
||||
|
||||
log2 =
|
||||
AiApiAuditLog.create!(
|
||||
post_id: pm_post.id,
|
||||
provider_id: 1,
|
||||
topic_id: pm_topic.id,
|
||||
raw_request_payload: "request",
|
||||
raw_response_payload: "response",
|
||||
request_tokens: 1,
|
||||
response_tokens: 2,
|
||||
)
|
||||
|
||||
log3 =
|
||||
AiApiAuditLog.create!(
|
||||
post_id: pm_post2.id,
|
||||
provider_id: 1,
|
||||
topic_id: pm_topic.id,
|
||||
raw_request_payload: "request",
|
||||
raw_response_payload: "response",
|
||||
request_tokens: 1,
|
||||
response_tokens: 2,
|
||||
)
|
||||
|
||||
Group.refresh_automatic_groups!
|
||||
SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s
|
||||
|
@ -38,18 +62,26 @@ RSpec.describe DiscourseAi::AiBot::BotController do
|
|||
get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info"
|
||||
expect(response.status).to eq(200)
|
||||
|
||||
expect(response.parsed_body["id"]).to eq(log2.id)
|
||||
expect(response.parsed_body["next_log_id"]).to eq(log3.id)
|
||||
expect(response.parsed_body["prev_log_id"]).to eq(log1.id)
|
||||
expect(response.parsed_body["topic_id"]).to eq(pm_topic.id)
|
||||
|
||||
expect(response.parsed_body["request_tokens"]).to eq(1)
|
||||
expect(response.parsed_body["response_tokens"]).to eq(2)
|
||||
expect(response.parsed_body["raw_request_payload"]).to eq("request")
|
||||
expect(response.parsed_body["raw_response_payload"]).to eq("response")
|
||||
|
||||
post2 = Fabricate(:post, topic: pm_topic)
|
||||
|
||||
# return previous post if current has no debug info
|
||||
get "/discourse-ai/ai-bot/post/#{post2.id}/show-debug-info"
|
||||
get "/discourse-ai/ai-bot/post/#{pm_post3.id}/show-debug-info"
|
||||
expect(response.status).to eq(200)
|
||||
expect(response.parsed_body["request_tokens"]).to eq(1)
|
||||
expect(response.parsed_body["response_tokens"]).to eq(2)
|
||||
|
||||
# can return debug info by id as well
|
||||
get "/discourse-ai/ai-bot/show-debug-info/#{log1.id}"
|
||||
expect(response.status).to eq(200)
|
||||
expect(response.parsed_body["id"]).to eq(log1.id)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue