FEATURE: improve tool support

This work in progress PR amends llm completion so it returns
objects for tools vs XML fragments

This will empower future features such as parameter streaming

XML was error prone, object implementation is more robust


Still very much in progress, a lot of code needs to change

Partially implemented on Anthropic at the moment.
This commit is contained in:
Sam Saffron 2024-11-08 14:58:54 +11:00
parent 1ad5321c09
commit bb6df426ae
No known key found for this signature in database
GPG Key ID: B9606168D2FFD9F5
18 changed files with 482 additions and 137 deletions

View File

@ -6,6 +6,16 @@ module DiscourseAi
requires_plugin ::DiscourseAi::PLUGIN_NAME requires_plugin ::DiscourseAi::PLUGIN_NAME
requires_login requires_login
def show_debug_info_by_id
log = AiApiAuditLog.find(params[:id])
if !log.topic
raise Discourse::NotFound
end
guardian.ensure_can_debug_ai_bot_conversation!(log.topic)
render json: AiApiAuditLogSerializer.new(log, root: false), status: 200
end
def show_debug_info def show_debug_info
post = Post.find(params[:post_id]) post = Post.find(params[:post_id])
guardian.ensure_can_debug_ai_bot_conversation!(post) guardian.ensure_can_debug_ai_bot_conversation!(post)

View File

@ -14,6 +14,14 @@ class AiApiAuditLog < ActiveRecord::Base
Ollama = 7 Ollama = 7
SambaNova = 8 SambaNova = 8
end end
def next_log_id
self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first
end
def prev_log_id
self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first
end
end end
# == Schema Information # == Schema Information

View File

@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer
:post_id, :post_id,
:feature_name, :feature_name,
:language_model, :language_model,
:created_at :created_at,
:prev_log_id,
:next_log_id
end end

View File

