FEATURE: improve tool support (#904)

This re-implements tool support in DiscourseAi::Completions::Llm #generate

Previously tool support was always returned via XML and it would be the responsibility of the caller to parse XML

New implementation has the endpoints return ToolCall objects.

Additionally this simplifies the Llm endpoint interface and gives it more clarity. Llms must implement

decode, decode_chunk (for streaming)

It is the implementers responsibility to figure out how to decode chunks, base no longer implements. To make this easy we ship a flexible json decoder which is easy to wire up.

Also (new)

    Better debugging for PMs, we now have a next / previous button to see all the Llm messages associated with a PM
    Token accounting is fixed for vllm (we were not correctly counting tokens)
This commit is contained in:
Sam 2024-11-12 08:14:30 +11:00 committed by GitHub
parent 644141ff08
commit e817b7dc11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 1685 additions and 1293 deletions

View File

@ -6,6 +6,14 @@ module DiscourseAi
requires_plugin ::DiscourseAi::PLUGIN_NAME requires_plugin ::DiscourseAi::PLUGIN_NAME
requires_login requires_login
def show_debug_info_by_id
log = AiApiAuditLog.find(params[:id])
raise Discourse::NotFound if !log.topic
guardian.ensure_can_debug_ai_bot_conversation!(log.topic)
render json: AiApiAuditLogSerializer.new(log, root: false), status: 200
end
def show_debug_info def show_debug_info
post = Post.find(params[:post_id]) post = Post.find(params[:post_id])
guardian.ensure_can_debug_ai_bot_conversation!(post) guardian.ensure_can_debug_ai_bot_conversation!(post)

View File

@ -14,6 +14,14 @@ class AiApiAuditLog < ActiveRecord::Base
Ollama = 7 Ollama = 7
SambaNova = 8 SambaNova = 8
end end
def next_log_id
self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first
end
def prev_log_id
self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first
end
end end
# == Schema Information # == Schema Information

View File

@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer
:post_id, :post_id,
:feature_name, :feature_name,
:language_model, :language_model,
:created_at :created_at,
:prev_log_id,
:next_log_id
end end

View File

@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template";
import DButton from "discourse/components/d-button"; import DButton from "discourse/components/d-button";
import DModal from "discourse/components/d-modal"; import DModal from "discourse/components/d-modal";
import { ajax } from "discourse/lib/ajax"; import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error";
import { clipboardCopy, escapeExpression } from "discourse/lib/utilities"; import { clipboardCopy, escapeExpression } from "discourse/lib/utilities";
import i18n from "discourse-common/helpers/i18n"; import i18n from "discourse-common/helpers/i18n";
import discourseLater from "discourse-common/lib/later"; import discourseLater from "discourse-common/lib/later";
@ -63,6 +64,28 @@ export default class DebugAiModal extends Component {
this.copy(this.info.raw_response_payload); this.copy(this.info.raw_response_payload);
} }
async loadLog(logId) {
try {
await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then(
(result) => {
this.info = result;
}
);
} catch (e) {
popupAjaxError(e);
}
}
@action
prevLog() {
this.loadLog(this.info.prev_log_id);
}
@action
nextLog() {
this.loadLog(this.info.next_log_id);
}
copy(text) { copy(text) {
clipboardCopy(text); clipboardCopy(text);
this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared"); this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared");
@ -73,10 +96,12 @@ export default class DebugAiModal extends Component {
} }
loadApiRequestInfo() { loadApiRequestInfo() {
ajax( ajax(`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`)
`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json` .then((result) => {
).then((result) => {
this.info = result; this.info = result;
})
.catch((e) => {
popupAjaxError(e);
}); });
} }
@ -147,6 +172,22 @@ export default class DebugAiModal extends Component {
@action={{this.copyResponse}} @action={{this.copyResponse}}
@label="discourse_ai.ai_bot.debug_ai_modal.copy_response" @label="discourse_ai.ai_bot.debug_ai_modal.copy_response"
/> />
{{#if this.info.prev_log_id}}
<DButton
class="btn"
@icon="angles-left"
@action={{this.prevLog}}
@label="discourse_ai.ai_bot.debug_ai_modal.previous_log"
/>
{{/if}}
{{#if this.info.next_log_id}}
<DButton
class="btn"
@icon="angles-right"
@action={{this.nextLog}}
@label="discourse_ai.ai_bot.debug_ai_modal.next_log"
/>
{{/if}}
<span class="ai-debut-modal__just-copied">{{this.justCopiedText}}</span> <span class="ai-debut-modal__just-copied">{{this.justCopiedText}}</span>
</:footer> </:footer>
</DModal> </DModal>

View File

@ -415,6 +415,8 @@ en:
response_tokens: "Response tokens:" response_tokens: "Response tokens:"
request: "Request" request: "Request"
response: "Response" response: "Response"
next_log: "Next"
previous_log: "Previous"
share_full_topic_modal: share_full_topic_modal:
title: "Share Conversation Publicly" title: "Share Conversation Publicly"

View File

@ -22,6 +22,7 @@ DiscourseAi::Engine.routes.draw do
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
get "bot-username" => "bot#show_bot_username" get "bot-username" => "bot#show_bot_username"
get "post/:post_id/show-debug-info" => "bot#show_debug_info" get "post/:post_id/show-debug-info" => "bot#show_debug_info"
get "show-debug-info/:id" => "bot#show_debug_info_by_id"
post "post/:post_id/stop-streaming" => "bot#stop_streaming_response" post "post/:post_id/stop-streaming" => "bot#stop_streaming_response"
end end

View File

@ -100,6 +100,7 @@ module DiscourseAi
llm_kwargs[:top_p] = persona.top_p if persona.top_p llm_kwargs[:top_p] = persona.top_p if persona.top_p
needs_newlines = false needs_newlines = false
tools_ran = 0
while total_completions <= MAX_COMPLETIONS && ongoing_chain while total_completions <= MAX_COMPLETIONS && ongoing_chain
tool_found = false tool_found = false
@ -107,9 +108,10 @@ module DiscourseAi
result = result =
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel| llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context) tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context)
tool = nil if tools_ran >= MAX_TOOLS
if (tools.present?) if tool.present?
tool_found = true tool_found = true
# a bit hacky, but extra newlines do no harm # a bit hacky, but extra newlines do no harm
if needs_newlines if needs_newlines
@ -117,15 +119,18 @@ module DiscourseAi
needs_newlines = false needs_newlines = false
end end
tools[0..MAX_TOOLS].each do |tool|
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
tools_ran += 1
ongoing_chain &&= tool.chain_next_response? ongoing_chain &&= tool.chain_next_response?
end
else else
needs_newlines = true needs_newlines = true
if partial.is_a?(DiscourseAi::Completions::ToolCall)
Rails.logger.warn("DiscourseAi: Tool not found: #{partial.name}")
else
update_blk.call(partial, cancel) update_blk.call(partial, cancel)
end end
end end
end
if !tool_found if !tool_found
ongoing_chain = false ongoing_chain = false

View File

@ -199,23 +199,16 @@ module DiscourseAi
prompt prompt
end end
def find_tools(partial, bot_user:, llm:, context:) def find_tool(partial, bot_user:, llm:, context:)
return [] if !partial.include?("</invoke>") return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall)
tool_instance(partial, bot_user: bot_user, llm: llm, context: context)
parsed_function = Nokogiri::HTML5.fragment(partial)
parsed_function
.css("invoke")
.map do |fragment|
tool_instance(fragment, bot_user: bot_user, llm: llm, context: context)
end
.compact
end end
protected protected
def tool_instance(parsed_function, bot_user:, llm:, context:) def tool_instance(tool_call, bot_user:, llm:, context:)
function_id = parsed_function.at("tool_id")&.text function_id = tool_call.id
function_name = parsed_function.at("tool_name")&.text function_name = tool_call.name
return nil if function_name.nil? return nil if function_name.nil?
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name } tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
@ -224,7 +217,7 @@ module DiscourseAi
arguments = {} arguments = {}
tool_klass.signature[:parameters].to_a.each do |param| tool_klass.signature[:parameters].to_a.each do |param|
name = param[:name] name = param[:name]
value = parsed_function.at(name)&.text value = tool_call.parameters[name.to_sym]
if param[:type] == "array" && value if param[:type] == "array" && value
value = value =

View File

@ -13,6 +13,11 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
def append(json) def append(json)
@raw_json << json @raw_json << json
end end
def to_tool_call
parameters = JSON.parse(raw_json, symbolize_names: true)
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
end
end end
attr_reader :tool_calls, :input_tokens, :output_tokens attr_reader :tool_calls, :input_tokens, :output_tokens
@ -20,52 +25,32 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
def initialize(streaming_mode:) def initialize(streaming_mode:)
@streaming_mode = streaming_mode @streaming_mode = streaming_mode
@tool_calls = [] @tool_calls = []
@current_tool_call = nil
end end
def to_xml_tool_calls(function_buffer) def to_tool_calls
return function_buffer if @tool_calls.blank? @tool_calls.map { |tool_call| tool_call.to_tool_call }
function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
<function_calls>
</function_calls>
TEXT
@tool_calls.each do |tool_call|
node =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
),
)
params = JSON.parse(tool_call.raw_json, symbolize_names: true)
xml =
params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}</#{name}>" }.join("\n")
node.at("tool_name").content = tool_call.name
node.at("tool_id").content = tool_call.id
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
end end
function_buffer def process_streamed_message(parsed)
end result = nil
def process_message(payload)
result = ""
parsed = JSON.parse(payload, symbolize_names: true)
if @streaming_mode
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use" if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
tool_name = parsed.dig(:content_block, :name) tool_name = parsed.dig(:content_block, :name)
tool_id = parsed.dig(:content_block, :id) tool_id = parsed.dig(:content_block, :id)
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name result = @current_tool_call.to_tool_call if @current_tool_call
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
if @tool_calls.present? if @current_tool_call
result = parsed.dig(:delta, :partial_json).to_s tool_delta = parsed.dig(:delta, :partial_json).to_s
@tool_calls.last.append(result) @current_tool_call.append(tool_delta)
else else
result = parsed.dig(:delta, :text).to_s result = parsed.dig(:delta, :text).to_s
end end
elsif parsed[:type] == "content_block_stop"
if @current_tool_call
result = @current_tool_call.to_tool_call
@current_tool_call = nil
end
elsif parsed[:type] == "message_start" elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens) @input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta" elsif parsed[:type] == "message_delta"
@ -78,21 +63,30 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
end end
end end
else result
end
def process_message(payload)
result = ""
parsed = payload
parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String)
content = parsed.dig(:content) content = parsed.dig(:content)
if content.is_a?(Array) if content.is_a?(Array)
tool_call = content.find { |c| c[:type] == "tool_use" } result =
if tool_call content.map do |data|
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id]) if data[:type] == "tool_use"
@tool_calls.last.append(tool_call[:input].to_json) call = AnthropicToolCall.new(data[:name], data[:id])
call.append(data[:input].to_json)
call.to_tool_call
else else
result = parsed.dig(:content, 0, :text).to_s data[:text]
end
end end
end end
@input_tokens = parsed.dig(:usage, :input_tokens) @input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens) @output_tokens = parsed.dig(:usage, :output_tokens)
end
result result
end end

