FEATURE: improve tool support (#904)

This re-implements tool support in DiscourseAi::Completions::Llm #generate

Previously tool support was always returned via XML and it would be the responsibility of the caller to parse XML

New implementation has the endpoints return ToolCall objects.

Additionally this simplifies the Llm endpoint interface and gives it more clarity. Llms must implement

decode, decode_chunk (for streaming)

It is the implementers responsibility to figure out how to decode chunks, base no longer implements. To make this easy we ship a flexible json decoder which is easy to wire up.

Also (new)

    Better debugging for PMs, we now have a next / previous button to see all the Llm messages associated with a PM
    Token accounting is fixed for vllm (we were not correctly counting tokens)
This commit is contained in:
Sam 2024-11-12 08:14:30 +11:00 committed by GitHub
parent 644141ff08
commit e817b7dc11
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
43 changed files with 1685 additions and 1293 deletions

View File

@ -6,6 +6,14 @@ module DiscourseAi
requires_plugin ::DiscourseAi::PLUGIN_NAME
requires_login
def show_debug_info_by_id
log = AiApiAuditLog.find(params[:id])
raise Discourse::NotFound if !log.topic
guardian.ensure_can_debug_ai_bot_conversation!(log.topic)
render json: AiApiAuditLogSerializer.new(log, root: false), status: 200
end
def show_debug_info
post = Post.find(params[:post_id])
guardian.ensure_can_debug_ai_bot_conversation!(post)

View File

@ -14,6 +14,14 @@ class AiApiAuditLog < ActiveRecord::Base
Ollama = 7
SambaNova = 8
end
def next_log_id
self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first
end
def prev_log_id
self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first
end
end
# == Schema Information

View File

@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer
:post_id,
:feature_name,
:language_model,
:created_at
:created_at,
:prev_log_id,
:next_log_id
end

View File

@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template";
import DButton from "discourse/components/d-button";
import DModal from "discourse/components/d-modal";
import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error";
import { clipboardCopy, escapeExpression } from "discourse/lib/utilities";
import i18n from "discourse-common/helpers/i18n";
import discourseLater from "discourse-common/lib/later";
@ -63,6 +64,28 @@ export default class DebugAiModal extends Component {
this.copy(this.info.raw_response_payload);
}
async loadLog(logId) {
try {
await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then(
(result) => {
this.info = result;
}
);
} catch (e) {
popupAjaxError(e);
}
}
@action
prevLog() {
this.loadLog(this.info.prev_log_id);
}
@action
nextLog() {
this.loadLog(this.info.next_log_id);
}
copy(text) {
clipboardCopy(text);
this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared");
@ -73,11 +96,13 @@ export default class DebugAiModal extends Component {
}
loadApiRequestInfo() {
ajax(
`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`
).then((result) => {
this.info = result;
});
ajax(`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`)
.then((result) => {
this.info = result;
})
.catch((e) => {
popupAjaxError(e);
});
}
get requestActive() {
@ -147,6 +172,22 @@ export default class DebugAiModal extends Component {
@action={{this.copyResponse}}
@label="discourse_ai.ai_bot.debug_ai_modal.copy_response"
/>
{{#if this.info.prev_log_id}}
<DButton
class="btn"
@icon="angles-left"
@action={{this.prevLog}}
@label="discourse_ai.ai_bot.debug_ai_modal.previous_log"
/>
{{/if}}
{{#if this.info.next_log_id}}
<DButton
class="btn"
@icon="angles-right"
@action={{this.nextLog}}
@label="discourse_ai.ai_bot.debug_ai_modal.next_log"
/>
{{/if}}
<span class="ai-debut-modal__just-copied">{{this.justCopiedText}}</span>
</:footer>
</DModal>

View File

@ -415,6 +415,8 @@ en:
response_tokens: "Response tokens:"
request: "Request"
response: "Response"
next_log: "Next"
previous_log: "Previous"
share_full_topic_modal:
title: "Share Conversation Publicly"

View File

@ -22,6 +22,7 @@ DiscourseAi::Engine.routes.draw do
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
get "bot-username" => "bot#show_bot_username"
get "post/:post_id/show-debug-info" => "bot#show_debug_info"
get "show-debug-info/:id" => "bot#show_debug_info_by_id"
post "post/:post_id/stop-streaming" => "bot#stop_streaming_response"
end

View File

@ -100,6 +100,7 @@ module DiscourseAi
llm_kwargs[:top_p] = persona.top_p if persona.top_p
needs_newlines = false
tools_ran = 0
while total_completions <= MAX_COMPLETIONS && ongoing_chain
tool_found = false
@ -107,9 +108,10 @@ module DiscourseAi
result =
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context)
tool = nil if tools_ran >= MAX_TOOLS
if (tools.present?)
if tool.present?
tool_found = true
# a bit hacky, but extra newlines do no harm
if needs_newlines
@ -117,13 +119,16 @@ module DiscourseAi
needs_newlines = false
end
tools[0..MAX_TOOLS].each do |tool|
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
ongoing_chain &&= tool.chain_next_response?
end
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
tools_ran += 1
ongoing_chain &&= tool.chain_next_response?
else
needs_newlines = true
update_blk.call(partial, cancel)
if partial.is_a?(DiscourseAi::Completions::ToolCall)
Rails.logger.warn("DiscourseAi: Tool not found: #{partial.name}")
else
update_blk.call(partial, cancel)
end
end
end

View File

@ -199,23 +199,16 @@ module DiscourseAi
prompt
end
def find_tools(partial, bot_user:, llm:, context:)
return [] if !partial.include?("</invoke>")
parsed_function = Nokogiri::HTML5.fragment(partial)
parsed_function
.css("invoke")
.map do |fragment|
tool_instance(fragment, bot_user: bot_user, llm: llm, context: context)
end
.compact
def find_tool(partial, bot_user:, llm:, context:)
return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall)
tool_instance(partial, bot_user: bot_user, llm: llm, context: context)
end
protected
def tool_instance(parsed_function, bot_user:, llm:, context:)
function_id = parsed_function.at("tool_id")&.text
function_name = parsed_function.at("tool_name")&.text
def tool_instance(tool_call, bot_user:, llm:, context:)
function_id = tool_call.id
function_name = tool_call.name
return nil if function_name.nil?
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
@ -224,7 +217,7 @@ module DiscourseAi
arguments = {}
tool_klass.signature[:parameters].to_a.each do |param|
name = param[:name]
value = parsed_function.at(name)&.text
value = tool_call.parameters[name.to_sym]
if param[:type] == "array" && value
value =

View File

@ -13,6 +13,11 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
def append(json)
@raw_json << json
end
def to_tool_call
parameters = JSON.parse(raw_json, symbolize_names: true)
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
end
end
attr_reader :tool_calls, :input_tokens, :output_tokens
@ -20,80 +25,69 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
def initialize(streaming_mode:)
@streaming_mode = streaming_mode
@tool_calls = []
@current_tool_call = nil
end
def to_xml_tool_calls(function_buffer)
return function_buffer if @tool_calls.blank?
def to_tool_calls
@tool_calls.map { |tool_call| tool_call.to_tool_call }
end
function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
<function_calls>
</function_calls>
TEXT
@tool_calls.each do |tool_call|
node =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
),
)
params = JSON.parse(tool_call.raw_json, symbolize_names: true)
xml =
params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}</#{name}>" }.join("\n")
node.at("tool_name").content = tool_call.name
node.at("tool_id").content = tool_call.id
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
def process_streamed_message(parsed)
result = nil
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
tool_name = parsed.dig(:content_block, :name)
tool_id = parsed.dig(:content_block, :id)
result = @current_tool_call.to_tool_call if @current_tool_call
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
if @current_tool_call
tool_delta = parsed.dig(:delta, :partial_json).to_s
@current_tool_call.append(tool_delta)
else
result = parsed.dig(:delta, :text).to_s
end
elsif parsed[:type] == "content_block_stop"
if @current_tool_call
result = @current_tool_call.to_tool_call
@current_tool_call = nil
end
elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta"
@output_tokens =
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
elsif parsed[:type] == "message_stop"
# bedrock has this ...
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
end
end
function_buffer
result
end
def process_message(payload)
result = ""
parsed = JSON.parse(payload, symbolize_names: true)
parsed = payload
parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String)
if @streaming_mode
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
tool_name = parsed.dig(:content_block, :name)
tool_id = parsed.dig(:content_block, :id)
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
if @tool_calls.present?
result = parsed.dig(:delta, :partial_json).to_s
@tool_calls.last.append(result)
else
result = parsed.dig(:delta, :text).to_s
content = parsed.dig(:content)
if content.is_a?(Array)
result =
content.map do |data|
if data[:type] == "tool_use"
call = AnthropicToolCall.new(data[:name], data[:id])
call.append(data[:input].to_json)
call.to_tool_call
else
data[:text]
end
end
elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta"
@output_tokens =
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
elsif parsed[:type] == "message_stop"
# bedrock has this ...
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
end
end
else
content = parsed.dig(:content)
if content.is_a?(Array)
tool_call = content.find { |c| c[:type] == "tool_use" }
if tool_call
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
@tool_calls.last.append(tool_call[:input].to_json)
else
result = parsed.dig(:content, 0, :text).to_s
end
end
@input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens)
end
@input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens)
result
end
end

View File

@ -63,8 +63,23 @@ module DiscourseAi
def user_msg(msg)
user_message = { role: "user", content: msg[:content] }
# TODO: Add support for user messages with empbeded user ids
# TODO: Add support for user messages with attachments
encoded_uploads = prompt.encoded_uploads(msg)
if encoded_uploads.present?
images =
encoded_uploads
.map do |upload|
if upload[:mime_type].start_with?("image/")
upload[:base64]
else
nil
end
end
.compact
user_message[:images] = images if images.present?
end
# TODO: Add support for user messages with embedded user ids
user_message
end

View File