@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template";
import DButton from "discourse/components/d-button"; import DButton from "discourse/components/d-button";
import DModal from "discourse/components/d-modal"; import DModal from "discourse/components/d-modal";
import { ajax } from "discourse/lib/ajax"; import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error";
import { clipboardCopy, escapeExpression } from "discourse/lib/utilities"; import { clipboardCopy, escapeExpression } from "discourse/lib/utilities";
import i18n from "discourse-common/helpers/i18n"; import i18n from "discourse-common/helpers/i18n";
import discourseLater from "discourse-common/lib/later"; import discourseLater from "discourse-common/lib/later";
@ -63,6 +64,28 @@ export default class DebugAiModal extends Component {
this.copy(this.info.raw_response_payload); this.copy(this.info.raw_response_payload);
} }
async loadLog(logId) {
try {
await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then(
(result) => {
this.info = result;
}
);
} catch (e) {
popupAjaxError(e);
}
}
@action
prevLog() {
this.loadLog(this.info.prev_log_id);
}
@action
nextLog() {
this.loadLog(this.info.next_log_id);
}
copy(text) { copy(text) {
clipboardCopy(text); clipboardCopy(text);
this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared"); this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared");
@ -77,6 +100,8 @@ export default class DebugAiModal extends Component {
`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json` `/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`
).then((result) => { ).then((result) => {
this.info = result; this.info = result;
}).catch((e) => {
popupAjaxError(e);
}); });
} }
@ -147,6 +172,22 @@ export default class DebugAiModal extends Component {
@action={{this.copyResponse}} @action={{this.copyResponse}}
@label="discourse_ai.ai_bot.debug_ai_modal.copy_response" @label="discourse_ai.ai_bot.debug_ai_modal.copy_response"
/> />
{{#if this.info.prev_log_id}}
<DButton
class="btn"
@icon="angles-left"
@action={{this.prevLog}}
@label="discourse_ai.ai_bot.debug_ai_modal.previous_log"
/>
{{/if}}
{{#if this.info.next_log_id}}
<DButton
class="btn"
@icon="angles-right"
@action={{this.nextLog}}
@label="discourse_ai.ai_bot.debug_ai_modal.next_log"
/>
{{/if}}
<span class="ai-debut-modal__just-copied">{{this.justCopiedText}}</span> <span class="ai-debut-modal__just-copied">{{this.justCopiedText}}</span>
</:footer> </:footer>
</DModal> </DModal>

View File

@ -415,6 +415,8 @@ en:
response_tokens: "Response tokens:" response_tokens: "Response tokens:"
request: "Request" request: "Request"
response: "Response" response: "Response"
next_log: "Next"
previous_log: "Previous"
share_full_topic_modal: share_full_topic_modal:
title: "Share Conversation Publicly" title: "Share Conversation Publicly"

View File

@ -22,6 +22,7 @@ DiscourseAi::Engine.routes.draw do
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
get "bot-username" => "bot#show_bot_username" get "bot-username" => "bot#show_bot_username"
get "post/:post_id/show-debug-info" => "bot#show_debug_info" get "post/:post_id/show-debug-info" => "bot#show_debug_info"
get "show-debug-info/:id" => "bot#show_debug_info_by_id"
post "post/:post_id/stop-streaming" => "bot#stop_streaming_response" post "post/:post_id/stop-streaming" => "bot#stop_streaming_response"
end end

View File

@ -100,6 +100,7 @@ module DiscourseAi
llm_kwargs[:top_p] = persona.top_p if persona.top_p llm_kwargs[:top_p] = persona.top_p if persona.top_p
needs_newlines = false needs_newlines = false
tools_ran = 0
while total_completions <= MAX_COMPLETIONS && ongoing_chain while total_completions <= MAX_COMPLETIONS && ongoing_chain
tool_found = false tool_found = false
@ -107,9 +108,10 @@ module DiscourseAi
result = result =
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel| llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context) tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context)
tool = nil if tools_ran >= MAX_TOOLS
if (tools.present?) if (tool.present?)
tool_found = true tool_found = true
# a bit hacky, but extra newlines do no harm # a bit hacky, but extra newlines do no harm
if needs_newlines if needs_newlines
@ -117,10 +119,9 @@ module DiscourseAi
needs_newlines = false needs_newlines = false
end end
tools[0..MAX_TOOLS].each do |tool| process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) tools_ran += 1
ongoing_chain &&= tool.chain_next_response? ongoing_chain &&= tool.chain_next_response?
end
else else
needs_newlines = true needs_newlines = true
update_blk.call(partial, cancel) update_blk.call(partial, cancel)

View File

@ -199,23 +199,17 @@ module DiscourseAi
prompt prompt
end end
def find_tools(partial, bot_user:, llm:, context:) def find_tool(partial, bot_user:, llm:, context:)
return [] if !partial.include?("</invoke>") return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall)
tool_instance(partial, bot_user: bot_user, llm: llm, context: context)
parsed_function = Nokogiri::HTML5.fragment(partial)
parsed_function
.css("invoke")
.map do |fragment|
tool_instance(fragment, bot_user: bot_user, llm: llm, context: context)
end
.compact
end end
protected protected
def tool_instance(parsed_function, bot_user:, llm:, context:) def tool_instance(tool_call, bot_user:, llm:, context:)
function_id = parsed_function.at("tool_id")&.text
function_name = parsed_function.at("tool_name")&.text function_id = tool_call.id
function_name = tool_call.name
return nil if function_name.nil? return nil if function_name.nil?
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name } tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
@ -224,7 +218,7 @@ module DiscourseAi
arguments = {} arguments = {}
tool_klass.signature[:parameters].to_a.each do |param| tool_klass.signature[:parameters].to_a.each do |param|
name = param[:name] name = param[:name]
value = parsed_function.at(name)&.text value = tool_call.parameters[name.to_sym]
if param[:type] == "array" && value if param[:type] == "array" && value
value = value =

View File