View File

@ -63,8 +63,23 @@ module DiscourseAi
def user_msg(msg) def user_msg(msg)
user_message = { role: "user", content: msg[:content] } user_message = { role: "user", content: msg[:content] }
# TODO: Add support for user messages with empbeded user ids encoded_uploads = prompt.encoded_uploads(msg)
# TODO: Add support for user messages with attachments if encoded_uploads.present?
images =
encoded_uploads
.map do |upload|
if upload[:mime_type].start_with?("image/")
upload[:base64]
else
nil
end
end
.compact
user_message[:images] = images if images.present?
end
# TODO: Add support for user messages with embedded user ids
user_message user_message
end end

View File

@ -63,6 +63,10 @@ module DiscourseAi
URI(llm_model.url) URI(llm_model.url)
end end
def xml_tools_enabled?
!@native_tool_support
end
def prepare_payload(prompt, model_params, dialect) def prepare_payload(prompt, model_params, dialect)
@native_tool_support = dialect.native_tool_support? @native_tool_support = dialect.native_tool_support?
@ -90,35 +94,34 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end
def decode_chunk(partial_data)
@decoder ||= JsonStreamDecoder.new
(@decoder << partial_data)
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
.compact
end
def decode(response_data)
processor.process_message(response_data)
end
def processor def processor
@processor ||= @processor ||=
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
end end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
processor.to_xml_tool_calls(function_buffer) if !partial
end
def extract_completion_from(response_raw)
processor.process_message(response_raw)
end
def has_tool?(_response_data) def has_tool?(_response_data)
processor.tool_calls.present? processor.tool_calls.present?
end end
def tool_calls
processor.to_tool_calls
end
def final_log_update(log) def final_log_update(log)
log.request_tokens = processor.input_tokens if processor.input_tokens log.request_tokens = processor.input_tokens if processor.input_tokens
log.response_tokens = processor.output_tokens if processor.output_tokens log.response_tokens = processor.output_tokens if processor.output_tokens
end end
def native_tool_support?
@native_tool_support
end
def partials_from(decoded_chunk)
decoded_chunk.split("\n").map { |line| line.split("data: ", 2)[1] }.compact
end
end end
end end
end end

View File

@ -117,7 +117,24 @@ module DiscourseAi
end end
end end
def decode(chunk) def decode_chunk(partial_data)
bedrock_decode(partial_data)
.map do |decoded_partial_data|
@raw_response ||= +""
@raw_response << decoded_partial_data
@raw_response << "\n"
parsed_json = JSON.parse(decoded_partial_data, symbolize_names: true)
processor.process_streamed_message(parsed_json)
end
.compact
end
def decode(response_data)
processor.process_message(response_data)
end
def bedrock_decode(chunk)
@decoder ||= Aws::EventStream::Decoder.new @decoder ||= Aws::EventStream::Decoder.new
decoded, _done = @decoder.decode_chunk(chunk) decoded, _done = @decoder.decode_chunk(chunk)
@ -147,12 +164,13 @@ module DiscourseAi
Aws::EventStream::Errors::MessageChecksumError, Aws::EventStream::Errors::MessageChecksumError,
Aws::EventStream::Errors::PreludeChecksumError => e Aws::EventStream::Errors::PreludeChecksumError => e
Rails.logger.error("#{self.class.name}: #{e.message}") Rails.logger.error("#{self.class.name}: #{e.message}")
nil []
end end
def final_log_update(log) def final_log_update(log)
log.request_tokens = processor.input_tokens if processor.input_tokens log.request_tokens = processor.input_tokens if processor.input_tokens
log.response_tokens = processor.output_tokens if processor.output_tokens log.response_tokens = processor.output_tokens if processor.output_tokens
log.raw_response_payload = @raw_response
end end
def processor def processor
@ -160,30 +178,8 @@ module DiscourseAi
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
end end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil) def xml_tools_enabled?
processor.to_xml_tool_calls(function_buffer) if !partial !@native_tool_support
end
def extract_completion_from(response_raw)
processor.process_message(response_raw)
end
def has_tool?(_response_data)
processor.tool_calls.present?
end
def partials_from(decoded_chunks)
decoded_chunks
end
def native_tool_support?
@native_tool_support
end
def chunk_to_string(chunk)
joined = +chunk.join("\n")
joined << "\n" if joined.length > 0
joined
end end
end end
end end

View File

@ -40,10 +40,6 @@ module DiscourseAi
@llm_model = llm_model @llm_model = llm_model
end end
def native_tool_support?
false
end
def use_ssl? def use_ssl?
if model_uri&.scheme.present? if model_uri&.scheme.present?
model_uri.scheme == "https" model_uri.scheme == "https"
@ -64,22 +60,10 @@ module DiscourseAi
feature_context: nil, feature_context: nil,
&blk &blk
) )
allow_tools = dialect.prompt.has_tools?
model_params = normalize_model_params(model_params) model_params = normalize_model_params(model_params)
orig_blk = blk orig_blk = blk
@streaming_mode = block_given? @streaming_mode = block_given?
to_strip = xml_tags_to_strip(dialect)
@xml_stripper =
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
if @streaming_mode && @xml_stripper
blk =
lambda do |partial, cancel|
partial = @xml_stripper << partial
orig_blk.call(partial, cancel) if partial
end
end
prompt = dialect.translate prompt = dialect.translate
@ -108,177 +92,91 @@ module DiscourseAi
raise CompletionFailed, response.body raise CompletionFailed, response.body
end end
xml_tool_processor = XmlToolProcessor.new if xml_tools_enabled? &&
dialect.prompt.has_tools?
to_strip = xml_tags_to_strip(dialect)
xml_stripper =
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
if @streaming_mode && xml_stripper
blk =
lambda do |partial, cancel|
partial = xml_stripper << partial if partial.is_a?(String)
orig_blk.call(partial, cancel) if partial
end
end
log = log =
AiApiAuditLog.new( start_log(
provider_id: provider_id, provider_id: provider_id,
user_id: user&.id, request_body: request_body,
raw_request_payload: request_body, dialect: dialect,
request_tokens: prompt_size(prompt), prompt: prompt,
topic_id: dialect.prompt.topic_id, user: user,
post_id: dialect.prompt.post_id,
feature_name: feature_name, feature_name: feature_name,
language_model: llm_model.name, feature_context: feature_context,
feature_context: feature_context.present? ? feature_context.as_json : nil,
) )
if !@streaming_mode if !@streaming_mode
response_raw = response.read_body return(
response_data = extract_completion_from(response_raw) non_streaming_response(
partials_raw = response_data.to_s response: response,
xml_tool_processor: xml_tool_processor,
if native_tool_support? xml_stripper: xml_stripper,
if allow_tools && has_tool?(response_data) partials_raw: partials_raw,
function_buffer = build_buffer # Nokogiri document response_raw: response_raw,
function_buffer = )
add_to_function_buffer(function_buffer, payload: response_data) )
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
response_data = +function_buffer.at("function_calls").to_s
response_data << "\n"
end end
else
if allow_tools
response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
response_data = function_calls if function_calls.present?
end
end
return response_data
end
has_tool = false
begin begin
cancelled = false cancelled = false
cancel = -> { cancelled = true } cancel = -> { cancelled = true }
wrapped_blk = ->(partial, inner_cancel) do
response_data << partial
blk.call(partial, inner_cancel)
end
normalizer = FunctionCallNormalizer.new(wrapped_blk, cancel)
leftover = ""
function_buffer = build_buffer # Nokogiri document
prev_processed_partials = 0
response.read_body do |chunk|
if cancelled if cancelled
http.finish http.finish
break break
end end
decoded_chunk = decode(chunk) response.read_body do |chunk|
if decoded_chunk.nil? response_raw << chunk
raise CompletionFailed, "#{self.class.name}: Failed to decode LLM completion" decode_chunk(chunk).each do |partial|
end
response_raw << chunk_to_string(decoded_chunk)
if decoded_chunk.is_a?(String)
redo_chunk = leftover + decoded_chunk
else
# custom implementation for endpoint
# no implicit leftover support
redo_chunk = decoded_chunk
end
raw_partials = partials_from(redo_chunk)
raw_partials =
raw_partials[prev_processed_partials..-1] if prev_processed_partials > 0
if raw_partials.blank? || (raw_partials.size == 1 && raw_partials.first.blank?)
leftover = redo_chunk
next
end
json_error = false
raw_partials.each do |raw_partial|
json_error = false
prev_processed_partials += 1
next if cancelled
next if raw_partial.blank?
begin
partial = extract_completion_from(raw_partial)
next if partial.nil?
# empty vs blank... we still accept " "
next if response_data.empty? && partial.empty?
partials_raw << partial.to_s partials_raw << partial.to_s
response_data << partial if partial.is_a?(String)
if native_tool_support? partials = [partial]
# Stop streaming the response as soon as you find a tool. if xml_tool_processor && partial.is_a?(String)
# We'll buffer and yield it later. partials = (xml_tool_processor << partial)
has_tool = true if allow_tools && has_tool?(partials_raw) if xml_tool_processor.should_cancel?
cancel.call
if has_tool break
function_buffer =
add_to_function_buffer(function_buffer, partial: partial)
else
response_data << partial
blk.call(partial, cancel) if partial
end
else
if allow_tools
normalizer << partial
else
response_data << partial
blk.call(partial, cancel) if partial
end end
end end
rescue JSON::ParserError partials.each { |inner_partial| blk.call(inner_partial, cancel) }
leftover = redo_chunk
json_error = true
end end
end end
if json_error
prev_processed_partials -= 1
else
leftover = ""
end
prev_processed_partials = 0 if leftover.blank?
end
rescue IOError, StandardError rescue IOError, StandardError
raise if !cancelled raise if !cancelled
end end
if xml_stripper
has_tool ||= has_tool?(partials_raw) stripped = xml_stripper.finish
# Once we have the full response, try to return the tool as a XML doc. if stripped.present?
if has_tool && native_tool_support? response_data << stripped
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw) result = []
result = (xml_tool_processor << stripped) if xml_tool_processor
if function_buffer.at("tool_name").text.present? result.each { |partial| blk.call(partial, cancel) }
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
invocation = +function_buffer.at("function_calls").to_s
invocation << "\n"
response_data << invocation
blk.call(invocation, cancel)
end end
end end
if xml_tool_processor
if !native_tool_support? && function_calls = normalizer.function_calls xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) }
response_data << function_calls
blk.call(function_calls, cancel)
end end
decode_chunk_finish.each { |partial| blk.call(partial, cancel) }
if @xml_stripper
leftover = @xml_stripper.finish
orig_blk.call(leftover, cancel) if leftover.present?
end
return response_data return response_data
ensure ensure
if log if log
log.raw_response_payload = response_raw log.raw_response_payload = response_raw
log.response_tokens = tokenizer.size(partials_raw)
final_log_update(log) final_log_update(log)
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
log.save! log.save!
if Rails.env.development? if Rails.env.development?
@ -330,15 +228,15 @@ module DiscourseAi
raise NotImplementedError raise NotImplementedError
end end
def extract_completion_from(_response_raw) def decode(_response_raw)
raise NotImplementedError raise NotImplementedError
end end
def decode(chunk) def decode_chunk_finish
chunk []
end end
def partials_from(_decoded_chunk) def decode_chunk(_chunk)
raise NotImplementedError raise NotImplementedError
end end
@ -346,49 +244,73 @@ module DiscourseAi
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end end
def build_buffer def xml_tools_enabled?
Nokogiri::HTML5.fragment(<<~TEXT) raise NotImplementedError
<function_calls>
#{noop_function_call_text}
</function_calls>
TEXT
end end
def self.noop_function_call_text private
(<<~TEXT).strip
<invoke> def start_log(
<tool_name></tool_name> provider_id:,
<parameters> request_body:,
</parameters> dialect:,
<tool_id></tool_id> prompt:,
</invoke> user:,
TEXT feature_name:,
feature_context:
)
AiApiAuditLog.new(
provider_id: provider_id,
user_id: user&.id,
raw_request_payload: request_body,
request_tokens: prompt_size(prompt),
topic_id: dialect.prompt.topic_id,
post_id: dialect.prompt.post_id,
feature_name: feature_name,
language_model: llm_model.name,
feature_context: feature_context.present? ? feature_context.as_json : nil,
)
end end
def noop_function_call_text def non_streaming_response(
self.class.noop_function_call_text response:,
xml_tool_processor:,
xml_stripper:,
partials_raw:,
response_raw:
)
response_raw << response.read_body
response_data = decode(response_raw)
response_data.each { |partial| partials_raw << partial.to_s }
if xml_tool_processor
response_data.each do |partial|
processed = (xml_tool_processor << partial)
processed << xml_tool_processor.finish
response_data = []
processed.flatten.compact.each { |inner| response_data << inner }
end
end end
def has_tool?(response) if xml_stripper
response.include?("<function_calls>") response_data.map! do |partial|
end stripped = (xml_stripper << partial) if partial.is_a?(String)
if stripped.present?
def chunk_to_string(chunk) stripped
if chunk.is_a?(String)
chunk
else else
chunk.to_s partial
end end
end end
response_data << xml_stripper.finish
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil) response_data.reject!(&:blank?)
if payload&.include?("</invoke>")
matches = payload.match(%r{<function_calls>.*</invoke>}m)
function_buffer =
Nokogiri::HTML5.fragment(matches[0] + "\n</function_calls>") if matches
end
function_buffer # this is to keep stuff backwards compatible
response_data = response_data.first if response_data.length == 1
response_data
end end
end end
end end