@ -63,6 +63,10 @@ module DiscourseAi
URI(llm_model.url)
end
def xml_tools_enabled?
!@native_tool_support
end
def prepare_payload(prompt, model_params, dialect)
@native_tool_support = dialect.native_tool_support?
@ -90,35 +94,34 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def decode_chunk(partial_data)
@decoder ||= JsonStreamDecoder.new
(@decoder << partial_data)
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
.compact
end
def decode(response_data)
processor.process_message(response_data)
end
def processor
@processor ||=
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
processor.to_xml_tool_calls(function_buffer) if !partial
end
def extract_completion_from(response_raw)
processor.process_message(response_raw)
end
def has_tool?(_response_data)
processor.tool_calls.present?
end
def tool_calls
processor.to_tool_calls
end
def final_log_update(log)
log.request_tokens = processor.input_tokens if processor.input_tokens
log.response_tokens = processor.output_tokens if processor.output_tokens
end
def native_tool_support?
@native_tool_support
end
def partials_from(decoded_chunk)
decoded_chunk.split("\n").map { |line| line.split("data: ", 2)[1] }.compact
end
end
end
end

View File

@ -117,7 +117,24 @@ module DiscourseAi
end
end
def decode(chunk)
def decode_chunk(partial_data)
bedrock_decode(partial_data)
.map do |decoded_partial_data|
@raw_response ||= +""
@raw_response << decoded_partial_data
@raw_response << "\n"
parsed_json = JSON.parse(decoded_partial_data, symbolize_names: true)
processor.process_streamed_message(parsed_json)
end
.compact
end
def decode(response_data)
processor.process_message(response_data)
end
def bedrock_decode(chunk)
@decoder ||= Aws::EventStream::Decoder.new
decoded, _done = @decoder.decode_chunk(chunk)
@ -147,12 +164,13 @@ module DiscourseAi
Aws::EventStream::Errors::MessageChecksumError,
Aws::EventStream::Errors::PreludeChecksumError => e
Rails.logger.error("#{self.class.name}: #{e.message}")
nil
[]
end
def final_log_update(log)
log.request_tokens = processor.input_tokens if processor.input_tokens
log.response_tokens = processor.output_tokens if processor.output_tokens
log.raw_response_payload = @raw_response
end
def processor
@ -160,30 +178,8 @@ module DiscourseAi
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
processor.to_xml_tool_calls(function_buffer) if !partial
end
def extract_completion_from(response_raw)
processor.process_message(response_raw)
end
def has_tool?(_response_data)
processor.tool_calls.present?
end
def partials_from(decoded_chunks)
decoded_chunks
end
def native_tool_support?
@native_tool_support
end
def chunk_to_string(chunk)
joined = +chunk.join("\n")
joined << "\n" if joined.length > 0
joined
def xml_tools_enabled?
!@native_tool_support
end
end
end

View File

@ -40,10 +40,6 @@ module DiscourseAi
@llm_model = llm_model
end
def native_tool_support?
false
end
def use_ssl?
if model_uri&.scheme.present?
model_uri.scheme == "https"
@ -64,22 +60,10 @@ module DiscourseAi
feature_context: nil,
&blk
)
allow_tools = dialect.prompt.has_tools?
model_params = normalize_model_params(model_params)
orig_blk = blk
@streaming_mode = block_given?
to_strip = xml_tags_to_strip(dialect)
@xml_stripper =
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
if @streaming_mode && @xml_stripper
blk =
lambda do |partial, cancel|
partial = @xml_stripper << partial
orig_blk.call(partial, cancel) if partial
end
end
prompt = dialect.translate
@ -108,177 +92,91 @@ module DiscourseAi
raise CompletionFailed, response.body
end
xml_tool_processor = XmlToolProcessor.new if xml_tools_enabled? &&
dialect.prompt.has_tools?
to_strip = xml_tags_to_strip(dialect)
xml_stripper =
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
if @streaming_mode && xml_stripper
blk =
lambda do |partial, cancel|
partial = xml_stripper << partial if partial.is_a?(String)
orig_blk.call(partial, cancel) if partial
end
end
log =
AiApiAuditLog.new(
start_log(
provider_id: provider_id,
user_id: user&.id,
raw_request_payload: request_body,
request_tokens: prompt_size(prompt),
topic_id: dialect.prompt.topic_id,
post_id: dialect.prompt.post_id,
request_body: request_body,
dialect: dialect,
prompt: prompt,
user: user,
feature_name: feature_name,
language_model: llm_model.name,
feature_context: feature_context.present? ? feature_context.as_json : nil,
feature_context: feature_context,
)
if !@streaming_mode
response_raw = response.read_body
response_data = extract_completion_from(response_raw)
partials_raw = response_data.to_s
if native_tool_support?
if allow_tools && has_tool?(response_data)
function_buffer = build_buffer # Nokogiri document
function_buffer =
add_to_function_buffer(function_buffer, payload: response_data)
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
response_data = +function_buffer.at("function_calls").to_s
response_data << "\n"
end
else
if allow_tools
response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
response_data = function_calls if function_calls.present?
end
end
return response_data
return(
non_streaming_response(
response: response,
xml_tool_processor: xml_tool_processor,
xml_stripper: xml_stripper,
partials_raw: partials_raw,
response_raw: response_raw,
)
)
end
has_tool = false
begin
cancelled = false
cancel = -> { cancelled = true }
wrapped_blk = ->(partial, inner_cancel) do
response_data << partial
blk.call(partial, inner_cancel)
if cancelled
http.finish
break
end
normalizer = FunctionCallNormalizer.new(wrapped_blk, cancel)
leftover = ""
function_buffer = build_buffer # Nokogiri document
prev_processed_partials = 0
response.read_body do |chunk|
if cancelled
http.finish
break
end
decoded_chunk = decode(chunk)
if decoded_chunk.nil?
raise CompletionFailed, "#{self.class.name}: Failed to decode LLM completion"
end
response_raw << chunk_to_string(decoded_chunk)
if decoded_chunk.is_a?(String)
redo_chunk = leftover + decoded_chunk
else
# custom implementation for endpoint
# no implicit leftover support
redo_chunk = decoded_chunk
end
raw_partials = partials_from(redo_chunk)
raw_partials =
raw_partials[prev_processed_partials..-1] if prev_processed_partials > 0
if raw_partials.blank? || (raw_partials.size == 1 && raw_partials.first.blank?)
leftover = redo_chunk
next
end
json_error = false
raw_partials.each do |raw_partial|
json_error = false
prev_processed_partials += 1
next if cancelled
next if raw_partial.blank?
begin
partial = extract_completion_from(raw_partial)
next if partial.nil?
# empty vs blank... we still accept " "
next if response_data.empty? && partial.empty?
partials_raw << partial.to_s
if native_tool_support?
# Stop streaming the response as soon as you find a tool.
# We'll buffer and yield it later.
has_tool = true if allow_tools && has_tool?(partials_raw)
if has_tool
function_buffer =
add_to_function_buffer(function_buffer, partial: partial)
else
response_data << partial
blk.call(partial, cancel) if partial
end
else
if allow_tools
normalizer << partial
else
response_data << partial
blk.call(partial, cancel) if partial
end
response_raw << chunk
decode_chunk(chunk).each do |partial|
partials_raw << partial.to_s
response_data << partial if partial.is_a?(String)
partials = [partial]
if xml_tool_processor && partial.is_a?(String)
partials = (xml_tool_processor << partial)
if xml_tool_processor.should_cancel?
cancel.call
break
end
rescue JSON::ParserError
leftover = redo_chunk
json_error = true
end
partials.each { |inner_partial| blk.call(inner_partial, cancel) }
end
if json_error
prev_processed_partials -= 1
else
leftover = ""
end
prev_processed_partials = 0 if leftover.blank?
end
rescue IOError, StandardError
raise if !cancelled
end
has_tool ||= has_tool?(partials_raw)
# Once we have the full response, try to return the tool as a XML doc.
if has_tool && native_tool_support?
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
if function_buffer.at("tool_name").text.present?
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
invocation = +function_buffer.at("function_calls").to_s
invocation << "\n"
response_data << invocation
blk.call(invocation, cancel)
if xml_stripper
stripped = xml_stripper.finish
if stripped.present?
response_data << stripped
result = []
result = (xml_tool_processor << stripped) if xml_tool_processor
result.each { |partial| blk.call(partial, cancel) }
end
end
if !native_tool_support? && function_calls = normalizer.function_calls
response_data << function_calls
blk.call(function_calls, cancel)
if xml_tool_processor
xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) }
end
if @xml_stripper
leftover = @xml_stripper.finish
orig_blk.call(leftover, cancel) if leftover.present?
end
decode_chunk_finish.each { |partial| blk.call(partial, cancel) }
return response_data
ensure
if log
log.raw_response_payload = response_raw
log.response_tokens = tokenizer.size(partials_raw)
final_log_update(log)
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
log.save!
if Rails.env.development?
@ -330,15 +228,15 @@ module DiscourseAi
raise NotImplementedError
end
def extract_completion_from(_response_raw)
def decode(_response_raw)
raise NotImplementedError
end
def decode(chunk)
chunk
def decode_chunk_finish
[]
end
def partials_from(_decoded_chunk)
def decode_chunk(_chunk)
raise NotImplementedError
end
@ -346,49 +244,73 @@ module DiscourseAi
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
def build_buffer
Nokogiri::HTML5.fragment(<<~TEXT)
<function_calls>
#{noop_function_call_text}
</function_calls>
TEXT
def xml_tools_enabled?
raise NotImplementedError
end
def self.noop_function_call_text
(<<~TEXT).strip
<invoke>
<tool_name></tool_name>
<parameters>
</parameters>
<tool_id></tool_id>
</invoke>
TEXT
private
def start_log(
provider_id:,
request_body:,
dialect:,
prompt:,
user:,
feature_name:,
feature_context:
)
AiApiAuditLog.new(
provider_id: provider_id,
user_id: user&.id,
raw_request_payload: request_body,
request_tokens: prompt_size(prompt),
topic_id: dialect.prompt.topic_id,
post_id: dialect.prompt.post_id,
feature_name: feature_name,
language_model: llm_model.name,
feature_context: feature_context.present? ? feature_context.as_json : nil,
)
end
def noop_function_call_text
self.class.noop_function_call_text
end
def non_streaming_response(
response:,
xml_tool_processor:,
xml_stripper:,
partials_raw:,
response_raw:
)
response_raw << response.read_body
response_data = decode(response_raw)
def has_tool?(response)
response.include?("<function_calls>")
end
response_data.each { |partial| partials_raw << partial.to_s }
def chunk_to_string(chunk)
if chunk.is_a?(String)
chunk
else
chunk.to_s
end
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
if payload&.include?("</invoke>")
matches = payload.match(%r{<function_calls>.*</invoke>}m)
function_buffer =
Nokogiri::HTML5.fragment(matches[0] + "\n</function_calls>") if matches
if xml_tool_processor
response_data.each do |partial|
processed = (xml_tool_processor << partial)
processed << xml_tool_processor.finish
response_data = []
processed.flatten.compact.each { |inner| response_data << inner }
end
end
function_buffer
if xml_stripper
response_data.map! do |partial|
stripped = (xml_stripper << partial) if partial.is_a?(String)
if stripped.present?
stripped
else
partial
end
end
response_data << xml_stripper.finish
end
response_data.reject!(&:blank?)
# this is to keep stuff backwards compatible
response_data = response_data.first if response_data.length == 1
response_data
end
end
end