@ -13,6 +13,11 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
def append(json) def append(json)
@raw_json << json @raw_json << json
end end
def to_tool_call
parameters = JSON.parse(raw_json, symbolize_names: true)
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
end
end end
attr_reader :tool_calls, :input_tokens, :output_tokens attr_reader :tool_calls, :input_tokens, :output_tokens
@ -20,80 +25,70 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
def initialize(streaming_mode:) def initialize(streaming_mode:)
@streaming_mode = streaming_mode @streaming_mode = streaming_mode
@tool_calls = [] @tool_calls = []
@current_tool_call = nil
end end
def to_xml_tool_calls(function_buffer) def to_tool_calls
return function_buffer if @tool_calls.blank? @tool_calls.map { |tool_call| tool_call.to_tool_call }
end
function_buffer = Nokogiri::HTML5.fragment(<<~TEXT) def process_streamed_message(parsed)
<function_calls> result = nil
</function_calls> if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
TEXT tool_name = parsed.dig(:content_block, :name)
tool_id = parsed.dig(:content_block, :id)
@tool_calls.each do |tool_call| if @current_tool_call
node = result = @current_tool_call.to_tool_call
function_buffer.at("function_calls").add_child( end
Nokogiri::HTML5::DocumentFragment.parse( @current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n", elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
), if @current_tool_call
) tool_delta = parsed.dig(:delta, :partial_json).to_s
@current_tool_call.append(tool_delta)
params = JSON.parse(tool_call.raw_json, symbolize_names: true) else
xml = result = parsed.dig(:delta, :text).to_s
params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}</#{name}>" }.join("\n") end
elsif parsed[:type] == "content_block_stop"
node.at("tool_name").content = tool_call.name if @current_tool_call
node.at("tool_id").content = tool_call.id result = @current_tool_call.to_tool_call
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present? @current_tool_call = nil
end
elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta"
@output_tokens =
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
elsif parsed[:type] == "message_stop"
# bedrock has this ...
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
end
end end
result
function_buffer
end end
def process_message(payload) def process_message(payload)
result = "" result = ""
parsed = JSON.parse(payload, symbolize_names: true) parsed = JSON.parse(payload, symbolize_names: true)
if @streaming_mode content = parsed.dig(:content)
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use" if content.is_a?(Array)
tool_name = parsed.dig(:content_block, :name) result =
tool_id = parsed.dig(:content_block, :id) content.map do |data|
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name if data[:type] == "tool_use"
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" call = AnthropicToolCall.new(data[:name], data[:id])
if @tool_calls.present? call.append(data[:input].to_json)
result = parsed.dig(:delta, :partial_json).to_s call.to_tool_call
@tool_calls.last.append(result) else
else data[:text]
result = parsed.dig(:delta, :text).to_s end
end end
elsif parsed[:type] == "message_start"
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
elsif parsed[:type] == "message_delta"
@output_tokens =
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
elsif parsed[:type] == "message_stop"
# bedrock has this ...
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
end
end
else
content = parsed.dig(:content)
if content.is_a?(Array)
tool_call = content.find { |c| c[:type] == "tool_use" }
if tool_call
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
@tool_calls.last.append(tool_call[:input].to_json)
else
result = parsed.dig(:content, 0, :text).to_s
end
end
@input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens)
end end
@input_tokens = parsed.dig(:usage, :input_tokens)
@output_tokens = parsed.dig(:usage, :output_tokens)
result result
end end
end end

View File

@ -90,15 +90,18 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end
def decode_chunk(partial_data)
@decoder ||= JsonStreamDecoder.new
(@decoder << partial_data).map do |parsed_json|
processor.process_streamed_message(parsed_json)
end.compact
end
def processor def processor
@processor ||= @processor ||=
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
end end
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
processor.to_xml_tool_calls(function_buffer) if !partial
end
def extract_completion_from(response_raw) def extract_completion_from(response_raw)
processor.process_message(response_raw) processor.process_message(response_raw)
end end
@ -107,6 +110,10 @@ module DiscourseAi
processor.tool_calls.present? processor.tool_calls.present?
end end
def tool_calls
processor.to_tool_calls
end
def final_log_update(log) def final_log_update(log)
log.request_tokens = processor.input_tokens if processor.input_tokens log.request_tokens = processor.input_tokens if processor.input_tokens
log.response_tokens = processor.output_tokens if processor.output_tokens log.response_tokens = processor.output_tokens if processor.output_tokens

View File