View File

@ -45,6 +45,8 @@ module DiscourseAi
cancel_fn = lambda { cancelled = true } cancel_fn = lambda { cancelled = true }
# We buffer and return tool invocations in one go. # We buffer and return tool invocations in one go.
as_array = response.is_a?(Array) ? response : [response]
as_array.each do |response|
if is_tool?(response) if is_tool?(response)
yield(response, cancel_fn) yield(response, cancel_fn)
else else
@ -53,11 +55,13 @@ module DiscourseAi
yield(char, cancel_fn) yield(char, cancel_fn)
end end
end end
else
response
end end
end end
response = response.first if response.is_a?(Array) && response.length == 1
response
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer DiscourseAi::Tokenizer::OpenAiTokenizer
end end
@ -65,7 +69,7 @@ module DiscourseAi
private private
def is_tool?(response) def is_tool?(response)
Nokogiri::HTML5.fragment(response).at("function_calls").present? response.is_a?(DiscourseAi::Completions::ToolCall)
end end
end end
end end

View File

@ -49,6 +49,47 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end
def decode(response_raw)
rval = []
parsed = JSON.parse(response_raw, symbolize_names: true)
text = parsed[:text]
rval << parsed[:text] if !text.to_s.empty? # also allow " "
# TODO tool calls
update_usage(parsed)
rval
end
def decode_chunk(chunk)
@tool_idx ||= -1
@json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/)
(@json_decoder << chunk)
.map do |parsed|
update_usage(parsed)
rval = []
rval << parsed[:text] if !parsed[:text].to_s.empty?
if tool_calls = parsed[:tool_calls]
tool_calls&.each do |tool_call|
@tool_idx += 1
tool_name = tool_call[:name]
tool_params = tool_call[:parameters]
tool_id = "tool_#{@tool_idx}"
rval << ToolCall.new(id: tool_id, name: tool_name, parameters: tool_params)
end
end
rval
end
.flatten
.compact
end
def extract_completion_from(response_raw) def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true) parsed = JSON.parse(response_raw, symbolize_names: true)
@ -77,36 +118,8 @@ module DiscourseAi
end end
end end
def has_tool?(_ignored) def xml_tools_enabled?
@has_tool false
end
def native_tool_support?
true
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
if partial
tools = JSON.parse(partial)
tools.each do |tool|
name = tool["name"]
parameters = tool["parameters"]
xml_params = parameters.map { |k, v| "<#{k}>#{v}</#{k}>\n" }.join
current_function = function_buffer.at("invoke")
if current_function.nil? || current_function.at("tool_name").content.present?
current_function =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
)
end
current_function.at("tool_name").content = name == "search_local" ? "search" : name
current_function.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(xml_params)
end
end
function_buffer
end end
def final_log_update(log) def final_log_update(log)
@ -114,10 +127,6 @@ module DiscourseAi
log.response_tokens = @output_tokens if @output_tokens log.response_tokens = @output_tokens if @output_tokens
end end
def partials_from(decoded_chunk)
decoded_chunk.split("\n").compact
end
def extract_prompt_for_tokenizer(prompt) def extract_prompt_for_tokenizer(prompt)
text = +"" text = +""
if prompt[:chat_history] if prompt[:chat_history]
@ -131,6 +140,18 @@ module DiscourseAi
text text
end end
private
def update_usage(parsed)
input_tokens = parsed.dig(:meta, :billed_units, :input_tokens)
input_tokens ||= parsed.dig(:response, :meta, :billed_units, :input_tokens)
@input_tokens = input_tokens if input_tokens.present?
output_tokens = parsed.dig(:meta, :billed_units, :output_tokens)
output_tokens ||= parsed.dig(:response, :meta, :billed_units, :output_tokens)
@output_tokens = output_tokens if output_tokens.present?
end
end end
end end
end end

View File

@ -133,6 +133,9 @@ module DiscourseAi
content = content.shift if content.is_a?(Array) content = content.shift if content.is_a?(Array)
if block_given? if block_given?
if content.is_a?(DiscourseAi::Completions::ToolCall)
yield(content, -> {})
else
split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
indexes = [0, *split_indices, content.length] indexes = [0, *split_indices, content.length]
@ -159,6 +162,7 @@ module DiscourseAi
yield(chunk, cancel_proc) yield(chunk, cancel_proc)
end end
end end
end
content content
end end

View File

@ -103,15 +103,7 @@ module DiscourseAi
end end
end end
def partials_from(decoded_chunk) class GeminiStreamingDecoder
decoded_chunk
end
def chunk_to_string(chunk)
chunk.to_s
end
class Decoder
def initialize def initialize
@buffer = +"" @buffer = +""
end end
@ -151,43 +143,87 @@ module DiscourseAi
end end
def decode(chunk) def decode(chunk)
@decoder ||= Decoder.new json = JSON.parse(chunk, symbolize_names: true)
@decoder.decode(chunk) idx = -1
json
.dig(:candidates, 0, :content, :parts)
.map do |part|
if part[:functionCall]
idx += 1
ToolCall.new(
id: "tool_#{idx}",
name: part[:functionCall][:name],
parameters: part[:functionCall][:args],
)
else
part = part[:text]
if part != ""
part
else
nil
end
end
end
end
def decode_chunk(chunk)
@tool_index ||= -1
streaming_decoder
.decode(chunk)
.map do |parsed|
update_usage(parsed)
parsed
.dig(:candidates, 0, :content, :parts)
.map do |part|
if part[:text]
part = part[:text]
if part != ""
part
else
nil
end
elsif part[:functionCall]
@tool_index += 1
ToolCall.new(
id: "tool_#{@tool_index}",
name: part[:functionCall][:name],
parameters: part[:functionCall][:args],
)
end
end
end
.flatten
.compact
end
def update_usage(parsed)
usage = parsed.dig(:usageMetadata)
if usage
if prompt_token_count = usage[:promptTokenCount]
@prompt_token_count = prompt_token_count
end
if candidate_token_count = usage[:candidatesTokenCount]
@candidate_token_count = candidate_token_count
end
end
end
def final_log_update(log)
log.request_tokens = @prompt_token_count if @prompt_token_count
log.response_tokens = @candidate_token_count if @candidate_token_count
end
def streaming_decoder
@decoder ||= GeminiStreamingDecoder.new
end end
def extract_prompt_for_tokenizer(prompt) def extract_prompt_for_tokenizer(prompt)
prompt.to_s prompt.to_s
end end
def has_tool?(_response_data) def xml_tools_enabled?
@has_function_call false
end
def native_tool_support?
true
end
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
if @streaming_mode
return function_buffer if !partial
else
partial = payload
end
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
if partial[:args]
argument_fragments =
partial[:args].reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}</#{arg_name}>"
end
argument_fragments << "\n"
function_buffer.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
end
function_buffer
end end
end end
end end

View File

@ -59,22 +59,30 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end
def extract_completion_from(response_raw) def xml_tools_enabled?
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) true
# half a line sent here
return if !parsed
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
response_h.dig(:content)
end end
def partials_from(decoded_chunk) def decode(response_raw)
decoded_chunk parsed = JSON.parse(response_raw, symbolize_names: true)
.split("\n") text = parsed.dig(:choices, 0, :message, :content)
.map do |line| if text.to_s.empty?
data = line.split("data:", 2)[1] [""]
data&.squish == "[DONE]" ? nil : data else
[text]
end
end
def decode_chunk(chunk)
@json_decoder ||= JsonStreamDecoder.new
(@json_decoder << chunk)
.map do |parsed|
text = parsed.dig(:choices, 0, :delta, :content)
if text.to_s.empty?
nil
else
text
end
end end
.compact .compact
end end

View File

@ -37,12 +37,8 @@ module DiscourseAi
URI(llm_model.url) URI(llm_model.url)
end end
def native_tool_support? def xml_tools_enabled?
@native_tool_support !@native_tool_support
end
def has_tool?(_response_data)
@has_function_call
end end
def prepare_payload(prompt, model_params, dialect) def prepare_payload(prompt, model_params, dialect)
@ -67,74 +63,30 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end
def partials_from(decoded_chunk) def decode_chunk(chunk)
decoded_chunk.split("\n").compact # Native tool calls are not working right in streaming mode, use XML
@json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/)
(@json_decoder << chunk).map { |parsed| parsed.dig(:message, :content) }.compact
end end
def extract_completion_from(response_raw) def decode(response_raw)
rval = []
parsed = JSON.parse(response_raw, symbolize_names: true) parsed = JSON.parse(response_raw, symbolize_names: true)
return if !parsed content = parsed.dig(:message, :content)
rval << content if !content.to_s.empty?
response_h = parsed.dig(:message) idx = -1
parsed
@has_function_call ||= response_h.dig(:tool_calls).present? .dig(:message, :tool_calls)
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content) &.each do |tool_call|
idx += 1
id = "tool_#{idx}"
name = tool_call.dig(:function, :name)
args = tool_call.dig(:function, :arguments)
rval << ToolCall.new(id: id, name: name, parameters: args)
end end
def add_to_function_buffer(function_buffer, payload: nil, partial: nil) rval
@args_buffer ||= +""
if @streaming_mode
return function_buffer if !partial
else
partial = payload
end
f_name = partial.dig(:function, :name)
@current_function ||= function_buffer.at("invoke")
if f_name
current_name = function_buffer.at("tool_name").content
if current_name.blank?
# first call
else
# we have a previous function, so we need to add a noop
@args_buffer = +""
@current_function =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
)
end
end
@current_function.at("tool_name").content = f_name if f_name
@current_function.at("tool_id").content = partial[:id] if partial[:id]
args = partial.dig(:function, :arguments)
# allow for SPACE within arguments
if args && args != ""
@args_buffer << args.to_json
begin
json_args = JSON.parse(@args_buffer, symbolize_names: true)
argument_fragments =
json_args.reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
end
argument_fragments << "\n"
@current_function.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
rescue JSON::ParserError
return function_buffer
end
end
function_buffer
end end
end end
end end