View File

@ -45,17 +45,21 @@ module DiscourseAi
cancel_fn = lambda { cancelled = true }
# We buffer and return tool invocations in one go.
if is_tool?(response)
yield(response, cancel_fn)
else
response.each_char do |char|
break if cancelled
yield(char, cancel_fn)
as_array = response.is_a?(Array) ? response : [response]
as_array.each do |response|
if is_tool?(response)
yield(response, cancel_fn)
else
response.each_char do |char|
break if cancelled
yield(char, cancel_fn)
end
end
end
else
response
end
response = response.first if response.is_a?(Array) && response.length == 1
response
end
def tokenizer
@ -65,7 +69,7 @@ module DiscourseAi
private
def is_tool?(response)
Nokogiri::HTML5.fragment(response).at("function_calls").present?
response.is_a?(DiscourseAi::Completions::ToolCall)
end
end
end

View File

@ -49,6 +49,47 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def decode(response_raw)
rval = []
parsed = JSON.parse(response_raw, symbolize_names: true)
text = parsed[:text]
rval << parsed[:text] if !text.to_s.empty? # also allow " "
# TODO tool calls
update_usage(parsed)
rval
end
def decode_chunk(chunk)
@tool_idx ||= -1
@json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/)
(@json_decoder << chunk)
.map do |parsed|
update_usage(parsed)
rval = []
rval << parsed[:text] if !parsed[:text].to_s.empty?
if tool_calls = parsed[:tool_calls]
tool_calls&.each do |tool_call|
@tool_idx += 1
tool_name = tool_call[:name]
tool_params = tool_call[:parameters]
tool_id = "tool_#{@tool_idx}"
rval << ToolCall.new(id: tool_id, name: tool_name, parameters: tool_params)
end
end
rval
end
.flatten
.compact
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true)
@ -77,36 +118,8 @@ module DiscourseAi
end
end
def has_tool?(_ignored)
@has_tool
end
def native_tool_support?
true
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
if partial
tools = JSON.parse(partial)
tools.each do |tool|
name = tool["name"]
parameters = tool["parameters"]
xml_params = parameters.map { |k, v| "<#{k}>#{v}</#{k}>\n" }.join
current_function = function_buffer.at("invoke")
if current_function.nil? || current_function.at("tool_name").content.present?
current_function =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
)
end
current_function.at("tool_name").content = name == "search_local" ? "search" : name
current_function.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(xml_params)
end
end
function_buffer
def xml_tools_enabled?
false
end
def final_log_update(log)
@ -114,10 +127,6 @@ module DiscourseAi
log.response_tokens = @output_tokens if @output_tokens
end
def partials_from(decoded_chunk)
decoded_chunk.split("\n").compact
end
def extract_prompt_for_tokenizer(prompt)
text = +""
if prompt[:chat_history]
@ -131,6 +140,18 @@ module DiscourseAi
text
end
private
def update_usage(parsed)
input_tokens = parsed.dig(:meta, :billed_units, :input_tokens)
input_tokens ||= parsed.dig(:response, :meta, :billed_units, :input_tokens)
@input_tokens = input_tokens if input_tokens.present?
output_tokens = parsed.dig(:meta, :billed_units, :output_tokens)
output_tokens ||= parsed.dig(:response, :meta, :billed_units, :output_tokens)
@output_tokens = output_tokens if output_tokens.present?
end
end
end
end

View File

@ -133,31 +133,35 @@ module DiscourseAi
content = content.shift if content.is_a?(Array)
if block_given?
split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
indexes = [0, *split_indices, content.length]
if content.is_a?(DiscourseAi::Completions::ToolCall)
yield(content, -> {})
else
split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
indexes = [0, *split_indices, content.length]
original_content = content
content = +""
original_content = content
content = +""
cancel = false
cancel_proc = -> { cancel = true }
cancel = false
cancel_proc = -> { cancel = true }
i = 0
indexes
.each_cons(2)
.map { |start, finish| original_content[start...finish] }
.each do |chunk|
break if cancel
if self.class.delays.present? &&
(delay = self.class.delays[i % self.class.delays.length])
sleep(delay)
i += 1
i = 0
indexes
.each_cons(2)
.map { |start, finish| original_content[start...finish] }
.each do |chunk|
break if cancel
if self.class.delays.present? &&
(delay = self.class.delays[i % self.class.delays.length])
sleep(delay)
i += 1
end
break if cancel
content << chunk
yield(chunk, cancel_proc)
end
break if cancel
content << chunk
yield(chunk, cancel_proc)
end
end
end
content

View File

@ -103,15 +103,7 @@ module DiscourseAi
end
end
def partials_from(decoded_chunk)
decoded_chunk
end
def chunk_to_string(chunk)
chunk.to_s
end
class Decoder
class GeminiStreamingDecoder
def initialize
@buffer = +""
end
@ -151,43 +143,87 @@ module DiscourseAi
end
def decode(chunk)
@decoder ||= Decoder.new
@decoder.decode(chunk)
json = JSON.parse(chunk, symbolize_names: true)
idx = -1
json
.dig(:candidates, 0, :content, :parts)
.map do |part|
if part[:functionCall]
idx += 1
ToolCall.new(
id: "tool_#{idx}",
name: part[:functionCall][:name],
parameters: part[:functionCall][:args],
)
else
part = part[:text]
if part != ""
part
else
nil
end
end
end
end
def decode_chunk(chunk)
@tool_index ||= -1
streaming_decoder
.decode(chunk)
.map do |parsed|
update_usage(parsed)
parsed
.dig(:candidates, 0, :content, :parts)
.map do |part|
if part[:text]
part = part[:text]
if part != ""
part
else
nil
end
elsif part[:functionCall]
@tool_index += 1
ToolCall.new(
id: "tool_#{@tool_index}",
name: part[:functionCall][:name],
parameters: part[:functionCall][:args],
)
end
end
end
.flatten
.compact
end
def update_usage(parsed)
usage = parsed.dig(:usageMetadata)
if usage
if prompt_token_count = usage[:promptTokenCount]
@prompt_token_count = prompt_token_count
end
if candidate_token_count = usage[:candidatesTokenCount]
@candidate_token_count = candidate_token_count
end
end
end
def final_log_update(log)
log.request_tokens = @prompt_token_count if @prompt_token_count
log.response_tokens = @candidate_token_count if @candidate_token_count
end
def streaming_decoder
@decoder ||= GeminiStreamingDecoder.new
end
def extract_prompt_for_tokenizer(prompt)
prompt.to_s
end
def has_tool?(_response_data)
@has_function_call
end
def native_tool_support?
true
end
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
if @streaming_mode
return function_buffer if !partial
else
partial = payload
end
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
if partial[:args]
argument_fragments =
partial[:args].reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}</#{arg_name}>"
end
argument_fragments << "\n"
function_buffer.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
end
function_buffer
def xml_tools_enabled?
false
end
end
end

View File

@ -59,22 +59,30 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
response_h.dig(:content)
def xml_tools_enabled?
true
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data:", 2)[1]
data&.squish == "[DONE]" ? nil : data
def decode(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true)
text = parsed.dig(:choices, 0, :message, :content)
if text.to_s.empty?
[""]
else
[text]
end
end
def decode_chunk(chunk)
@json_decoder ||= JsonStreamDecoder.new
(@json_decoder << chunk)
.map do |parsed|
text = parsed.dig(:choices, 0, :delta, :content)
if text.to_s.empty?
nil
else
text
end
end
.compact
end

View File

@ -37,12 +37,8 @@ module DiscourseAi
URI(llm_model.url)
end
def native_tool_support?
@native_tool_support
end
def has_tool?(_response_data)
@has_function_call
def xml_tools_enabled?
!@native_tool_support
end
def prepare_payload(prompt, model_params, dialect)
@ -67,74 +63,30 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def partials_from(decoded_chunk)
decoded_chunk.split("\n").compact
def decode_chunk(chunk)
# Native tool calls are not working right in streaming mode, use XML
@json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/)
(@json_decoder << chunk).map { |parsed| parsed.dig(:message, :content) }.compact
end
def extract_completion_from(response_raw)
def decode(response_raw)
rval = []
parsed = JSON.parse(response_raw, symbolize_names: true)
return if !parsed
content = parsed.dig(:message, :content)
rval << content if !content.to_s.empty?
response_h = parsed.dig(:message)
@has_function_call ||= response_h.dig(:tool_calls).present?
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
end
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
@args_buffer ||= +""
if @streaming_mode
return function_buffer if !partial
else
partial = payload
end
f_name = partial.dig(:function, :name)
@current_function ||= function_buffer.at("invoke")
if f_name
current_name = function_buffer.at("tool_name").content
if current_name.blank?
# first call
else
# we have a previous function, so we need to add a noop
@args_buffer = +""
@current_function =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
)
idx = -1
parsed
.dig(:message, :tool_calls)
&.each do |tool_call|
idx += 1
id = "tool_#{idx}"
name = tool_call.dig(:function, :name)
args = tool_call.dig(:function, :arguments)
rval << ToolCall.new(id: id, name: name, parameters: args)
end
end
@current_function.at("tool_name").content = f_name if f_name
@current_function.at("tool_id").content = partial[:id] if partial[:id]
args = partial.dig(:function, :arguments)
# allow for SPACE within arguments
if args && args != ""
@args_buffer << args.to_json
begin
json_args = JSON.parse(@args_buffer, symbolize_names: true)
argument_fragments =
json_args.reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
end
argument_fragments << "\n"
@current_function.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
rescue JSON::ParserError
return function_buffer
end
end
function_buffer
rval
end
end
end