@ -126,22 +126,129 @@ module DiscourseAi
response_data = extract_completion_from(response_raw) response_data = extract_completion_from(response_raw)
partials_raw = response_data.to_s partials_raw = response_data.to_s
if native_tool_support? if allow_tools && !native_tool_support?
if allow_tools && has_tool?(response_data) response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
function_buffer = build_buffer # Nokogiri document response_data = function_calls if function_calls.present?
function_buffer = end
add_to_function_buffer(function_buffer, payload: response_data)
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
response_data = +function_buffer.at("function_calls").to_s if response_data.is_a?(Array) && response_data.length == 1
response_data << "\n" response_data = response_data.first
end end
else
if allow_tools return response_data
response_data, function_calls = FunctionCallNormalizer.normalize(response_data) end
response_data = function_calls if function_calls.present?
begin
cancelled = false
cancel = -> { cancelled = true }
if cancelled
http.finish
break
end
response.read_body do |chunk|
decode_chunk(chunk).each do |partial|
yield partial, cancel
end end
end end
rescue IOError, StandardError
raise if !cancelled
end
return response_data
ensure
if log
log.raw_response_payload = response_raw
final_log_update(log)
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
log.save!
if Rails.env.development?
puts "#{self.class.name}: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
end
end
end
end
end
def perform_completionx!(
dialect,
user,
model_params = {},
feature_name: nil,
feature_context: nil,
&blk
)
allow_tools = dialect.prompt.has_tools?
model_params = normalize_model_params(model_params)
orig_blk = blk
@streaming_mode = block_given?
to_strip = xml_tags_to_strip(dialect)
@xml_stripper =
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
if @streaming_mode && @xml_stripper
blk =
lambda do |partial, cancel|
partial = @xml_stripper << partial
orig_blk.call(partial, cancel) if partial
end
end
prompt = dialect.translate
FinalDestination::HTTP.start(
model_uri.host,
model_uri.port,
use_ssl: use_ssl?,
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) do |http|
response_data = +""
response_raw = +""
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
partials_raw = +""
request_body = prepare_payload(prompt, model_params, dialect).to_json
request = prepare_request(request_body)
http.request(request) do |response|
if response.code.to_i != 200
Rails.logger.error(
"#{self.class.name}: status: #{response.code.to_i} - body: #{response.body}",
)
raise CompletionFailed, response.body
end
log =
AiApiAuditLog.new(
provider_id: provider_id,
user_id: user&.id,
raw_request_payload: request_body,
request_tokens: prompt_size(prompt),
topic_id: dialect.prompt.topic_id,
post_id: dialect.prompt.post_id,
feature_name: feature_name,
language_model: llm_model.name,
feature_context: feature_context.present? ? feature_context.as_json : nil,
)
if !@streaming_mode
response_raw = response.read_body
response_data = extract_completion_from(response_raw)
partials_raw = response_data.to_s
if allow_tools && !native_tool_support?
response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
response_data = function_calls if function_calls.present?
end
if response_data.is_a?(Array) && response_data.length == 1
response_data = response_data.first
end
return response_data return response_data
end end
@ -277,8 +384,9 @@ module DiscourseAi
ensure ensure
if log if log
log.raw_response_payload = response_raw log.raw_response_payload = response_raw
log.response_tokens = tokenizer.size(partials_raw)
final_log_update(log) final_log_update(log)
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
log.save! log.save!
if Rails.env.development? if Rails.env.development?

View File

@ -45,12 +45,16 @@ module DiscourseAi
cancel_fn = lambda { cancelled = true } cancel_fn = lambda { cancelled = true }
# We buffer and return tool invocations in one go. # We buffer and return tool invocations in one go.
if is_tool?(response) as_array = response.is_a?(Array) ? response : [response]
yield(response, cancel_fn)
else as_array.each do |response|
response.each_char do |char| if is_tool?(response)
break if cancelled yield(response, cancel_fn)
yield(char, cancel_fn) else
response.each_char do |char|
break if cancelled
yield(char, cancel_fn)
end
end end
end end
else else
@ -65,7 +69,7 @@ module DiscourseAi
private private
def is_tool?(response) def is_tool?(response)
Nokogiri::HTML5.fragment(response).at("function_calls").present? response.is_a?(DiscourseAi::Completions::ToolCall)
end end
end end
end end

View File

@ -0,0 +1,47 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
# will work for anthropic and open ai compatible
class JsonStreamDecoder
attr_reader :buffer
LINE_REGEX = /data: ({.*})\s*$/
def initialize(symbolize_keys: true)
@symbolize_keys = symbolize_keys
@buffer = +""
end
def <<(raw)
@buffer << raw.to_s
rval = []
split = @buffer.scan(/.*\n?/)
split.pop if split.last.blank?
@buffer = +(split.pop.to_s)
split.each do |line|
matches = line.match(LINE_REGEX)
next if !matches
rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
end
if @buffer.present?
matches = @buffer.match(LINE_REGEX)
if matches
begin
rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
@buffer = +""
rescue JSON::ParserError
# maybe it is a partial line
end
end
end
rval
end
end
end
end

View File