View File

@ -93,98 +93,34 @@ module DiscourseAi
end end
def final_log_update(log) def final_log_update(log)
log.request_tokens = @prompt_tokens if @prompt_tokens log.request_tokens = processor.prompt_tokens if processor.prompt_tokens
log.response_tokens = @completion_tokens if @completion_tokens log.response_tokens = processor.completion_tokens if processor.completion_tokens
end end
def extract_completion_from(response_raw) def decode(response_raw)
json = JSON.parse(response_raw, symbolize_names: true) processor.process_message(JSON.parse(response_raw, symbolize_names: true))
if @streaming_mode
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
@completion_tokens ||= json.dig(:usage, :completion_tokens)
end end
parsed = json.dig(:choices, 0) def decode_chunk(chunk)
return if !parsed @decoder ||= JsonStreamDecoder.new
(@decoder << chunk)
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) .map { |parsed_json| processor.process_streamed_message(parsed_json) }
@has_function_call ||= response_h.dig(:tool_calls).present? .flatten
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data: ", 2)[1]
data == "[DONE]" ? nil : data
end
.compact .compact
end end
def has_tool?(_response_data) def decode_chunk_finish
@has_function_call @processor.finish
end end
def native_tool_support? def xml_tools_enabled?
true false
end end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil) private
if @streaming_mode
return function_buffer if !partial
else
partial = payload
end
@args_buffer ||= +"" def processor
@processor ||= OpenAiMessageProcessor.new
f_name = partial.dig(:function, :name)
@current_function ||= function_buffer.at("invoke")
if f_name
current_name = function_buffer.at("tool_name").content
if current_name.blank?
# first call
else
# we have a previous function, so we need to add a noop
@args_buffer = +""
@current_function =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
)
end
end
@current_function.at("tool_name").content = f_name if f_name
@current_function.at("tool_id").content = partial[:id] if partial[:id]
args = partial.dig(:function, :arguments)
# allow for SPACE within arguments
if args && args != ""
@args_buffer << args
begin
json_args = JSON.parse(@args_buffer, symbolize_names: true)
argument_fragments =
json_args.reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}</#{arg_name}>"
end
argument_fragments << "\n"
@current_function.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
rescue JSON::ParserError
return function_buffer
end
end
function_buffer
end end
end end
end end

View File

@ -55,27 +55,31 @@ module DiscourseAi
log.response_tokens = @completion_tokens if @completion_tokens log.response_tokens = @completion_tokens if @completion_tokens
end end
def extract_completion_from(response_raw) def xml_tools_enabled?
json = JSON.parse(response_raw, symbolize_names: true) true
end
def decode(response_raw)
json = JSON.parse(response_raw, symbolize_names: true)
[json.dig(:choices, 0, :message, :content)]
end
def decode_chunk(chunk)
@json_decoder ||= JsonStreamDecoder.new
(@json_decoder << chunk)
.map do |json|
text = json.dig(:choices, 0, :delta, :content)
if @streaming_mode
@prompt_tokens ||= json.dig(:usage, :prompt_tokens) @prompt_tokens ||= json.dig(:usage, :prompt_tokens)
@completion_tokens ||= json.dig(:usage, :completion_tokens) @completion_tokens ||= json.dig(:usage, :completion_tokens)
end
parsed = json.dig(:choices, 0) if !text.to_s.empty?
return if !parsed text
else
@streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content) nil
end end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data: ", 2)[1]
data == "[DONE]" ? nil : data
end end
.flatten
.compact .compact
end end
end end

View File

@ -42,7 +42,10 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect) def prepare_payload(prompt, model_params, dialect)
payload = default_options.merge(model_params).merge(messages: prompt) payload = default_options.merge(model_params).merge(messages: prompt)
if @streaming_mode
payload[:stream] = true if @streaming_mode payload[:stream] = true if @streaming_mode
payload[:stream_options] = { include_usage: true }
end
payload payload
end end
@ -56,25 +59,43 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end
def partials_from(decoded_chunk) def xml_tools_enabled?
decoded_chunk true
.split("\n") end
.map do |line|
data = line.split("data: ", 2)[1] def final_log_update(log)
data == "[DONE]" ? nil : data log.request_tokens = @prompt_tokens if @prompt_tokens
log.response_tokens = @completion_tokens if @completion_tokens
end
def decode(response_raw)
json = JSON.parse(response_raw, symbolize_names: true)
@prompt_tokens = json.dig(:usage, :prompt_tokens)
@completion_tokens = json.dig(:usage, :completion_tokens)
[json.dig(:choices, 0, :message, :content)]
end
def decode_chunk(chunk)
@json_decoder ||= JsonStreamDecoder.new
(@json_decoder << chunk)
.map do |parsed|
# vLLM keeps sending usage over and over again
prompt_tokens = parsed.dig(:usage, :prompt_tokens)
completion_tokens = parsed.dig(:usage, :completion_tokens)
@prompt_tokens = prompt_tokens if prompt_tokens
@completion_tokens = completion_tokens if completion_tokens
text = parsed.dig(:choices, 0, :delta, :content)
if text.to_s.empty?
nil
else
text
end
end end
.compact .compact
end end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
response_h.dig(:content)
end
end end
end end
end end

View File

@ -1,113 +0,0 @@
# frozen_string_literal: true
class DiscourseAi::Completions::FunctionCallNormalizer
attr_reader :done
# blk is the block to call with filtered data
def initialize(blk, cancel)
@blk = blk
@cancel = cancel
@done = false
@in_tool = false
@buffer = +""
@function_buffer = +""
end
def self.normalize(data)
text = +""
cancel = -> {}
blk = ->(partial, _) { text << partial }
normalizer = self.new(blk, cancel)
normalizer << data
[text, normalizer.function_calls]
end
def function_calls
return nil if @function_buffer.blank?
xml = Nokogiri::HTML5.fragment(@function_buffer)
self.class.normalize_function_ids!(xml)
last_invoke = xml.at("invoke:last")
if last_invoke
last_invoke.next_sibling.remove while last_invoke.next_sibling
xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
end
xml.at("function_calls").to_s.dup.force_encoding("UTF-8")
end
def <<(text)
@buffer << text
if !@in_tool
# double check if we are clearly in a tool
search_length = text.length + 20
search_string = @buffer[-search_length..-1] || @buffer
index = search_string.rindex("<function_calls>")
@in_tool = !!index
if @in_tool
@function_buffer = @buffer[index..-1]
text_index = text.rindex("<function_calls>")
@blk.call(text[0..text_index - 1].strip, @cancel) if text_index && text_index > 0
end
else
@function_buffer << text
end
if !@in_tool
if maybe_has_tool?(@buffer)
split_index = text.rindex("<").to_i - 1
if split_index >= 0
@function_buffer = text[split_index + 1..-1] || ""
text = text[0..split_index] || ""
else
@function_buffer << text
text = ""
end
else
if @function_buffer.length > 0
@blk.call(@function_buffer, @cancel)
@function_buffer = +""
end
end
@blk.call(text, @cancel) if text.length > 0
else
if text.include?("</function_calls>")
@done = true
@cancel.call
end
end
end
def self.normalize_function_ids!(function_buffer)
function_buffer
.css("invoke")
.each_with_index do |invoke, index|
if invoke.at("tool_id")
invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
else
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
end
end
end
private
def maybe_has_tool?(text)
# 16 is the length of function calls
substring = text[-16..-1] || text
split = substring.split("<")
if split.length > 1
match = "<" + split.last
"<function_calls>".start_with?(match)
else
substring.ends_with?("<")
end
end
end

View File

@ -0,0 +1,48 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
# will work for anthropic and open ai compatible
class JsonStreamDecoder
attr_reader :buffer
LINE_REGEX = /data: ({.*})\s*$/
def initialize(symbolize_keys: true, line_regex: LINE_REGEX)
@symbolize_keys = symbolize_keys
@buffer = +""
@line_regex = line_regex
end
def <<(raw)
@buffer << raw.to_s
rval = []
split = @buffer.scan(/.*\n?/)
split.pop if split.last.blank?
@buffer = +(split.pop.to_s)
split.each do |line|
matches = line.match(@line_regex)
next if !matches
rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
end
if @buffer.present?
matches = @buffer.match(@line_regex)
if matches
begin
rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
@buffer = +""
rescue JSON::ParserError
# maybe it is a partial line
end
end
end
rval
end
end
end
end

View File

@ -0,0 +1,103 @@
# frozen_string_literal: true
module DiscourseAi::Completions
class OpenAiMessageProcessor
attr_reader :prompt_tokens, :completion_tokens
def initialize
@tool = nil
@tool_arguments = +""
@prompt_tokens = nil
@completion_tokens = nil
end
def process_message(json)
result = []
tool_calls = json.dig(:choices, 0, :message, :tool_calls)
message = json.dig(:choices, 0, :message, :content)
result << message if message.present?
if tool_calls.present?
tool_calls.each do |tool_call|
id = tool_call.dig(:id)
name = tool_call.dig(:function, :name)
arguments = tool_call.dig(:function, :arguments)
parameters = arguments.present? ? JSON.parse(arguments, symbolize_names: true) : {}
result << ToolCall.new(id: id, name: name, parameters: parameters)
end
end
update_usage(json)
result
end
def process_streamed_message(json)
rval = nil
tool_calls = json.dig(:choices, 0, :delta, :tool_calls)
content = json.dig(:choices, 0, :delta, :content)
finished_tools = json.dig(:choices, 0, :finish_reason) || tool_calls == []
if tool_calls.present?
id = tool_calls.dig(0, :id)
name = tool_calls.dig(0, :function, :name)
arguments = tool_calls.dig(0, :function, :arguments)
# TODO: multiple tool support may require index
#index = tool_calls[0].dig(:index)
if id.present? && @tool && @tool.id != id
process_arguments
rval = @tool
@tool = nil
end
if id.present? && name.present?
@tool_arguments = +""
@tool = ToolCall.new(id: id, name: name)
end
@tool_arguments << arguments.to_s
elsif finished_tools && @tool
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
@tool.parameters = parsed_args
rval = @tool
@tool = nil
elsif content.present?
rval = content
end
update_usage(json)
rval
end
def finish
rval = []
if @tool
process_arguments
rval << @tool
@tool = nil
end
rval
end
private
def process_arguments
if @tool_arguments.present?
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
@tool.parameters = parsed_args
@tool_arguments = nil
end
end
def update_usage(json)
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
@completion_tokens ||= json.dig(:usage, :completion_tokens)
end
end
end