View File

@ -93,98 +93,34 @@ module DiscourseAi
end
def final_log_update(log)
log.request_tokens = @prompt_tokens if @prompt_tokens
log.response_tokens = @completion_tokens if @completion_tokens
log.request_tokens = processor.prompt_tokens if processor.prompt_tokens
log.response_tokens = processor.completion_tokens if processor.completion_tokens
end
def extract_completion_from(response_raw)
json = JSON.parse(response_raw, symbolize_names: true)
if @streaming_mode
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
@completion_tokens ||= json.dig(:usage, :completion_tokens)
end
parsed = json.dig(:choices, 0)
return if !parsed
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
@has_function_call ||= response_h.dig(:tool_calls).present?
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
def decode(response_raw)
processor.process_message(JSON.parse(response_raw, symbolize_names: true))
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data: ", 2)[1]
data == "[DONE]" ? nil : data
end
def decode_chunk(chunk)
@decoder ||= JsonStreamDecoder.new
(@decoder << chunk)
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
.flatten
.compact
end
def has_tool?(_response_data)
@has_function_call
def decode_chunk_finish
@processor.finish
end
def native_tool_support?
true
def xml_tools_enabled?
false
end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
if @streaming_mode
return function_buffer if !partial
else
partial = payload
end
private
@args_buffer ||= +""
f_name = partial.dig(:function, :name)
@current_function ||= function_buffer.at("invoke")
if f_name
current_name = function_buffer.at("tool_name").content
if current_name.blank?
# first call
else
# we have a previous function, so we need to add a noop
@args_buffer = +""
@current_function =
function_buffer.at("function_calls").add_child(
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
)
end
end
@current_function.at("tool_name").content = f_name if f_name
@current_function.at("tool_id").content = partial[:id] if partial[:id]
args = partial.dig(:function, :arguments)
# allow for SPACE within arguments
if args && args != ""
@args_buffer << args
begin
json_args = JSON.parse(@args_buffer, symbolize_names: true)
argument_fragments =
json_args.reduce(+"") do |memo, (arg_name, value)|
memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}</#{arg_name}>"
end
argument_fragments << "\n"
@current_function.at("parameters").children =
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
rescue JSON::ParserError
return function_buffer
end
end
function_buffer
def processor
@processor ||= OpenAiMessageProcessor.new
end
end
end

View File

@ -55,27 +55,31 @@ module DiscourseAi
log.response_tokens = @completion_tokens if @completion_tokens
end
def extract_completion_from(response_raw)
json = JSON.parse(response_raw, symbolize_names: true)
if @streaming_mode
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
@completion_tokens ||= json.dig(:usage, :completion_tokens)
end
parsed = json.dig(:choices, 0)
return if !parsed
@streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content)
def xml_tools_enabled?
true
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data: ", 2)[1]
data == "[DONE]" ? nil : data
def decode(response_raw)
json = JSON.parse(response_raw, symbolize_names: true)
[json.dig(:choices, 0, :message, :content)]
end
def decode_chunk(chunk)
@json_decoder ||= JsonStreamDecoder.new
(@json_decoder << chunk)
.map do |json|
text = json.dig(:choices, 0, :delta, :content)
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
@completion_tokens ||= json.dig(:usage, :completion_tokens)
if !text.to_s.empty?
text
else
nil
end
end
.flatten
.compact
end
end

View File

@ -42,7 +42,10 @@ module DiscourseAi
def prepare_payload(prompt, model_params, dialect)
payload = default_options.merge(model_params).merge(messages: prompt)
payload[:stream] = true if @streaming_mode
if @streaming_mode
payload[:stream] = true if @streaming_mode
payload[:stream_options] = { include_usage: true }
end
payload
end
@ -56,24 +59,42 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def partials_from(decoded_chunk)
decoded_chunk
.split("\n")
.map do |line|
data = line.split("data: ", 2)[1]
data == "[DONE]" ? nil : data
end
.compact
def xml_tools_enabled?
true
end
def extract_completion_from(response_raw)
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
# half a line sent here
return if !parsed
def final_log_update(log)
log.request_tokens = @prompt_tokens if @prompt_tokens
log.response_tokens = @completion_tokens if @completion_tokens
end
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
def decode(response_raw)
json = JSON.parse(response_raw, symbolize_names: true)
@prompt_tokens = json.dig(:usage, :prompt_tokens)
@completion_tokens = json.dig(:usage, :completion_tokens)
[json.dig(:choices, 0, :message, :content)]
end
response_h.dig(:content)
def decode_chunk(chunk)
@json_decoder ||= JsonStreamDecoder.new
(@json_decoder << chunk)
.map do |parsed|
# vLLM keeps sending usage over and over again
prompt_tokens = parsed.dig(:usage, :prompt_tokens)
completion_tokens = parsed.dig(:usage, :completion_tokens)
@prompt_tokens = prompt_tokens if prompt_tokens
@completion_tokens = completion_tokens if completion_tokens
text = parsed.dig(:choices, 0, :delta, :content)
if text.to_s.empty?
nil
else
text
end
end
.compact
end
end
end

View File

@ -1,113 +0,0 @@
# frozen_string_literal: true
class DiscourseAi::Completions::FunctionCallNormalizer
attr_reader :done
# blk is the block to call with filtered data
def initialize(blk, cancel)
@blk = blk
@cancel = cancel
@done = false
@in_tool = false
@buffer = +""
@function_buffer = +""
end
def self.normalize(data)
text = +""
cancel = -> {}
blk = ->(partial, _) { text << partial }
normalizer = self.new(blk, cancel)
normalizer << data
[text, normalizer.function_calls]
end
def function_calls
return nil if @function_buffer.blank?
xml = Nokogiri::HTML5.fragment(@function_buffer)
self.class.normalize_function_ids!(xml)
last_invoke = xml.at("invoke:last")
if last_invoke
last_invoke.next_sibling.remove while last_invoke.next_sibling
xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
end
xml.at("function_calls").to_s.dup.force_encoding("UTF-8")
end
def <<(text)
@buffer << text
if !@in_tool
# double check if we are clearly in a tool
search_length = text.length + 20
search_string = @buffer[-search_length..-1] || @buffer
index = search_string.rindex("<function_calls>")
@in_tool = !!index
if @in_tool
@function_buffer = @buffer[index..-1]
text_index = text.rindex("<function_calls>")
@blk.call(text[0..text_index - 1].strip, @cancel) if text_index && text_index > 0
end
else
@function_buffer << text
end
if !@in_tool
if maybe_has_tool?(@buffer)
split_index = text.rindex("<").to_i - 1
if split_index >= 0
@function_buffer = text[split_index + 1..-1] || ""
text = text[0..split_index] || ""
else
@function_buffer << text
text = ""
end
else
if @function_buffer.length > 0
@blk.call(@function_buffer, @cancel)
@function_buffer = +""
end
end
@blk.call(text, @cancel) if text.length > 0
else
if text.include?("</function_calls>")
@done = true
@cancel.call
end
end
end
def self.normalize_function_ids!(function_buffer)
function_buffer
.css("invoke")
.each_with_index do |invoke, index|
if invoke.at("tool_id")
invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
else
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
end
end
end
private
def maybe_has_tool?(text)
# 16 is the length of function calls
substring = text[-16..-1] || text
split = substring.split("<")
if split.length > 1
match = "<" + split.last
"<function_calls>".start_with?(match)
else
substring.ends_with?("<")
end
end
end

View File

@ -0,0 +1,48 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
# will work for anthropic and open ai compatible
class JsonStreamDecoder
attr_reader :buffer
LINE_REGEX = /data: ({.*})\s*$/
def initialize(symbolize_keys: true, line_regex: LINE_REGEX)
@symbolize_keys = symbolize_keys
@buffer = +""
@line_regex = line_regex
end
def <<(raw)
@buffer << raw.to_s
rval = []
split = @buffer.scan(/.*\n?/)
split.pop if split.last.blank?
@buffer = +(split.pop.to_s)
split.each do |line|
matches = line.match(@line_regex)
next if !matches
rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
end
if @buffer.present?
matches = @buffer.match(@line_regex)
if matches
begin
rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
@buffer = +""
rescue JSON::ParserError
# maybe it is a partial line
end
end
end
rval
end
end
end
end

View File

@ -0,0 +1,103 @@
# frozen_string_literal: true
module DiscourseAi::Completions
class OpenAiMessageProcessor
attr_reader :prompt_tokens, :completion_tokens
def initialize
@tool = nil
@tool_arguments = +""
@prompt_tokens = nil
@completion_tokens = nil
end
def process_message(json)
result = []
tool_calls = json.dig(:choices, 0, :message, :tool_calls)
message = json.dig(:choices, 0, :message, :content)
result << message if message.present?
if tool_calls.present?
tool_calls.each do |tool_call|
id = tool_call.dig(:id)
name = tool_call.dig(:function, :name)
arguments = tool_call.dig(:function, :arguments)
parameters = arguments.present? ? JSON.parse(arguments, symbolize_names: true) : {}
result << ToolCall.new(id: id, name: name, parameters: parameters)
end
end
update_usage(json)
result
end
def process_streamed_message(json)
rval = nil
tool_calls = json.dig(:choices, 0, :delta, :tool_calls)
content = json.dig(:choices, 0, :delta, :content)
finished_tools = json.dig(:choices, 0, :finish_reason) || tool_calls == []
if tool_calls.present?
id = tool_calls.dig(0, :id)
name = tool_calls.dig(0, :function, :name)
arguments = tool_calls.dig(0, :function, :arguments)
# TODO: multiple tool support may require index
#index = tool_calls[0].dig(:index)
if id.present? && @tool && @tool.id != id
process_arguments
rval = @tool
@tool = nil
end
if id.present? && name.present?
@tool_arguments = +""
@tool = ToolCall.new(id: id, name: name)
end
@tool_arguments << arguments.to_s
elsif finished_tools && @tool
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
@tool.parameters = parsed_args
rval = @tool
@tool = nil
elsif content.present?
rval = content
end
update_usage(json)
rval
end
def finish
rval = []
if @tool
process_arguments
rval << @tool
@tool = nil
end
rval
end
private
def process_arguments
if @tool_arguments.present?
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
@tool.parameters = parsed_args
@tool_arguments = nil
end
end
def update_usage(json)
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
@completion_tokens ||= json.dig(:usage, :completion_tokens)
end
end
end