@ -0,0 +1,24 @@
# frozen_string_literal: true
module DiscourseAi
module Completions
class ToolCall
attr_reader :id, :name, :parameters
def initialize(id:, name:, parameters: nil)
@id = id
@name = name
@parameters = parameters || {}
end
def ==(other)
id == other.id && name == other.name && parameters == other.parameters
end
def to_s
"#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)"
end
end
end
end

View File

@ -104,7 +104,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
data: {"type":"message_stop"} data: {"type":"message_stop"}
STRING STRING
result = +"" result = []
body = body.scan(/.*\n/) body = body.scan(/.*\n/)
EndpointMock.with_chunk_array_support do EndpointMock.with_chunk_array_support do
stub_request(:post, url).to_return(status: 200, body: body) stub_request(:post, url).to_return(status: 200, body: body)
@ -114,18 +114,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end end
end end
expected = (<<~TEXT).strip tool_call =
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "search",
<tool_name>search</tool_name> id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
<parameters><search_query>s&lt;a&gt;m sam</search_query> parameters: {
<category>general</category></parameters> search_query: "s<a>m sam",
<tool_id>toolu_01DjrShFRRHp9SnHYRFRc53F</tool_id> category: "general",
</invoke> },
</function_calls> )
TEXT
expect(result.strip).to eq(expected) expect(result).to eq([tool_call])
end end
it "can stream a response" do it "can stream a response" do
@ -242,17 +241,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
result = llm.generate(prompt, user: Discourse.system_user) result = llm.generate(prompt, user: Discourse.system_user)
expected = <<~TEXT.strip tool_call =
<function_calls> DiscourseAi::Completions::ToolCall.new(
<invoke> name: "calculate",
<tool_name>calculate</tool_name> id: "toolu_012kBdhG4eHaV68W56p4N94h",
<parameters><expression>2758975 + 21.11</expression></parameters> parameters: {
<tool_id>toolu_012kBdhG4eHaV68W56p4N94h</tool_id> expression: "2758975 + 21.11",
</invoke> },
</function_calls> )
TEXT
expect(result.strip).to eq(expected) expect(result).to eq(["Here is the calculation:", tool_call])
log = AiApiAuditLog.order(:id).last
expect(log.request_tokens).to eq(345)
expect(log.response_tokens).to eq(65)
end end
it "can send images via a completion prompt" do it "can send images via a completion prompt" do

View File

@ -0,0 +1,51 @@
# frozen_string_literal: true
describe DiscourseAi::Completions::JsonStreamDecoder do
let(:decoder) { DiscourseAi::Completions::JsonStreamDecoder.new }
it "should be able to parse simple messages" do
result = decoder << "data: #{{ hello: "world" }.to_json}"
expect(result).to eq([{ hello: "world" }])
end
it "should handle anthropic mixed stlye streams" do
stream = (<<~TEXT).split("|")
event: |message_start|
data: |{"hel|lo": "world"}|
event: |message_start
data: {"foo": "bar"}
event: |message_start
data: {"ba|z": "qux"|}
[DONE]
TEXT
results = []
stream.each do |chunk|
results << (decoder << chunk)
end
expect(results.flatten.compact).to eq([{ "hello": "world" }, { "foo": "bar" }, { "baz": "qux" }])
end
it "should be able to handle complex overlaps" do
stream = (<<~TEXT).split("|")
data: |{"hel|lo": "world"}
data: {"foo": "bar"}
data: {"ba|z": "qux"|}
[DONE]
TEXT
results = []
stream.each do |chunk|
results << (decoder << chunk)
end
expect(results.flatten.compact).to eq([{ "hello": "world" }, { "foo": "bar" }, { "baz": "qux" }])
end
end

View File

