From e817b7dc110ddc0749a9968e2a90a049e3d44137 Mon Sep 17 00:00:00 2001 From: Sam Date: Tue, 12 Nov 2024 08:14:30 +1100 Subject: [PATCH] FEATURE: improve tool support (#904) This re-implements tool support in DiscourseAi::Completions::Llm #generate Previously tool support was always returned via XML and it would be the responsibility of the caller to parse XML New implementation has the endpoints return ToolCall objects. Additionally this simplifies the Llm endpoint interface and gives it more clarity. Llms must implement decode, decode_chunk (for streaming) It is the implementers responsibility to figure out how to decode chunks, base no longer implements. To make this easy we ship a flexible json decoder which is easy to wire up. Also (new) Better debugging for PMs, we now have a next / previous button to see all the Llm messages associated with a PM Token accounting is fixed for vllm (we were not correctly counting tokens) --- .../discourse_ai/ai_bot/bot_controller.rb | 8 + app/models/ai_api_audit_log.rb | 8 + .../ai_api_audit_log_serializer.rb | 4 +- .../components/modal/debug-ai-modal.gjs | 51 ++- config/locales/client.en.yml | 2 + config/routes.rb | 1 + lib/ai_bot/bot.rb | 19 +- lib/ai_bot/personas/persona.rb | 21 +- .../anthropic_message_processor.rb | 118 ++++--- lib/completions/dialects/ollama.rb | 19 +- lib/completions/endpoints/anthropic.rb | 35 +- lib/completions/endpoints/aws_bedrock.rb | 48 ++- lib/completions/endpoints/base.rb | 316 +++++++----------- lib/completions/endpoints/canned_response.rb | 22 +- lib/completions/endpoints/cohere.rb | 89 +++-- lib/completions/endpoints/fake.rb | 46 +-- lib/completions/endpoints/gemini.rb | 116 ++++--- lib/completions/endpoints/hugging_face.rb | 36 +- lib/completions/endpoints/ollama.rb | 88 ++--- lib/completions/endpoints/open_ai.rb | 96 +----- lib/completions/endpoints/samba_nova.rb | 40 ++- lib/completions/endpoints/vllm.rb | 51 ++- lib/completions/function_call_normalizer.rb | 113 ------- lib/completions/json_stream_decoder.rb | 48 +++ lib/completions/open_ai_message_processor.rb | 103 ++++++ lib/completions/tool_call.rb | 29 ++ lib/completions/xml_tool_processor.rb | 124 +++++++ .../completions/endpoints/anthropic_spec.rb | 48 +-- .../completions/endpoints/aws_bedrock_spec.rb | 49 ++- spec/lib/completions/endpoints/cohere_spec.rb | 25 +- .../endpoints/endpoint_compliance.rb | 29 +- spec/lib/completions/endpoints/gemini_spec.rb | 87 ++++- spec/lib/completions/endpoints/ollama_spec.rb | 2 +- .../lib/completions/endpoints/open_ai_spec.rb | 115 +++---- .../completions/endpoints/samba_nova_spec.rb | 9 +- spec/lib/completions/endpoints/vllm_spec.rb | 128 ++++++- .../function_call_normalizer_spec.rb | 182 ---------- .../completions/json_stream_decoder_spec.rb | 47 +++ .../completions/xml_tool_processor_spec.rb | 188 +++++++++++ .../modules/ai_bot/personas/persona_spec.rb | 193 ++++++----- spec/lib/modules/ai_bot/playground_spec.rb | 157 +++++---- .../admin/ai_personas_controller_spec.rb | 12 +- spec/requests/ai_bot/bot_controller_spec.rb | 56 +++- 43 files changed, 1685 insertions(+), 1293 deletions(-) delete mode 100644 lib/completions/function_call_normalizer.rb create mode 100644 lib/completions/json_stream_decoder.rb create mode 100644 lib/completions/open_ai_message_processor.rb create mode 100644 lib/completions/tool_call.rb create mode 100644 lib/completions/xml_tool_processor.rb delete mode 100644 spec/lib/completions/function_call_normalizer_spec.rb create mode 100644 spec/lib/completions/json_stream_decoder_spec.rb create mode 100644 spec/lib/completions/xml_tool_processor_spec.rb diff --git a/app/controllers/discourse_ai/ai_bot/bot_controller.rb b/app/controllers/discourse_ai/ai_bot/bot_controller.rb index e5d5bcf0..5ea13795 100644 --- a/app/controllers/discourse_ai/ai_bot/bot_controller.rb +++ b/app/controllers/discourse_ai/ai_bot/bot_controller.rb @@ -6,6 +6,14 @@ module DiscourseAi requires_plugin ::DiscourseAi::PLUGIN_NAME requires_login + def show_debug_info_by_id + log = AiApiAuditLog.find(params[:id]) + raise Discourse::NotFound if !log.topic + + guardian.ensure_can_debug_ai_bot_conversation!(log.topic) + render json: AiApiAuditLogSerializer.new(log, root: false), status: 200 + end + def show_debug_info post = Post.find(params[:post_id]) guardian.ensure_can_debug_ai_bot_conversation!(post) diff --git a/app/models/ai_api_audit_log.rb b/app/models/ai_api_audit_log.rb index 2fa9f5c3..2fa0a214 100644 --- a/app/models/ai_api_audit_log.rb +++ b/app/models/ai_api_audit_log.rb @@ -14,6 +14,14 @@ class AiApiAuditLog < ActiveRecord::Base Ollama = 7 SambaNova = 8 end + + def next_log_id + self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first + end + + def prev_log_id + self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first + end end # == Schema Information diff --git a/app/serializers/ai_api_audit_log_serializer.rb b/app/serializers/ai_api_audit_log_serializer.rb index 0c438a7b..eeb3843a 100644 --- a/app/serializers/ai_api_audit_log_serializer.rb +++ b/app/serializers/ai_api_audit_log_serializer.rb @@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer :post_id, :feature_name, :language_model, - :created_at + :created_at, + :prev_log_id, + :next_log_id end diff --git a/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs b/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs index 5d0cdf69..c21e8df3 100644 --- a/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs +++ b/assets/javascripts/discourse/components/modal/debug-ai-modal.gjs @@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template"; import DButton from "discourse/components/d-button"; import DModal from "discourse/components/d-modal"; import { ajax } from "discourse/lib/ajax"; +import { popupAjaxError } from "discourse/lib/ajax-error"; import { clipboardCopy, escapeExpression } from "discourse/lib/utilities"; import i18n from "discourse-common/helpers/i18n"; import discourseLater from "discourse-common/lib/later"; @@ -63,6 +64,28 @@ export default class DebugAiModal extends Component { this.copy(this.info.raw_response_payload); } + async loadLog(logId) { + try { + await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then( + (result) => { + this.info = result; + } + ); + } catch (e) { + popupAjaxError(e); + } + } + + @action + prevLog() { + this.loadLog(this.info.prev_log_id); + } + + @action + nextLog() { + this.loadLog(this.info.next_log_id); + } + copy(text) { clipboardCopy(text); this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared"); @@ -73,11 +96,13 @@ export default class DebugAiModal extends Component { } loadApiRequestInfo() { - ajax( - `/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json` - ).then((result) => { - this.info = result; - }); + ajax(`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`) + .then((result) => { + this.info = result; + }) + .catch((e) => { + popupAjaxError(e); + }); } get requestActive() { @@ -147,6 +172,22 @@ export default class DebugAiModal extends Component { @action={{this.copyResponse}} @label="discourse_ai.ai_bot.debug_ai_modal.copy_response" /> + {{#if this.info.prev_log_id}} + + {{/if}} + {{#if this.info.next_log_id}} + + {{/if}} {{this.justCopiedText}} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 2d517946..82898c91 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -415,6 +415,8 @@ en: response_tokens: "Response tokens:" request: "Request" response: "Response" + next_log: "Next" + previous_log: "Previous" share_full_topic_modal: title: "Share Conversation Publicly" diff --git a/config/routes.rb b/config/routes.rb index a5c009ff..322e67ce 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -22,6 +22,7 @@ DiscourseAi::Engine.routes.draw do scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do get "bot-username" => "bot#show_bot_username" get "post/:post_id/show-debug-info" => "bot#show_debug_info" + get "show-debug-info/:id" => "bot#show_debug_info_by_id" post "post/:post_id/stop-streaming" => "bot#stop_streaming_response" end diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 834ae059..b965b1f6 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -100,6 +100,7 @@ module DiscourseAi llm_kwargs[:top_p] = persona.top_p if persona.top_p needs_newlines = false + tools_ran = 0 while total_completions <= MAX_COMPLETIONS && ongoing_chain tool_found = false @@ -107,9 +108,10 @@ module DiscourseAi result = llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel| - tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context) + tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context) + tool = nil if tools_ran >= MAX_TOOLS - if (tools.present?) + if tool.present? tool_found = true # a bit hacky, but extra newlines do no harm if needs_newlines @@ -117,13 +119,16 @@ module DiscourseAi needs_newlines = false end - tools[0..MAX_TOOLS].each do |tool| - process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) - ongoing_chain &&= tool.chain_next_response? - end + process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context) + tools_ran += 1 + ongoing_chain &&= tool.chain_next_response? else needs_newlines = true - update_blk.call(partial, cancel) + if partial.is_a?(DiscourseAi::Completions::ToolCall) + Rails.logger.warn("DiscourseAi: Tool not found: #{partial.name}") + else + update_blk.call(partial, cancel) + end end end diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 73224808..63255a17 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -199,23 +199,16 @@ module DiscourseAi prompt end - def find_tools(partial, bot_user:, llm:, context:) - return [] if !partial.include?("") - - parsed_function = Nokogiri::HTML5.fragment(partial) - parsed_function - .css("invoke") - .map do |fragment| - tool_instance(fragment, bot_user: bot_user, llm: llm, context: context) - end - .compact + def find_tool(partial, bot_user:, llm:, context:) + return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall) + tool_instance(partial, bot_user: bot_user, llm: llm, context: context) end protected - def tool_instance(parsed_function, bot_user:, llm:, context:) - function_id = parsed_function.at("tool_id")&.text - function_name = parsed_function.at("tool_name")&.text + def tool_instance(tool_call, bot_user:, llm:, context:) + function_id = tool_call.id + function_name = tool_call.name return nil if function_name.nil? tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name } @@ -224,7 +217,7 @@ module DiscourseAi arguments = {} tool_klass.signature[:parameters].to_a.each do |param| name = param[:name] - value = parsed_function.at(name)&.text + value = tool_call.parameters[name.to_sym] if param[:type] == "array" && value value = diff --git a/lib/completions/anthropic_message_processor.rb b/lib/completions/anthropic_message_processor.rb index 1d1516fa..5d5602ef 100644 --- a/lib/completions/anthropic_message_processor.rb +++ b/lib/completions/anthropic_message_processor.rb @@ -13,6 +13,11 @@ class DiscourseAi::Completions::AnthropicMessageProcessor def append(json) @raw_json << json end + + def to_tool_call + parameters = JSON.parse(raw_json, symbolize_names: true) + DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters) + end end attr_reader :tool_calls, :input_tokens, :output_tokens @@ -20,80 +25,69 @@ class DiscourseAi::Completions::AnthropicMessageProcessor def initialize(streaming_mode:) @streaming_mode = streaming_mode @tool_calls = [] + @current_tool_call = nil end - def to_xml_tool_calls(function_buffer) - return function_buffer if @tool_calls.blank? + def to_tool_calls + @tool_calls.map { |tool_call| tool_call.to_tool_call } + end - function_buffer = Nokogiri::HTML5.fragment(<<~TEXT) - - - TEXT - - @tool_calls.each do |tool_call| - node = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse( - DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n", - ), - ) - - params = JSON.parse(tool_call.raw_json, symbolize_names: true) - xml = - params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}" }.join("\n") - - node.at("tool_name").content = tool_call.name - node.at("tool_id").content = tool_call.id - node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present? + def process_streamed_message(parsed) + result = nil + if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use" + tool_name = parsed.dig(:content_block, :name) + tool_id = parsed.dig(:content_block, :id) + result = @current_tool_call.to_tool_call if @current_tool_call + @current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name + elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" + if @current_tool_call + tool_delta = parsed.dig(:delta, :partial_json).to_s + @current_tool_call.append(tool_delta) + else + result = parsed.dig(:delta, :text).to_s + end + elsif parsed[:type] == "content_block_stop" + if @current_tool_call + result = @current_tool_call.to_tool_call + @current_tool_call = nil + end + elsif parsed[:type] == "message_start" + @input_tokens = parsed.dig(:message, :usage, :input_tokens) + elsif parsed[:type] == "message_delta" + @output_tokens = + parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens) + elsif parsed[:type] == "message_stop" + # bedrock has this ... + if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym) + @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens + @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens + end end - - function_buffer + result end def process_message(payload) result = "" - parsed = JSON.parse(payload, symbolize_names: true) + parsed = payload + parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String) - if @streaming_mode - if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use" - tool_name = parsed.dig(:content_block, :name) - tool_id = parsed.dig(:content_block, :id) - @tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name - elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta" - if @tool_calls.present? - result = parsed.dig(:delta, :partial_json).to_s - @tool_calls.last.append(result) - else - result = parsed.dig(:delta, :text).to_s + content = parsed.dig(:content) + if content.is_a?(Array) + result = + content.map do |data| + if data[:type] == "tool_use" + call = AnthropicToolCall.new(data[:name], data[:id]) + call.append(data[:input].to_json) + call.to_tool_call + else + data[:text] + end end - elsif parsed[:type] == "message_start" - @input_tokens = parsed.dig(:message, :usage, :input_tokens) - elsif parsed[:type] == "message_delta" - @output_tokens = - parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens) - elsif parsed[:type] == "message_stop" - # bedrock has this ... - if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym) - @input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens - @output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens - end - end - else - content = parsed.dig(:content) - if content.is_a?(Array) - tool_call = content.find { |c| c[:type] == "tool_use" } - if tool_call - @tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id]) - @tool_calls.last.append(tool_call[:input].to_json) - else - result = parsed.dig(:content, 0, :text).to_s - end - end - - @input_tokens = parsed.dig(:usage, :input_tokens) - @output_tokens = parsed.dig(:usage, :output_tokens) end + @input_tokens = parsed.dig(:usage, :input_tokens) + @output_tokens = parsed.dig(:usage, :output_tokens) + result end end diff --git a/lib/completions/dialects/ollama.rb b/lib/completions/dialects/ollama.rb index 541d0e73..3a32e592 100644 --- a/lib/completions/dialects/ollama.rb +++ b/lib/completions/dialects/ollama.rb @@ -63,8 +63,23 @@ module DiscourseAi def user_msg(msg) user_message = { role: "user", content: msg[:content] } - # TODO: Add support for user messages with empbeded user ids - # TODO: Add support for user messages with attachments + encoded_uploads = prompt.encoded_uploads(msg) + if encoded_uploads.present? + images = + encoded_uploads + .map do |upload| + if upload[:mime_type].start_with?("image/") + upload[:base64] + else + nil + end + end + .compact + + user_message[:images] = images if images.present? + end + + # TODO: Add support for user messages with embedded user ids user_message end diff --git a/lib/completions/endpoints/anthropic.rb b/lib/completions/endpoints/anthropic.rb index 44762b88..6576ef3b 100644 --- a/lib/completions/endpoints/anthropic.rb +++ b/lib/completions/endpoints/anthropic.rb @@ -63,6 +63,10 @@ module DiscourseAi URI(llm_model.url) end + def xml_tools_enabled? + !@native_tool_support + end + def prepare_payload(prompt, model_params, dialect) @native_tool_support = dialect.native_tool_support? @@ -90,35 +94,34 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end + def decode_chunk(partial_data) + @decoder ||= JsonStreamDecoder.new + (@decoder << partial_data) + .map { |parsed_json| processor.process_streamed_message(parsed_json) } + .compact + end + + def decode(response_data) + processor.process_message(response_data) + end + def processor @processor ||= DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) end - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - processor.to_xml_tool_calls(function_buffer) if !partial - end - - def extract_completion_from(response_raw) - processor.process_message(response_raw) - end - def has_tool?(_response_data) processor.tool_calls.present? end + def tool_calls + processor.to_tool_calls + end + def final_log_update(log) log.request_tokens = processor.input_tokens if processor.input_tokens log.response_tokens = processor.output_tokens if processor.output_tokens end - - def native_tool_support? - @native_tool_support - end - - def partials_from(decoded_chunk) - decoded_chunk.split("\n").map { |line| line.split("data: ", 2)[1] }.compact - end end end end diff --git a/lib/completions/endpoints/aws_bedrock.rb b/lib/completions/endpoints/aws_bedrock.rb index f3146c2d..c17a051f 100644 --- a/lib/completions/endpoints/aws_bedrock.rb +++ b/lib/completions/endpoints/aws_bedrock.rb @@ -117,7 +117,24 @@ module DiscourseAi end end - def decode(chunk) + def decode_chunk(partial_data) + bedrock_decode(partial_data) + .map do |decoded_partial_data| + @raw_response ||= +"" + @raw_response << decoded_partial_data + @raw_response << "\n" + + parsed_json = JSON.parse(decoded_partial_data, symbolize_names: true) + processor.process_streamed_message(parsed_json) + end + .compact + end + + def decode(response_data) + processor.process_message(response_data) + end + + def bedrock_decode(chunk) @decoder ||= Aws::EventStream::Decoder.new decoded, _done = @decoder.decode_chunk(chunk) @@ -147,12 +164,13 @@ module DiscourseAi Aws::EventStream::Errors::MessageChecksumError, Aws::EventStream::Errors::PreludeChecksumError => e Rails.logger.error("#{self.class.name}: #{e.message}") - nil + [] end def final_log_update(log) log.request_tokens = processor.input_tokens if processor.input_tokens log.response_tokens = processor.output_tokens if processor.output_tokens + log.raw_response_payload = @raw_response end def processor @@ -160,30 +178,8 @@ module DiscourseAi DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode) end - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - processor.to_xml_tool_calls(function_buffer) if !partial - end - - def extract_completion_from(response_raw) - processor.process_message(response_raw) - end - - def has_tool?(_response_data) - processor.tool_calls.present? - end - - def partials_from(decoded_chunks) - decoded_chunks - end - - def native_tool_support? - @native_tool_support - end - - def chunk_to_string(chunk) - joined = +chunk.join("\n") - joined << "\n" if joined.length > 0 - joined + def xml_tools_enabled? + !@native_tool_support end end end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index a0405b42..c78fcdd9 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -40,10 +40,6 @@ module DiscourseAi @llm_model = llm_model end - def native_tool_support? - false - end - def use_ssl? if model_uri&.scheme.present? model_uri.scheme == "https" @@ -64,22 +60,10 @@ module DiscourseAi feature_context: nil, &blk ) - allow_tools = dialect.prompt.has_tools? model_params = normalize_model_params(model_params) orig_blk = blk @streaming_mode = block_given? - to_strip = xml_tags_to_strip(dialect) - @xml_stripper = - DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present? - - if @streaming_mode && @xml_stripper - blk = - lambda do |partial, cancel| - partial = @xml_stripper << partial - orig_blk.call(partial, cancel) if partial - end - end prompt = dialect.translate @@ -108,177 +92,91 @@ module DiscourseAi raise CompletionFailed, response.body end + xml_tool_processor = XmlToolProcessor.new if xml_tools_enabled? && + dialect.prompt.has_tools? + + to_strip = xml_tags_to_strip(dialect) + xml_stripper = + DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present? + + if @streaming_mode && xml_stripper + blk = + lambda do |partial, cancel| + partial = xml_stripper << partial if partial.is_a?(String) + orig_blk.call(partial, cancel) if partial + end + end + log = - AiApiAuditLog.new( + start_log( provider_id: provider_id, - user_id: user&.id, - raw_request_payload: request_body, - request_tokens: prompt_size(prompt), - topic_id: dialect.prompt.topic_id, - post_id: dialect.prompt.post_id, + request_body: request_body, + dialect: dialect, + prompt: prompt, + user: user, feature_name: feature_name, - language_model: llm_model.name, - feature_context: feature_context.present? ? feature_context.as_json : nil, + feature_context: feature_context, ) if !@streaming_mode - response_raw = response.read_body - response_data = extract_completion_from(response_raw) - partials_raw = response_data.to_s - - if native_tool_support? - if allow_tools && has_tool?(response_data) - function_buffer = build_buffer # Nokogiri document - function_buffer = - add_to_function_buffer(function_buffer, payload: response_data) - FunctionCallNormalizer.normalize_function_ids!(function_buffer) - - response_data = +function_buffer.at("function_calls").to_s - response_data << "\n" - end - else - if allow_tools - response_data, function_calls = FunctionCallNormalizer.normalize(response_data) - response_data = function_calls if function_calls.present? - end - end - - return response_data + return( + non_streaming_response( + response: response, + xml_tool_processor: xml_tool_processor, + xml_stripper: xml_stripper, + partials_raw: partials_raw, + response_raw: response_raw, + ) + ) end - has_tool = false - begin cancelled = false cancel = -> { cancelled = true } - - wrapped_blk = ->(partial, inner_cancel) do - response_data << partial - blk.call(partial, inner_cancel) + if cancelled + http.finish + break end - normalizer = FunctionCallNormalizer.new(wrapped_blk, cancel) - - leftover = "" - function_buffer = build_buffer # Nokogiri document - prev_processed_partials = 0 - response.read_body do |chunk| - if cancelled - http.finish - break - end - - decoded_chunk = decode(chunk) - if decoded_chunk.nil? - raise CompletionFailed, "#{self.class.name}: Failed to decode LLM completion" - end - response_raw << chunk_to_string(decoded_chunk) - - if decoded_chunk.is_a?(String) - redo_chunk = leftover + decoded_chunk - else - # custom implementation for endpoint - # no implicit leftover support - redo_chunk = decoded_chunk - end - - raw_partials = partials_from(redo_chunk) - - raw_partials = - raw_partials[prev_processed_partials..-1] if prev_processed_partials > 0 - - if raw_partials.blank? || (raw_partials.size == 1 && raw_partials.first.blank?) - leftover = redo_chunk - next - end - - json_error = false - - raw_partials.each do |raw_partial| - json_error = false - prev_processed_partials += 1 - - next if cancelled - next if raw_partial.blank? - - begin - partial = extract_completion_from(raw_partial) - next if partial.nil? - # empty vs blank... we still accept " " - next if response_data.empty? && partial.empty? - partials_raw << partial.to_s - - if native_tool_support? - # Stop streaming the response as soon as you find a tool. - # We'll buffer and yield it later. - has_tool = true if allow_tools && has_tool?(partials_raw) - - if has_tool - function_buffer = - add_to_function_buffer(function_buffer, partial: partial) - else - response_data << partial - blk.call(partial, cancel) if partial - end - else - if allow_tools - normalizer << partial - else - response_data << partial - blk.call(partial, cancel) if partial - end + response_raw << chunk + decode_chunk(chunk).each do |partial| + partials_raw << partial.to_s + response_data << partial if partial.is_a?(String) + partials = [partial] + if xml_tool_processor && partial.is_a?(String) + partials = (xml_tool_processor << partial) + if xml_tool_processor.should_cancel? + cancel.call + break end - rescue JSON::ParserError - leftover = redo_chunk - json_error = true end + partials.each { |inner_partial| blk.call(inner_partial, cancel) } end - - if json_error - prev_processed_partials -= 1 - else - leftover = "" - end - - prev_processed_partials = 0 if leftover.blank? end rescue IOError, StandardError raise if !cancelled end - - has_tool ||= has_tool?(partials_raw) - # Once we have the full response, try to return the tool as a XML doc. - if has_tool && native_tool_support? - function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw) - - if function_buffer.at("tool_name").text.present? - FunctionCallNormalizer.normalize_function_ids!(function_buffer) - - invocation = +function_buffer.at("function_calls").to_s - invocation << "\n" - - response_data << invocation - blk.call(invocation, cancel) + if xml_stripper + stripped = xml_stripper.finish + if stripped.present? + response_data << stripped + result = [] + result = (xml_tool_processor << stripped) if xml_tool_processor + result.each { |partial| blk.call(partial, cancel) } end end - - if !native_tool_support? && function_calls = normalizer.function_calls - response_data << function_calls - blk.call(function_calls, cancel) + if xml_tool_processor + xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) } end - - if @xml_stripper - leftover = @xml_stripper.finish - orig_blk.call(leftover, cancel) if leftover.present? - end - + decode_chunk_finish.each { |partial| blk.call(partial, cancel) } return response_data ensure if log log.raw_response_payload = response_raw - log.response_tokens = tokenizer.size(partials_raw) final_log_update(log) + + log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank? log.save! if Rails.env.development? @@ -330,15 +228,15 @@ module DiscourseAi raise NotImplementedError end - def extract_completion_from(_response_raw) + def decode(_response_raw) raise NotImplementedError end - def decode(chunk) - chunk + def decode_chunk_finish + [] end - def partials_from(_decoded_chunk) + def decode_chunk(_chunk) raise NotImplementedError end @@ -346,49 +244,73 @@ module DiscourseAi prompt.map { |message| message[:content] || message["content"] || "" }.join("\n") end - def build_buffer - Nokogiri::HTML5.fragment(<<~TEXT) - - #{noop_function_call_text} - - TEXT + def xml_tools_enabled? + raise NotImplementedError end - def self.noop_function_call_text - (<<~TEXT).strip - - - - - - - TEXT + private + + def start_log( + provider_id:, + request_body:, + dialect:, + prompt:, + user:, + feature_name:, + feature_context: + ) + AiApiAuditLog.new( + provider_id: provider_id, + user_id: user&.id, + raw_request_payload: request_body, + request_tokens: prompt_size(prompt), + topic_id: dialect.prompt.topic_id, + post_id: dialect.prompt.post_id, + feature_name: feature_name, + language_model: llm_model.name, + feature_context: feature_context.present? ? feature_context.as_json : nil, + ) end - def noop_function_call_text - self.class.noop_function_call_text - end + def non_streaming_response( + response:, + xml_tool_processor:, + xml_stripper:, + partials_raw:, + response_raw: + ) + response_raw << response.read_body + response_data = decode(response_raw) - def has_tool?(response) - response.include?("") - end + response_data.each { |partial| partials_raw << partial.to_s } - def chunk_to_string(chunk) - if chunk.is_a?(String) - chunk - else - chunk.to_s - end - end - - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - if payload&.include?("") - matches = payload.match(%r{.*}m) - function_buffer = - Nokogiri::HTML5.fragment(matches[0] + "\n") if matches + if xml_tool_processor + response_data.each do |partial| + processed = (xml_tool_processor << partial) + processed << xml_tool_processor.finish + response_data = [] + processed.flatten.compact.each { |inner| response_data << inner } + end end - function_buffer + if xml_stripper + response_data.map! do |partial| + stripped = (xml_stripper << partial) if partial.is_a?(String) + if stripped.present? + stripped + else + partial + end + end + response_data << xml_stripper.finish + end + + response_data.reject!(&:blank?) + + # this is to keep stuff backwards compatible + response_data = response_data.first if response_data.length == 1 + + response_data end end end diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index eaef21da..bd3ae4ea 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -45,17 +45,21 @@ module DiscourseAi cancel_fn = lambda { cancelled = true } # We buffer and return tool invocations in one go. - if is_tool?(response) - yield(response, cancel_fn) - else - response.each_char do |char| - break if cancelled - yield(char, cancel_fn) + as_array = response.is_a?(Array) ? response : [response] + as_array.each do |response| + if is_tool?(response) + yield(response, cancel_fn) + else + response.each_char do |char| + break if cancelled + yield(char, cancel_fn) + end end end - else - response end + + response = response.first if response.is_a?(Array) && response.length == 1 + response end def tokenizer @@ -65,7 +69,7 @@ module DiscourseAi private def is_tool?(response) - Nokogiri::HTML5.fragment(response).at("function_calls").present? + response.is_a?(DiscourseAi::Completions::ToolCall) end end end diff --git a/lib/completions/endpoints/cohere.rb b/lib/completions/endpoints/cohere.rb index 180c27c8..258062a1 100644 --- a/lib/completions/endpoints/cohere.rb +++ b/lib/completions/endpoints/cohere.rb @@ -49,6 +49,47 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end + def decode(response_raw) + rval = [] + + parsed = JSON.parse(response_raw, symbolize_names: true) + + text = parsed[:text] + rval << parsed[:text] if !text.to_s.empty? # also allow " " + + # TODO tool calls + + update_usage(parsed) + + rval + end + + def decode_chunk(chunk) + @tool_idx ||= -1 + @json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/) + (@json_decoder << chunk) + .map do |parsed| + update_usage(parsed) + rval = [] + + rval << parsed[:text] if !parsed[:text].to_s.empty? + + if tool_calls = parsed[:tool_calls] + tool_calls&.each do |tool_call| + @tool_idx += 1 + tool_name = tool_call[:name] + tool_params = tool_call[:parameters] + tool_id = "tool_#{@tool_idx}" + rval << ToolCall.new(id: tool_id, name: tool_name, parameters: tool_params) + end + end + + rval + end + .flatten + .compact + end + def extract_completion_from(response_raw) parsed = JSON.parse(response_raw, symbolize_names: true) @@ -77,36 +118,8 @@ module DiscourseAi end end - def has_tool?(_ignored) - @has_tool - end - - def native_tool_support? - true - end - - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - if partial - tools = JSON.parse(partial) - tools.each do |tool| - name = tool["name"] - parameters = tool["parameters"] - xml_params = parameters.map { |k, v| "<#{k}>#{v}\n" }.join - - current_function = function_buffer.at("invoke") - if current_function.nil? || current_function.at("tool_name").content.present? - current_function = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"), - ) - end - - current_function.at("tool_name").content = name == "search_local" ? "search" : name - current_function.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(xml_params) - end - end - function_buffer + def xml_tools_enabled? + false end def final_log_update(log) @@ -114,10 +127,6 @@ module DiscourseAi log.response_tokens = @output_tokens if @output_tokens end - def partials_from(decoded_chunk) - decoded_chunk.split("\n").compact - end - def extract_prompt_for_tokenizer(prompt) text = +"" if prompt[:chat_history] @@ -131,6 +140,18 @@ module DiscourseAi text end + + private + + def update_usage(parsed) + input_tokens = parsed.dig(:meta, :billed_units, :input_tokens) + input_tokens ||= parsed.dig(:response, :meta, :billed_units, :input_tokens) + @input_tokens = input_tokens if input_tokens.present? + + output_tokens = parsed.dig(:meta, :billed_units, :output_tokens) + output_tokens ||= parsed.dig(:response, :meta, :billed_units, :output_tokens) + @output_tokens = output_tokens if output_tokens.present? + end end end end diff --git a/lib/completions/endpoints/fake.rb b/lib/completions/endpoints/fake.rb index a51ff3ac..15cc254d 100644 --- a/lib/completions/endpoints/fake.rb +++ b/lib/completions/endpoints/fake.rb @@ -133,31 +133,35 @@ module DiscourseAi content = content.shift if content.is_a?(Array) if block_given? - split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort - indexes = [0, *split_indices, content.length] + if content.is_a?(DiscourseAi::Completions::ToolCall) + yield(content, -> {}) + else + split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort + indexes = [0, *split_indices, content.length] - original_content = content - content = +"" + original_content = content + content = +"" - cancel = false - cancel_proc = -> { cancel = true } + cancel = false + cancel_proc = -> { cancel = true } - i = 0 - indexes - .each_cons(2) - .map { |start, finish| original_content[start...finish] } - .each do |chunk| - break if cancel - if self.class.delays.present? && - (delay = self.class.delays[i % self.class.delays.length]) - sleep(delay) - i += 1 + i = 0 + indexes + .each_cons(2) + .map { |start, finish| original_content[start...finish] } + .each do |chunk| + break if cancel + if self.class.delays.present? && + (delay = self.class.delays[i % self.class.delays.length]) + sleep(delay) + i += 1 + end + break if cancel + + content << chunk + yield(chunk, cancel_proc) end - break if cancel - - content << chunk - yield(chunk, cancel_proc) - end + end end content diff --git a/lib/completions/endpoints/gemini.rb b/lib/completions/endpoints/gemini.rb index ddf607b2..2450dc99 100644 --- a/lib/completions/endpoints/gemini.rb +++ b/lib/completions/endpoints/gemini.rb @@ -103,15 +103,7 @@ module DiscourseAi end end - def partials_from(decoded_chunk) - decoded_chunk - end - - def chunk_to_string(chunk) - chunk.to_s - end - - class Decoder + class GeminiStreamingDecoder def initialize @buffer = +"" end @@ -151,43 +143,87 @@ module DiscourseAi end def decode(chunk) - @decoder ||= Decoder.new - @decoder.decode(chunk) + json = JSON.parse(chunk, symbolize_names: true) + idx = -1 + json + .dig(:candidates, 0, :content, :parts) + .map do |part| + if part[:functionCall] + idx += 1 + ToolCall.new( + id: "tool_#{idx}", + name: part[:functionCall][:name], + parameters: part[:functionCall][:args], + ) + else + part = part[:text] + if part != "" + part + else + nil + end + end + end + end + + def decode_chunk(chunk) + @tool_index ||= -1 + + streaming_decoder + .decode(chunk) + .map do |parsed| + update_usage(parsed) + parsed + .dig(:candidates, 0, :content, :parts) + .map do |part| + if part[:text] + part = part[:text] + if part != "" + part + else + nil + end + elsif part[:functionCall] + @tool_index += 1 + ToolCall.new( + id: "tool_#{@tool_index}", + name: part[:functionCall][:name], + parameters: part[:functionCall][:args], + ) + end + end + end + .flatten + .compact + end + + def update_usage(parsed) + usage = parsed.dig(:usageMetadata) + if usage + if prompt_token_count = usage[:promptTokenCount] + @prompt_token_count = prompt_token_count + end + if candidate_token_count = usage[:candidatesTokenCount] + @candidate_token_count = candidate_token_count + end + end + end + + def final_log_update(log) + log.request_tokens = @prompt_token_count if @prompt_token_count + log.response_tokens = @candidate_token_count if @candidate_token_count + end + + def streaming_decoder + @decoder ||= GeminiStreamingDecoder.new end def extract_prompt_for_tokenizer(prompt) prompt.to_s end - def has_tool?(_response_data) - @has_function_call - end - - def native_tool_support? - true - end - - def add_to_function_buffer(function_buffer, payload: nil, partial: nil) - if @streaming_mode - return function_buffer if !partial - else - partial = payload - end - - function_buffer.at("tool_name").content = partial[:name] if partial[:name].present? - - if partial[:args] - argument_fragments = - partial[:args].reduce(+"") do |memo, (arg_name, value)| - memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}" - end - argument_fragments << "\n" - - function_buffer.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) - end - - function_buffer + def xml_tools_enabled? + false end end end diff --git a/lib/completions/endpoints/hugging_face.rb b/lib/completions/endpoints/hugging_face.rb index bd7edc06..b0b14722 100644 --- a/lib/completions/endpoints/hugging_face.rb +++ b/lib/completions/endpoints/hugging_face.rb @@ -59,22 +59,30 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - # half a line sent here - return if !parsed - - response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) - - response_h.dig(:content) + def xml_tools_enabled? + true end - def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data:", 2)[1] - data&.squish == "[DONE]" ? nil : data + def decode(response_raw) + parsed = JSON.parse(response_raw, symbolize_names: true) + text = parsed.dig(:choices, 0, :message, :content) + if text.to_s.empty? + [""] + else + [text] + end + end + + def decode_chunk(chunk) + @json_decoder ||= JsonStreamDecoder.new + (@json_decoder << chunk) + .map do |parsed| + text = parsed.dig(:choices, 0, :delta, :content) + if text.to_s.empty? + nil + else + text + end end .compact end diff --git a/lib/completions/endpoints/ollama.rb b/lib/completions/endpoints/ollama.rb index cc58006a..dd4ca2c7 100644 --- a/lib/completions/endpoints/ollama.rb +++ b/lib/completions/endpoints/ollama.rb @@ -37,12 +37,8 @@ module DiscourseAi URI(llm_model.url) end - def native_tool_support? - @native_tool_support - end - - def has_tool?(_response_data) - @has_function_call + def xml_tools_enabled? + !@native_tool_support end def prepare_payload(prompt, model_params, dialect) @@ -67,74 +63,30 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def partials_from(decoded_chunk) - decoded_chunk.split("\n").compact + def decode_chunk(chunk) + # Native tool calls are not working right in streaming mode, use XML + @json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/) + (@json_decoder << chunk).map { |parsed| parsed.dig(:message, :content) }.compact end - def extract_completion_from(response_raw) + def decode(response_raw) + rval = [] parsed = JSON.parse(response_raw, symbolize_names: true) - return if !parsed + content = parsed.dig(:message, :content) + rval << content if !content.to_s.empty? - response_h = parsed.dig(:message) - - @has_function_call ||= response_h.dig(:tool_calls).present? - @has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content) - end - - def add_to_function_buffer(function_buffer, payload: nil, partial: nil) - @args_buffer ||= +"" - - if @streaming_mode - return function_buffer if !partial - else - partial = payload - end - - f_name = partial.dig(:function, :name) - - @current_function ||= function_buffer.at("invoke") - - if f_name - current_name = function_buffer.at("tool_name").content - - if current_name.blank? - # first call - else - # we have a previous function, so we need to add a noop - @args_buffer = +"" - @current_function = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"), - ) + idx = -1 + parsed + .dig(:message, :tool_calls) + &.each do |tool_call| + idx += 1 + id = "tool_#{idx}" + name = tool_call.dig(:function, :name) + args = tool_call.dig(:function, :arguments) + rval << ToolCall.new(id: id, name: name, parameters: args) end - end - @current_function.at("tool_name").content = f_name if f_name - @current_function.at("tool_id").content = partial[:id] if partial[:id] - - args = partial.dig(:function, :arguments) - - # allow for SPACE within arguments - if args && args != "" - @args_buffer << args.to_json - - begin - json_args = JSON.parse(@args_buffer, symbolize_names: true) - - argument_fragments = - json_args.reduce(+"") do |memo, (arg_name, value)| - memo << "\n<#{arg_name}>#{value}" - end - argument_fragments << "\n" - - @current_function.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) - rescue JSON::ParserError - return function_buffer - end - end - - function_buffer + rval end end end diff --git a/lib/completions/endpoints/open_ai.rb b/lib/completions/endpoints/open_ai.rb index 92315ed5..a185a840 100644 --- a/lib/completions/endpoints/open_ai.rb +++ b/lib/completions/endpoints/open_ai.rb @@ -93,98 +93,34 @@ module DiscourseAi end def final_log_update(log) - log.request_tokens = @prompt_tokens if @prompt_tokens - log.response_tokens = @completion_tokens if @completion_tokens + log.request_tokens = processor.prompt_tokens if processor.prompt_tokens + log.response_tokens = processor.completion_tokens if processor.completion_tokens end - def extract_completion_from(response_raw) - json = JSON.parse(response_raw, symbolize_names: true) - - if @streaming_mode - @prompt_tokens ||= json.dig(:usage, :prompt_tokens) - @completion_tokens ||= json.dig(:usage, :completion_tokens) - end - - parsed = json.dig(:choices, 0) - return if !parsed - - response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) - @has_function_call ||= response_h.dig(:tool_calls).present? - @has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content) + def decode(response_raw) + processor.process_message(JSON.parse(response_raw, symbolize_names: true)) end - def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data: ", 2)[1] - data == "[DONE]" ? nil : data - end + def decode_chunk(chunk) + @decoder ||= JsonStreamDecoder.new + (@decoder << chunk) + .map { |parsed_json| processor.process_streamed_message(parsed_json) } + .flatten .compact end - def has_tool?(_response_data) - @has_function_call + def decode_chunk_finish + @processor.finish end - def native_tool_support? - true + def xml_tools_enabled? + false end - def add_to_function_buffer(function_buffer, partial: nil, payload: nil) - if @streaming_mode - return function_buffer if !partial - else - partial = payload - end + private - @args_buffer ||= +"" - - f_name = partial.dig(:function, :name) - - @current_function ||= function_buffer.at("invoke") - - if f_name - current_name = function_buffer.at("tool_name").content - - if current_name.blank? - # first call - else - # we have a previous function, so we need to add a noop - @args_buffer = +"" - @current_function = - function_buffer.at("function_calls").add_child( - Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"), - ) - end - end - - @current_function.at("tool_name").content = f_name if f_name - @current_function.at("tool_id").content = partial[:id] if partial[:id] - - args = partial.dig(:function, :arguments) - - # allow for SPACE within arguments - if args && args != "" - @args_buffer << args - - begin - json_args = JSON.parse(@args_buffer, symbolize_names: true) - - argument_fragments = - json_args.reduce(+"") do |memo, (arg_name, value)| - memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}" - end - argument_fragments << "\n" - - @current_function.at("parameters").children = - Nokogiri::HTML5::DocumentFragment.parse(argument_fragments) - rescue JSON::ParserError - return function_buffer - end - end - - function_buffer + def processor + @processor ||= OpenAiMessageProcessor.new end end end diff --git a/lib/completions/endpoints/samba_nova.rb b/lib/completions/endpoints/samba_nova.rb index ccb883cc..cc81e786 100644 --- a/lib/completions/endpoints/samba_nova.rb +++ b/lib/completions/endpoints/samba_nova.rb @@ -55,27 +55,31 @@ module DiscourseAi log.response_tokens = @completion_tokens if @completion_tokens end - def extract_completion_from(response_raw) - json = JSON.parse(response_raw, symbolize_names: true) - - if @streaming_mode - @prompt_tokens ||= json.dig(:usage, :prompt_tokens) - @completion_tokens ||= json.dig(:usage, :completion_tokens) - end - - parsed = json.dig(:choices, 0) - return if !parsed - - @streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content) + def xml_tools_enabled? + true end - def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data: ", 2)[1] - data == "[DONE]" ? nil : data + def decode(response_raw) + json = JSON.parse(response_raw, symbolize_names: true) + [json.dig(:choices, 0, :message, :content)] + end + + def decode_chunk(chunk) + @json_decoder ||= JsonStreamDecoder.new + (@json_decoder << chunk) + .map do |json| + text = json.dig(:choices, 0, :delta, :content) + + @prompt_tokens ||= json.dig(:usage, :prompt_tokens) + @completion_tokens ||= json.dig(:usage, :completion_tokens) + + if !text.to_s.empty? + text + else + nil + end end + .flatten .compact end end diff --git a/lib/completions/endpoints/vllm.rb b/lib/completions/endpoints/vllm.rb index 57fcf051..6b371a09 100644 --- a/lib/completions/endpoints/vllm.rb +++ b/lib/completions/endpoints/vllm.rb @@ -42,7 +42,10 @@ module DiscourseAi def prepare_payload(prompt, model_params, dialect) payload = default_options.merge(model_params).merge(messages: prompt) - payload[:stream] = true if @streaming_mode + if @streaming_mode + payload[:stream] = true if @streaming_mode + payload[:stream_options] = { include_usage: true } + end payload end @@ -56,24 +59,42 @@ module DiscourseAi Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } end - def partials_from(decoded_chunk) - decoded_chunk - .split("\n") - .map do |line| - data = line.split("data: ", 2)[1] - data == "[DONE]" ? nil : data - end - .compact + def xml_tools_enabled? + true end - def extract_completion_from(response_raw) - parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0) - # half a line sent here - return if !parsed + def final_log_update(log) + log.request_tokens = @prompt_tokens if @prompt_tokens + log.response_tokens = @completion_tokens if @completion_tokens + end - response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message) + def decode(response_raw) + json = JSON.parse(response_raw, symbolize_names: true) + @prompt_tokens = json.dig(:usage, :prompt_tokens) + @completion_tokens = json.dig(:usage, :completion_tokens) + [json.dig(:choices, 0, :message, :content)] + end - response_h.dig(:content) + def decode_chunk(chunk) + @json_decoder ||= JsonStreamDecoder.new + (@json_decoder << chunk) + .map do |parsed| + # vLLM keeps sending usage over and over again + prompt_tokens = parsed.dig(:usage, :prompt_tokens) + completion_tokens = parsed.dig(:usage, :completion_tokens) + + @prompt_tokens = prompt_tokens if prompt_tokens + + @completion_tokens = completion_tokens if completion_tokens + + text = parsed.dig(:choices, 0, :delta, :content) + if text.to_s.empty? + nil + else + text + end + end + .compact end end end diff --git a/lib/completions/function_call_normalizer.rb b/lib/completions/function_call_normalizer.rb deleted file mode 100644 index ef40809c..00000000 --- a/lib/completions/function_call_normalizer.rb +++ /dev/null @@ -1,113 +0,0 @@ -# frozen_string_literal: true - -class DiscourseAi::Completions::FunctionCallNormalizer - attr_reader :done - - # blk is the block to call with filtered data - def initialize(blk, cancel) - @blk = blk - @cancel = cancel - @done = false - - @in_tool = false - - @buffer = +"" - @function_buffer = +"" - end - - def self.normalize(data) - text = +"" - cancel = -> {} - blk = ->(partial, _) { text << partial } - - normalizer = self.new(blk, cancel) - normalizer << data - - [text, normalizer.function_calls] - end - - def function_calls - return nil if @function_buffer.blank? - - xml = Nokogiri::HTML5.fragment(@function_buffer) - self.class.normalize_function_ids!(xml) - last_invoke = xml.at("invoke:last") - if last_invoke - last_invoke.next_sibling.remove while last_invoke.next_sibling - xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling - end - xml.at("function_calls").to_s.dup.force_encoding("UTF-8") - end - - def <<(text) - @buffer << text - - if !@in_tool - # double check if we are clearly in a tool - search_length = text.length + 20 - search_string = @buffer[-search_length..-1] || @buffer - - index = search_string.rindex("") - @in_tool = !!index - if @in_tool - @function_buffer = @buffer[index..-1] - text_index = text.rindex("") - @blk.call(text[0..text_index - 1].strip, @cancel) if text_index && text_index > 0 - end - else - @function_buffer << text - end - - if !@in_tool - if maybe_has_tool?(@buffer) - split_index = text.rindex("<").to_i - 1 - if split_index >= 0 - @function_buffer = text[split_index + 1..-1] || "" - text = text[0..split_index] || "" - else - @function_buffer << text - text = "" - end - else - if @function_buffer.length > 0 - @blk.call(@function_buffer, @cancel) - @function_buffer = +"" - end - end - - @blk.call(text, @cancel) if text.length > 0 - else - if text.include?("") - @done = true - @cancel.call - end - end - end - - def self.normalize_function_ids!(function_buffer) - function_buffer - .css("invoke") - .each_with_index do |invoke, index| - if invoke.at("tool_id") - invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank? - else - invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id") - end - end - end - - private - - def maybe_has_tool?(text) - # 16 is the length of function calls - substring = text[-16..-1] || text - split = substring.split("<") - - if split.length > 1 - match = "<" + split.last - "".start_with?(match) - else - substring.ends_with?("<") - end - end -end diff --git a/lib/completions/json_stream_decoder.rb b/lib/completions/json_stream_decoder.rb new file mode 100644 index 00000000..e575a3b7 --- /dev/null +++ b/lib/completions/json_stream_decoder.rb @@ -0,0 +1,48 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + # will work for anthropic and open ai compatible + class JsonStreamDecoder + attr_reader :buffer + + LINE_REGEX = /data: ({.*})\s*$/ + + def initialize(symbolize_keys: true, line_regex: LINE_REGEX) + @symbolize_keys = symbolize_keys + @buffer = +"" + @line_regex = line_regex + end + + def <<(raw) + @buffer << raw.to_s + rval = [] + + split = @buffer.scan(/.*\n?/) + split.pop if split.last.blank? + + @buffer = +(split.pop.to_s) + + split.each do |line| + matches = line.match(@line_regex) + next if !matches + rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys) + end + + if @buffer.present? + matches = @buffer.match(@line_regex) + if matches + begin + rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys) + @buffer = +"" + rescue JSON::ParserError + # maybe it is a partial line + end + end + end + + rval + end + end + end +end diff --git a/lib/completions/open_ai_message_processor.rb b/lib/completions/open_ai_message_processor.rb new file mode 100644 index 00000000..02369bec --- /dev/null +++ b/lib/completions/open_ai_message_processor.rb @@ -0,0 +1,103 @@ +# frozen_string_literal: true +module DiscourseAi::Completions + class OpenAiMessageProcessor + attr_reader :prompt_tokens, :completion_tokens + + def initialize + @tool = nil + @tool_arguments = +"" + @prompt_tokens = nil + @completion_tokens = nil + end + + def process_message(json) + result = [] + tool_calls = json.dig(:choices, 0, :message, :tool_calls) + + message = json.dig(:choices, 0, :message, :content) + result << message if message.present? + + if tool_calls.present? + tool_calls.each do |tool_call| + id = tool_call.dig(:id) + name = tool_call.dig(:function, :name) + arguments = tool_call.dig(:function, :arguments) + parameters = arguments.present? ? JSON.parse(arguments, symbolize_names: true) : {} + result << ToolCall.new(id: id, name: name, parameters: parameters) + end + end + + update_usage(json) + + result + end + + def process_streamed_message(json) + rval = nil + + tool_calls = json.dig(:choices, 0, :delta, :tool_calls) + content = json.dig(:choices, 0, :delta, :content) + + finished_tools = json.dig(:choices, 0, :finish_reason) || tool_calls == [] + + if tool_calls.present? + id = tool_calls.dig(0, :id) + name = tool_calls.dig(0, :function, :name) + arguments = tool_calls.dig(0, :function, :arguments) + + # TODO: multiple tool support may require index + #index = tool_calls[0].dig(:index) + + if id.present? && @tool && @tool.id != id + process_arguments + rval = @tool + @tool = nil + end + + if id.present? && name.present? + @tool_arguments = +"" + @tool = ToolCall.new(id: id, name: name) + end + + @tool_arguments << arguments.to_s + elsif finished_tools && @tool + parsed_args = JSON.parse(@tool_arguments, symbolize_names: true) + @tool.parameters = parsed_args + rval = @tool + @tool = nil + elsif content.present? + rval = content + end + + update_usage(json) + + rval + end + + def finish + rval = [] + if @tool + process_arguments + rval << @tool + @tool = nil + end + + rval + end + + private + + def process_arguments + if @tool_arguments.present? + parsed_args = JSON.parse(@tool_arguments, symbolize_names: true) + @tool.parameters = parsed_args + @tool_arguments = nil + end + end + + def update_usage(json) + @prompt_tokens ||= json.dig(:usage, :prompt_tokens) + @completion_tokens ||= json.dig(:usage, :completion_tokens) + end + end +end diff --git a/lib/completions/tool_call.rb b/lib/completions/tool_call.rb new file mode 100644 index 00000000..15be7b3f --- /dev/null +++ b/lib/completions/tool_call.rb @@ -0,0 +1,29 @@ +# frozen_string_literal: true + +module DiscourseAi + module Completions + class ToolCall + attr_reader :id, :name, :parameters + + def initialize(id:, name:, parameters: nil) + @id = id + @name = name + self.parameters = parameters if parameters + @parameters ||= {} + end + + def parameters=(parameters) + raise ArgumentError, "parameters must be a hash" unless parameters.is_a?(Hash) + @parameters = parameters.symbolize_keys + end + + def ==(other) + id == other.id && name == other.name && parameters == other.parameters + end + + def to_s + "#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)" + end + end + end +end diff --git a/lib/completions/xml_tool_processor.rb b/lib/completions/xml_tool_processor.rb new file mode 100644 index 00000000..1b42b333 --- /dev/null +++ b/lib/completions/xml_tool_processor.rb @@ -0,0 +1,124 @@ +# frozen_string_literal: true + +# This class can be used to process a stream of text that may contain XML tool +# calls. +# It will return either text or ToolCall objects. + +module DiscourseAi + module Completions + class XmlToolProcessor + def initialize + @buffer = +"" + @function_buffer = +"" + @should_cancel = false + @in_tool = false + end + + def <<(text) + @buffer << text + result = [] + + if !@in_tool + # double check if we are clearly in a tool + search_length = text.length + 20 + search_string = @buffer[-search_length..-1] || @buffer + + index = search_string.rindex("") + @in_tool = !!index + if @in_tool + @function_buffer = @buffer[index..-1] + text_index = text.rindex("") + result << text[0..text_index - 1].strip if text_index && text_index > 0 + end + else + @function_buffer << text + end + + if !@in_tool + if maybe_has_tool?(@buffer) + split_index = text.rindex("<").to_i - 1 + if split_index >= 0 + @function_buffer = text[split_index + 1..-1] || "" + text = text[0..split_index] || "" + else + @function_buffer << text + text = "" + end + else + if @function_buffer.length > 0 + result << @function_buffer + @function_buffer = +"" + end + end + + result << text if text.length > 0 + else + @should_cancel = true if text.include?("") + end + + result + end + + def finish + return [] if @function_buffer.blank? + + xml = Nokogiri::HTML5.fragment(@function_buffer) + normalize_function_ids!(xml) + last_invoke = xml.at("invoke:last") + if last_invoke + last_invoke.next_sibling.remove while last_invoke.next_sibling + xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling + end + + xml + .css("invoke") + .map do |invoke| + tool_name = invoke.at("tool_name").content.force_encoding("UTF-8") + tool_id = invoke.at("tool_id").content.force_encoding("UTF-8") + parameters = {} + invoke + .at("parameters") + &.children + &.each do |node| + next if node.text? + name = node.name + value = node.content.to_s + parameters[name.to_sym] = value.to_s.force_encoding("UTF-8") + end + ToolCall.new(id: tool_id, name: tool_name, parameters: parameters) + end + end + + def should_cancel? + @should_cancel + end + + private + + def normalize_function_ids!(function_buffer) + function_buffer + .css("invoke") + .each_with_index do |invoke, index| + if invoke.at("tool_id") + invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank? + else + invoke.add_child("tool_#{index}\n") if !invoke.at("tool_id") + end + end + end + + def maybe_has_tool?(text) + # 16 is the length of function calls + substring = text[-16..-1] || text + split = substring.split("<") + + if split.length > 1 + match = "<" + split.last + "".start_with?(match) + else + substring.ends_with?("<") + end + end + end + end +end diff --git a/spec/lib/completions/endpoints/anthropic_spec.rb b/spec/lib/completions/endpoints/anthropic_spec.rb index 94d5d655..40eca30f 100644 --- a/spec/lib/completions/endpoints/anthropic_spec.rb +++ b/spec/lib/completions/endpoints/anthropic_spec.rb @@ -104,7 +104,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do data: {"type":"message_stop"} STRING - result = +"" + result = [] body = body.scan(/.*\n/) EndpointMock.with_chunk_array_support do stub_request(:post, url).to_return(status: 200, body: body) @@ -114,18 +114,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do end end - expected = (<<~TEXT).strip - - - search - s<a>m sam - general - toolu_01DjrShFRRHp9SnHYRFRc53F - - - TEXT + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "toolu_01DjrShFRRHp9SnHYRFRc53F", + parameters: { + search_query: "sm sam", + category: "general", + }, + ) - expect(result.strip).to eq(expected) + expect(result).to eq([tool_call]) end it "can stream a response" do @@ -191,6 +190,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do expect(log.feature_name).to eq("testing") expect(log.response_tokens).to eq(15) expect(log.request_tokens).to eq(25) + expect(log.raw_request_payload).to eq(expected_body.to_json) + expect(log.raw_response_payload.strip).to eq(body.strip) end it "supports non streaming tool calls" do @@ -242,17 +243,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do result = llm.generate(prompt, user: Discourse.system_user) - expected = <<~TEXT.strip - - - calculate - 2758975 + 21.11 - toolu_012kBdhG4eHaV68W56p4N94h - - - TEXT + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "calculate", + id: "toolu_012kBdhG4eHaV68W56p4N94h", + parameters: { + expression: "2758975 + 21.11", + }, + ) - expect(result.strip).to eq(expected) + expect(result).to eq(["Here is the calculation:", tool_call]) + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(345) + expect(log.response_tokens).to eq(65) end it "can send images via a completion prompt" do diff --git a/spec/lib/completions/endpoints/aws_bedrock_spec.rb b/spec/lib/completions/endpoints/aws_bedrock_spec.rb index d9519344..2a9cc77f 100644 --- a/spec/lib/completions/endpoints/aws_bedrock_spec.rb +++ b/spec/lib/completions/endpoints/aws_bedrock_spec.rb @@ -79,7 +79,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do } prompt.tools = [tool] - response = +"" + response = [] proxy.generate(prompt, user: user) { |partial| response << partial } expect(request.headers["Authorization"]).to be_present @@ -90,21 +90,18 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do expect(parsed_body["tools"]).to eq(nil) expect(parsed_body["stop_sequences"]).to eq([""]) - # note we now have a tool_id cause we were normalized - function_call = <<~XML.strip - hello + expected = [ + "hello\n", + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "google", + parameters: { + query: "sydney weather today", + }, + ), + ] - - - - google - sydney weather today - tool_0 - - - XML - - expect(response.strip).to eq(function_call) + expect(response).to eq(expected) end end @@ -230,23 +227,23 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do } prompt.tools = [tool] - response = +"" + response = [] proxy.generate(prompt, user: user) { |partial| response << partial } expect(request.headers["Authorization"]).to be_present expect(request.headers["X-Amz-Content-Sha256"]).to be_present - expected_response = (<<~RESPONSE).strip - - - google - sydney weather today - toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7 - - - RESPONSE + expected_response = [ + DiscourseAi::Completions::ToolCall.new( + id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7", + name: "google", + parameters: { + query: "sydney weather today", + }, + ), + ] - expect(response.strip).to eq(expected_response) + expect(response).to eq(expected_response) expected = { "max_tokens" => 3000, diff --git a/spec/lib/completions/endpoints/cohere_spec.rb b/spec/lib/completions/endpoints/cohere_spec.rb index 4bb213ff..bdff8fc3 100644 --- a/spec/lib/completions/endpoints/cohere_spec.rb +++ b/spec/lib/completions/endpoints/cohere_spec.rb @@ -66,7 +66,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do TEXT parsed_body = nil - result = +"" + result = [] sig = { name: "google", @@ -91,21 +91,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do }, ).to_return(status: 200, body: body.split("|")) - result = llm.generate(prompt, user: user) { |partial, cancel| result << partial } + llm.generate(prompt, user: user) { |partial, cancel| result << partial } end - expected = <<~TEXT - - - google - who is sam saffron - - tool_0 - - - TEXT + text = "I will search for 'who is sam saffron' and relay the information to the user." + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "google", + parameters: { + query: "who is sam saffron", + }, + ) - expect(result.strip).to eq(expected.strip) + expect(result).to eq([text, tool_call]) expected = { model: "command-r-plus", diff --git a/spec/lib/completions/endpoints/endpoint_compliance.rb b/spec/lib/completions/endpoints/endpoint_compliance.rb index 372c529b..130c735b 100644 --- a/spec/lib/completions/endpoints/endpoint_compliance.rb +++ b/spec/lib/completions/endpoints/endpoint_compliance.rb @@ -62,18 +62,14 @@ class EndpointMock end def invocation_response - <<~TEXT - - - get_weather - - Sydney - c - - tool_0 - - - TEXT + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "get_weather", + parameters: { + location: "Sydney", + unit: "c", + }, + ) end def tool_id @@ -185,7 +181,7 @@ class EndpointsCompliance mock.stub_tool_call(a_dialect.translate) completion_response = endpoint.perform_completion!(a_dialect, user) - expect(completion_response.strip).to eq(mock.invocation_response.strip) + expect(completion_response).to eq(mock.invocation_response) end def streaming_mode_simple_prompt(mock) @@ -205,6 +201,7 @@ class EndpointsCompliance expect(log.raw_request_payload).to be_present expect(log.raw_response_payload).to be_present expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate)) + expect(log.response_tokens).to eq( endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join), ) @@ -216,14 +213,14 @@ class EndpointsCompliance a_dialect = dialect(prompt: prompt) mock.stub_streamed_tool_call(a_dialect.translate) do - buffered_partial = +"" + buffered_partial = [] endpoint.perform_completion!(a_dialect, user) do |partial, cancel| buffered_partial << partial - cancel.call if buffered_partial.include?("") + cancel.call if partial.is_a?(DiscourseAi::Completions::ToolCall) end - expect(buffered_partial.strip).to eq(mock.invocation_response.strip) + expect(buffered_partial).to eq([mock.invocation_response]) end end diff --git a/spec/lib/completions/endpoints/gemini_spec.rb b/spec/lib/completions/endpoints/gemini_spec.rb index 2f602d3a..18933843 100644 --- a/spec/lib/completions/endpoints/gemini_spec.rb +++ b/spec/lib/completions/endpoints/gemini_spec.rb @@ -195,19 +195,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do response = llm.generate(prompt, user: user) - expected = (<<~XML).strip - - - echo - - <S>ydney - - tool_0 - - - XML + tool = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "echo", + parameters: { + text: "ydney", + }, + ) - expect(response.strip).to eq(expected) + expect(response).to eq(tool) end it "Supports Vision API" do @@ -265,6 +262,68 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do expect(JSON.parse(req_body)).to eq(expected_prompt) end + it "Can stream tool calls correctly" do + rows = [ + { + candidates: [ + { + content: { + parts: [{ functionCall: { name: "echo", args: { text: "sam<>wh!s" } } }], + role: "model", + }, + safetyRatings: [ + { category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" }, + { category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" }, + ], + }, + ], + usageMetadata: { + promptTokenCount: 625, + totalTokenCount: 625, + }, + modelVersion: "gemini-1.5-pro-002", + }, + { + candidates: [{ content: { parts: [{ text: "" }], role: "model" }, finishReason: "STOP" }], + usageMetadata: { + promptTokenCount: 625, + candidatesTokenCount: 4, + totalTokenCount: 629, + }, + modelVersion: "gemini-1.5-pro-002", + }, + ] + + payload = rows.map { |r| "data: #{r.to_json}\n\n" }.join + + llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") + url = "#{model.url}:streamGenerateContent?alt=sse&key=123" + + prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool]) + + output = [] + + stub_request(:post, url).to_return(status: 200, body: payload) + llm.generate(prompt, user: user) { |partial| output << partial } + + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "echo", + parameters: { + text: "sam<>wh!s", + }, + ) + + expect(output).to eq([tool_call]) + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(625) + expect(log.response_tokens).to eq(4) + end + it "Can correctly handle streamed responses even if they are chunked badly" do data = +"" data << "da|ta: |" @@ -279,12 +338,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}") url = "#{model.url}:streamGenerateContent?alt=sse&key=123" - output = +"" + output = [] gemini_mock.with_chunk_array_support do stub_request(:post, url).to_return(status: 200, body: split) llm.generate("Hello", user: user) { |partial| output << partial } end - expect(output).to eq("Hello World Sam") + expect(output.join).to eq("Hello World Sam") end end diff --git a/spec/lib/completions/endpoints/ollama_spec.rb b/spec/lib/completions/endpoints/ollama_spec.rb index eb6bc63c..4f458283 100644 --- a/spec/lib/completions/endpoints/ollama_spec.rb +++ b/spec/lib/completions/endpoints/ollama_spec.rb @@ -150,7 +150,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do end describe "when using streaming mode" do - context "with simpel prompts" do + context "with simple prompts" do it "completes a trivial prompt and logs the response" do compliance.streaming_mode_simple_prompt(ollama_mock) end diff --git a/spec/lib/completions/endpoints/open_ai_spec.rb b/spec/lib/completions/endpoints/open_ai_spec.rb index 60df1d67..c4d7758a 100644 --- a/spec/lib/completions/endpoints/open_ai_spec.rb +++ b/spec/lib/completions/endpoints/open_ai_spec.rb @@ -17,8 +17,8 @@ class OpenAiMock < EndpointMock created: 1_678_464_820, model: "gpt-3.5-turbo-0301", usage: { - prompt_tokens: 337, - completion_tokens: 162, + prompt_tokens: 8, + completion_tokens: 13, total_tokens: 499, }, choices: [ @@ -231,19 +231,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do result = llm.generate(prompt, user: user) - expected = (<<~TXT).strip - - - echo - - hello - - call_I8LKnoijVuhKOM85nnEQgWwd - - - TXT + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "call_I8LKnoijVuhKOM85nnEQgWwd", + name: "echo", + parameters: { + text: "hello", + }, + ) - expect(result.strip).to eq(expected) + expect(result).to eq(tool_call) stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( body: { choices: [message: { content: "OK" }] }.to_json, @@ -320,19 +317,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } }) - expected = (<<~TXT).strip - - - echo - - h<e>llo - - call_I8LKnoijVuhKOM85nnEQgWwd - - - TXT + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(55) + expect(log.response_tokens).to eq(13) - expect(result.strip).to eq(expected) + expected = + DiscourseAi::Completions::ToolCall.new( + id: "call_I8LKnoijVuhKOM85nnEQgWwd", + name: "echo", + parameters: { + text: "hllo", + }, + ) + + expect(result).to eq(expected) stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return( body: { choices: [message: { content: "OK" }] }.to_json, @@ -487,7 +485,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]} - data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot\\"}"}}]},"logprobs":null,"finish_reason":null}]} + data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot2\\"}"}}]},"logprobs":null,"finish_reason":null}]} data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]} @@ -495,32 +493,30 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do TEXT open_ai_mock.stub_raw(raw_data) - content = +"" + response = [] dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools)) - endpoint.perform_completion!(dialect, user) { |partial| content << partial } + endpoint.perform_completion!(dialect, user) { |partial| response << partial } - expected = <<~TEXT - - - search - - Discourse AI bot - - call_3Gyr3HylFJwfrtKrL6NaIit1 - - - search - - Discourse AI bot - - call_H7YkbgYurHpyJqzwUN4bghwN - - - TEXT + tool_calls = [ + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_3Gyr3HylFJwfrtKrL6NaIit1", + parameters: { + search_query: "Discourse AI bot", + }, + ), + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_H7YkbgYurHpyJqzwUN4bghwN", + parameters: { + query: "Discourse AI bot2", + }, + ), + ] - expect(content).to eq(expected) + expect(response).to eq(tool_calls) end it "uses proper token accounting" do @@ -593,21 +589,16 @@ TEXT dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools)) endpoint.perform_completion!(dialect, user) { |partial| partials << partial } - expect(partials.length).to eq(1) + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "func_id", + name: "google", + parameters: { + query: "Adabas 9.1", + }, + ) - function_call = (<<~TXT).strip - - - google - - Adabas 9.1 - - func_id - - - TXT - - expect(partials[0].strip).to eq(function_call) + expect(partials).to eq([tool_call]) end end end diff --git a/spec/lib/completions/endpoints/samba_nova_spec.rb b/spec/lib/completions/endpoints/samba_nova_spec.rb index 0f1f68ac..83839bf4 100644 --- a/spec/lib/completions/endpoints/samba_nova_spec.rb +++ b/spec/lib/completions/endpoints/samba_nova_spec.rb @@ -22,10 +22,15 @@ data: [DONE] }, ).to_return(status: 200, body: body, headers: {}) - response = +"" + response = [] llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial } - expect(response).to eq("I am a bot") + expect(response).to eq(["I am a bot"]) + + log = AiApiAuditLog.order(:id).last + + expect(log.request_tokens).to eq(21) + expect(log.response_tokens).to eq(41) end it "can perform regular completions" do diff --git a/spec/lib/completions/endpoints/vllm_spec.rb b/spec/lib/completions/endpoints/vllm_spec.rb index 6f5387c0..824bcbe0 100644 --- a/spec/lib/completions/endpoints/vllm_spec.rb +++ b/spec/lib/completions/endpoints/vllm_spec.rb @@ -51,7 +51,13 @@ class VllmMock < EndpointMock WebMock .stub_request(:post, "https://test.dev/v1/chat/completions") - .with(body: model.default_options.merge(messages: prompt, stream: true).to_json) + .with( + body: + model + .default_options + .merge(messages: prompt, stream: true, stream_options: { include_usage: true }) + .to_json, + ) .to_return(status: 200, body: chunks) end end @@ -136,29 +142,115 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do result = llm.generate(prompt, user: Discourse.system_user) - expected = <<~TEXT - - - calculate - - 1+1 - tool_0 - - - TEXT + expected = + DiscourseAi::Completions::ToolCall.new( + name: "calculate", + id: "tool_0", + parameters: { + expression: "1+1", + }, + ) - expect(result.strip).to eq(expected.strip) + expect(result).to eq(expected) end end + it "correctly accounts for tokens in non streaming mode" do + body = (<<~TEXT).strip + {"id":"chat-c580e4a9ebaa44a0becc802ed5dc213a","object":"chat.completion","created":1731294404,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Random Number Generator Produces Smallest Possible Result","tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":146,"total_tokens":156,"completion_tokens":10},"prompt_logprobs":null} + TEXT + + stub_request(:post, "https://test.dev/v1/chat/completions").to_return(status: 200, body: body) + + result = llm.generate("generate a title", user: Discourse.system_user) + + expect(result).to eq("Random Number Generator Produces Smallest Possible Result") + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(146) + expect(log.response_tokens).to eq(10) + end + + it "can properly include usage in streaming mode" do + payload = <<~TEXT.strip + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":46,"completion_tokens":0}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":47,"completion_tokens":1}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Sam"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":48,"completion_tokens":2}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":49,"completion_tokens":3}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" It"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":50,"completion_tokens":4}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"'s"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":51,"completion_tokens":5}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" nice"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":52,"completion_tokens":6}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":53,"completion_tokens":7}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" meet"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":54,"completion_tokens":8}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":55,"completion_tokens":9}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":56,"completion_tokens":10}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Is"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":57,"completion_tokens":11}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" there"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":58,"completion_tokens":12}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" something"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":59,"completion_tokens":13}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":60,"completion_tokens":14}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":61,"completion_tokens":15}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" help"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":62,"completion_tokens":16}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":63,"completion_tokens":17}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" with"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":64,"completion_tokens":18}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" or"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":65,"completion_tokens":19}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" would"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":66,"completion_tokens":20}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":67,"completion_tokens":21}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" like"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":68,"completion_tokens":22}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":69,"completion_tokens":23}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" chat"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":70,"completion_tokens":24}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":71,"completion_tokens":25}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":""},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}} + + data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}} + + data: [DONE] + TEXT + + stub_request(:post, "https://test.dev/v1/chat/completions").to_return( + status: 200, + body: payload, + ) + + response = [] + llm.generate("say hello", user: Discourse.system_user) { |partial| response << partial } + + expect(response.join).to eq( + "Hello Sam. It's nice to meet you. Is there something I can help you with or would you like to chat?", + ) + + log = AiApiAuditLog.order(:id).last + expect(log.request_tokens).to eq(46) + expect(log.response_tokens).to eq(26) + end + describe "#perform_completion!" do context "when using regular mode" do - context "with simple prompts" do - it "completes a trivial prompt and logs the response" do - compliance.regular_mode_simple_prompt(vllm_mock) - end - end - context "with tools" do it "returns a function invocation" do compliance.regular_mode_tools(vllm_mock) diff --git a/spec/lib/completions/function_call_normalizer_spec.rb b/spec/lib/completions/function_call_normalizer_spec.rb deleted file mode 100644 index dd78ed7f..00000000 --- a/spec/lib/completions/function_call_normalizer_spec.rb +++ /dev/null @@ -1,182 +0,0 @@ -# frozen_string_literal: true - -RSpec.describe DiscourseAi::Completions::FunctionCallNormalizer do - let(:buffer) { +"" } - - let(:normalizer) do - blk = ->(data, cancel) { buffer << data } - cancel = -> { @done = true } - DiscourseAi::Completions::FunctionCallNormalizer.new(blk, cancel) - end - - def pass_through!(data) - normalizer << data - expect(buffer[-data.length..-1]).to eq(data) - end - - it "is usable in non streaming mode" do - xml = (<<~XML).strip - hello - - - hello - - XML - - text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml) - - expect(text).to eq("hello") - - expected_function_calls = (<<~XML).strip - - - hello - tool_0 - - - XML - - expect(function_calls).to eq(expected_function_calls) - end - - it "strips junk from end of function calls" do - xml = (<<~XML).strip - hello - - - hello - - junk - XML - - _text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml) - - expected_function_calls = (<<~XML).strip - - - hello - tool_0 - - - XML - - expect(function_calls).to eq(expected_function_calls) - end - - it "returns nil for function calls if there are none" do - input = "hello world\n" - text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(input) - - expect(text).to eq(input) - expect(function_calls).to eq(nil) - end - - it "passes through data if there are no function calls detected" do - pass_through!("hello") - pass_through!("hello") - pass_through!("world") - pass_through!("") - end - - it "properly handles non English tools" do - normalizer << "hello\n" - - normalizer << (<<~XML).strip - - hello - - 世界 - - - XML - - expected = (<<~XML).strip - - - hello - - 世界 - - tool_0 - - - XML - - function_calls = normalizer.function_calls - expect(function_calls).to eq(expected) - end - - it "works correctly even if you only give it 1 letter at a time" do - xml = (<<~XML).strip - abc - - - hello - - world - - abc - - - hello2 - - world - - aba - - - XML - - xml.each_char { |char| normalizer << char } - - expect(buffer + normalizer.function_calls).to eq(xml) - end - - it "supports multiple invokes" do - xml = (<<~XML).strip - - - hello - - world - - abc - - - hello2 - - world - - aba - - - XML - - normalizer << xml - - expect(normalizer.function_calls).to eq(xml) - end - - it "can will cancel if it encounteres " do - normalizer << "" - expect(normalizer.done).to eq(false) - normalizer << "" - expect(normalizer.done).to eq(true) - expect(@done).to eq(true) - - expect(normalizer.function_calls).to eq("") - end - - it "pauses on function call and starts buffering" do - normalizer << "hello" - expect(buffer).to eq("hello") - expect(normalizer.done).to eq(false) - end -end diff --git a/spec/lib/completions/json_stream_decoder_spec.rb b/spec/lib/completions/json_stream_decoder_spec.rb new file mode 100644 index 00000000..831bad6f --- /dev/null +++ b/spec/lib/completions/json_stream_decoder_spec.rb @@ -0,0 +1,47 @@ +# frozen_string_literal: true + +describe DiscourseAi::Completions::JsonStreamDecoder do + let(:decoder) { DiscourseAi::Completions::JsonStreamDecoder.new } + + it "should be able to parse simple messages" do + result = decoder << "data: #{{ hello: "world" }.to_json}" + expect(result).to eq([{ hello: "world" }]) + end + + it "should handle anthropic mixed stlye streams" do + stream = (<<~TEXT).split("|") + event: |message_start| + data: |{"hel|lo": "world"}| + + event: |message_start + data: {"foo": "bar"} + + event: |message_start + data: {"ba|z": "qux"|} + + [DONE] + TEXT + + results = [] + stream.each { |chunk| results << (decoder << chunk) } + + expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }]) + end + + it "should be able to handle complex overlaps" do + stream = (<<~TEXT).split("|") + data: |{"hel|lo": "world"} + + data: {"foo": "bar"} + + data: {"ba|z": "qux"|} + + [DONE] + TEXT + + results = [] + stream.each { |chunk| results << (decoder << chunk) } + + expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }]) + end +end diff --git a/spec/lib/completions/xml_tool_processor_spec.rb b/spec/lib/completions/xml_tool_processor_spec.rb new file mode 100644 index 00000000..003f4356 --- /dev/null +++ b/spec/lib/completions/xml_tool_processor_spec.rb @@ -0,0 +1,188 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Completions::XmlToolProcessor do + let(:processor) { DiscourseAi::Completions::XmlToolProcessor.new } + + it "can process simple text" do + result = [] + result << (processor << "hello") + result << (processor << " world ") + expect(result).to eq([["hello"], [" world "]]) + expect(processor.finish).to eq([]) + expect(processor.should_cancel?).to eq(false) + end + + it "is usable for simple single message mode" do + xml = (<<~XML).strip + hello + + + hello + + world + value + + + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "hello", + parameters: { + hello: "world", + test: "value", + }, + ) + expect(result).to eq([["hello"], [tool_call]]) + expect(processor.should_cancel?).to eq(false) + end + + it "handles multiple tool calls in sequence" do + xml = (<<~XML).strip + start + + + first_tool + + value1 + + + + second_tool + + value2 + + + + end + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + first_tool = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "first_tool", + parameters: { + param1: "value1", + }, + ) + + second_tool = + DiscourseAi::Completions::ToolCall.new( + id: "tool_1", + name: "second_tool", + parameters: { + param2: "value2", + }, + ) + + expect(result).to eq([["start"], [first_tool, second_tool]]) + expect(processor.should_cancel?).to eq(true) + end + + it "handles non-English parameters correctly" do + xml = (<<~XML).strip + こんにちは + + + translator + + 世界 + + + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + tool_call = + DiscourseAi::Completions::ToolCall.new( + id: "tool_0", + name: "translator", + parameters: { + text: "世界", + }, + ) + + expect(result).to eq([["こんにちは"], [tool_call]]) + end + + it "processes input character by character" do + xml = + "hitest

v

" + + result = [] + xml.each_char { |char| result << (processor << char) } + result << processor.finish + + tool_call = + DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { p: "v" }) + + filtered_result = result.reject(&:empty?) + expect(filtered_result).to eq([["h"], ["i"], [tool_call]]) + end + + it "handles malformed XML gracefully" do + xml = (<<~XML).strip + text + + + test + + value + + + malformed + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + # Should just do its best to parse the XML + tool_call = + DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { param: "" }) + expect(result).to eq([["text"], [tool_call]]) + end + + it "correctly processes empty parameter sets" do + xml = (<<~XML).strip + hello + + + no_params + + + + XML + + result = [] + result << (processor << xml) + result << (processor.finish) + + tool_call = + DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "no_params", parameters: {}) + + expect(result).to eq([["hello"], [tool_call]]) + end + + it "properly handles cancelled processing" do + xml = "start" + result = [] + result << (processor << xml) + result << (processor << "more text") + result << processor.finish + + expect(result).to eq([["start"], [], []]) + expect(processor.should_cancel?).to eq(true) + end +end diff --git a/spec/lib/modules/ai_bot/personas/persona_spec.rb b/spec/lib/modules/ai_bot/personas/persona_spec.rb index 2fb95d19..5271374d 100644 --- a/spec/lib/modules/ai_bot/personas/persona_spec.rb +++ b/spec/lib/modules/ai_bot/personas/persona_spec.rb @@ -72,40 +72,27 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do it "can parse string that are wrapped in quotes" do SiteSetting.ai_stability_api_key = "123" - xml = <<~XML - - - image - call_JtYQMful5QKqw97XFsHzPweB - - ["cat oil painting", "big car"] - "16:9" - - - - image - call_JtYQMful5QKqw97XFsHzPweB - - ["cat oil painting", "big car"] - '16:9' - - - - XML - image1, image2 = - tools = - DiscourseAi::AiBot::Personas::Artist.new.find_tools( - xml, - bot_user: nil, - llm: nil, - context: nil, - ) - expect(image1.parameters[:prompts]).to eq(["cat oil painting", "big car"]) - expect(image1.parameters[:aspect_ratio]).to eq("16:9") - expect(image2.parameters[:aspect_ratio]).to eq("16:9") + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "image", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + prompts: ["cat oil painting", "big car"], + aspect_ratio: "16:9", + }, + ) - expect(tools.length).to eq(2) + tool_instance = + DiscourseAi::AiBot::Personas::Artist.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) + + expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"]) + expect(tool_instance.parameters[:aspect_ratio]).to eq("16:9") end it "enforces enums" do @@ -132,42 +119,68 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do XML - search1, search2 = - tools = - DiscourseAi::AiBot::Personas::General.new.find_tools( - xml, - bot_user: nil, - llm: nil, - context: nil, - ) + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + max_posts: "3.2", + status: "cow", + foo: "bar", + }, + ) - expect(search1.parameters.key?(:status)).to eq(false) - expect(search2.parameters[:status]).to eq("open") + tool_instance = + DiscourseAi::AiBot::Personas::General.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) + + expect(tool_instance.parameters.key?(:status)).to eq(false) + + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + max_posts: "3.2", + status: "open", + foo: "bar", + }, + ) + + tool_instance = + DiscourseAi::AiBot::Personas::General.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) + + expect(tool_instance.parameters[:status]).to eq("open") end it "can coerce integers" do - xml = <<~XML - - - search - call_JtYQMful5QKqw97XFsHzPweB - - "3.2" - hello world - bar - - - - XML + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + max_posts: "3.2", + search_query: "hello world", + foo: "bar", + }, + ) - search, = - tools = - DiscourseAi::AiBot::Personas::General.new.find_tools( - xml, - bot_user: nil, - llm: nil, - context: nil, - ) + search = + DiscourseAi::AiBot::Personas::General.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) expect(search.parameters[:max_posts]).to eq(3) expect(search.parameters[:search_query]).to eq("hello world") @@ -177,43 +190,23 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do it "can correctly parse arrays in tools" do SiteSetting.ai_openai_api_key = "123" - # Dall E tool uses an array for params - xml = <<~XML - - - dall_e - call_JtYQMful5QKqw97XFsHzPweB - - ["cat oil painting", "big car"] - - - - dall_e - abc - - ["pic3"] - - - - unknown - abc - - ["pic3"] - - - - XML - dall_e1, dall_e2 = - tools = - DiscourseAi::AiBot::Personas::DallE3.new.find_tools( - xml, - bot_user: nil, - llm: nil, - context: nil, - ) - expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"]) - expect(dall_e2.parameters[:prompts]).to eq(["pic3"]) - expect(tools.length).to eq(2) + tool_call = + DiscourseAi::Completions::ToolCall.new( + name: "dall_e", + id: "call_JtYQMful5QKqw97XFsHzPweB", + parameters: { + prompts: ["cat oil painting", "big car"], + }, + ) + + tool_instance = + DiscourseAi::AiBot::Personas::DallE3.new.find_tool( + tool_call, + bot_user: nil, + llm: nil, + context: nil, + ) + expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"]) end describe "custom personas" do diff --git a/spec/lib/modules/ai_bot/playground_spec.rb b/spec/lib/modules/ai_bot/playground_spec.rb index 9c98a08a..2a07ad52 100644 --- a/spec/lib/modules/ai_bot/playground_spec.rb +++ b/spec/lib/modules/ai_bot/playground_spec.rb @@ -55,6 +55,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do ) end + before { SiteSetting.ai_embeddings_enabled = false } + after do # we must reset cache on persona cause data can be rolled back AiPersona.persona_cache.flush! @@ -83,17 +85,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do end let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) } - let(:function_call) { (<<~XML).strip } - - - search - 666 - - Can you use the custom tool - - - ", - XML + let(:tool_call) do + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "666", + parameters: { + query: "Can you use the custom tool", + }, + ) + end let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) } @@ -115,7 +115,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do reply_post = nil prompts = nil - responses = [function_call] + responses = [tool_call] DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts| new_post = Fabricate(:post, raw: "Can you use the custom tool?") reply_post = playground.reply_to(new_post) @@ -133,7 +133,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do it "can force usage of a tool" do tool_name = "custom-#{custom_tool.id}" ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1) - responses = [function_call, "custom tool did stuff (maybe)"] + responses = [tool_call, "custom tool did stuff (maybe)"] prompts = nil reply_post = nil @@ -166,7 +166,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) playground = DiscourseAi::AiBot::Playground.new(bot) - responses = [function_call, "custom tool did stuff (maybe)"] + responses = [tool_call, "custom tool did stuff (maybe)"] reply_post = nil @@ -206,13 +206,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new) playground = DiscourseAi::AiBot::Playground.new(bot) + responses = ["custom tool did stuff (maybe)", tool_call] + # lets ensure tool does not run... DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt| new_post = Fabricate(:post, raw: "Can you use the custom tool?") reply_post = playground.reply_to(new_post) end - expect(reply_post.raw.strip).to eq(function_call) + expect(reply_post.raw.strip).to eq("custom tool did stuff (maybe)") end end @@ -452,10 +454,25 @@ RSpec.describe DiscourseAi::AiBot::Playground do it "can run tools" do persona.update!(tools: ["Time"]) - responses = [ - "timetimeBuenos Aires", - "The time is 2023-12-14 17:24:00 -0300", - ] + tool_call1 = + DiscourseAi::Completions::ToolCall.new( + name: "time", + id: "time", + parameters: { + timezone: "Buenos Aires", + }, + ) + + tool_call2 = + DiscourseAi::Completions::ToolCall.new( + name: "time", + id: "time", + parameters: { + timezone: "Sydney", + }, + ) + + responses = [[tool_call1, tool_call2], "The time is 2023-12-14 17:24:00 -0300"] message = DiscourseAi::Completions::Llm.with_prepared_responses(responses) do @@ -470,7 +487,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do # it also needs to have tool details now set on message prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id) - expect(prompt.custom_prompt.length).to eq(3) + + expect(prompt.custom_prompt.length).to eq(5) # TODO in chat I am mixed on including this in the context, but I guess maybe? # thinking about this @@ -782,30 +800,29 @@ RSpec.describe DiscourseAi::AiBot::Playground do end it "supports multiple function calls" do - response1 = (<<~TXT).strip - - - search - search - - testing various things - - - - search - search - - another search - - - - TXT + tool_call1 = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "search", + parameters: { + search_query: "testing various things", + }, + ) + + tool_call2 = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "search", + parameters: { + search_query: "another search", + }, + ) response2 = "I found stuff" - DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do - playground.reply_to(third_post) - end + DiscourseAi::Completions::Llm.with_prepared_responses( + [[tool_call1, tool_call2], response2], + ) { playground.reply_to(third_post) } last_post = third_post.topic.reload.posts.order(:post_number).last @@ -819,17 +836,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona.class_instance.new) playground = described_class.new(bot) - response1 = (<<~TXT).strip - - - search - search - - testing various things - - - - TXT + response1 = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "search", + parameters: { + search_query: "testing various things", + }, + ) response2 = "I found stuff" @@ -843,17 +857,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do end it "does not include placeholders in conversation context but includes all completions" do - response1 = (<<~TXT).strip - - - search - search - - testing various things - - - - TXT + response1 = + DiscourseAi::Completions::ToolCall.new( + name: "search", + id: "search", + parameters: { + search_query: "testing various things", + }, + ) response2 = "I found some really amazing stuff!" @@ -889,17 +900,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do [{ b64_json: image, revised_prompt: "a pink cow 1" }] end - let(:response) { (<<~TXT).strip } - - - dall_e - dall_e - - ["a pink cow"] - - - - TXT + let(:response) do + DiscourseAi::Completions::ToolCall.new( + name: "dall_e", + id: "dall_e", + parameters: { + prompts: ["a pink cow"], + }, + ) + end it "properly returns an image when skipping tool details" do persona.update!(tool_details: false) diff --git a/spec/requests/admin/ai_personas_controller_spec.rb b/spec/requests/admin/ai_personas_controller_spec.rb index fb42506e..16e0001b 100644 --- a/spec/requests/admin/ai_personas_controller_spec.rb +++ b/spec/requests/admin/ai_personas_controller_spec.rb @@ -541,16 +541,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do expect(topic.title).to eq("An amazing title") expect(topic.posts.count).to eq(2) - # now let's try to make a reply with a tool call - function_call = <<~XML - - - categories - - - XML + tool_call = + DiscourseAi::Completions::ToolCall.new(name: "categories", parameters: {}, id: "tool_1") - fake_endpoint.fake_content = [function_call, "this is the response after the tool"] + fake_endpoint.fake_content = [tool_call, "this is the response after the tool"] # this simplifies function calls fake_endpoint.chunk_count = 1 diff --git a/spec/requests/ai_bot/bot_controller_spec.rb b/spec/requests/ai_bot/bot_controller_spec.rb index e7430185..007e868b 100644 --- a/spec/requests/ai_bot/bot_controller_spec.rb +++ b/spec/requests/ai_bot/bot_controller_spec.rb @@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiBot::BotController do fab!(:user) fab!(:pm_topic) { Fabricate(:private_message_topic) } fab!(:pm_post) { Fabricate(:post, topic: pm_topic) } + fab!(:pm_post2) { Fabricate(:post, topic: pm_topic) } + fab!(:pm_post3) { Fabricate(:post, topic: pm_topic) } before { sign_in(user) } @@ -22,15 +24,37 @@ RSpec.describe DiscourseAi::AiBot::BotController do user = pm_topic.topic_allowed_users.first.user sign_in(user) - AiApiAuditLog.create!( - post_id: pm_post.id, - provider_id: 1, - topic_id: pm_topic.id, - raw_request_payload: "request", - raw_response_payload: "response", - request_tokens: 1, - response_tokens: 2, - ) + log1 = + AiApiAuditLog.create!( + provider_id: 1, + topic_id: pm_topic.id, + raw_request_payload: "request", + raw_response_payload: "response", + request_tokens: 1, + response_tokens: 2, + ) + + log2 = + AiApiAuditLog.create!( + post_id: pm_post.id, + provider_id: 1, + topic_id: pm_topic.id, + raw_request_payload: "request", + raw_response_payload: "response", + request_tokens: 1, + response_tokens: 2, + ) + + log3 = + AiApiAuditLog.create!( + post_id: pm_post2.id, + provider_id: 1, + topic_id: pm_topic.id, + raw_request_payload: "request", + raw_response_payload: "response", + request_tokens: 1, + response_tokens: 2, + ) Group.refresh_automatic_groups! SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s @@ -38,18 +62,26 @@ RSpec.describe DiscourseAi::AiBot::BotController do get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info" expect(response.status).to eq(200) + expect(response.parsed_body["id"]).to eq(log2.id) + expect(response.parsed_body["next_log_id"]).to eq(log3.id) + expect(response.parsed_body["prev_log_id"]).to eq(log1.id) + expect(response.parsed_body["topic_id"]).to eq(pm_topic.id) + expect(response.parsed_body["request_tokens"]).to eq(1) expect(response.parsed_body["response_tokens"]).to eq(2) expect(response.parsed_body["raw_request_payload"]).to eq("request") expect(response.parsed_body["raw_response_payload"]).to eq("response") - post2 = Fabricate(:post, topic: pm_topic) - # return previous post if current has no debug info - get "/discourse-ai/ai-bot/post/#{post2.id}/show-debug-info" + get "/discourse-ai/ai-bot/post/#{pm_post3.id}/show-debug-info" expect(response.status).to eq(200) expect(response.parsed_body["request_tokens"]).to eq(1) expect(response.parsed_body["response_tokens"]).to eq(2) + + # can return debug info by id as well + get "/discourse-ai/ai-bot/show-debug-info/#{log1.id}" + expect(response.status).to eq(200) + expect(response.parsed_body["id"]).to eq(log1.id) end end