View File

@ -0,0 +1,29 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
class ToolCall
attr_reader :id, :name, :parameters
def initialize(id:, name:, parameters: nil)
@id = id
@name = name
self.parameters = parameters if parameters
@parameters ||= {}
end
def parameters=(parameters)
raise ArgumentError, "parameters must be a hash" unless parameters.is_a?(Hash)
@parameters = parameters.symbolize_keys
end
def ==(other)
id == other.id && name == other.name && parameters == other.parameters
end
def to_s
"#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)"
end
end
end
end

View File

@ -0,0 +1,124 @@
# frozen_string_literal: true
# This class can be used to process a stream of text that may contain XML tool
# calls.
# It will return either text or ToolCall objects.
module DiscourseAi
module Completions
class XmlToolProcessor
def initialize
@buffer = +""
@function_buffer = +""
@should_cancel = false
@in_tool = false
end
def <<(text)
@buffer << text
result = []
if !@in_tool
# double check if we are clearly in a tool
search_length = text.length + 20
search_string = @buffer[-search_length..-1] || @buffer
index = search_string.rindex("<function_calls>")
@in_tool = !!index
if @in_tool
@function_buffer = @buffer[index..-1]
text_index = text.rindex("<function_calls>")
result << text[0..text_index - 1].strip if text_index && text_index > 0
end
else
@function_buffer << text
end
if !@in_tool
if maybe_has_tool?(@buffer)
split_index = text.rindex("<").to_i - 1
if split_index >= 0
@function_buffer = text[split_index + 1..-1] || ""
text = text[0..split_index] || ""
else
@function_buffer << text
text = ""
end
else
if @function_buffer.length > 0
result << @function_buffer
@function_buffer = +""
end
end
result << text if text.length > 0
else
@should_cancel = true if text.include?("</function_calls>")
end
result
end
def finish
return [] if @function_buffer.blank?
xml = Nokogiri::HTML5.fragment(@function_buffer)
normalize_function_ids!(xml)
last_invoke = xml.at("invoke:last")
if last_invoke
last_invoke.next_sibling.remove while last_invoke.next_sibling
xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
end
xml
.css("invoke")
.map do |invoke|
tool_name = invoke.at("tool_name").content.force_encoding("UTF-8")
tool_id = invoke.at("tool_id").content.force_encoding("UTF-8")
parameters = {}
invoke
.at("parameters")
&.children
&.each do |node|
next if node.text?
name = node.name
value = node.content.to_s
parameters[name.to_sym] = value.to_s.force_encoding("UTF-8")
end
ToolCall.new(id: tool_id, name: tool_name, parameters: parameters)
end
end
def should_cancel?
@should_cancel
end
private
def normalize_function_ids!(function_buffer)
function_buffer
.css("invoke")
.each_with_index do |invoke, index|
if invoke.at("tool_id")
invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
else
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
end
end
end
def maybe_has_tool?(text)
# 16 is the length of function calls
substring = text[-16..-1] || text
split = substring.split("<")
if split.length > 1
match = "<" + split.last
"<function_calls>".start_with?(match)
else
substring.ends_with?("<")
end
end
end
end
end

View File

@ -104,7 +104,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
data: {"type":"message_stop"} data: {"type":"message_stop"}
STRING STRING
result = +"" result = []
body = body.scan(/.*\n/) body = body.scan(/.*\n/)
EndpointMock.with_chunk_array_support do EndpointMock.with_chunk_array_support do
stub_request(:post, url).to_return(status: 200, body: body) stub_request(:post, url).to_return(status: 200, body: body)
@ -114,18 +114,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end end
end end
expected = (<<~TEXT).strip tool_call =
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "search",
<tool_name>search</tool_name> id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
<parameters><search_query>s&lt;a&gt;m sam</search_query> parameters: {
<category>general</category></parameters> search_query: "s<a>m sam",
<tool_id>toolu_01DjrShFRRHp9SnHYRFRc53F</tool_id> category: "general",
</invoke> },
</function_calls> )
TEXT
expect(result.strip).to eq(expected) expect(result).to eq([tool_call])
end end
it "can stream a response" do it "can stream a response" do
@ -191,6 +190,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
expect(log.feature_name).to eq("testing") expect(log.feature_name).to eq("testing")
expect(log.response_tokens).to eq(15) expect(log.response_tokens).to eq(15)
expect(log.request_tokens).to eq(25) expect(log.request_tokens).to eq(25)
expect(log.raw_request_payload).to eq(expected_body.to_json)
expect(log.raw_response_payload.strip).to eq(body.strip)
end end
it "supports non streaming tool calls" do it "supports non streaming tool calls" do
@ -242,17 +243,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
result = llm.generate(prompt, user: Discourse.system_user) result = llm.generate(prompt, user: Discourse.system_user)
expected = <<~TEXT.strip tool_call =
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "calculate",
<tool_name>calculate</tool_name> id: "toolu_012kBdhG4eHaV68W56p4N94h",
<parameters><expression>2758975 + 21.11</expression></parameters> parameters: {
<tool_id>toolu_012kBdhG4eHaV68W56p4N94h</tool_id> expression: "2758975 + 21.11",
</invoke> },
</function_calls> )
TEXT
expect(result.strip).to eq(expected) expect(result).to eq(["Here is the calculation:", tool_call])
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(345)
expect(log.response_tokens).to eq(65)
end end
it "can send images via a completion prompt" do it "can send images via a completion prompt" do

View File

@ -79,7 +79,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
} }
prompt.tools = [tool] prompt.tools = [tool]
response = +"" response = []
proxy.generate(prompt, user: user) { |partial| response << partial } proxy.generate(prompt, user: user) { |partial| response << partial }
expect(request.headers["Authorization"]).to be_present expect(request.headers["Authorization"]).to be_present
@ -90,21 +90,18 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
expect(parsed_body["tools"]).to eq(nil) expect(parsed_body["tools"]).to eq(nil)
expect(parsed_body["stop_sequences"]).to eq(["</function_calls>"]) expect(parsed_body["stop_sequences"]).to eq(["</function_calls>"])
# note we now have a tool_id cause we were normalized expected = [
function_call = <<~XML.strip "hello\n",
hello DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "google",
parameters: {
query: "sydney weather today",
},
),
]
expect(response).to eq(expected)
<function_calls>
<invoke>
<tool_name>google</tool_name>
<parameters><query>sydney weather today</query></parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
expect(response.strip).to eq(function_call)
end end
end end
@ -230,23 +227,23 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
} }
prompt.tools = [tool] prompt.tools = [tool]
response = +"" response = []
proxy.generate(prompt, user: user) { |partial| response << partial } proxy.generate(prompt, user: user) { |partial| response << partial }
expect(request.headers["Authorization"]).to be_present expect(request.headers["Authorization"]).to be_present
expect(request.headers["X-Amz-Content-Sha256"]).to be_present expect(request.headers["X-Amz-Content-Sha256"]).to be_present
expected_response = (<<~RESPONSE).strip expected_response = [
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
<tool_name>google</tool_name> name: "google",
<parameters><query>sydney weather today</query></parameters> parameters: {
<tool_id>toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7</tool_id> query: "sydney weather today",
</invoke> },
</function_calls> ),
RESPONSE ]
expect(response.strip).to eq(expected_response) expect(response).to eq(expected_response)
expected = { expected = {
"max_tokens" => 3000, "max_tokens" => 3000,

View File

@ -66,7 +66,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
TEXT TEXT
parsed_body = nil parsed_body = nil
result = +"" result = []
sig = { sig = {
name: "google", name: "google",
@ -91,21 +91,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
}, },
).to_return(status: 200, body: body.split("|")) ).to_return(status: 200, body: body.split("|"))
result = llm.generate(prompt, user: user) { |partial, cancel| result << partial } llm.generate(prompt, user: user) { |partial, cancel| result << partial }
end end
expected = <<~TEXT text = "I will search for 'who is sam saffron' and relay the information to the user."
<function_calls> tool_call =
<invoke> DiscourseAi::Completions::ToolCall.new(
<tool_name>google</tool_name> id: "tool_0",
<parameters><query>who is sam saffron</query> name: "google",
</parameters> parameters: {
<tool_id>tool_0</tool_id> query: "who is sam saffron",
</invoke> },
</function_calls> )
TEXT
expect(result.strip).to eq(expected.strip) expect(result).to eq([text, tool_call])
expected = { expected = {
model: "command-r-plus", model: "command-r-plus",

View File

@ -62,18 +62,14 @@ class EndpointMock
end end
def invocation_response def invocation_response
<<~TEXT DiscourseAi::Completions::ToolCall.new(
<function_calls> id: "tool_0",
<invoke> name: "get_weather",
<tool_name>get_weather</tool_name> parameters: {
<parameters> location: "Sydney",
<location>Sydney</location> unit: "c",
<unit>c</unit> },
</parameters> )
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
TEXT
end end
def tool_id def tool_id
@ -185,7 +181,7 @@ class EndpointsCompliance
mock.stub_tool_call(a_dialect.translate) mock.stub_tool_call(a_dialect.translate)
completion_response = endpoint.perform_completion!(a_dialect, user) completion_response = endpoint.perform_completion!(a_dialect, user)
expect(completion_response.strip).to eq(mock.invocation_response.strip) expect(completion_response).to eq(mock.invocation_response)
end end
def streaming_mode_simple_prompt(mock) def streaming_mode_simple_prompt(mock)
@ -205,6 +201,7 @@ class EndpointsCompliance
expect(log.raw_request_payload).to be_present expect(log.raw_request_payload).to be_present
expect(log.raw_response_payload).to be_present expect(log.raw_response_payload).to be_present
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate)) expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
expect(log.response_tokens).to eq( expect(log.response_tokens).to eq(
endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join), endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join),
) )
@ -216,14 +213,14 @@ class EndpointsCompliance
a_dialect = dialect(prompt: prompt) a_dialect = dialect(prompt: prompt)
mock.stub_streamed_tool_call(a_dialect.translate) do mock.stub_streamed_tool_call(a_dialect.translate) do
buffered_partial = +"" buffered_partial = []
endpoint.perform_completion!(a_dialect, user) do |partial, cancel| endpoint.perform_completion!(a_dialect, user) do |partial, cancel|
buffered_partial << partial buffered_partial << partial
cancel.call if buffered_partial.include?("<function_calls>") cancel.call if partial.is_a?(DiscourseAi::Completions::ToolCall)
end end
expect(buffered_partial.strip).to eq(mock.invocation_response.strip) expect(buffered_partial).to eq([mock.invocation_response])
end end
end end

View File

