diff --git a/app/controllers/discourse_ai/ai_bot/bot_controller.rb b/app/controllers/discourse_ai/ai_bot/bot_controller.rb index e5d5bcf0..37570380 100644 --- a/app/controllers/discourse_ai/ai_bot/bot_controller.rb +++ b/app/controllers/discourse_ai/ai_bot/bot_controller.rb @@ -6,6 +6,16 @@ module DiscourseAi requires_plugin ::DiscourseAi::PLUGIN_NAME requires_login + def show_debug_info_by_id + log = AiApiAuditLog.find(params[:id]) + if !log.topic + raise Discourse::NotFound + end + + guardian.ensure_can_debug_ai_bot_conversation!(log.topic) + render json: AiApiAuditLogSerializer.new(log, root: false), status: 200 + end + def show_debug_info post = Post.find(params[:post_id]) guardian.ensure_can_debug_ai_bot_conversation!(post) diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index 2fa9f5c3..2fa0a214 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -14,6 +14,14 @@ class AiApiAuditLog < ActiveRecord::Base Ollama = 7 SambaNova = 8 end + + def next_log_id + self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first + end + + def prev_log_id + self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first + end end # == Schema Information diff --git a/app/serializers/ai_api_audit_log_serializer.rb b/app/serializers/ai_api_audit_log_serializer.rb index 0c438a7b..eeb3843a 100644 --- a/app/serializers/ai_api_audit_log_serializer.rb +++ b/app/serializers/ai_api_audit_log_serializer.rb @@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer :post_id, :feature_name, :language_model, - :created_at + :created_at, + :prev_log_id, + :next_log_id end diff --git a/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs b/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs index 5d0cdf69..e959c123 100644 --- a/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs +++ b/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs @@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template"; import DButton from "discourse/components/d-button"; import DModal from "discourse/components/d-modal"; import { ajax } from "discourse/lib/ajax"; +import { popupAjaxError } from "discourse/lib/ajax-error"; import { clipboardCopy, escapeExpression } from "discourse/lib/utilities"; import i18n from "discourse-common/helpers/i18n"; import discourseLater from "discourse-common/lib/later"; @@ -63,6 +64,28 @@ export default class DebugAiModal extends Component { this.copy(this.info.raw_response_payload); } + async loadLog(logId) { + try { + await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then( + (result) => { + this.info = result; + } + ); + } catch (e) { + popupAjaxError(e); + } + } + + @action + prevLog() { + this.loadLog(this.info.prev_log_id); + } + + @action + nextLog() { + this.loadLog(this.info.next_log_id); + } + copy(text) { clipboardCopy(text); this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared"); @@ -77,6 +100,8 @@ export default class DebugAiModal extends Component { `/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json` ).then((result) => { this.info = result; + }).catch((e) => { + popupAjaxError(e); }); } @@ -147,6 +172,22 @@ export default class DebugAiModal extends Component { @action={{this.copyResponse}} @label="discourse_ai.ai_bot.debug_ai_modal.copy_response" /> + {{#if this.info.prev_log_id}} + + {{/if}} + {{#if this.info.next_log_id}} + + {{/if}} {{this.justCopiedText}} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 2d517946..82898c91 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -415,6 +415,8 @@ en: response_tokens: "Response tokens:" request: "Request" response: "Response" + next_log: "Next" + previous_log: "Previous" share_full_topic_modal: title: "Share Conversation Publicly" diff --git a/config/routes.rb b/config/routes.rb index a5c009ff..322e67ce 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -22,6 +22,7 @@ DiscourseAi::Engine.routes.draw do scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do get "bot-username" => "bot#show_bot_username" get "post/:post_id/show-debug-info" => "bot#show_debug_info" + get "show-debug-info/:id" => "bot#show_debug_info_by_id" post "post/:post_id/stop-streaming" => "bot#stop_streaming_response" end diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 834ae059..0b05b567 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -100,6 +100,7 @@ module DiscourseAi llm_kwargs[:top_p] = persona.top_p if persona.top_p needs_newlines = false + tools_ran = 0 while total_completions <= MAX_COMPLETIONS && ongoing_chain tool_found = false @@ -107,9 +108,10 @@ module DiscourseAi result = llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel| - tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context) + tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context) + tool = nil if tools_ran >= MAX_TOOLS - if (tools.present?) + if (tool.present?) tool_found = true # a bit hacky, but extra newlines do no harm if needs_newlines @@ -117,10 +119,9 @@ module DiscourseAi needs_newlines = false end - tools[0..MAX_TOOLS].each do |tool| - process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) - ongoing_chain &&= tool.chain_next_response? - end + process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) + tools_ran += 1 + ongoing_chain &&= tool.chain_next_response? else needs_newlines = true update_blk.call(partial, cancel) diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 73224808..23cefe56 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -199,23 +199,17 @@ module DiscourseAi prompt end - def find_tools(partial, bot_user:, llm:, context:) - return [] if !partial.include?("") - - parsed_function = Nokogiri::HTML5.fragment(partial) - parsed_function - .css("invoke") - .map do |fragment| - tool_instance(fragment, bot_user: bot_user, llm: llm, context: context) - end - .compact + def find_tool(partial, bot_user:, llm:, context:) + return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall) + tool_instance(partial, bot_user: bot_user, llm: llm, context: context) end protected - def tool_instance(parsed_function, bot_user:, llm:, context:) - function_id = parsed_function.at("tool_id")&.text - function_name = parsed_function.at("tool_name")&.text + def tool_instance(tool_call, bot_user:, llm:, context:) + + function_id = tool_call.id + function_name = tool_call.name return nil if function_name.nil? tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name } @@ -224,7 +218,7 @@ module DiscourseAi arguments = {} tool_klass.signature[:parameters].to_a.each do |param| name = param[:name] - value = parsed_function.at(name)&.text + value = tool_call.parameters[name.to_sym] if param[:type] == "array" && value value = diff --git a/lib/completions/anthropic_message_processor.rb b/lib/completions/anthropic_message_processor.rb index 1d1516fa..413990d7 100644 --- a/lib/completions/anthropic_message_processor.rb +++ b/lib/completions/anthropic_message_processor.rb @@ -13,6 +13,11 @@ class DiscourseAi::Completions::AnthropicMessageProcessor def append(json) @raw_json << json end + + def to_tool_call + parameters = JSON.parse(raw_json, symbolize_names: true) + DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters) + end end attr_reader :tool_calls, :input_tokens, :output_tokens @@ -20,80 +25,70 @@ class DiscourseAi::Completions::AnthropicMessageProcessor def initialize(streaming_mode:) @streaming_mode = streaming_mode @tool_calls = [] + @current_tool_call = nil end - def to_xml_tool_calls(function_buffer) - return function_buffer if @tool_calls.blank? + def to_tool_calls + @tool_calls.map { |tool_call| tool_call.to_tool_call } + end - function_buffer = Nokogiri::HTML5.fragment(<<~TEXT) - - - TEXT - - @tool_calls.each do |tool_call| - node = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse( - DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n", - ), - ) - - params = JSON.parse(tool_call.raw_json, symbolize_names: true) - xml = - params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}" }.join("\n") - - node.at("tool_name").content = tool_call.name - node.at("tool_id").content = tool_call.id - node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present? + def process_streamed_message(parsed) + result = nil + if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use" + tool_name = parsed.dig(:content_block, :name) + tool_id = parsed.dig(:content_block, :id) + if @current_tool_call + result = @current_tool_call.to_tool_call + end + @current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name + elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" + if @current_tool_call + tool_delta = parsed.dig(:delta, :partial_json).to_s + @current_tool_call.append(tool_delta) + else + result = parsed.dig(:delta, :text).to_s + end + elsif parsed[:type] == "content_block_stop" + if @current_tool_call + result = @current_tool_call.to_tool_call + @current_tool_call = nil + end + elsif parsed[:type] == "message_start" + @input_tokens = parsed.dig(:message, :usage, :input_tokens) + elsif parsed[:type] == "message_delta" + @output_tokens = + parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens) + elsif parsed[:type] == "message_stop" + # bedrock has this ... + if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym) + @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens + @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens + end end - - function_buffer + result end def process_message(payload) result = "" parsed = JSON.parse(payload, symbolize_names: true) - if @streaming_mode - if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use" - tool_name = parsed.dig(:content_block, :name) - tool_id = parsed.dig(:content_block, :id) - @tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name - elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" - if @tool_calls.present? - result = parsed.dig(:delta, :partial_json).to_s - @tool_calls.last.append(result) - else - result = parsed.dig(:delta, :text).to_s + content = parsed.dig(:content) + if content.is_a?(Array) + result = + content.map do |data| + if data[:type] == "tool_use" + call = AnthropicToolCall.new(data[:name], data[:id]) + call.append(data[:input].to_json) + call.to_tool_call + else + data[:text] + end end - elsif parsed[:type] == "message_start" - @input_tokens = parsed.dig(:message, :usage, :input_tokens) - elsif parsed[:type] == "message_delta" - @output_tokens = - parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens) - elsif parsed[:type] == "message_stop" - # bedrock has this ... - if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym) - @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens - @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens - end - end - else - content = parsed.dig(:content) - if content.is_a?(Array) - tool_call = content.find { |c| c[:type] == "tool_use" } - if tool_call - @tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id]) - @tool_calls.last.append(tool_call[:input].to_json) - else - result = parsed.dig(:content, 0, :text).to_s - end - end - - @input_tokens = parsed.dig(:usage, :input_tokens) - @output_tokens = parsed.dig(:usage, :output_tokens) end + @input_tokens = parsed.dig(:usage, :input_tokens) + @output_tokens = parsed.dig(:usage, :output_tokens) + result end end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 44762b88..490ae9ed 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -90,15 +90,18 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end + def decode_chunk(partial_data) + @decoder ||= JsonStreamDecoder.new + (@decoder << partial_data).map do |parsed_json| + processor.process_streamed_message(parsed_json) + end.compact + end + def processor @processor ||= DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) end - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - processor.to_xml_tool_calls(function_buffer) if !partial - end - def extract_completion_from(response_raw) processor.process_message(response_raw) end @@ -107,6 +110,10 @@ module DiscourseAi processor.tool_calls.present? end + def tool_calls + processor.to_tool_calls + end + def final_log_update(log) log.request_tokens = processor.input_tokens if processor.input_tokens log.response_tokens = processor.output_tokens if processor.output_tokens diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index a0405b42..e525590f 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -126,22 +126,129 @@ module DiscourseAi response_data = extract_completion_from(response_raw) partials_raw = response_data.to_s - if native_tool_support? - if allow_tools && has_tool?(response_data) - function_buffer = build_buffer # Nokogiri document - function_buffer = - add_to_function_buffer(function_buffer, payload: response_data) - FunctionCallNormalizer.normalize_function_ids!(function_buffer) + if allow_tools && !native_tool_support? + response_data, function_calls = FunctionCallNormalizer.normalize(response_data) + response_data = function_calls if function_calls.present? + end - response_data = +function_buffer.at("function_calls").to_s - response_data << "\n" - end - else - if allow_tools - response_data, function_calls = FunctionCallNormalizer.normalize(response_data) - response_data = function_calls if function_calls.present? + if response_data.is_a?(Array) && response_data.length == 1 + response_data = response_data.first + end + + return response_data + end + + begin + cancelled = false + cancel = -> { cancelled = true } + if cancelled + http.finish + break + end + + response.read_body do |chunk| + decode_chunk(chunk).each do |partial| + yield partial, cancel end end + rescue IOError, StandardError + raise if !cancelled + end + return response_data + ensure + if log + log.raw_response_payload = response_raw + final_log_update(log) + + log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank? + log.save! + + if Rails.env.development? + puts "#{self.class.name}: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}" + end + end + end + end + end + + def perform_completionx!( + dialect, + user, + model_params = {}, + feature_name: nil, + feature_context: nil, + &blk + ) + allow_tools = dialect.prompt.has_tools? + model_params = normalize_model_params(model_params) + orig_blk = blk + + @streaming_mode = block_given? + to_strip = xml_tags_to_strip(dialect) + @xml_stripper = + DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present? + + if @streaming_mode && @xml_stripper + blk = + lambda do |partial, cancel| + partial = @xml_stripper << partial + orig_blk.call(partial, cancel) if partial + end + end + + prompt = dialect.translate + + FinalDestination::HTTP.start( + model_uri.host, + model_uri.port, + use_ssl: use_ssl?, + read_timeout: TIMEOUT, + open_timeout: TIMEOUT, + write_timeout: TIMEOUT, + ) do |http| + response_data = +"" + response_raw = +"" + + # Needed to response token calculations. Cannot rely on response_data due to function buffering. + partials_raw = +"" + request_body = prepare_payload(prompt, model_params, dialect).to_json + + request = prepare_request(request_body) + + http.request(request) do |response| + if response.code.to_i != 200 + Rails.logger.error( + "#{self.class.name}: status: #{response.code.to_i} - body: #{response.body}", + ) + raise CompletionFailed, response.body + end + + log = + AiApiAuditLog.new( + provider_id: provider_id, + user_id: user&.id, + raw_request_payload: request_body, + request_tokens: prompt_size(prompt), + topic_id: dialect.prompt.topic_id, + post_id: dialect.prompt.post_id, + feature_name: feature_name, + language_model: llm_model.name, + feature_context: feature_context.present? ? feature_context.as_json : nil, + ) + + if !@streaming_mode + response_raw = response.read_body + response_data = extract_completion_from(response_raw) + partials_raw = response_data.to_s + + if allow_tools && !native_tool_support? + response_data, function_calls = FunctionCallNormalizer.normalize(response_data) + response_data = function_calls if function_calls.present? + end + + if response_data.is_a?(Array) && response_data.length == 1 + response_data = response_data.first + end return response_data end @@ -277,8 +384,9 @@ module DiscourseAi ensure if log log.raw_response_payload = response_raw - log.response_tokens = tokenizer.size(partials_raw) final_log_update(log) + + log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank? log.save! if Rails.env.development? diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index eaef21da..eb924b7f 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -45,12 +45,16 @@ module DiscourseAi cancel_fn = lambda { cancelled = true } # We buffer and return tool invocations in one go. - if is_tool?(response) - yield(response, cancel_fn) - else - response.each_char do |char| - break if cancelled - yield(char, cancel_fn) + as_array = response.is_a?(Array) ? response : [response] + + as_array.each do |response| + if is_tool?(response) + yield(response, cancel_fn) + else + response.each_char do |char| + break if cancelled + yield(char, cancel_fn) + end end end else @@ -65,7 +69,7 @@ module DiscourseAi private def is_tool?(response) - Nokogiri::HTML5.fragment(response).at("function_calls").present? + response.is_a?(DiscourseAi::Completions::ToolCall) end end end diff --git a/lib/completions/json_stream_decoder.rb b/lib/completions/json_stream_decoder.rb new file mode 100644 index 00000000..b3d68fc0 --- /dev/null +++ b/lib/completions/json_stream_decoder.rb @@ -0,0 +1,47 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + # will work for anthropic and open ai compatible + class JsonStreamDecoder + attr_reader :buffer + + LINE_REGEX = /data: ({.*})\s*$/ + + def initialize(symbolize_keys: true) + @symbolize_keys = symbolize_keys + @buffer = +"" + end + + def <<(raw) + @buffer << raw.to_s + rval = [] + + split = @buffer.scan(/.*\n?/) + split.pop if split.last.blank? + + @buffer = +(split.pop.to_s) + + split.each do |line| + matches = line.match(LINE_REGEX) + next if !matches + rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys) + end + + if @buffer.present? + matches = @buffer.match(LINE_REGEX) + if matches + begin + rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys) + @buffer = +"" + rescue JSON::ParserError + # maybe it is a partial line + end + end + end + + rval + end + end + end +end diff --git a/lib/completions/tool_call.rb b/lib/completions/tool_call.rb new file mode 100644 index 00000000..a1daf6f4 --- /dev/null +++ b/lib/completions/tool_call.rb @@ -0,0 +1,24 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + class ToolCall + attr_reader :id, :name, :parameters + + def initialize(id:, name:, parameters: nil) + @id = id + @name = name + @parameters = parameters || {} + end + + def ==(other) + id == other.id && name == other.name && parameters == other.parameters + end + + def to_s + "#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)" + end + + end + end +end diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index 94d5d655..2ae488b5 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -104,7 +104,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do data: {"type":"message_stop"} STRING - result = +"" + result = [] body = body.scan(/.*\n/) EndpointMock.with_chunk_array_support do stub_request(:post, url).to_return(status: 200, body: body) @@ -114,18 +114,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do end end - expected = (<<~TEXT).strip - - - search - s<a>m sam - general - toolu_01DjrShFRRHp9SnHYRFRc53F - - - TEXT + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "toolu_01DjrShFRRHp9SnHYRFRc53F", + parameters: { + search_query: "sm sam", + category: "general", + }, + ) - expect(result.strip).to eq(expected) + expect(result).to eq([tool_call]) end it "can stream a response" do @@ -242,17 +241,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do result = llm.generate(prompt, user: Discourse.system_user) - expected = <<~TEXT.strip - - - calculate - 2758975 + 21.11 - toolu_012kBdhG4eHaV68W56p4N94h - - - TEXT + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "calculate", + id: "toolu_012kBdhG4eHaV68W56p4N94h", + parameters: { + expression: "2758975 + 21.11", + }, + ) - expect(result.strip).to eq(expected) + expect(result).to eq(["Here is the calculation:", tool_call]) + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(345) + expect(log.response_tokens).to eq(65) end it "can send images via a completion prompt" do diff --git a/spec/lib/completions/json_stream_decoder_spec.rb b/spec/lib/completions/json_stream_decoder_spec.rb new file mode 100644 index 00000000..dd0fb139 --- /dev/null +++ b/spec/lib/completions/json_stream_decoder_spec.rb @@ -0,0 +1,51 @@ +# frozen_string_literal: true + +describe DiscourseAi::Completions::JsonStreamDecoder do + let(:decoder) { DiscourseAi::Completions::JsonStreamDecoder.new } + + it "should be able to parse simple messages" do + result = decoder << "data: #{{ hello: "world" }.to_json}" + expect(result).to eq([{ hello: "world" }]) + end + + it "should handle anthropic mixed stlye streams" do + stream = (<<~TEXT).split("|") + event: |message_start| + data: |{"hel|lo": "world"}| + + event: |message_start + data: {"foo": "bar"} + + event: |message_start + data: {"ba|z": "qux"|} + + [DONE] + TEXT + + results = [] + stream.each do |chunk| + results << (decoder << chunk) + end + + expect(results.flatten.compact).to eq([{ "hello": "world" }, { "foo": "bar" }, { "baz": "qux" }]) + end + + it "should be able to handle complex overlaps" do + stream = (<<~TEXT).split("|") + data: |{"hel|lo": "world"} + + data: {"foo": "bar"} + + data: {"ba|z": "qux"|} + + [DONE] + TEXT + + results = [] + stream.each do |chunk| + results << (decoder << chunk) + end + + expect(results.flatten.compact).to eq([{ "hello": "world" }, { "foo": "bar" }, { "baz": "qux" }]) + end +end diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 9c98a08a..641734e6 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -220,6 +220,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do before do Jobs.run_immediately! SiteSetting.ai_bot_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}" + SiteSetting.ai_embeddings_enabled = false end fab!(:persona) do @@ -452,10 +453,25 @@ RSpec.describe DiscourseAi::AiBot::Playground do it "can run tools" do persona.update!(tools: ["Time"]) - responses = [ - "timetimeBuenos Aires", - "The time is 2023-12-14 17:24:00 -0300", - ] + tool_call1 = + DiscourseAi::Completions::ToolCall.new( + name: "time", + id: "time", + parameters: { + timezone: "Buenos Aires", + }, + ) + + tool_call2 = + DiscourseAi::Completions::ToolCall.new( + name: "time", + id: "time", + parameters: { + timezone: "Sydney", + }, + ) + + responses = [[tool_call1, tool_call2], "The time is 2023-12-14 17:24:00 -0300"] message = DiscourseAi::Completions::Llm.with_prepared_responses(responses) do @@ -470,7 +486,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do # it also needs to have tool details now set on message prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id) - expect(prompt.custom_prompt.length).to eq(3) + + expect(prompt.custom_prompt.length).to eq(5) # TODO in chat I am mixed on including this in the context, but I guess maybe? # thinking about this diff --git a/spec/requests/ai_bot/bot_controller_spec.rb b/spec/requests/ai_bot/bot_controller_spec.rb index e7430185..130d33eb 100644 --- a/spec/requests/ai_bot/bot_controller_spec.rb +++ b/spec/requests/ai_bot/bot_controller_spec.rb @@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiBot::BotController do fab!(:user) fab!(:pm_topic) { Fabricate(:private_message_topic) } fab!(:pm_post) { Fabricate(:post, topic: pm_topic) } + fab!(:pm_post2) { Fabricate(:post, topic: pm_topic) } + fab!(:pm_post3) { Fabricate(:post, topic: pm_topic) } before { sign_in(user) } @@ -22,7 +24,17 @@ RSpec.describe DiscourseAi::AiBot::BotController do user = pm_topic.topic_allowed_users.first.user sign_in(user) - AiApiAuditLog.create!( + + log1 = AiApiAuditLog.create!( + provider_id: 1, + topic_id: pm_topic.id, + raw_request_payload: "request", + raw_response_payload: "response", + request_tokens: 1, + response_tokens: 2, + ) + + log2 = AiApiAuditLog.create!( post_id: pm_post.id, provider_id: 1, topic_id: pm_topic.id, @@ -32,24 +44,43 @@ RSpec.describe DiscourseAi::AiBot::BotController do response_tokens: 2, ) + log3 = AiApiAuditLog.create!( + post_id: pm_post2.id, + provider_id: 1, + topic_id: pm_topic.id, + raw_request_payload: "request", + raw_response_payload: "response", + request_tokens: 1, + response_tokens: 2, + ) + Group.refresh_automatic_groups! SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info" expect(response.status).to eq(200) + expect(response.parsed_body["id"]).to eq(log2.id) + expect(response.parsed_body["next_log_id"]).to eq(log3.id) + expect(response.parsed_body["prev_log_id"]).to eq(log1.id) + expect(response.parsed_body["topic_id"]).to eq(pm_topic.id) + expect(response.parsed_body["request_tokens"]).to eq(1) expect(response.parsed_body["response_tokens"]).to eq(2) expect(response.parsed_body["raw_request_payload"]).to eq("request") expect(response.parsed_body["raw_response_payload"]).to eq("response") - post2 = Fabricate(:post, topic: pm_topic) # return previous post if current has no debug info - get "/discourse-ai/ai-bot/post/#{post2.id}/show-debug-info" + get "/discourse-ai/ai-bot/post/#{pm_post3.id}/show-debug-info" expect(response.status).to eq(200) expect(response.parsed_body["request_tokens"]).to eq(1) expect(response.parsed_body["response_tokens"]).to eq(2) + + # can return debug info by id as well + get "/discourse-ai/ai-bot/show-debug-info/#{log1.id}" + expect(response.status).to eq(200) + expect(response.parsed_body["id"]).to eq(log1.id) end end