diff --git a/app/controllers/discourse_ai/ai_bot/bot_controller.rb b/app/controllers/discourse_ai/ai_bot/bot_controller.rb
index e5d5bcf0..37570380 100644
--- a/app/controllers/discourse_ai/ai_bot/bot_controller.rb
+++ b/app/controllers/discourse_ai/ai_bot/bot_controller.rb
@@ -6,6 +6,16 @@ module DiscourseAi
requires_plugin ::DiscourseAi::PLUGIN_NAME
requires_login
+ def show_debug_info_by_id
+ log = AiApiAuditLog.find(params[:id])
+ if !log.topic
+ raise Discourse::NotFound
+ end
+
+ guardian.ensure_can_debug_ai_bot_conversation!(log.topic)
+ render json: AiApiAuditLogSerializer.new(log, root: false), status: 200
+ end
+
def show_debug_info
post = Post.find(params[:post_id])
guardian.ensure_can_debug_ai_bot_conversation!(post)
diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb
index 2fa9f5c3..2fa0a214 100644
--- a/app/models/ai_api_audit_log.rb
+++ b/app/models/ai_api_audit_log.rb
@@ -14,6 +14,14 @@ class AiApiAuditLog < ActiveRecord::Base
Ollama = 7
SambaNova = 8
end
+
+ def next_log_id
+ self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first
+ end
+
+ def prev_log_id
+ self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first
+ end
end
# == Schema Information
diff --git a/app/serializers/ai_api_audit_log_serializer.rb b/app/serializers/ai_api_audit_log_serializer.rb
index 0c438a7b..eeb3843a 100644
--- a/app/serializers/ai_api_audit_log_serializer.rb
+++ b/app/serializers/ai_api_audit_log_serializer.rb
@@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer
:post_id,
:feature_name,
:language_model,
- :created_at
+ :created_at,
+ :prev_log_id,
+ :next_log_id
end
diff --git a/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs b/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs
index 5d0cdf69..e959c123 100644
--- a/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs
+++ b/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs
@@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template";
import DButton from "discourse/components/d-button";
import DModal from "discourse/components/d-modal";
import { ajax } from "discourse/lib/ajax";
+import { popupAjaxError } from "discourse/lib/ajax-error";
import { clipboardCopy, escapeExpression } from "discourse/lib/utilities";
import i18n from "discourse-common/helpers/i18n";
import discourseLater from "discourse-common/lib/later";
@@ -63,6 +64,28 @@ export default class DebugAiModal extends Component {
this.copy(this.info.raw_response_payload);
}
+ async loadLog(logId) {
+ try {
+ await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then(
+ (result) => {
+ this.info = result;
+ }
+ );
+ } catch (e) {
+ popupAjaxError(e);
+ }
+ }
+
+ @action
+ prevLog() {
+ this.loadLog(this.info.prev_log_id);
+ }
+
+ @action
+ nextLog() {
+ this.loadLog(this.info.next_log_id);
+ }
+
copy(text) {
clipboardCopy(text);
this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared");
@@ -77,6 +100,8 @@ export default class DebugAiModal extends Component {
`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`
).then((result) => {
this.info = result;
+ }).catch((e) => {
+ popupAjaxError(e);
});
}
@@ -147,6 +172,22 @@ export default class DebugAiModal extends Component {
@action={{this.copyResponse}}
@label="discourse_ai.ai_bot.debug_ai_modal.copy_response"
/>
+ {{#if this.info.prev_log_id}}
+
+ {{/if}}
+ {{#if this.info.next_log_id}}
+
+ {{/if}}
{{this.justCopiedText}}
diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml
index 2d517946..82898c91 100644
--- a/config/locales/client.en.yml
+++ b/config/locales/client.en.yml
@@ -415,6 +415,8 @@ en:
response_tokens: "Response tokens:"
request: "Request"
response: "Response"
+ next_log: "Next"
+ previous_log: "Previous"
share_full_topic_modal:
title: "Share Conversation Publicly"
diff --git a/config/routes.rb b/config/routes.rb
index a5c009ff..322e67ce 100644
--- a/config/routes.rb
+++ b/config/routes.rb
@@ -22,6 +22,7 @@ DiscourseAi::Engine.routes.draw do
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
get "bot-username" => "bot#show_bot_username"
get "post/:post_id/show-debug-info" => "bot#show_debug_info"
+ get "show-debug-info/:id" => "bot#show_debug_info_by_id"
post "post/:post_id/stop-streaming" => "bot#stop_streaming_response"
end
diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb
index 834ae059..0b05b567 100644
--- a/lib/ai_bot/bot.rb
+++ b/lib/ai_bot/bot.rb
@@ -100,6 +100,7 @@ module DiscourseAi
llm_kwargs[:top_p] = persona.top_p if persona.top_p
needs_newlines = false
+ tools_ran = 0
while total_completions <= MAX_COMPLETIONS && ongoing_chain
tool_found = false
@@ -107,9 +108,10 @@ module DiscourseAi
result =
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
- tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
+ tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context)
+ tool = nil if tools_ran >= MAX_TOOLS
- if (tools.present?)
+ if (tool.present?)
tool_found = true
# a bit hacky, but extra newlines do no harm
if needs_newlines
@@ -117,10 +119,9 @@ module DiscourseAi
needs_newlines = false
end
- tools[0..MAX_TOOLS].each do |tool|
- process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
- ongoing_chain &&= tool.chain_next_response?
- end
+ process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
+ tools_ran += 1
+ ongoing_chain &&= tool.chain_next_response?
else
needs_newlines = true
update_blk.call(partial, cancel)
diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb
index 73224808..23cefe56 100644
--- a/lib/ai_bot/personas/persona.rb
+++ b/lib/ai_bot/personas/persona.rb
@@ -199,23 +199,17 @@ module DiscourseAi
prompt
end
- def find_tools(partial, bot_user:, llm:, context:)
- return [] if !partial.include?("")
-
- parsed_function = Nokogiri::HTML5.fragment(partial)
- parsed_function
- .css("invoke")
- .map do |fragment|
- tool_instance(fragment, bot_user: bot_user, llm: llm, context: context)
- end
- .compact
+ def find_tool(partial, bot_user:, llm:, context:)
+ return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall)
+ tool_instance(partial, bot_user: bot_user, llm: llm, context: context)
end
protected
- def tool_instance(parsed_function, bot_user:, llm:, context:)
- function_id = parsed_function.at("tool_id")&.text
- function_name = parsed_function.at("tool_name")&.text
+ def tool_instance(tool_call, bot_user:, llm:, context:)
+
+ function_id = tool_call.id
+ function_name = tool_call.name
return nil if function_name.nil?
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
@@ -224,7 +218,7 @@ module DiscourseAi
arguments = {}
tool_klass.signature[:parameters].to_a.each do |param|
name = param[:name]
- value = parsed_function.at(name)&.text
+ value = tool_call.parameters[name.to_sym]
if param[:type] == "array" && value
value =
diff --git a/lib/completions/anthropic_message_processor.rb b/lib/completions/anthropic_message_processor.rb
index 1d1516fa..413990d7 100644
--- a/lib/completions/anthropic_message_processor.rb
+++ b/lib/completions/anthropic_message_processor.rb
@@ -13,6 +13,11 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
def append(json)
@raw_json << json
end
+
+ def to_tool_call
+ parameters = JSON.parse(raw_json, symbolize_names: true)
+ DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
+ end
end
attr_reader :tool_calls, :input_tokens, :output_tokens
@@ -20,80 +25,70 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
def initialize(streaming_mode:)
@streaming_mode = streaming_mode
@tool_calls = []
+ @current_tool_call = nil
end
- def to_xml_tool_calls(function_buffer)
- return function_buffer if @tool_calls.blank?
+ def to_tool_calls
+ @tool_calls.map { |tool_call| tool_call.to_tool_call }
+ end
- function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
-
-
- TEXT
-
- @tool_calls.each do |tool_call|
- node =
- function_buffer.at("function_calls").add_child(
- Nokogiri::HTML5::DocumentFragment.parse(
- DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
- ),
- )
-
- params = JSON.parse(tool_call.raw_json, symbolize_names: true)
- xml =
- params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}#{name}>" }.join("\n")
-
- node.at("tool_name").content = tool_call.name
- node.at("tool_id").content = tool_call.id
- node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
+ def process_streamed_message(parsed)
+ result = nil
+ if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
+ tool_name = parsed.dig(:content_block, :name)
+ tool_id = parsed.dig(:content_block, :id)
+ if @current_tool_call
+ result = @current_tool_call.to_tool_call
+ end
+ @current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
+ elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
+ if @current_tool_call
+ tool_delta = parsed.dig(:delta, :partial_json).to_s
+ @current_tool_call.append(tool_delta)
+ else
+ result = parsed.dig(:delta, :text).to_s
+ end
+ elsif parsed[:type] == "content_block_stop"
+ if @current_tool_call
+ result = @current_tool_call.to_tool_call
+ @current_tool_call = nil
+ end
+ elsif parsed[:type] == "message_start"
+ @input_tokens = parsed.dig(:message, :usage, :input_tokens)
+ elsif parsed[:type] == "message_delta"
+ @output_tokens =
+ parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
+ elsif parsed[:type] == "message_stop"
+ # bedrock has this ...
+ if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
+ @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
+ @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
+ end
end
-
- function_buffer
+ result
end
def process_message(payload)
result = ""
parsed = JSON.parse(payload, symbolize_names: true)
- if @streaming_mode
- if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
- tool_name = parsed.dig(:content_block, :name)
- tool_id = parsed.dig(:content_block, :id)
- @tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
- elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
- if @tool_calls.present?
- result = parsed.dig(:delta, :partial_json).to_s
- @tool_calls.last.append(result)
- else
- result = parsed.dig(:delta, :text).to_s
+ content = parsed.dig(:content)
+ if content.is_a?(Array)
+ result =
+ content.map do |data|
+ if data[:type] == "tool_use"
+ call = AnthropicToolCall.new(data[:name], data[:id])
+ call.append(data[:input].to_json)
+ call.to_tool_call
+ else
+ data[:text]
+ end
end
- elsif parsed[:type] == "message_start"
- @input_tokens = parsed.dig(:message, :usage, :input_tokens)
- elsif parsed[:type] == "message_delta"
- @output_tokens =
- parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
- elsif parsed[:type] == "message_stop"
- # bedrock has this ...
- if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
- @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
- @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
- end
- end
- else
- content = parsed.dig(:content)
- if content.is_a?(Array)
- tool_call = content.find { |c| c[:type] == "tool_use" }
- if tool_call
- @tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
- @tool_calls.last.append(tool_call[:input].to_json)
- else
- result = parsed.dig(:content, 0, :text).to_s
- end
- end
-
- @input_tokens = parsed.dig(:usage, :input_tokens)
- @output_tokens = parsed.dig(:usage, :output_tokens)
end
+ @input_tokens = parsed.dig(:usage, :input_tokens)
+ @output_tokens = parsed.dig(:usage, :output_tokens)
+
result
end
end
diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb
index 44762b88..490ae9ed 100644
--- a/lib/completions/endpoints/anthropic.rb
+++ b/lib/completions/endpoints/anthropic.rb
@@ -90,15 +90,18 @@ module DiscourseAi
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
+ def decode_chunk(partial_data)
+ @decoder ||= JsonStreamDecoder.new
+ (@decoder << partial_data).map do |parsed_json|
+ processor.process_streamed_message(parsed_json)
+ end.compact
+ end
+
def processor
@processor ||=
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
end
- def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
- processor.to_xml_tool_calls(function_buffer) if !partial
- end
-
def extract_completion_from(response_raw)
processor.process_message(response_raw)
end
@@ -107,6 +110,10 @@ module DiscourseAi
processor.tool_calls.present?
end
+ def tool_calls
+ processor.to_tool_calls
+ end
+
def final_log_update(log)
log.request_tokens = processor.input_tokens if processor.input_tokens
log.response_tokens = processor.output_tokens if processor.output_tokens
diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb
index a0405b42..e525590f 100644
--- a/lib/completions/endpoints/base.rb
+++ b/lib/completions/endpoints/base.rb
@@ -126,22 +126,129 @@ module DiscourseAi
response_data = extract_completion_from(response_raw)
partials_raw = response_data.to_s
- if native_tool_support?
- if allow_tools && has_tool?(response_data)
- function_buffer = build_buffer # Nokogiri document
- function_buffer =
- add_to_function_buffer(function_buffer, payload: response_data)
- FunctionCallNormalizer.normalize_function_ids!(function_buffer)
+ if allow_tools && !native_tool_support?
+ response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
+ response_data = function_calls if function_calls.present?
+ end
- response_data = +function_buffer.at("function_calls").to_s
- response_data << "\n"
- end
- else
- if allow_tools
- response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
- response_data = function_calls if function_calls.present?
+ if response_data.is_a?(Array) && response_data.length == 1
+ response_data = response_data.first
+ end
+
+ return response_data
+ end
+
+ begin
+ cancelled = false
+ cancel = -> { cancelled = true }
+ if cancelled
+ http.finish
+ break
+ end
+
+ response.read_body do |chunk|
+ decode_chunk(chunk).each do |partial|
+ yield partial, cancel
end
end
+ rescue IOError, StandardError
+ raise if !cancelled
+ end
+ return response_data
+ ensure
+ if log
+ log.raw_response_payload = response_raw
+ final_log_update(log)
+
+ log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
+ log.save!
+
+ if Rails.env.development?
+ puts "#{self.class.name}: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
+ end
+ end
+ end
+ end
+ end
+
+ def perform_completionx!(
+ dialect,
+ user,
+ model_params = {},
+ feature_name: nil,
+ feature_context: nil,
+ &blk
+ )
+ allow_tools = dialect.prompt.has_tools?
+ model_params = normalize_model_params(model_params)
+ orig_blk = blk
+
+ @streaming_mode = block_given?
+ to_strip = xml_tags_to_strip(dialect)
+ @xml_stripper =
+ DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
+
+ if @streaming_mode && @xml_stripper
+ blk =
+ lambda do |partial, cancel|
+ partial = @xml_stripper << partial
+ orig_blk.call(partial, cancel) if partial
+ end
+ end
+
+ prompt = dialect.translate
+
+ FinalDestination::HTTP.start(
+ model_uri.host,
+ model_uri.port,
+ use_ssl: use_ssl?,
+ read_timeout: TIMEOUT,
+ open_timeout: TIMEOUT,
+ write_timeout: TIMEOUT,
+ ) do |http|
+ response_data = +""
+ response_raw = +""
+
+ # Needed to response token calculations. Cannot rely on response_data due to function buffering.
+ partials_raw = +""
+ request_body = prepare_payload(prompt, model_params, dialect).to_json
+
+ request = prepare_request(request_body)
+
+ http.request(request) do |response|
+ if response.code.to_i != 200
+ Rails.logger.error(
+ "#{self.class.name}: status: #{response.code.to_i} - body: #{response.body}",
+ )
+ raise CompletionFailed, response.body
+ end
+
+ log =
+ AiApiAuditLog.new(
+ provider_id: provider_id,
+ user_id: user&.id,
+ raw_request_payload: request_body,
+ request_tokens: prompt_size(prompt),
+ topic_id: dialect.prompt.topic_id,
+ post_id: dialect.prompt.post_id,
+ feature_name: feature_name,
+ language_model: llm_model.name,
+ feature_context: feature_context.present? ? feature_context.as_json : nil,
+ )
+
+ if !@streaming_mode
+ response_raw = response.read_body
+ response_data = extract_completion_from(response_raw)
+ partials_raw = response_data.to_s
+
+ if allow_tools && !native_tool_support?
+ response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
+ response_data = function_calls if function_calls.present?
+ end
+
+ if response_data.is_a?(Array) && response_data.length == 1
+ response_data = response_data.first
+ end
return response_data
end
@@ -277,8 +384,9 @@ module DiscourseAi
ensure
if log
log.raw_response_payload = response_raw
- log.response_tokens = tokenizer.size(partials_raw)
final_log_update(log)
+
+ log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
log.save!
if Rails.env.development?
diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb
index eaef21da..eb924b7f 100644
--- a/lib/completions/endpoints/canned_response.rb
+++ b/lib/completions/endpoints/canned_response.rb
@@ -45,12 +45,16 @@ module DiscourseAi
cancel_fn = lambda { cancelled = true }
# We buffer and return tool invocations in one go.
- if is_tool?(response)
- yield(response, cancel_fn)
- else
- response.each_char do |char|
- break if cancelled
- yield(char, cancel_fn)
+ as_array = response.is_a?(Array) ? response : [response]
+
+ as_array.each do |response|
+ if is_tool?(response)
+ yield(response, cancel_fn)
+ else
+ response.each_char do |char|
+ break if cancelled
+ yield(char, cancel_fn)
+ end
end
end
else
@@ -65,7 +69,7 @@ module DiscourseAi
private
def is_tool?(response)
- Nokogiri::HTML5.fragment(response).at("function_calls").present?
+ response.is_a?(DiscourseAi::Completions::ToolCall)
end
end
end
diff --git a/lib/completions/json_stream_decoder.rb b/lib/completions/json_stream_decoder.rb
new file mode 100644
index 00000000..b3d68fc0
--- /dev/null
+++ b/lib/completions/json_stream_decoder.rb
@@ -0,0 +1,47 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ # will work for anthropic and open ai compatible
+ class JsonStreamDecoder
+ attr_reader :buffer
+
+ LINE_REGEX = /data: ({.*})\s*$/
+
+ def initialize(symbolize_keys: true)
+ @symbolize_keys = symbolize_keys
+ @buffer = +""
+ end
+
+ def <<(raw)
+ @buffer << raw.to_s
+ rval = []
+
+ split = @buffer.scan(/.*\n?/)
+ split.pop if split.last.blank?
+
+ @buffer = +(split.pop.to_s)
+
+ split.each do |line|
+ matches = line.match(LINE_REGEX)
+ next if !matches
+ rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
+ end
+
+ if @buffer.present?
+ matches = @buffer.match(LINE_REGEX)
+ if matches
+ begin
+ rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
+ @buffer = +""
+ rescue JSON::ParserError
+ # maybe it is a partial line
+ end
+ end
+ end
+
+ rval
+ end
+ end
+ end
+end
diff --git a/lib/completions/tool_call.rb b/lib/completions/tool_call.rb
new file mode 100644
index 00000000..a1daf6f4
--- /dev/null
+++ b/lib/completions/tool_call.rb
@@ -0,0 +1,24 @@
+# frozen_string_literal: true
+
+module DiscourseAi
+ module Completions
+ class ToolCall
+ attr_reader :id, :name, :parameters
+
+ def initialize(id:, name:, parameters: nil)
+ @id = id
+ @name = name
+ @parameters = parameters || {}
+ end
+
+ def ==(other)
+ id == other.id && name == other.name && parameters == other.parameters
+ end
+
+ def to_s
+ "#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)"
+ end
+
+ end
+ end
+end
diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb
index 94d5d655..2ae488b5 100644
--- a/spec/lib/completions/endpoints/anthropic_spec.rb
+++ b/spec/lib/completions/endpoints/anthropic_spec.rb
@@ -104,7 +104,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
data: {"type":"message_stop"}
STRING
- result = +""
+ result = []
body = body.scan(/.*\n/)
EndpointMock.with_chunk_array_support do
stub_request(:post, url).to_return(status: 200, body: body)
@@ -114,18 +114,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
end
end
- expected = (<<~TEXT).strip
-
-
- search
- s<a>m sam
- general
- toolu_01DjrShFRRHp9SnHYRFRc53F
-
-
- TEXT
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "search",
+ id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
+ parameters: {
+ search_query: "sm sam",
+ category: "general",
+ },
+ )
- expect(result.strip).to eq(expected)
+ expect(result).to eq([tool_call])
end
it "can stream a response" do
@@ -242,17 +241,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
result = llm.generate(prompt, user: Discourse.system_user)
- expected = <<~TEXT.strip
-
-
- calculate
- 2758975 + 21.11
- toolu_012kBdhG4eHaV68W56p4N94h
-
-
- TEXT
+ tool_call =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "calculate",
+ id: "toolu_012kBdhG4eHaV68W56p4N94h",
+ parameters: {
+ expression: "2758975 + 21.11",
+ },
+ )
- expect(result.strip).to eq(expected)
+ expect(result).to eq(["Here is the calculation:", tool_call])
+
+ log = AiApiAuditLog.order(:id).last
+ expect(log.request_tokens).to eq(345)
+ expect(log.response_tokens).to eq(65)
end
it "can send images via a completion prompt" do
diff --git a/spec/lib/completions/json_stream_decoder_spec.rb b/spec/lib/completions/json_stream_decoder_spec.rb
new file mode 100644
index 00000000..dd0fb139
--- /dev/null
+++ b/spec/lib/completions/json_stream_decoder_spec.rb
@@ -0,0 +1,51 @@
+# frozen_string_literal: true
+
+describe DiscourseAi::Completions::JsonStreamDecoder do
+ let(:decoder) { DiscourseAi::Completions::JsonStreamDecoder.new }
+
+ it "should be able to parse simple messages" do
+ result = decoder << "data: #{{ hello: "world" }.to_json}"
+ expect(result).to eq([{ hello: "world" }])
+ end
+
+ it "should handle anthropic mixed stlye streams" do
+ stream = (<<~TEXT).split("|")
+ event: |message_start|
+ data: |{"hel|lo": "world"}|
+
+ event: |message_start
+ data: {"foo": "bar"}
+
+ event: |message_start
+ data: {"ba|z": "qux"|}
+
+ [DONE]
+ TEXT
+
+ results = []
+ stream.each do |chunk|
+ results << (decoder << chunk)
+ end
+
+ expect(results.flatten.compact).to eq([{ "hello": "world" }, { "foo": "bar" }, { "baz": "qux" }])
+ end
+
+ it "should be able to handle complex overlaps" do
+ stream = (<<~TEXT).split("|")
+ data: |{"hel|lo": "world"}
+
+ data: {"foo": "bar"}
+
+ data: {"ba|z": "qux"|}
+
+ [DONE]
+ TEXT
+
+ results = []
+ stream.each do |chunk|
+ results << (decoder << chunk)
+ end
+
+ expect(results.flatten.compact).to eq([{ "hello": "world" }, { "foo": "bar" }, { "baz": "qux" }])
+ end
+end
diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb
index 9c98a08a..641734e6 100644
--- a/spec/lib/modules/ai_bot/playground_spec.rb
+++ b/spec/lib/modules/ai_bot/playground_spec.rb
@@ -220,6 +220,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
before do
Jobs.run_immediately!
SiteSetting.ai_bot_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}"
+ SiteSetting.ai_embeddings_enabled = false
end
fab!(:persona) do
@@ -452,10 +453,25 @@ RSpec.describe DiscourseAi::AiBot::Playground do
it "can run tools" do
persona.update!(tools: ["Time"])
- responses = [
- "timetimeBuenos Aires",
- "The time is 2023-12-14 17:24:00 -0300",
- ]
+ tool_call1 =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "time",
+ id: "time",
+ parameters: {
+ timezone: "Buenos Aires",
+ },
+ )
+
+ tool_call2 =
+ DiscourseAi::Completions::ToolCall.new(
+ name: "time",
+ id: "time",
+ parameters: {
+ timezone: "Sydney",
+ },
+ )
+
+ responses = [[tool_call1, tool_call2], "The time is 2023-12-14 17:24:00 -0300"]
message =
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do
@@ -470,7 +486,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
# it also needs to have tool details now set on message
prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id)
- expect(prompt.custom_prompt.length).to eq(3)
+
+ expect(prompt.custom_prompt.length).to eq(5)
# TODO in chat I am mixed on including this in the context, but I guess maybe?
# thinking about this
diff --git a/spec/requests/ai_bot/bot_controller_spec.rb b/spec/requests/ai_bot/bot_controller_spec.rb
index e7430185..130d33eb 100644
--- a/spec/requests/ai_bot/bot_controller_spec.rb
+++ b/spec/requests/ai_bot/bot_controller_spec.rb
@@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiBot::BotController do
fab!(:user)
fab!(:pm_topic) { Fabricate(:private_message_topic) }
fab!(:pm_post) { Fabricate(:post, topic: pm_topic) }
+ fab!(:pm_post2) { Fabricate(:post, topic: pm_topic) }
+ fab!(:pm_post3) { Fabricate(:post, topic: pm_topic) }
before { sign_in(user) }
@@ -22,7 +24,17 @@ RSpec.describe DiscourseAi::AiBot::BotController do
user = pm_topic.topic_allowed_users.first.user
sign_in(user)
- AiApiAuditLog.create!(
+
+ log1 = AiApiAuditLog.create!(
+ provider_id: 1,
+ topic_id: pm_topic.id,
+ raw_request_payload: "request",
+ raw_response_payload: "response",
+ request_tokens: 1,
+ response_tokens: 2,
+ )
+
+ log2 = AiApiAuditLog.create!(
post_id: pm_post.id,
provider_id: 1,
topic_id: pm_topic.id,
@@ -32,24 +44,43 @@ RSpec.describe DiscourseAi::AiBot::BotController do
response_tokens: 2,
)
+ log3 = AiApiAuditLog.create!(
+ post_id: pm_post2.id,
+ provider_id: 1,
+ topic_id: pm_topic.id,
+ raw_request_payload: "request",
+ raw_response_payload: "response",
+ request_tokens: 1,
+ response_tokens: 2,
+ )
+
Group.refresh_automatic_groups!
SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s
get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info"
expect(response.status).to eq(200)
+ expect(response.parsed_body["id"]).to eq(log2.id)
+ expect(response.parsed_body["next_log_id"]).to eq(log3.id)
+ expect(response.parsed_body["prev_log_id"]).to eq(log1.id)
+ expect(response.parsed_body["topic_id"]).to eq(pm_topic.id)
+
expect(response.parsed_body["request_tokens"]).to eq(1)
expect(response.parsed_body["response_tokens"]).to eq(2)
expect(response.parsed_body["raw_request_payload"]).to eq("request")
expect(response.parsed_body["raw_response_payload"]).to eq("response")
- post2 = Fabricate(:post, topic: pm_topic)
# return previous post if current has no debug info
- get "/discourse-ai/ai-bot/post/#{post2.id}/show-debug-info"
+ get "/discourse-ai/ai-bot/post/#{pm_post3.id}/show-debug-info"
expect(response.status).to eq(200)
expect(response.parsed_body["request_tokens"]).to eq(1)
expect(response.parsed_body["response_tokens"]).to eq(2)
+
+ # can return debug info by id as well
+ get "/discourse-ai/ai-bot/show-debug-info/#{log1.id}"
+ expect(response.status).to eq(200)
+ expect(response.parsed_body["id"]).to eq(log1.id)
end
end