@ -195,19 +195,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
response = llm.generate(prompt, user: user) response = llm.generate(prompt, user: user)
expected = (<<~XML).strip tool =
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> id: "tool_0",
<tool_name>echo</tool_name> name: "echo",
<parameters> parameters: {
<text>&lt;S&gt;ydney</text> text: "<S>ydney",
</parameters> },
<tool_id>tool_0</tool_id> )
</invoke>
</function_calls>
XML
expect(response.strip).to eq(expected) expect(response).to eq(tool)
end end
it "Supports Vision API" do it "Supports Vision API" do
@ -265,6 +262,68 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
expect(JSON.parse(req_body)).to eq(expected_prompt) expect(JSON.parse(req_body)).to eq(expected_prompt)
end end
it "Can stream tool calls correctly" do
rows = [
{
candidates: [
{
content: {
parts: [{ functionCall: { name: "echo", args: { text: "sam<>wh!s" } } }],
role: "model",
},
safetyRatings: [
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
],
},
],
usageMetadata: {
promptTokenCount: 625,
totalTokenCount: 625,
},
modelVersion: "gemini-1.5-pro-002",
},
{
candidates: [{ content: { parts: [{ text: "" }], role: "model" }, finishReason: "STOP" }],
usageMetadata: {
promptTokenCount: 625,
candidatesTokenCount: 4,
totalTokenCount: 629,
},
modelVersion: "gemini-1.5-pro-002",
},
]
payload = rows.map { |r| "data: #{r.to_json}\n\n" }.join
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool])
output = []
stub_request(:post, url).to_return(status: 200, body: payload)
llm.generate(prompt, user: user) { |partial| output << partial }
tool_call =
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "echo",
parameters: {
text: "sam<>wh!s",
},
)
expect(output).to eq([tool_call])
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(625)
expect(log.response_tokens).to eq(4)
end
it "Can correctly handle streamed responses even if they are chunked badly" do it "Can correctly handle streamed responses even if they are chunked badly" do
data = +"" data = +""
data << "da|ta: |" data << "da|ta: |"
@ -279,12 +338,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
url = "#{model.url}:streamGenerateContent?alt=sse&key=123" url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
output = +"" output = []
gemini_mock.with_chunk_array_support do gemini_mock.with_chunk_array_support do
stub_request(:post, url).to_return(status: 200, body: split) stub_request(:post, url).to_return(status: 200, body: split)
llm.generate("Hello", user: user) { |partial| output << partial } llm.generate("Hello", user: user) { |partial| output << partial }
end end
expect(output).to eq("Hello World Sam") expect(output.join).to eq("Hello World Sam")
end end
end end

View File

@ -150,7 +150,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do
end end
describe "when using streaming mode" do describe "when using streaming mode" do
context "with simpel prompts" do context "with simple prompts" do
it "completes a trivial prompt and logs the response" do it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(ollama_mock) compliance.streaming_mode_simple_prompt(ollama_mock)
end end

View File

@ -17,8 +17,8 @@ class OpenAiMock < EndpointMock
created: 1_678_464_820, created: 1_678_464_820,
model: "gpt-3.5-turbo-0301", model: "gpt-3.5-turbo-0301",
usage: { usage: {
prompt_tokens: 337, prompt_tokens: 8,
completion_tokens: 162, completion_tokens: 13,
total_tokens: 499, total_tokens: 499,
}, },
choices: [ choices: [
@ -231,19 +231,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
result = llm.generate(prompt, user: user) result = llm.generate(prompt, user: user)
expected = (<<~TXT).strip tool_call =
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> id: "call_I8LKnoijVuhKOM85nnEQgWwd",
<tool_name>echo</tool_name> name: "echo",
<parameters> parameters: {
<text>hello</text> text: "hello",
</parameters> },
<tool_id>call_I8LKnoijVuhKOM85nnEQgWwd</tool_id> )
</invoke>
</function_calls>
TXT
expect(result.strip).to eq(expected) expect(result).to eq(tool_call)
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
body: { choices: [message: { content: "OK" }] }.to_json, body: { choices: [message: { content: "OK" }] }.to_json,
@ -320,19 +317,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } }) expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } })
expected = (<<~TXT).strip log = AiApiAuditLog.order(:id).last
<function_calls> expect(log.request_tokens).to eq(55)
<invoke> expect(log.response_tokens).to eq(13)
<tool_name>echo</tool_name>
<parameters>
<text>h&lt;e&gt;llo</text>
</parameters>
<tool_id>call_I8LKnoijVuhKOM85nnEQgWwd</tool_id>
</invoke>
</function_calls>
TXT
expect(result.strip).to eq(expected) expected =
DiscourseAi::Completions::ToolCall.new(
id: "call_I8LKnoijVuhKOM85nnEQgWwd",
name: "echo",
parameters: {
text: "h<e>llo",
},
)
expect(result).to eq(expected)
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
body: { choices: [message: { content: "OK" }] }.to_json, body: { choices: [message: { content: "OK" }] }.to_json,
@ -487,7 +485,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot\\"}"}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot2\\"}"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]} data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
@ -495,32 +493,30 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
TEXT TEXT
open_ai_mock.stub_raw(raw_data) open_ai_mock.stub_raw(raw_data)
content = +"" response = []
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools)) dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
endpoint.perform_completion!(dialect, user) { |partial| content << partial } endpoint.perform_completion!(dialect, user) { |partial| response << partial }
expected = <<~TEXT tool_calls = [
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "search",
<tool_name>search</tool_name> id: "call_3Gyr3HylFJwfrtKrL6NaIit1",
<parameters> parameters: {
<search_query>Discourse AI bot</search_query> search_query: "Discourse AI bot",
</parameters> },
<tool_id>call_3Gyr3HylFJwfrtKrL6NaIit1</tool_id> ),
</invoke> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "search",
<tool_name>search</tool_name> id: "call_H7YkbgYurHpyJqzwUN4bghwN",
<parameters> parameters: {
<query>Discourse AI bot</query> query: "Discourse AI bot2",
</parameters> },
<tool_id>call_H7YkbgYurHpyJqzwUN4bghwN</tool_id> ),
</invoke> ]
</function_calls>
TEXT
expect(content).to eq(expected) expect(response).to eq(tool_calls)
end end
it "uses proper token accounting" do it "uses proper token accounting" do
@ -593,21 +589,16 @@ TEXT
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools)) dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
endpoint.perform_completion!(dialect, user) { |partial| partials << partial } endpoint.perform_completion!(dialect, user) { |partial| partials << partial }
expect(partials.length).to eq(1) tool_call =
DiscourseAi::Completions::ToolCall.new(
id: "func_id",
name: "google",
parameters: {
query: "Adabas 9.1",
},
)
function_call = (<<~TXT).strip expect(partials).to eq([tool_call])
<function_calls>
<invoke>
<tool_name>google</tool_name>
<parameters>
<query>Adabas 9.1</query>
</parameters>
<tool_id>func_id</tool_id>
</invoke>
</function_calls>
TXT
expect(partials[0].strip).to eq(function_call)
end end
end end
end end

View File

@ -22,10 +22,15 @@ data: [DONE]
}, },
).to_return(status: 200, body: body, headers: {}) ).to_return(status: 200, body: body, headers: {})
response = +"" response = []
llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial } llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial }
expect(response).to eq("I am a bot") expect(response).to eq(["I am a bot"])
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(21)
expect(log.response_tokens).to eq(41)
end end
it "can perform regular completions" do it "can perform regular completions" do

View File

@ -51,7 +51,13 @@ class VllmMock < EndpointMock
WebMock WebMock
.stub_request(:post, "https://test.dev/v1/chat/completions") .stub_request(:post, "https://test.dev/v1/chat/completions")
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json) .with(
body:
model
.default_options
.merge(messages: prompt, stream: true, stream_options: { include_usage: true })
.to_json,
)
.to_return(status: 200, body: chunks) .to_return(status: 200, body: chunks)
end end
end end
@ -136,29 +142,115 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
result = llm.generate(prompt, user: Discourse.system_user) result = llm.generate(prompt, user: Discourse.system_user)
expected = <<~TEXT expected =
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "calculate",
<tool_name>calculate</tool_name> id: "tool_0",
<parameters> parameters: {
<expression>1+1</expression></parameters> expression: "1+1",
<tool_id>tool_0</tool_id> },
</invoke> )
</function_calls>
expect(result).to eq(expected)
end
end
it "correctly accounts for tokens in non streaming mode" do
body = (<<~TEXT).strip
{"id":"chat-c580e4a9ebaa44a0becc802ed5dc213a","object":"chat.completion","created":1731294404,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Random Number Generator Produces Smallest Possible Result","tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":146,"total_tokens":156,"completion_tokens":10},"prompt_logprobs":null}
TEXT TEXT
expect(result.strip).to eq(expected.strip) stub_request(:post, "https://test.dev/v1/chat/completions").to_return(status: 200, body: body)
result = llm.generate("generate a title", user: Discourse.system_user)
expect(result).to eq("Random Number Generator Produces Smallest Possible Result")
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(146)
expect(log.response_tokens).to eq(10)
end end
it "can properly include usage in streaming mode" do
payload = <<~TEXT.strip
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":46,"completion_tokens":0}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":47,"completion_tokens":1}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Sam"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":48,"completion_tokens":2}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":49,"completion_tokens":3}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" It"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":50,"completion_tokens":4}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"'s"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":51,"completion_tokens":5}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" nice"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":52,"completion_tokens":6}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":53,"completion_tokens":7}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" meet"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":54,"completion_tokens":8}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":55,"completion_tokens":9}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":56,"completion_tokens":10}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Is"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":57,"completion_tokens":11}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" there"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":58,"completion_tokens":12}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" something"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":59,"completion_tokens":13}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":60,"completion_tokens":14}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":61,"completion_tokens":15}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" help"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":62,"completion_tokens":16}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":63,"completion_tokens":17}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" with"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":64,"completion_tokens":18}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" or"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":65,"completion_tokens":19}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" would"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":66,"completion_tokens":20}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":67,"completion_tokens":21}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" like"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":68,"completion_tokens":22}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":69,"completion_tokens":23}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" chat"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":70,"completion_tokens":24}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":71,"completion_tokens":25}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":""},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}}
data: [DONE]
TEXT
stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
status: 200,
body: payload,
)
response = []
llm.generate("say hello", user: Discourse.system_user) { |partial| response << partial }
expect(response.join).to eq(
"Hello Sam. It's nice to meet you. Is there something I can help you with or would you like to chat?",
)
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(46)
expect(log.response_tokens).to eq(26)
end end
describe "#perform_completion!" do describe "#perform_completion!" do
context "when using regular mode" do context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(vllm_mock)
end
end
context "with tools" do context "with tools" do
it "returns a function invocation" do it "returns a function invocation" do
compliance.regular_mode_tools(vllm_mock) compliance.regular_mode_tools(vllm_mock)

View File