View File

@ -0,0 +1,29 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
class ToolCall
attr_reader :id, :name, :parameters
def initialize(id:, name:, parameters: nil)
@id = id
@name = name
self.parameters = parameters if parameters
@parameters ||= {}
end
def parameters=(parameters)
raise ArgumentError, "parameters must be a hash" unless parameters.is_a?(Hash)
@parameters = parameters.symbolize_keys
end
def ==(other)
id == other.id && name == other.name && parameters == other.parameters
end
def to_s
"#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)"
end
end
end
end

View File

@ -0,0 +1,124 @@
# frozen_string_literal: true
# This class can be used to process a stream of text that may contain XML tool
# calls.
# It will return either text or ToolCall objects.
module DiscourseAi
module Completions
class XmlToolProcessor
def initialize
@buffer = +""
@function_buffer = +""
@should_cancel = false
@in_tool = false
end
def <<(text)
@buffer << text
result = []
if !@in_tool
# double check if we are clearly in a tool
search_length = text.length + 20
search_string = @buffer[-search_length..-1] || @buffer
index = search_string.rindex("<function_calls>")
@in_tool = !!index
if @in_tool
@function_buffer = @buffer[index..-1]
text_index = text.rindex("<function_calls>")
result << text[0..text_index - 1].strip if text_index && text_index > 0
end
else
@function_buffer << text
end
if !@in_tool
if maybe_has_tool?(@buffer)
split_index = text.rindex("<").to_i - 1
if split_index >= 0
@function_buffer = text[split_index + 1..-1] || ""
text = text[0..split_index] || ""
else
@function_buffer << text
text = ""
end
else
if @function_buffer.length > 0
result << @function_buffer
@function_buffer = +""
end
end
result << text if text.length > 0
else
@should_cancel = true if text.include?("</function_calls>")
end
result
end
def finish
return [] if @function_buffer.blank?
xml = Nokogiri::HTML5.fragment(@function_buffer)
normalize_function_ids!(xml)
last_invoke = xml.at("invoke:last")
if last_invoke
last_invoke.next_sibling.remove while last_invoke.next_sibling
xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
end
xml
.css("invoke")
.map do |invoke|
tool_name = invoke.at("tool_name").content.force_encoding("UTF-8")
tool_id = invoke.at("tool_id").content.force_encoding("UTF-8")
parameters = {}
invoke
.at("parameters")
&.children
&.each do |node|
next if node.text?
name = node.name
value = node.content.to_s
parameters[name.to_sym] = value.to_s.force_encoding("UTF-8")
end
ToolCall.new(id: tool_id, name: tool_name, parameters: parameters)
end
end
def should_cancel?
@should_cancel
end
private
def normalize_function_ids!(function_buffer)
function_buffer
.css("invoke")
.each_with_index do |invoke, index|
if invoke.at("tool_id")
invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
else
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
end
end
end
def maybe_has_tool?(text)
# 16 is the length of function calls
substring = text[-16..-1] || text
split = substring.split("<")
if split.length > 1
match = "<" + split.last
"<function_calls>".start_with?(match)
else
substring.ends_with?("<")
end
end
end
end
end

View File

@ -104,7 +104,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
data: {"type":"message_stop"}
STRING
result = +""
result = []
body = body.scan(/.*\n/)
EndpointMock.with_chunk_array_support do
stub_request(:post, url).to_return(status: 200, body: body)
@ -114,18 +114,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end
end
expected = (<<~TEXT).strip
<function_calls>
<invoke>
<tool_name>search</tool_name>
<parameters><search_query>s&lt;a&gt;m sam</search_query>
<category>general</category></parameters>
<tool_id>toolu_01DjrShFRRHp9SnHYRFRc53F</tool_id>
</invoke>
</function_calls>
TEXT
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
parameters: {
search_query: "s<a>m sam",
category: "general",
},
)
expect(result.strip).to eq(expected)
expect(result).to eq([tool_call])
end
it "can stream a response" do
@ -191,6 +190,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
expect(log.feature_name).to eq("testing")
expect(log.response_tokens).to eq(15)
expect(log.request_tokens).to eq(25)
expect(log.raw_request_payload).to eq(expected_body.to_json)
expect(log.raw_response_payload.strip).to eq(body.strip)
end
it "supports non streaming tool calls" do
@ -242,17 +243,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
result = llm.generate(prompt, user: Discourse.system_user)
expected = <<~TEXT.strip
<function_calls>
<invoke>
<tool_name>calculate</tool_name>
<parameters><expression>2758975 + 21.11</expression></parameters>
<tool_id>toolu_012kBdhG4eHaV68W56p4N94h</tool_id>
</invoke>
</function_calls>
TEXT
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "calculate",
id: "toolu_012kBdhG4eHaV68W56p4N94h",
parameters: {
expression: "2758975 + 21.11",
},
)
expect(result.strip).to eq(expected)
expect(result).to eq(["Here is the calculation:", tool_call])
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(345)
expect(log.response_tokens).to eq(65)
end
it "can send images via a completion prompt" do

View File

@ -79,7 +79,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
}
prompt.tools = [tool]
response = +""
response = []
proxy.generate(prompt, user: user) { |partial| response << partial }
expect(request.headers["Authorization"]).to be_present
@ -90,21 +90,18 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
expect(parsed_body["tools"]).to eq(nil)
expect(parsed_body["stop_sequences"]).to eq(["</function_calls>"])
# note we now have a tool_id cause we were normalized
function_call = <<~XML.strip
hello
expected = [
"hello\n",
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "google",
parameters: {
query: "sydney weather today",
},
),
]
<function_calls>
<invoke>
<tool_name>google</tool_name>
<parameters><query>sydney weather today</query></parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
expect(response.strip).to eq(function_call)
expect(response).to eq(expected)
end
end
@ -230,23 +227,23 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
}
prompt.tools = [tool]
response = +""
response = []
proxy.generate(prompt, user: user) { |partial| response << partial }
expect(request.headers["Authorization"]).to be_present
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
expected_response = (<<~RESPONSE).strip
<function_calls>
<invoke>
<tool_name>google</tool_name>
<parameters><query>sydney weather today</query></parameters>
<tool_id>toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7</tool_id>
</invoke>
</function_calls>
RESPONSE
expected_response = [
DiscourseAi::Completions::ToolCall.new(
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
name: "google",
parameters: {
query: "sydney weather today",
},
),
]
expect(response.strip).to eq(expected_response)
expect(response).to eq(expected_response)
expected = {
"max_tokens" => 3000,

View File

@ -66,7 +66,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
TEXT
parsed_body = nil
result = +""
result = []
sig = {
name: "google",
@ -91,21 +91,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
},
).to_return(status: 200, body: body.split("|"))
result = llm.generate(prompt, user: user) { |partial, cancel| result << partial }
llm.generate(prompt, user: user) { |partial, cancel| result << partial }
end
expected = <<~TEXT
<function_calls>
<invoke>
<tool_name>google</tool_name>
<parameters><query>who is sam saffron</query>
</parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
TEXT
text = "I will search for 'who is sam saffron' and relay the information to the user."
tool_call =
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "google",
parameters: {
query: "who is sam saffron",
},
)
expect(result.strip).to eq(expected.strip)
expect(result).to eq([text, tool_call])
expected = {
model: "command-r-plus",

View File

@ -62,18 +62,14 @@ class EndpointMock
end
def invocation_response
<<~TEXT
<function_calls>
<invoke>
<tool_name>get_weather</tool_name>
<parameters>
<location>Sydney</location>
<unit>c</unit>
</parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
TEXT
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "get_weather",
parameters: {
location: "Sydney",
unit: "c",
},
)
end
def tool_id
@ -185,7 +181,7 @@ class EndpointsCompliance
mock.stub_tool_call(a_dialect.translate)
completion_response = endpoint.perform_completion!(a_dialect, user)
expect(completion_response.strip).to eq(mock.invocation_response.strip)
expect(completion_response).to eq(mock.invocation_response)
end
def streaming_mode_simple_prompt(mock)
@ -205,6 +201,7 @@ class EndpointsCompliance
expect(log.raw_request_payload).to be_present
expect(log.raw_response_payload).to be_present
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
expect(log.response_tokens).to eq(
endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join),
)
@ -216,14 +213,14 @@ class EndpointsCompliance
a_dialect = dialect(prompt: prompt)
mock.stub_streamed_tool_call(a_dialect.translate) do
buffered_partial = +""
buffered_partial = []
endpoint.perform_completion!(a_dialect, user) do |partial, cancel|
buffered_partial << partial
cancel.call if buffered_partial.include?("<function_calls>")
cancel.call if partial.is_a?(DiscourseAi::Completions::ToolCall)
end
expect(buffered_partial.strip).to eq(mock.invocation_response.strip)
expect(buffered_partial).to eq([mock.invocation_response])
end
end

View File