@ -220,6 +220,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
before do before do
Jobs.run_immediately! Jobs.run_immediately!
SiteSetting.ai_bot_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}" SiteSetting.ai_bot_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}"
SiteSetting.ai_embeddings_enabled = false
end end
fab!(:persona) do fab!(:persona) do
@ -452,10 +453,25 @@ RSpec.describe DiscourseAi::AiBot::Playground do
it "can run tools" do it "can run tools" do
persona.update!(tools: ["Time"]) persona.update!(tools: ["Time"])
responses = [ tool_call1 =
"<function_calls><invoke><tool_name>time</tool_name><tool_id>time</tool_id><parameters><timezone>Buenos Aires</timezone></parameters></invoke></function_calls>", DiscourseAi::Completions::ToolCall.new(
"The time is 2023-12-14 17:24:00 -0300", name: "time",
] id: "time",
parameters: {
timezone: "Buenos Aires",
},
)
tool_call2 =
DiscourseAi::Completions::ToolCall.new(
name: "time",
id: "time",
parameters: {
timezone: "Sydney",
},
)
responses = [[tool_call1, tool_call2], "The time is 2023-12-14 17:24:00 -0300"]
message = message =
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do DiscourseAi::Completions::Llm.with_prepared_responses(responses) do
@ -470,7 +486,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
# it also needs to have tool details now set on message # it also needs to have tool details now set on message
prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id) prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id)
expect(prompt.custom_prompt.length).to eq(3)
expect(prompt.custom_prompt.length).to eq(5)
# TODO in chat I am mixed on including this in the context, but I guess maybe? # TODO in chat I am mixed on including this in the context, but I guess maybe?
# thinking about this # thinking about this

View File

@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiBot::BotController do
fab!(:user) fab!(:user)
fab!(:pm_topic) { Fabricate(:private_message_topic) } fab!(:pm_topic) { Fabricate(:private_message_topic) }
fab!(:pm_post) { Fabricate(:post, topic: pm_topic) } fab!(:pm_post) { Fabricate(:post, topic: pm_topic) }
fab!(:pm_post2) { Fabricate(:post, topic: pm_topic) }
fab!(:pm_post3) { Fabricate(:post, topic: pm_topic) }
before { sign_in(user) } before { sign_in(user) }
@ -22,7 +24,17 @@ RSpec.describe DiscourseAi::AiBot::BotController do
user = pm_topic.topic_allowed_users.first.user user = pm_topic.topic_allowed_users.first.user
sign_in(user) sign_in(user)
AiApiAuditLog.create!(
log1 = AiApiAuditLog.create!(
provider_id: 1,
topic_id: pm_topic.id,
raw_request_payload: "request",
raw_response_payload: "response",
request_tokens: 1,
response_tokens: 2,
)
log2 = AiApiAuditLog.create!(
post_id: pm_post.id, post_id: pm_post.id,
provider_id: 1, provider_id: 1,
topic_id: pm_topic.id, topic_id: pm_topic.id,
@ -32,24 +44,43 @@ RSpec.describe DiscourseAi::AiBot::BotController do
response_tokens: 2, response_tokens: 2,
) )
log3 = AiApiAuditLog.create!(
post_id: pm_post2.id,
provider_id: 1,
topic_id: pm_topic.id,
raw_request_payload: "request",
raw_response_payload: "response",
request_tokens: 1,
response_tokens: 2,
)
Group.refresh_automatic_groups! Group.refresh_automatic_groups!
SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s
get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info" get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info"
expect(response.status).to eq(200) expect(response.status).to eq(200)
expect(response.parsed_body["id"]).to eq(log2.id)
expect(response.parsed_body["next_log_id"]).to eq(log3.id)
expect(response.parsed_body["prev_log_id"]).to eq(log1.id)
expect(response.parsed_body["topic_id"]).to eq(pm_topic.id)
expect(response.parsed_body["request_tokens"]).to eq(1) expect(response.parsed_body["request_tokens"]).to eq(1)
expect(response.parsed_body["response_tokens"]).to eq(2) expect(response.parsed_body["response_tokens"]).to eq(2)
expect(response.parsed_body["raw_request_payload"]).to eq("request") expect(response.parsed_body["raw_request_payload"]).to eq("request")
expect(response.parsed_body["raw_response_payload"]).to eq("response") expect(response.parsed_body["raw_response_payload"]).to eq("response")
post2 = Fabricate(:post, topic: pm_topic)
# return previous post if current has no debug info # return previous post if current has no debug info
get "/discourse-ai/ai-bot/post/#{post2.id}/show-debug-info" get "/discourse-ai/ai-bot/post/#{pm_post3.id}/show-debug-info"
expect(response.status).to eq(200) expect(response.status).to eq(200)
expect(response.parsed_body["request_tokens"]).to eq(1) expect(response.parsed_body["request_tokens"]).to eq(1)
expect(response.parsed_body["response_tokens"]).to eq(2) expect(response.parsed_body["response_tokens"]).to eq(2)
# can return debug info by id as well
get "/discourse-ai/ai-bot/show-debug-info/#{log1.id}"
expect(response.status).to eq(200)
expect(response.parsed_body["id"]).to eq(log1.id)
end end
end end