@ -1,182 +0,0 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::FunctionCallNormalizer do
let(:buffer) { +"" }
let(:normalizer) do
blk = ->(data, cancel) { buffer << data }
cancel = -> { @done = true }
DiscourseAi::Completions::FunctionCallNormalizer.new(blk, cancel)
end
def pass_through!(data)
normalizer << data
expect(buffer[-data.length..-1]).to eq(data)
end
it "is usable in non streaming mode" do
xml = (<<~XML).strip
hello
<function_calls>
<invoke>
<tool_name>hello</tool_name>
</invoke>
XML
text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
expect(text).to eq("hello")
expected_function_calls = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
expect(function_calls).to eq(expected_function_calls)
end
it "strips junk from end of function calls" do
xml = (<<~XML).strip
hello
<function_calls>
<invoke>
<tool_name>hello</tool_name>
</invoke>
junk
XML
_text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
expected_function_calls = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
expect(function_calls).to eq(expected_function_calls)
end
it "returns nil for function calls if there are none" do
input = "hello world\n"
text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(input)
expect(text).to eq(input)
expect(function_calls).to eq(nil)
end
it "passes through data if there are no function calls detected" do
pass_through!("hello")
pass_through!("<tool_name>hello</tool_name>")
pass_through!("<parameters><hello>world</hello></parameters>")
pass_through!("<function_call>")
end
it "properly handles non English tools" do
normalizer << "hello<function"
expect(buffer).to eq("hello")
normalizer << "_calls>\n"
normalizer << (<<~XML).strip
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello></hello>
</parameters>
</invoke>
XML
expected = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello></hello>
</parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
function_calls = normalizer.function_calls
expect(function_calls).to eq(expected)
end
it "works correctly even if you only give it 1 letter at a time" do
xml = (<<~XML).strip
abc
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello>world</hello>
</parameters>
<tool_id>abc</tool_id>
</invoke>
<invoke>
<tool_name>hello2</tool_name>
<parameters>
<hello>world</hello>
</parameters>
<tool_id>aba</tool_id>
</invoke>
</function_calls>
XML
xml.each_char { |char| normalizer << char }
expect(buffer + normalizer.function_calls).to eq(xml)
end
it "supports multiple invokes" do
xml = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello>world</hello>
</parameters>
<tool_id>abc</tool_id>
</invoke>
<invoke>
<tool_name>hello2</tool_name>
<parameters>
<hello>world</hello>
</parameters>
<tool_id>aba</tool_id>
</invoke>
</function_calls>
XML
normalizer << xml
expect(normalizer.function_calls).to eq(xml)
end
it "can will cancel if it encounteres </function_calls>" do
normalizer << "<function_calls>"
expect(normalizer.done).to eq(false)
normalizer << "</function_calls>"
expect(normalizer.done).to eq(true)
expect(@done).to eq(true)
expect(normalizer.function_calls).to eq("<function_calls></function_calls>")
end
it "pauses on function call and starts buffering" do
normalizer << "hello<function_call"
expect(buffer).to eq("hello")
expect(normalizer.done).to eq(false)
normalizer << ">"
expect(buffer).to eq("hello<function_call>")
expect(normalizer.done).to eq(false)
end
end

View File

@ -0,0 +1,47 @@
# frozen_string_literal: true
describe DiscourseAi::Completions::JsonStreamDecoder do
let(:decoder) { DiscourseAi::Completions::JsonStreamDecoder.new }
it "should be able to parse simple messages" do
result = decoder << "data: #{{ hello: "world" }.to_json}"
expect(result).to eq([{ hello: "world" }])
end
it "should handle anthropic mixed stlye streams" do
stream = (<<~TEXT).split("|")
event: |message_start|
data: |{"hel|lo": "world"}|
event: |message_start
data: {"foo": "bar"}
event: |message_start
data: {"ba|z": "qux"|}
[DONE]
TEXT
results = []
stream.each { |chunk| results << (decoder << chunk) }
expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }])
end
it "should be able to handle complex overlaps" do
stream = (<<~TEXT).split("|")
data: |{"hel|lo": "world"}
data: {"foo": "bar"}
data: {"ba|z": "qux"|}
[DONE]
TEXT
results = []
stream.each { |chunk| results << (decoder << chunk) }
expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }])
end
end

View File

@ -0,0 +1,188 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::XmlToolProcessor do
let(:processor) { DiscourseAi::Completions::XmlToolProcessor.new }
it "can process simple text" do
result = []
result << (processor << "hello")
result << (processor << " world ")
expect(result).to eq([["hello"], [" world "]])
expect(processor.finish).to eq([])
expect(processor.should_cancel?).to eq(false)
end
it "is usable for simple single message mode" do
xml = (<<~XML).strip
hello
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello>world</hello>
<test>value</test>
</parameters>
</invoke>
XML
result = []
result << (processor << xml)
result << (processor.finish)
tool_call =
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "hello",
parameters: {
hello: "world",
test: "value",
},
)
expect(result).to eq([["hello"], [tool_call]])
expect(processor.should_cancel?).to eq(false)
end
it "handles multiple tool calls in sequence" do
xml = (<<~XML).strip
start
<function_calls>
<invoke>
<tool_name>first_tool</tool_name>
<parameters>
<param1>value1</param1>
</parameters>
</invoke>
<invoke>
<tool_name>second_tool</tool_name>
<parameters>
<param2>value2</param2>
</parameters>
</invoke>
</function_calls>
end
XML
result = []
result << (processor << xml)
result << (processor.finish)
first_tool =
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "first_tool",
parameters: {
param1: "value1",
},
)
second_tool =
DiscourseAi::Completions::ToolCall.new(
id: "tool_1",
name: "second_tool",
parameters: {
param2: "value2",
},
)
expect(result).to eq([["start"], [first_tool, second_tool]])
expect(processor.should_cancel?).to eq(true)
end
it "handles non-English parameters correctly" do
xml = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>translator</tool_name>
<parameters>
<text></text>
</parameters>
</invoke>
XML
result = []
result << (processor << xml)
result << (processor.finish)
tool_call =
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "translator",
parameters: {
text: "世界",
},
)
expect(result).to eq([["こんにちは"], [tool_call]])
end
it "processes input character by character" do
xml =
"hi<function_calls><invoke><tool_name>test</tool_name><parameters><p>v</p></parameters></invoke>"
result = []
xml.each_char { |char| result << (processor << char) }
result << processor.finish
tool_call =
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { p: "v" })
filtered_result = result.reject(&:empty?)
expect(filtered_result).to eq([["h"], ["i"], [tool_call]])
end
it "handles malformed XML gracefully" do
xml = (<<~XML).strip
text
<function_calls>
<invoke>
<tool_name>test</tool_name>
<parameters>
<param>value
</parameters>
</invoke>
malformed
XML
result = []
result << (processor << xml)
result << (processor.finish)
# Should just do its best to parse the XML
tool_call =
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { param: "" })
expect(result).to eq([["text"], [tool_call]])
end
it "correctly processes empty parameter sets" do
xml = (<<~XML).strip
hello
<function_calls>
<invoke>
<tool_name>no_params</tool_name>
<parameters>
</parameters>
</invoke>
XML
result = []
result << (processor << xml)
result << (processor.finish)
tool_call =
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "no_params", parameters: {})
expect(result).to eq([["hello"], [tool_call]])
end
it "properly handles cancelled processing" do
xml = "start<function_calls></function_calls>"
result = []
result << (processor << xml)
result << (processor << "more text")
result << processor.finish
expect(result).to eq([["start"], [], []])
expect(processor.should_cancel?).to eq(true)
end
end

View File

@ -72,40 +72,27 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
it "can parse string that are wrapped in quotes" do it "can parse string that are wrapped in quotes" do
SiteSetting.ai_stability_api_key = "123" SiteSetting.ai_stability_api_key = "123"
xml = <<~XML
<function_calls>
<invoke>
<tool_name>image</tool_name>
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
<parameters>
<prompts>["cat oil painting", "big car"]</prompts>
<aspect_ratio>"16:9"</aspect_ratio>
</parameters>
</invoke>
<invoke>
<tool_name>image</tool_name>
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
<parameters>
<prompts>["cat oil painting", "big car"]</prompts>
<aspect_ratio>'16:9'</aspect_ratio>
</parameters>
</invoke>
</function_calls>
XML
image1, image2 = tool_call =
tools = DiscourseAi::Completions::ToolCall.new(
DiscourseAi::AiBot::Personas::Artist.new.find_tools( name: "image",
xml, id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
prompts: ["cat oil painting", "big car"],
aspect_ratio: "16:9",
},
)
tool_instance =
DiscourseAi::AiBot::Personas::Artist.new.find_tool(
tool_call,
bot_user: nil, bot_user: nil,
llm: nil, llm: nil,
context: nil, context: nil,
) )
expect(image1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
expect(image1.parameters[:aspect_ratio]).to eq("16:9")
expect(image2.parameters[:aspect_ratio]).to eq("16:9")
expect(tools.length).to eq(2) expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"])
expect(tool_instance.parameters[:aspect_ratio]).to eq("16:9")
end end
it "enforces enums" do it "enforces enums" do
@ -132,38 +119,64 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
</function_calls> </function_calls>
XML XML
search1, search2 = tool_call =
tools = DiscourseAi::Completions::ToolCall.new(
DiscourseAi::AiBot::Personas::General.new.find_tools( name: "search",
xml, id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
max_posts: "3.2",
status: "cow",
foo: "bar",
},
)
tool_instance =
DiscourseAi::AiBot::Personas::General.new.find_tool(
tool_call,
bot_user: nil, bot_user: nil,
llm: nil, llm: nil,
context: nil, context: nil,
) )
expect(search1.parameters.key?(:status)).to eq(false) expect(tool_instance.parameters.key?(:status)).to eq(false)
expect(search2.parameters[:status]).to eq("open")
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
max_posts: "3.2",
status: "open",
foo: "bar",
},
)
tool_instance =
DiscourseAi::AiBot::Personas::General.new.find_tool(
tool_call,
bot_user: nil,
llm: nil,
context: nil,
)
expect(tool_instance.parameters[:status]).to eq("open")
end end
it "can coerce integers" do it "can coerce integers" do
xml = <<~XML tool_call =
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "search",
<tool_name>search</tool_name> id: "call_JtYQMful5QKqw97XFsHzPweB",
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id> parameters: {
<parameters> max_posts: "3.2",
<max_posts>"3.2"</max_posts> search_query: "hello world",
<search_query>hello world</search_query> foo: "bar",
<foo>bar</foo> },
</parameters> )
</invoke>
</function_calls>
XML
search, = search =
tools = DiscourseAi::AiBot::Personas::General.new.find_tool(
DiscourseAi::AiBot::Personas::General.new.find_tools( tool_call,
xml,
bot_user: nil, bot_user: nil,
llm: nil, llm: nil,
context: nil, context: nil,
@ -177,43 +190,23 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
it "can correctly parse arrays in tools" do it "can correctly parse arrays in tools" do
SiteSetting.ai_openai_api_key = "123" SiteSetting.ai_openai_api_key = "123"
# Dall E tool uses an array for params tool_call =
xml = <<~XML DiscourseAi::Completions::ToolCall.new(
<function_calls> name: "dall_e",
<invoke> id: "call_JtYQMful5QKqw97XFsHzPweB",
<tool_name>dall_e</tool_name> parameters: {
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id> prompts: ["cat oil painting", "big car"],
<parameters> },
<prompts>["cat oil painting", "big car"]</prompts> )
</parameters>
</invoke> tool_instance =
<invoke> DiscourseAi::AiBot::Personas::DallE3.new.find_tool(
<tool_name>dall_e</tool_name> tool_call,
<tool_id>abc</tool_id>
<parameters>
<prompts>["pic3"]</prompts>
</parameters>
</invoke>
<invoke>
<tool_name>unknown</tool_name>
<tool_id>abc</tool_id>
<parameters>
<prompts>["pic3"]</prompts>
</parameters>
</invoke>
</function_calls>
XML
dall_e1, dall_e2 =
tools =
DiscourseAi::AiBot::Personas::DallE3.new.find_tools(
xml,
bot_user: nil, bot_user: nil,
llm: nil, llm: nil,
context: nil, context: nil,
) )
expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"]) expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"])
expect(dall_e2.parameters[:prompts]).to eq(["pic3"])
expect(tools.length).to eq(2)
end end
describe "custom personas" do describe "custom personas" do