@ -195,19 +195,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
response = llm.generate(prompt, user: user)
expected = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>&lt;S&gt;ydney</text>
</parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
tool =
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "echo",
parameters: {
text: "<S>ydney",
},
)
expect(response.strip).to eq(expected)
expect(response).to eq(tool)
end
it "Supports Vision API" do
@ -265,6 +262,68 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
expect(JSON.parse(req_body)).to eq(expected_prompt)
end
it "Can stream tool calls correctly" do
rows = [
{
candidates: [
{
content: {
parts: [{ functionCall: { name: "echo", args: { text: "sam<>wh!s" } } }],
role: "model",
},
safetyRatings: [
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
],
},
],
usageMetadata: {
promptTokenCount: 625,
totalTokenCount: 625,
},
modelVersion: "gemini-1.5-pro-002",
},
{
candidates: [{ content: { parts: [{ text: "" }], role: "model" }, finishReason: "STOP" }],
usageMetadata: {
promptTokenCount: 625,
candidatesTokenCount: 4,
totalTokenCount: 629,
},
modelVersion: "gemini-1.5-pro-002",
},
]
payload = rows.map { |r| "data: #{r.to_json}\n\n" }.join
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool])
output = []
stub_request(:post, url).to_return(status: 200, body: payload)
llm.generate(prompt, user: user) { |partial| output << partial }
tool_call =
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "echo",
parameters: {
text: "sam<>wh!s",
},
)
expect(output).to eq([tool_call])
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(625)
expect(log.response_tokens).to eq(4)
end
it "Can correctly handle streamed responses even if they are chunked badly" do
data = +""
data << "da|ta: |"
@ -279,12 +338,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
output = +""
output = []
gemini_mock.with_chunk_array_support do
stub_request(:post, url).to_return(status: 200, body: split)
llm.generate("Hello", user: user) { |partial| output << partial }
end
expect(output).to eq("Hello World Sam")
expect(output.join).to eq("Hello World Sam")
end
end

View File

@ -150,7 +150,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do
end
describe "when using streaming mode" do
context "with simpel prompts" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.streaming_mode_simple_prompt(ollama_mock)
end

View File

@ -17,8 +17,8 @@ class OpenAiMock < EndpointMock
created: 1_678_464_820,
model: "gpt-3.5-turbo-0301",
usage: {
prompt_tokens: 337,
completion_tokens: 162,
prompt_tokens: 8,
completion_tokens: 13,
total_tokens: 499,
},
choices: [
@ -231,19 +231,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
result = llm.generate(prompt, user: user)
expected = (<<~TXT).strip
<function_calls>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>hello</text>
</parameters>
<tool_id>call_I8LKnoijVuhKOM85nnEQgWwd</tool_id>
</invoke>
</function_calls>
TXT
tool_call =
DiscourseAi::Completions::ToolCall.new(
id: "call_I8LKnoijVuhKOM85nnEQgWwd",
name: "echo",
parameters: {
text: "hello",
},
)
expect(result.strip).to eq(expected)
expect(result).to eq(tool_call)
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
body: { choices: [message: { content: "OK" }] }.to_json,
@ -320,19 +317,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } })
expected = (<<~TXT).strip
<function_calls>
<invoke>
<tool_name>echo</tool_name>
<parameters>
<text>h&lt;e&gt;llo</text>
</parameters>
<tool_id>call_I8LKnoijVuhKOM85nnEQgWwd</tool_id>
</invoke>
</function_calls>
TXT
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(55)
expect(log.response_tokens).to eq(13)
expect(result.strip).to eq(expected)
expected =
DiscourseAi::Completions::ToolCall.new(
id: "call_I8LKnoijVuhKOM85nnEQgWwd",
name: "echo",
parameters: {
text: "h<e>llo",
},
)
expect(result).to eq(expected)
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
body: { choices: [message: { content: "OK" }] }.to_json,
@ -487,7 +485,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot\\"}"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot2\\"}"}}]},"logprobs":null,"finish_reason":null}]}
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
@ -495,32 +493,30 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
TEXT
open_ai_mock.stub_raw(raw_data)
content = +""
response = []
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
endpoint.perform_completion!(dialect, user) { |partial| content << partial }
endpoint.perform_completion!(dialect, user) { |partial| response << partial }
expected = <<~TEXT
<function_calls>
<invoke>
<tool_name>search</tool_name>
<parameters>
<search_query>Discourse AI bot</search_query>
</parameters>
<tool_id>call_3Gyr3HylFJwfrtKrL6NaIit1</tool_id>
</invoke>
<invoke>
<tool_name>search</tool_name>
<parameters>
<query>Discourse AI bot</query>
</parameters>
<tool_id>call_H7YkbgYurHpyJqzwUN4bghwN</tool_id>
</invoke>
</function_calls>
TEXT
tool_calls = [
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "call_3Gyr3HylFJwfrtKrL6NaIit1",
parameters: {
search_query: "Discourse AI bot",
},
),
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "call_H7YkbgYurHpyJqzwUN4bghwN",
parameters: {
query: "Discourse AI bot2",
},
),
]
expect(content).to eq(expected)
expect(response).to eq(tool_calls)
end
it "uses proper token accounting" do
@ -593,21 +589,16 @@ TEXT
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
endpoint.perform_completion!(dialect, user) { |partial| partials << partial }
expect(partials.length).to eq(1)
tool_call =
DiscourseAi::Completions::ToolCall.new(
id: "func_id",
name: "google",
parameters: {
query: "Adabas 9.1",
},
)
function_call = (<<~TXT).strip
<function_calls>
<invoke>
<tool_name>google</tool_name>
<parameters>
<query>Adabas 9.1</query>
</parameters>
<tool_id>func_id</tool_id>
</invoke>
</function_calls>
TXT
expect(partials[0].strip).to eq(function_call)
expect(partials).to eq([tool_call])
end
end
end

View File

@ -22,10 +22,15 @@ data: [DONE]
},
).to_return(status: 200, body: body, headers: {})
response = +""
response = []
llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial }
expect(response).to eq("I am a bot")
expect(response).to eq(["I am a bot"])
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(21)
expect(log.response_tokens).to eq(41)
end
it "can perform regular completions" do

View File

@ -51,7 +51,13 @@ class VllmMock < EndpointMock
WebMock
.stub_request(:post, "https://test.dev/v1/chat/completions")
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
.with(
body:
model
.default_options
.merge(messages: prompt, stream: true, stream_options: { include_usage: true })
.to_json,
)
.to_return(status: 200, body: chunks)
end
end
@ -136,29 +142,115 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
result = llm.generate(prompt, user: Discourse.system_user)
expected = <<~TEXT
<function_calls>
<invoke>
<tool_name>calculate</tool_name>
<parameters>
<expression>1+1</expression></parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
TEXT
expected =
DiscourseAi::Completions::ToolCall.new(
name: "calculate",
id: "tool_0",
parameters: {
expression: "1+1",
},
)
expect(result.strip).to eq(expected.strip)
expect(result).to eq(expected)
end
end
it "correctly accounts for tokens in non streaming mode" do
body = (<<~TEXT).strip
{"id":"chat-c580e4a9ebaa44a0becc802ed5dc213a","object":"chat.completion","created":1731294404,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Random Number Generator Produces Smallest Possible Result","tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":146,"total_tokens":156,"completion_tokens":10},"prompt_logprobs":null}
TEXT
stub_request(:post, "https://test.dev/v1/chat/completions").to_return(status: 200, body: body)
result = llm.generate("generate a title", user: Discourse.system_user)
expect(result).to eq("Random Number Generator Produces Smallest Possible Result")
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(146)
expect(log.response_tokens).to eq(10)
end
it "can properly include usage in streaming mode" do
payload = <<~TEXT.strip
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":46,"completion_tokens":0}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":47,"completion_tokens":1}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Sam"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":48,"completion_tokens":2}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":49,"completion_tokens":3}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" It"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":50,"completion_tokens":4}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"'s"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":51,"completion_tokens":5}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" nice"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":52,"completion_tokens":6}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":53,"completion_tokens":7}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" meet"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":54,"completion_tokens":8}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":55,"completion_tokens":9}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":56,"completion_tokens":10}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Is"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":57,"completion_tokens":11}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" there"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":58,"completion_tokens":12}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" something"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":59,"completion_tokens":13}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":60,"completion_tokens":14}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":61,"completion_tokens":15}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" help"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":62,"completion_tokens":16}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":63,"completion_tokens":17}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" with"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":64,"completion_tokens":18}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" or"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":65,"completion_tokens":19}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" would"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":66,"completion_tokens":20}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":67,"completion_tokens":21}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" like"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":68,"completion_tokens":22}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":69,"completion_tokens":23}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" chat"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":70,"completion_tokens":24}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":71,"completion_tokens":25}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":""},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}}
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}}
data: [DONE]
TEXT
stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
status: 200,
body: payload,
)
response = []
llm.generate("say hello", user: Discourse.system_user) { |partial| response << partial }
expect(response.join).to eq(
"Hello Sam. It's nice to meet you. Is there something I can help you with or would you like to chat?",
)
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(46)
expect(log.response_tokens).to eq(26)
end
describe "#perform_completion!" do
context "when using regular mode" do
context "with simple prompts" do
it "completes a trivial prompt and logs the response" do
compliance.regular_mode_simple_prompt(vllm_mock)
end
end
context "with tools" do
it "returns a function invocation" do
compliance.regular_mode_tools(vllm_mock)

View File

@ -1,182 +0,0 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::FunctionCallNormalizer do
let(:buffer) { +"" }
let(:normalizer) do
blk = ->(data, cancel) { buffer << data }
cancel = -> { @done = true }
DiscourseAi::Completions::FunctionCallNormalizer.new(blk, cancel)
end
def pass_through!(data)
normalizer << data
expect(buffer[-data.length..-1]).to eq(data)
end
it "is usable in non streaming mode" do
xml = (<<~XML).strip
hello
<function_calls>
<invoke>
<tool_name>hello</tool_name>
</invoke>
XML
text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
expect(text).to eq("hello")
expected_function_calls = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
expect(function_calls).to eq(expected_function_calls)
end
it "strips junk from end of function calls" do
xml = (<<~XML).strip
hello
<function_calls>
<invoke>
<tool_name>hello</tool_name>
</invoke>
junk
XML
_text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
expected_function_calls = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
expect(function_calls).to eq(expected_function_calls)
end
it "returns nil for function calls if there are none" do
input = "hello world\n"
text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(input)
expect(text).to eq(input)
expect(function_calls).to eq(nil)
end
it "passes through data if there are no function calls detected" do
pass_through!("hello")
pass_through!("<tool_name>hello</tool_name>")
pass_through!("<parameters><hello>world</hello></parameters>")
pass_through!("<function_call>")
end
it "properly handles non English tools" do
normalizer << "hello<function"
expect(buffer).to eq("hello")
normalizer << "_calls>\n"
normalizer << (<<~XML).strip
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello></hello>
</parameters>
</invoke>
XML
expected = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello></hello>
</parameters>
<tool_id>tool_0</tool_id>
</invoke>
</function_calls>
XML
function_calls = normalizer.function_calls
expect(function_calls).to eq(expected)
end
it "works correctly even if you only give it 1 letter at a time" do
xml = (<<~XML).strip
abc
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello>world</hello>
</parameters>
<tool_id>abc</tool_id>
</invoke>
<invoke>
<tool_name>hello2</tool_name>
<parameters>
<hello>world</hello>
</parameters>
<tool_id>aba</tool_id>
</invoke>
</function_calls>
XML
xml.each_char { |char| normalizer << char }
expect(buffer + normalizer.function_calls).to eq(xml)
end
it "supports multiple invokes" do
xml = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello>world</hello>
</parameters>
<tool_id>abc</tool_id>
</invoke>
<invoke>
<tool_name>hello2</tool_name>
<parameters>
<hello>world</hello>
</parameters>
<tool_id>aba</tool_id>
</invoke>
</function_calls>
XML
normalizer << xml
expect(normalizer.function_calls).to eq(xml)
end
it "can will cancel if it encounteres </function_calls>" do
normalizer << "<function_calls>"
expect(normalizer.done).to eq(false)
normalizer << "</function_calls>"
expect(normalizer.done).to eq(true)
expect(@done).to eq(true)
expect(normalizer.function_calls).to eq("<function_calls></function_calls>")
end
it "pauses on function call and starts buffering" do
normalizer << "hello<function_call"
expect(buffer).to eq("hello")
expect(normalizer.done).to eq(false)
normalizer << ">"
expect(buffer).to eq("hello<function_call>")
expect(normalizer.done).to eq(false)
end
end

View File

@ -0,0 +1,47 @@
# frozen_string_literal: true
describe DiscourseAi::Completions::JsonStreamDecoder do
let(:decoder) { DiscourseAi::Completions::JsonStreamDecoder.new }
it "should be able to parse simple messages" do
result = decoder << "data: #{{ hello: "world" }.to_json}"
expect(result).to eq([{ hello: "world" }])
end
it "should handle anthropic mixed stlye streams" do
stream = (<<~TEXT).split("|")
event: |message_start|
data: |{"hel|lo": "world"}|
event: |message_start
data: {"foo": "bar"}
event: |message_start
data: {"ba|z": "qux"|}
[DONE]
TEXT
results = []
stream.each { |chunk| results << (decoder << chunk) }
expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }])
end
it "should be able to handle complex overlaps" do
stream = (<<~TEXT).split("|")
data: |{"hel|lo": "world"}
data: {"foo": "bar"}
data: {"ba|z": "qux"|}
[DONE]
TEXT
results = []
stream.each { |chunk| results << (decoder << chunk) }
expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }])
end
end

View File

@ -0,0 +1,188 @@
# frozen_string_literal: true
RSpec.describe DiscourseAi::Completions::XmlToolProcessor do
let(:processor) { DiscourseAi::Completions::XmlToolProcessor.new }
it "can process simple text" do
result = []
result << (processor << "hello")
result << (processor << " world ")
expect(result).to eq([["hello"], [" world "]])
expect(processor.finish).to eq([])
expect(processor.should_cancel?).to eq(false)
end
it "is usable for simple single message mode" do
xml = (<<~XML).strip
hello
<function_calls>
<invoke>
<tool_name>hello</tool_name>
<parameters>
<hello>world</hello>
<test>value</test>
</parameters>
</invoke>
XML
result = []
result << (processor << xml)
result << (processor.finish)
tool_call =
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "hello",
parameters: {
hello: "world",
test: "value",
},
)
expect(result).to eq([["hello"], [tool_call]])
expect(processor.should_cancel?).to eq(false)
end
it "handles multiple tool calls in sequence" do
xml = (<<~XML).strip
start
<function_calls>
<invoke>
<tool_name>first_tool</tool_name>
<parameters>
<param1>value1</param1>
</parameters>
</invoke>
<invoke>
<tool_name>second_tool</tool_name>
<parameters>
<param2>value2</param2>
</parameters>
</invoke>
</function_calls>
end
XML
result = []
result << (processor << xml)
result << (processor.finish)
first_tool =
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "first_tool",
parameters: {
param1: "value1",
},
)
second_tool =
DiscourseAi::Completions::ToolCall.new(
id: "tool_1",
name: "second_tool",
parameters: {
param2: "value2",
},
)
expect(result).to eq([["start"], [first_tool, second_tool]])
expect(processor.should_cancel?).to eq(true)
end
it "handles non-English parameters correctly" do
xml = (<<~XML).strip
<function_calls>
<invoke>
<tool_name>translator</tool_name>
<parameters>
<text></text>
</parameters>
</invoke>
XML
result = []
result << (processor << xml)
result << (processor.finish)
tool_call =
DiscourseAi::Completions::ToolCall.new(
id: "tool_0",
name: "translator",
parameters: {
text: "世界",
},
)
expect(result).to eq([["こんにちは"], [tool_call]])
end
it "processes input character by character" do
xml =
"hi<function_calls><invoke><tool_name>test</tool_name><parameters><p>v</p></parameters></invoke>"
result = []
xml.each_char { |char| result << (processor << char) }
result << processor.finish
tool_call =
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { p: "v" })
filtered_result = result.reject(&:empty?)
expect(filtered_result).to eq([["h"], ["i"], [tool_call]])
end
it "handles malformed XML gracefully" do
xml = (<<~XML).strip
text
<function_calls>
<invoke>
<tool_name>test</tool_name>
<parameters>
<param>value
</parameters>
</invoke>
malformed
XML
result = []
result << (processor << xml)
result << (processor.finish)
# Should just do its best to parse the XML
tool_call =
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { param: "" })
expect(result).to eq([["text"], [tool_call]])
end
it "correctly processes empty parameter sets" do
xml = (<<~XML).strip
hello
<function_calls>
<invoke>
<tool_name>no_params</tool_name>
<parameters>
</parameters>
</invoke>
XML
result = []
result << (processor << xml)
result << (processor.finish)
tool_call =
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "no_params", parameters: {})
expect(result).to eq([["hello"], [tool_call]])
end
it "properly handles cancelled processing" do
xml = "start<function_calls></function_calls>"
result = []
result << (processor << xml)
result << (processor << "more text")
result << processor.finish
expect(result).to eq([["start"], [], []])
expect(processor.should_cancel?).to eq(true)
end
end

View File

@ -72,40 +72,27 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
it "can parse string that are wrapped in quotes" do
SiteSetting.ai_stability_api_key = "123"
xml = <<~XML
<function_calls>
<invoke>
<tool_name>image</tool_name>
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
<parameters>
<prompts>["cat oil painting", "big car"]</prompts>
<aspect_ratio>"16:9"</aspect_ratio>
</parameters>
</invoke>
<invoke>
<tool_name>image</tool_name>
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
<parameters>
<prompts>["cat oil painting", "big car"]</prompts>
<aspect_ratio>'16:9'</aspect_ratio>
</parameters>
</invoke>
</function_calls>
XML
image1, image2 =
tools =
DiscourseAi::AiBot::Personas::Artist.new.find_tools(
xml,
bot_user: nil,
llm: nil,
context: nil,
)
expect(image1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
expect(image1.parameters[:aspect_ratio]).to eq("16:9")
expect(image2.parameters[:aspect_ratio]).to eq("16:9")
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "image",
id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
prompts: ["cat oil painting", "big car"],
aspect_ratio: "16:9",
},
)
expect(tools.length).to eq(2)
tool_instance =
DiscourseAi::AiBot::Personas::Artist.new.find_tool(
tool_call,
bot_user: nil,
llm: nil,
context: nil,
)
expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"])
expect(tool_instance.parameters[:aspect_ratio]).to eq("16:9")
end
it "enforces enums" do
@ -132,42 +119,68 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
</function_calls>
XML
search1, search2 =
tools =
DiscourseAi::AiBot::Personas::General.new.find_tools(
xml,
bot_user: nil,
llm: nil,
context: nil,
)
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
max_posts: "3.2",
status: "cow",
foo: "bar",
},
)
expect(search1.parameters.key?(:status)).to eq(false)
expect(search2.parameters[:status]).to eq("open")
tool_instance =
DiscourseAi::AiBot::Personas::General.new.find_tool(
tool_call,
bot_user: nil,
llm: nil,
context: nil,
)
expect(tool_instance.parameters.key?(:status)).to eq(false)
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
max_posts: "3.2",
status: "open",
foo: "bar",
},
)
tool_instance =
DiscourseAi::AiBot::Personas::General.new.find_tool(
tool_call,
bot_user: nil,
llm: nil,
context: nil,
)
expect(tool_instance.parameters[:status]).to eq("open")
end
it "can coerce integers" do
xml = <<~XML
<function_calls>
<invoke>
<tool_name>search</tool_name>
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
<parameters>
<max_posts>"3.2"</max_posts>
<search_query>hello world</search_query>
<foo>bar</foo>
</parameters>
</invoke>
</function_calls>
XML
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
max_posts: "3.2",
search_query: "hello world",
foo: "bar",
},
)
search, =
tools =
DiscourseAi::AiBot::Personas::General.new.find_tools(
xml,
bot_user: nil,
llm: nil,
context: nil,
)
search =
DiscourseAi::AiBot::Personas::General.new.find_tool(
tool_call,
bot_user: nil,
llm: nil,
context: nil,
)
expect(search.parameters[:max_posts]).to eq(3)
expect(search.parameters[:search_query]).to eq("hello world")
@ -177,43 +190,23 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
it "can correctly parse arrays in tools" do
SiteSetting.ai_openai_api_key = "123"
# Dall E tool uses an array for params
xml = <<~XML
<function_calls>
<invoke>
<tool_name>dall_e</tool_name>
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
<parameters>
<prompts>["cat oil painting", "big car"]</prompts>
</parameters>
</invoke>
<invoke>
<tool_name>dall_e</tool_name>
<tool_id>abc</tool_id>
<parameters>
<prompts>["pic3"]</prompts>
</parameters>
</invoke>
<invoke>
<tool_name>unknown</tool_name>
<tool_id>abc</tool_id>
<parameters>
<prompts>["pic3"]</prompts>
</parameters>
</invoke>
</function_calls>
XML
dall_e1, dall_e2 =
tools =
DiscourseAi::AiBot::Personas::DallE3.new.find_tools(
xml,
bot_user: nil,
llm: nil,
context: nil,
)
expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
expect(dall_e2.parameters[:prompts]).to eq(["pic3"])
expect(tools.length).to eq(2)
tool_call =
DiscourseAi::Completions::ToolCall.new(
name: "dall_e",
id: "call_JtYQMful5QKqw97XFsHzPweB",
parameters: {
prompts: ["cat oil painting", "big car"],
},
)
tool_instance =
DiscourseAi::AiBot::Personas::DallE3.new.find_tool(
tool_call,
bot_user: nil,
llm: nil,
context: nil,
)
expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"])
end
describe "custom personas" do