View File

@ -55,6 +55,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
) )
end end
before { SiteSetting.ai_embeddings_enabled = false }
after do after do
# we must reset cache on persona cause data can be rolled back # we must reset cache on persona cause data can be rolled back
AiPersona.persona_cache.flush! AiPersona.persona_cache.flush!
@ -83,17 +85,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end end
let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) } let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) }
let(:function_call) { (<<~XML).strip } let(:tool_call) do
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "search",
<tool_name>search</tool_name> id: "666",
<tool_id>666</tool_id> parameters: {
<parameters> query: "Can you use the custom tool",
<query>Can you use the custom tool</query> },
</parameters> )
</invoke> end
</function_calls>",
XML
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) } let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) }
@ -115,7 +115,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
reply_post = nil reply_post = nil
prompts = nil prompts = nil
responses = [function_call] responses = [tool_call]
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts| DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
new_post = Fabricate(:post, raw: "Can you use the custom tool?") new_post = Fabricate(:post, raw: "Can you use the custom tool?")
reply_post = playground.reply_to(new_post) reply_post = playground.reply_to(new_post)
@ -133,7 +133,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
it "can force usage of a tool" do it "can force usage of a tool" do
tool_name = "custom-#{custom_tool.id}" tool_name = "custom-#{custom_tool.id}"
ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1) ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1)
responses = [function_call, "custom tool did stuff (maybe)"] responses = [tool_call, "custom tool did stuff (maybe)"]
prompts = nil prompts = nil
reply_post = nil reply_post = nil
@ -166,7 +166,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
playground = DiscourseAi::AiBot::Playground.new(bot) playground = DiscourseAi::AiBot::Playground.new(bot)
responses = [function_call, "custom tool did stuff (maybe)"] responses = [tool_call, "custom tool did stuff (maybe)"]
reply_post = nil reply_post = nil
@ -206,13 +206,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
playground = DiscourseAi::AiBot::Playground.new(bot) playground = DiscourseAi::AiBot::Playground.new(bot)
responses = ["custom tool did stuff (maybe)", tool_call]
# lets ensure tool does not run... # lets ensure tool does not run...
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt| DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt|
new_post = Fabricate(:post, raw: "Can you use the custom tool?") new_post = Fabricate(:post, raw: "Can you use the custom tool?")
reply_post = playground.reply_to(new_post) reply_post = playground.reply_to(new_post)
end end
expect(reply_post.raw.strip).to eq(function_call) expect(reply_post.raw.strip).to eq("custom tool did stuff (maybe)")
end end
end end
@ -452,10 +454,25 @@ RSpec.describe DiscourseAi::AiBot::Playground do
it "can run tools" do it "can run tools" do
persona.update!(tools: ["Time"]) persona.update!(tools: ["Time"])
responses = [ tool_call1 =
"<function_calls><invoke><tool_name>time</tool_name><tool_id>time</tool_id><parameters><timezone>Buenos Aires</timezone></parameters></invoke></function_calls>", DiscourseAi::Completions::ToolCall.new(
"The time is 2023-12-14 17:24:00 -0300", name: "time",
] id: "time",
parameters: {
timezone: "Buenos Aires",
},
)
tool_call2 =
DiscourseAi::Completions::ToolCall.new(
name: "time",
id: "time",
parameters: {
timezone: "Sydney",
},
)
responses = [[tool_call1, tool_call2], "The time is 2023-12-14 17:24:00 -0300"]
message = message =
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do DiscourseAi::Completions::Llm.with_prepared_responses(responses) do
@ -470,7 +487,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
# it also needs to have tool details now set on message # it also needs to have tool details now set on message
prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id) prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id)
expect(prompt.custom_prompt.length).to eq(3)
expect(prompt.custom_prompt.length).to eq(5)
# TODO in chat I am mixed on including this in the context, but I guess maybe? # TODO in chat I am mixed on including this in the context, but I guess maybe?
# thinking about this # thinking about this
@ -782,30 +800,29 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end end
it "supports multiple function calls" do it "supports multiple function calls" do
response1 = (<<~TXT).strip tool_call1 =
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "search",
<tool_name>search</tool_name> id: "search",
<tool_id>search</tool_id> parameters: {
<parameters> search_query: "testing various things",
<search_query>testing various things</search_query> },
</parameters> )
</invoke>
<invoke> tool_call2 =
<tool_name>search</tool_name> DiscourseAi::Completions::ToolCall.new(
<tool_id>search</tool_id> name: "search",
<parameters> id: "search",
<search_query>another search</search_query> parameters: {
</parameters> search_query: "another search",
</invoke> },
</function_calls> )
TXT
response2 = "I found stuff" response2 = "I found stuff"
DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do DiscourseAi::Completions::Llm.with_prepared_responses(
playground.reply_to(third_post) [[tool_call1, tool_call2], response2],
end ) { playground.reply_to(third_post) }
last_post = third_post.topic.reload.posts.order(:post_number).last last_post = third_post.topic.reload.posts.order(:post_number).last
@ -819,17 +836,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona.class_instance.new) bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona.class_instance.new)
playground = described_class.new(bot) playground = described_class.new(bot)
response1 = (<<~TXT).strip response1 =
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "search",
<tool_name>search</tool_name> id: "search",
<tool_id>search</tool_id> parameters: {
<parameters> search_query: "testing various things",
<search_query>testing various things</search_query> },
</parameters> )
</invoke>
</function_calls>
TXT
response2 = "I found stuff" response2 = "I found stuff"
@ -843,17 +857,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end end
it "does not include placeholders in conversation context but includes all completions" do it "does not include placeholders in conversation context but includes all completions" do
response1 = (<<~TXT).strip response1 =
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "search",
<tool_name>search</tool_name> id: "search",
<tool_id>search</tool_id> parameters: {
<parameters> search_query: "testing various things",
<search_query>testing various things</search_query> },
</parameters> )
</invoke>
</function_calls>
TXT
response2 = "I found some really amazing stuff!" response2 = "I found some really amazing stuff!"
@ -889,17 +900,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
[{ b64_json: image, revised_prompt: "a pink cow 1" }] [{ b64_json: image, revised_prompt: "a pink cow 1" }]
end end
let(:response) { (<<~TXT).strip } let(:response) do
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "dall_e",
<tool_name>dall_e</tool_name> id: "dall_e",
<tool_id>dall_e</tool_id> parameters: {
<parameters> prompts: ["a pink cow"],
<prompts>["a pink cow"]</prompts> },
</parameters> )
</invoke> end
</function_calls>
TXT
it "properly returns an image when skipping tool details" do it "properly returns an image when skipping tool details" do
persona.update!(tool_details: false) persona.update!(tool_details: false)

View File

@ -541,16 +541,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(topic.title).to eq("An amazing title") expect(topic.title).to eq("An amazing title")
expect(topic.posts.count).to eq(2) expect(topic.posts.count).to eq(2)
# now let's try to make a reply with a tool call tool_call =
function_call = <<~XML DiscourseAi::Completions::ToolCall.new(name: "categories", parameters: {}, id: "tool_1")
<function_calls>
<invoke>
<tool_name>categories</tool_name>
</invoke>
</function_calls>
XML
fake_endpoint.fake_content = [function_call, "this is the response after the tool"] fake_endpoint.fake_content = [tool_call, "this is the response after the tool"]
# this simplifies function calls # this simplifies function calls
fake_endpoint.chunk_count = 1 fake_endpoint.chunk_count = 1

View File

@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiBot::BotController do
fab!(:user) fab!(:user)
fab!(:pm_topic) { Fabricate(:private_message_topic) } fab!(:pm_topic) { Fabricate(:private_message_topic) }
fab!(:pm_post) { Fabricate(:post, topic: pm_topic) } fab!(:pm_post) { Fabricate(:post, topic: pm_topic) }
fab!(:pm_post2) { Fabricate(:post, topic: pm_topic) }
fab!(:pm_post3) { Fabricate(:post, topic: pm_topic) }
before { sign_in(user) } before { sign_in(user) }
@ -22,6 +24,17 @@ RSpec.describe DiscourseAi::AiBot::BotController do
user = pm_topic.topic_allowed_users.first.user user = pm_topic.topic_allowed_users.first.user
sign_in(user) sign_in(user)
log1 =
AiApiAuditLog.create!(
provider_id: 1,
topic_id: pm_topic.id,
raw_request_payload: "request",
raw_response_payload: "response",
request_tokens: 1,
response_tokens: 2,
)
log2 =
AiApiAuditLog.create!( AiApiAuditLog.create!(
post_id: pm_post.id, post_id: pm_post.id,
provider_id: 1, provider_id: 1,
@ -32,24 +45,43 @@ RSpec.describe DiscourseAi::AiBot::BotController do
response_tokens: 2, response_tokens: 2,
) )
log3 =
AiApiAuditLog.create!(
post_id: pm_post2.id,
provider_id: 1,
topic_id: pm_topic.id,
raw_request_payload: "request",
raw_response_payload: "response",
request_tokens: 1,
response_tokens: 2,
)
Group.refresh_automatic_groups! Group.refresh_automatic_groups!
SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s
get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info" get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info"
expect(response.status).to eq(200) expect(response.status).to eq(200)
expect(response.parsed_body["id"]).to eq(log2.id)
expect(response.parsed_body["next_log_id"]).to eq(log3.id)
expect(response.parsed_body["prev_log_id"]).to eq(log1.id)
expect(response.parsed_body["topic_id"]).to eq(pm_topic.id)
expect(response.parsed_body["request_tokens"]).to eq(1) expect(response.parsed_body["request_tokens"]).to eq(1)
expect(response.parsed_body["response_tokens"]).to eq(2) expect(response.parsed_body["response_tokens"]).to eq(2)
expect(response.parsed_body["raw_request_payload"]).to eq("request") expect(response.parsed_body["raw_request_payload"]).to eq("request")
expect(response.parsed_body["raw_response_payload"]).to eq("response") expect(response.parsed_body["raw_response_payload"]).to eq("response")
post2 = Fabricate(:post, topic: pm_topic)
# return previous post if current has no debug info # return previous post if current has no debug info
get "/discourse-ai/ai-bot/post/#{post2.id}/show-debug-info" get "/discourse-ai/ai-bot/post/#{pm_post3.id}/show-debug-info"
expect(response.status).to eq(200) expect(response.status).to eq(200)
expect(response.parsed_body["request_tokens"]).to eq(1) expect(response.parsed_body["request_tokens"]).to eq(1)
expect(response.parsed_body["response_tokens"]).to eq(2) expect(response.parsed_body["response_tokens"]).to eq(2)
# can return debug info by id as well
get "/discourse-ai/ai-bot/show-debug-info/#{log1.id}"
expect(response.status).to eq(200)
expect(response.parsed_body["id"]).to eq(log1.id)
end end
end end