View File

@ -55,6 +55,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
)
end
before { SiteSetting.ai_embeddings_enabled = false }
after do
# we must reset cache on persona cause data can be rolled back
AiPersona.persona_cache.flush!
@ -83,17 +85,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end
let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) }
let(:function_call) { (<<~XML).strip }
<function_calls>
<invoke>
<tool_name>search</tool_name>
<tool_id>666</tool_id>
<parameters>
<query>Can you use the custom tool</query>
</parameters>
</invoke>
</function_calls>",
XML
let(:tool_call) do
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "666",
parameters: {
query: "Can you use the custom tool",
},
)
end
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) }
@ -115,7 +115,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
reply_post = nil
prompts = nil
responses = [function_call]
responses = [tool_call]
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
reply_post = playground.reply_to(new_post)
@ -133,7 +133,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
it "can force usage of a tool" do
tool_name = "custom-#{custom_tool.id}"
ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1)
responses = [function_call, "custom tool did stuff (maybe)"]
responses = [tool_call, "custom tool did stuff (maybe)"]
prompts = nil
reply_post = nil
@ -166,7 +166,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
playground = DiscourseAi::AiBot::Playground.new(bot)
responses = [function_call, "custom tool did stuff (maybe)"]
responses = [tool_call, "custom tool did stuff (maybe)"]
reply_post = nil
@ -206,13 +206,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
playground = DiscourseAi::AiBot::Playground.new(bot)
responses = ["custom tool did stuff (maybe)", tool_call]
# lets ensure tool does not run...
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt|
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
reply_post = playground.reply_to(new_post)
end
expect(reply_post.raw.strip).to eq(function_call)
expect(reply_post.raw.strip).to eq("custom tool did stuff (maybe)")
end
end
@ -452,10 +454,25 @@ RSpec.describe DiscourseAi::AiBot::Playground do
it "can run tools" do
persona.update!(tools: ["Time"])
responses = [
"<function_calls><invoke><tool_name>time</tool_name><tool_id>time</tool_id><parameters><timezone>Buenos Aires</timezone></parameters></invoke></function_calls>",
"The time is 2023-12-14 17:24:00 -0300",
]
tool_call1 =
DiscourseAi::Completions::ToolCall.new(
name: "time",
id: "time",
parameters: {
timezone: "Buenos Aires",
},
)
tool_call2 =
DiscourseAi::Completions::ToolCall.new(
name: "time",
id: "time",
parameters: {
timezone: "Sydney",
},
)
responses = [[tool_call1, tool_call2], "The time is 2023-12-14 17:24:00 -0300"]
message =
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do
@ -470,7 +487,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
# it also needs to have tool details now set on message
prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id)
expect(prompt.custom_prompt.length).to eq(3)
expect(prompt.custom_prompt.length).to eq(5)
# TODO in chat I am mixed on including this in the context, but I guess maybe?
# thinking about this
@ -782,30 +800,29 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end
it "supports multiple function calls" do
response1 = (<<~TXT).strip
<function_calls>
<invoke>
<tool_name>search</tool_name>
<tool_id>search</tool_id>
<parameters>
<search_query>testing various things</search_query>
</parameters>
</invoke>
<invoke>
<tool_name>search</tool_name>
<tool_id>search</tool_id>
<parameters>
<search_query>another search</search_query>
</parameters>
</invoke>
</function_calls>
TXT
tool_call1 =
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "search",
parameters: {
search_query: "testing various things",
},
)
tool_call2 =
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "search",
parameters: {
search_query: "another search",
},
)
response2 = "I found stuff"
DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do
playground.reply_to(third_post)
end
DiscourseAi::Completions::Llm.with_prepared_responses(
[[tool_call1, tool_call2], response2],
) { playground.reply_to(third_post) }
last_post = third_post.topic.reload.posts.order(:post_number).last
@ -819,17 +836,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona.class_instance.new)
playground = described_class.new(bot)
response1 = (<<~TXT).strip
<function_calls>
<invoke>
<tool_name>search</tool_name>
<tool_id>search</tool_id>
<parameters>
<search_query>testing various things</search_query>
</parameters>
</invoke>
</function_calls>
TXT
response1 =
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "search",
parameters: {
search_query: "testing various things",
},
)
response2 = "I found stuff"
@ -843,17 +857,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do
end
it "does not include placeholders in conversation context but includes all completions" do
response1 = (<<~TXT).strip
<function_calls>
<invoke>
<tool_name>search</tool_name>
<tool_id>search</tool_id>
<parameters>
<search_query>testing various things</search_query>
</parameters>
</invoke>
</function_calls>
TXT
response1 =
DiscourseAi::Completions::ToolCall.new(
name: "search",
id: "search",
parameters: {
search_query: "testing various things",
},
)
response2 = "I found some really amazing stuff!"
@ -889,17 +900,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
[{ b64_json: image, revised_prompt: "a pink cow 1" }]
end
let(:response) { (<<~TXT).strip }
<function_calls>
<invoke>
<tool_name>dall_e</tool_name>
<tool_id>dall_e</tool_id>
<parameters>
<prompts>["a pink cow"]</prompts>
</parameters>
</invoke>
</function_calls>
TXT
let(:response) do
DiscourseAi::Completions::ToolCall.new(
name: "dall_e",
id: "dall_e",
parameters: {
prompts: ["a pink cow"],
},
)
end
it "properly returns an image when skipping tool details" do
persona.update!(tool_details: false)

View File

@ -541,16 +541,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
expect(topic.title).to eq("An amazing title")
expect(topic.posts.count).to eq(2)
# now let's try to make a reply with a tool call
function_call = <<~XML
<function_calls>
<invoke>
<tool_name>categories</tool_name>
</invoke>
</function_calls>
XML
tool_call =
DiscourseAi::Completions::ToolCall.new(name: "categories", parameters: {}, id: "tool_1")
fake_endpoint.fake_content = [function_call, "this is the response after the tool"]
fake_endpoint.fake_content = [tool_call, "this is the response after the tool"]
# this simplifies function calls
fake_endpoint.chunk_count = 1

View File

@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiBot::BotController do
fab!(:user)
fab!(:pm_topic) { Fabricate(:private_message_topic) }
fab!(:pm_post) { Fabricate(:post, topic: pm_topic) }
fab!(:pm_post2) { Fabricate(:post, topic: pm_topic) }
fab!(:pm_post3) { Fabricate(:post, topic: pm_topic) }
before { sign_in(user) }
@ -22,15 +24,37 @@ RSpec.describe DiscourseAi::AiBot::BotController do
user = pm_topic.topic_allowed_users.first.user
sign_in(user)
AiApiAuditLog.create!(
post_id: pm_post.id,
provider_id: 1,
topic_id: pm_topic.id,
raw_request_payload: "request",
raw_response_payload: "response",
request_tokens: 1,
response_tokens: 2,
)
log1 =
AiApiAuditLog.create!(
provider_id: 1,
topic_id: pm_topic.id,
raw_request_payload: "request",
raw_response_payload: "response",
request_tokens: 1,
response_tokens: 2,
)
log2 =
AiApiAuditLog.create!(
post_id: pm_post.id,
provider_id: 1,
topic_id: pm_topic.id,
raw_request_payload: "request",
raw_response_payload: "response",
request_tokens: 1,
response_tokens: 2,
)
log3 =
AiApiAuditLog.create!(
post_id: pm_post2.id,
provider_id: 1,
topic_id: pm_topic.id,
raw_request_payload: "request",
raw_response_payload: "response",
request_tokens: 1,
response_tokens: 2,
)
Group.refresh_automatic_groups!
SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s
@ -38,18 +62,26 @@ RSpec.describe DiscourseAi::AiBot::BotController do
get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info"
expect(response.status).to eq(200)
expect(response.parsed_body["id"]).to eq(log2.id)
expect(response.parsed_body["next_log_id"]).to eq(log3.id)
expect(response.parsed_body["prev_log_id"]).to eq(log1.id)
expect(response.parsed_body["topic_id"]).to eq(pm_topic.id)
expect(response.parsed_body["request_tokens"]).to eq(1)
expect(response.parsed_body["response_tokens"]).to eq(2)
expect(response.parsed_body["raw_request_payload"]).to eq("request")
expect(response.parsed_body["raw_response_payload"]).to eq("response")
post2 = Fabricate(:post, topic: pm_topic)
# return previous post if current has no debug info
get "/discourse-ai/ai-bot/post/#{post2.id}/show-debug-info"
get "/discourse-ai/ai-bot/post/#{pm_post3.id}/show-debug-info"
expect(response.status).to eq(200)
expect(response.parsed_body["request_tokens"]).to eq(1)
expect(response.parsed_body["response_tokens"]).to eq(2)
# can return debug info by id as well
get "/discourse-ai/ai-bot/show-debug-info/#{log1.id}"
expect(response.status).to eq(200)
expect(response.parsed_body["id"]).to eq(log1.id)
end
end