FEATURE: improve tool support
This work in progress PR amends llm completion so it returns objects for tools vs XML fragments This will empower future features such as parameter streaming XML was error prone, object implementation is more robust Still very much in progress, a lot of code needs to change Partially implemented on Anthropic at the moment.
This commit is contained in:
parent
1ad5321c09
commit
bb6df426ae
|
@ -6,6 +6,16 @@ module DiscourseAi
|
|||
requires_plugin ::DiscourseAi::PLUGIN_NAME
|
||||
requires_login
|
||||
|
||||
def show_debug_info_by_id
|
||||
log = AiApiAuditLog.find(params[:id])
|
||||
if !log.topic
|
||||
raise Discourse::NotFound
|
||||
end
|
||||
|
||||
guardian.ensure_can_debug_ai_bot_conversation!(log.topic)
|
||||
render json: AiApiAuditLogSerializer.new(log, root: false), status: 200
|
||||
end
|
||||
|
||||
def show_debug_info
|
||||
post = Post.find(params[:post_id])
|
||||
guardian.ensure_can_debug_ai_bot_conversation!(post)
|
||||
|
|
|
@ -14,6 +14,14 @@ class AiApiAuditLog < ActiveRecord::Base
|
|||
Ollama = 7
|
||||
SambaNova = 8
|
||||
end
|
||||
|
||||
def next_log_id
|
||||
self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first
|
||||
end
|
||||
|
||||
def prev_log_id
|
||||
self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first
|
||||
end
|
||||
end
|
||||
|
||||
# == Schema Information
|
||||
|
|
|
@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer
|
|||
:post_id,
|
||||
:feature_name,
|
||||
:language_model,
|
||||
:created_at
|
||||
:created_at,
|
||||
:prev_log_id,
|
||||
:next_log_id
|
||||
end
|
||||
|
|
|
@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template";
|
|||
import DButton from "discourse/components/d-button";
|
||||
import DModal from "discourse/components/d-modal";
|
||||
import { ajax } from "discourse/lib/ajax";
|
||||
import { popupAjaxError } from "discourse/lib/ajax-error";
|
||||
import { clipboardCopy, escapeExpression } from "discourse/lib/utilities";
|
||||
import i18n from "discourse-common/helpers/i18n";
|
||||
import discourseLater from "discourse-common/lib/later";
|
||||
|
@ -63,6 +64,28 @@ export default class DebugAiModal extends Component {
|
|||
this.copy(this.info.raw_response_payload);
|
||||
}
|
||||
|
||||
async loadLog(logId) {
|
||||
try {
|
||||
await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then(
|
||||
(result) => {
|
||||
this.info = result;
|
||||
}
|
||||
);
|
||||
} catch (e) {
|
||||
popupAjaxError(e);
|
||||
}
|
||||
}
|
||||
|
||||
@action
|
||||
prevLog() {
|
||||
this.loadLog(this.info.prev_log_id);
|
||||
}
|
||||
|
||||
@action
|
||||
nextLog() {
|
||||
this.loadLog(this.info.next_log_id);
|
||||
}
|
||||
|
||||
copy(text) {
|
||||
clipboardCopy(text);
|
||||
this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared");
|
||||
|
@ -77,6 +100,8 @@ export default class DebugAiModal extends Component {
|
|||
`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`
|
||||
).then((result) => {
|
||||
this.info = result;
|
||||
}).catch((e) => {
|
||||
popupAjaxError(e);
|
||||
});
|
||||
}
|
||||
|
||||
|
@ -147,6 +172,22 @@ export default class DebugAiModal extends Component {
|
|||
@action={{this.copyResponse}}
|
||||
@label="discourse_ai.ai_bot.debug_ai_modal.copy_response"
|
||||
/>
|
||||
{{#if this.info.prev_log_id}}
|
||||
<DButton
|
||||
class="btn"
|
||||
@icon="angles-left"
|
||||
@action={{this.prevLog}}
|
||||
@label="discourse_ai.ai_bot.debug_ai_modal.previous_log"
|
||||
/>
|
||||
{{/if}}
|
||||
{{#if this.info.next_log_id}}
|
||||
<DButton
|
||||
class="btn"
|
||||
@icon="angles-right"
|
||||
@action={{this.nextLog}}
|
||||
@label="discourse_ai.ai_bot.debug_ai_modal.next_log"
|
||||
/>
|
||||
{{/if}}
|
||||
<span class="ai-debut-modal__just-copied">{{this.justCopiedText}}</span>
|
||||
</:footer>
|
||||
</DModal>
|
||||
|
|
|
@ -415,6 +415,8 @@ en:
|
|||
response_tokens: "Response tokens:"
|
||||
request: "Request"
|
||||
response: "Response"
|
||||
next_log: "Next"
|
||||
previous_log: "Previous"
|
||||
|
||||
share_full_topic_modal:
|
||||
title: "Share Conversation Publicly"
|
||||
|
|
|
@ -22,6 +22,7 @@ DiscourseAi::Engine.routes.draw do
|
|||
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
|
||||
get "bot-username" => "bot#show_bot_username"
|
||||
get "post/:post_id/show-debug-info" => "bot#show_debug_info"
|
||||
get "show-debug-info/:id" => "bot#show_debug_info_by_id"
|
||||
post "post/:post_id/stop-streaming" => "bot#stop_streaming_response"
|
||||
end
|
||||
|
||||
|
|
|
@ -100,6 +100,7 @@ module DiscourseAi
|
|||
llm_kwargs[:top_p] = persona.top_p if persona.top_p
|
||||
|
||||
needs_newlines = false
|
||||
tools_ran = 0
|
||||
|
||||
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
||||
tool_found = false
|
||||
|
@ -107,9 +108,10 @@ module DiscourseAi
|
|||
|
||||
result =
|
||||
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
|
||||
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
|
||||
tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context)
|
||||
tool = nil if tools_ran >= MAX_TOOLS
|
||||
|
||||
if (tools.present?)
|
||||
if (tool.present?)
|
||||
tool_found = true
|
||||
# a bit hacky, but extra newlines do no harm
|
||||
if needs_newlines
|
||||
|
@ -117,10 +119,9 @@ module DiscourseAi
|
|||
needs_newlines = false
|
||||
end
|
||||
|
||||
tools[0..MAX_TOOLS].each do |tool|
|
||||
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
|
||||
ongoing_chain &&= tool.chain_next_response?
|
||||
end
|
||||
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
|
||||
tools_ran += 1
|
||||
ongoing_chain &&= tool.chain_next_response?
|
||||
else
|
||||
needs_newlines = true
|
||||
update_blk.call(partial, cancel)
|
||||
|
|
|
@ -199,23 +199,17 @@ module DiscourseAi
|
|||
prompt
|
||||
end
|
||||
|
||||
def find_tools(partial, bot_user:, llm:, context:)
|
||||
return [] if !partial.include?("</invoke>")
|
||||
|
||||
parsed_function = Nokogiri::HTML5.fragment(partial)
|
||||
parsed_function
|
||||
.css("invoke")
|
||||
.map do |fragment|
|
||||
tool_instance(fragment, bot_user: bot_user, llm: llm, context: context)
|
||||
end
|
||||
.compact
|
||||
def find_tool(partial, bot_user:, llm:, context:)
|
||||
return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall)
|
||||
tool_instance(partial, bot_user: bot_user, llm: llm, context: context)
|
||||
end
|
||||
|
||||
protected
|
||||
|
||||
def tool_instance(parsed_function, bot_user:, llm:, context:)
|
||||
function_id = parsed_function.at("tool_id")&.text
|
||||
function_name = parsed_function.at("tool_name")&.text
|
||||
def tool_instance(tool_call, bot_user:, llm:, context:)
|
||||
|
||||
function_id = tool_call.id
|
||||
function_name = tool_call.name
|
||||
return nil if function_name.nil?
|
||||
|
||||
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
|
||||
|
@ -224,7 +218,7 @@ module DiscourseAi
|
|||
arguments = {}
|
||||
tool_klass.signature[:parameters].to_a.each do |param|
|
||||
name = param[:name]
|
||||
value = parsed_function.at(name)&.text
|
||||
value = tool_call.parameters[name.to_sym]
|
||||
|
||||
if param[:type] == "array" && value
|
||||
value =
|
||||
|
|
|
@ -13,6 +13,11 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
|||
def append(json)
|
||||
@raw_json << json
|
||||
end
|
||||
|
||||
def to_tool_call
|
||||
parameters = JSON.parse(raw_json, symbolize_names: true)
|
||||
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
|
||||
end
|
||||
end
|
||||
|
||||
attr_reader :tool_calls, :input_tokens, :output_tokens
|
||||
|
@ -20,80 +25,70 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
|||
def initialize(streaming_mode:)
|
||||
@streaming_mode = streaming_mode
|
||||
@tool_calls = []
|
||||
@current_tool_call = nil
|
||||
end
|
||||
|
||||
def to_xml_tool_calls(function_buffer)
|
||||
return function_buffer if @tool_calls.blank?
|
||||
def to_tool_calls
|
||||
@tool_calls.map { |tool_call| tool_call.to_tool_call }
|
||||
end
|
||||
|
||||
function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
|
||||
<function_calls>
|
||||
</function_calls>
|
||||
TEXT
|
||||
|
||||
@tool_calls.each do |tool_call|
|
||||
node =
|
||||
function_buffer.at("function_calls").add_child(
|
||||
Nokogiri::HTML5::DocumentFragment.parse(
|
||||
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
|
||||
),
|
||||
)
|
||||
|
||||
params = JSON.parse(tool_call.raw_json, symbolize_names: true)
|
||||
xml =
|
||||
params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}</#{name}>" }.join("\n")
|
||||
|
||||
node.at("tool_name").content = tool_call.name
|
||||
node.at("tool_id").content = tool_call.id
|
||||
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
|
||||
def process_streamed_message(parsed)
|
||||
result = nil
|
||||
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
|
||||
tool_name = parsed.dig(:content_block, :name)
|
||||
tool_id = parsed.dig(:content_block, :id)
|
||||
if @current_tool_call
|
||||
result = @current_tool_call.to_tool_call
|
||||
end
|
||||
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
|
||||
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||
if @current_tool_call
|
||||
tool_delta = parsed.dig(:delta, :partial_json).to_s
|
||||
@current_tool_call.append(tool_delta)
|
||||
else
|
||||
result = parsed.dig(:delta, :text).to_s
|
||||
end
|
||||
elsif parsed[:type] == "content_block_stop"
|
||||
if @current_tool_call
|
||||
result = @current_tool_call.to_tool_call
|
||||
@current_tool_call = nil
|
||||
end
|
||||
elsif parsed[:type] == "message_start"
|
||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
||||
elsif parsed[:type] == "message_delta"
|
||||
@output_tokens =
|
||||
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
|
||||
elsif parsed[:type] == "message_stop"
|
||||
# bedrock has this ...
|
||||
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
|
||||
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
|
||||
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
|
||||
end
|
||||
end
|
||||
|
||||
function_buffer
|
||||
result
|
||||
end
|
||||
|
||||
def process_message(payload)
|
||||
result = ""
|
||||
parsed = JSON.parse(payload, symbolize_names: true)
|
||||
|
||||
if @streaming_mode
|
||||
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
|
||||
tool_name = parsed.dig(:content_block, :name)
|
||||
tool_id = parsed.dig(:content_block, :id)
|
||||
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
|
||||
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||
if @tool_calls.present?
|
||||
result = parsed.dig(:delta, :partial_json).to_s
|
||||
@tool_calls.last.append(result)
|
||||
else
|
||||
result = parsed.dig(:delta, :text).to_s
|
||||
content = parsed.dig(:content)
|
||||
if content.is_a?(Array)
|
||||
result =
|
||||
content.map do |data|
|
||||
if data[:type] == "tool_use"
|
||||
call = AnthropicToolCall.new(data[:name], data[:id])
|
||||
call.append(data[:input].to_json)
|
||||
call.to_tool_call
|
||||
else
|
||||
data[:text]
|
||||
end
|
||||
end
|
||||
elsif parsed[:type] == "message_start"
|
||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
||||
elsif parsed[:type] == "message_delta"
|
||||
@output_tokens =
|
||||
parsed.dig(:usage, :output_tokens) || parsed.dig(:delta, :usage, :output_tokens)
|
||||
elsif parsed[:type] == "message_stop"
|
||||
# bedrock has this ...
|
||||
if bedrock_stats = parsed.dig("amazon-bedrock-invocationMetrics".to_sym)
|
||||
@input_tokens = bedrock_stats[:inputTokenCount] || @input_tokens
|
||||
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
|
||||
end
|
||||
end
|
||||
else
|
||||
content = parsed.dig(:content)
|
||||
if content.is_a?(Array)
|
||||
tool_call = content.find { |c| c[:type] == "tool_use" }
|
||||
if tool_call
|
||||
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
|
||||
@tool_calls.last.append(tool_call[:input].to_json)
|
||||
else
|
||||
result = parsed.dig(:content, 0, :text).to_s
|
||||
end
|
||||
end
|
||||
|
||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
||||
end
|
||||
|
||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
||||
|
||||
result
|
||||
end
|
||||
end
|
||||
|
|
|
@ -90,15 +90,18 @@ module DiscourseAi
|
|||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
||||
def decode_chunk(partial_data)
|
||||
@decoder ||= JsonStreamDecoder.new
|
||||
(@decoder << partial_data).map do |parsed_json|
|
||||
processor.process_streamed_message(parsed_json)
|
||||
end.compact
|
||||
end
|
||||
|
||||
def processor
|
||||
@processor ||=
|
||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||
end
|
||||
|
||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
||||
processor.to_xml_tool_calls(function_buffer) if !partial
|
||||
end
|
||||
|
||||
def extract_completion_from(response_raw)
|
||||
processor.process_message(response_raw)
|
||||
end
|
||||
|
@ -107,6 +110,10 @@ module DiscourseAi
|
|||
processor.tool_calls.present?
|
||||
end
|
||||
|
||||
def tool_calls
|
||||
processor.to_tool_calls
|
||||
end
|
||||
|
||||
def final_log_update(log)
|
||||
log.request_tokens = processor.input_tokens if processor.input_tokens
|
||||
log.response_tokens = processor.output_tokens if processor.output_tokens
|
||||
|
|
|
@ -126,22 +126,129 @@ module DiscourseAi
|
|||
response_data = extract_completion_from(response_raw)
|
||||
partials_raw = response_data.to_s
|
||||
|
||||
if native_tool_support?
|
||||
if allow_tools && has_tool?(response_data)
|
||||
function_buffer = build_buffer # Nokogiri document
|
||||
function_buffer =
|
||||
add_to_function_buffer(function_buffer, payload: response_data)
|
||||
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
|
||||
if allow_tools && !native_tool_support?
|
||||
response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
|
||||
response_data = function_calls if function_calls.present?
|
||||
end
|
||||
|
||||
response_data = +function_buffer.at("function_calls").to_s
|
||||
response_data << "\n"
|
||||
end
|
||||
else
|
||||
if allow_tools
|
||||
response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
|
||||
response_data = function_calls if function_calls.present?
|
||||
if response_data.is_a?(Array) && response_data.length == 1
|
||||
response_data = response_data.first
|
||||
end
|
||||
|
||||
return response_data
|
||||
end
|
||||
|
||||
begin
|
||||
cancelled = false
|
||||
cancel = -> { cancelled = true }
|
||||
if cancelled
|
||||
http.finish
|
||||
break
|
||||
end
|
||||
|
||||
response.read_body do |chunk|
|
||||
decode_chunk(chunk).each do |partial|
|
||||
yield partial, cancel
|
||||
end
|
||||
end
|
||||
rescue IOError, StandardError
|
||||
raise if !cancelled
|
||||
end
|
||||
return response_data
|
||||
ensure
|
||||
if log
|
||||
log.raw_response_payload = response_raw
|
||||
final_log_update(log)
|
||||
|
||||
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
|
||||
log.save!
|
||||
|
||||
if Rails.env.development?
|
||||
puts "#{self.class.name}: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
def perform_completionx!(
|
||||
dialect,
|
||||
user,
|
||||
model_params = {},
|
||||
feature_name: nil,
|
||||
feature_context: nil,
|
||||
&blk
|
||||
)
|
||||
allow_tools = dialect.prompt.has_tools?
|
||||
model_params = normalize_model_params(model_params)
|
||||
orig_blk = blk
|
||||
|
||||
@streaming_mode = block_given?
|
||||
to_strip = xml_tags_to_strip(dialect)
|
||||
@xml_stripper =
|
||||
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
|
||||
|
||||
if @streaming_mode && @xml_stripper
|
||||
blk =
|
||||
lambda do |partial, cancel|
|
||||
partial = @xml_stripper << partial
|
||||
orig_blk.call(partial, cancel) if partial
|
||||
end
|
||||
end
|
||||
|
||||
prompt = dialect.translate
|
||||
|
||||
FinalDestination::HTTP.start(
|
||||
model_uri.host,
|
||||
model_uri.port,
|
||||
use_ssl: use_ssl?,
|
||||
read_timeout: TIMEOUT,
|
||||
open_timeout: TIMEOUT,
|
||||
write_timeout: TIMEOUT,
|
||||
) do |http|
|
||||
response_data = +""
|
||||
response_raw = +""
|
||||
|
||||
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
|
||||
partials_raw = +""
|
||||
request_body = prepare_payload(prompt, model_params, dialect).to_json
|
||||
|
||||
request = prepare_request(request_body)
|
||||
|
||||
http.request(request) do |response|
|
||||
if response.code.to_i != 200
|
||||
Rails.logger.error(
|
||||
"#{self.class.name}: status: #{response.code.to_i} - body: #{response.body}",
|
||||
)
|
||||
raise CompletionFailed, response.body
|
||||
end
|
||||
|
||||
log =
|
||||
AiApiAuditLog.new(
|
||||
provider_id: provider_id,
|
||||
user_id: user&.id,
|
||||
raw_request_payload: request_body,
|
||||
request_tokens: prompt_size(prompt),
|
||||
topic_id: dialect.prompt.topic_id,
|
||||
post_id: dialect.prompt.post_id,
|
||||
feature_name: feature_name,
|
||||
language_model: llm_model.name,
|
||||
feature_context: feature_context.present? ? feature_context.as_json : nil,
|
||||
)
|
||||
|
||||
if !@streaming_mode
|
||||
response_raw = response.read_body
|
||||
response_data = extract_completion_from(response_raw)
|
||||
partials_raw = response_data.to_s
|
||||
|
||||
if allow_tools && !native_tool_support?
|
||||
response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
|
||||
response_data = function_calls if function_calls.present?
|
||||
end
|
||||
|
||||
if response_data.is_a?(Array) && response_data.length == 1
|
||||
response_data = response_data.first
|
||||
end
|
||||
|
||||
return response_data
|
||||
end
|
||||
|
@ -277,8 +384,9 @@ module DiscourseAi
|
|||
ensure
|
||||
if log
|
||||
log.raw_response_payload = response_raw
|
||||
log.response_tokens = tokenizer.size(partials_raw)
|
||||
final_log_update(log)
|
||||
|
||||
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
|
||||
log.save!
|
||||
|
||||
if Rails.env.development?
|
||||
|
|
|
@ -45,12 +45,16 @@ module DiscourseAi
|
|||
cancel_fn = lambda { cancelled = true }
|
||||
|
||||
# We buffer and return tool invocations in one go.
|
||||
if is_tool?(response)
|
||||
yield(response, cancel_fn)
|
||||
else
|
||||
response.each_char do |char|
|
||||
break if cancelled
|
||||
yield(char, cancel_fn)
|
||||
as_array = response.is_a?(Array) ? response : [response]
|
||||
|
||||
as_array.each do |response|
|
||||
if is_tool?(response)
|
||||
yield(response, cancel_fn)
|
||||
else
|
||||
response.each_char do |char|
|
||||
break if cancelled
|
||||
yield(char, cancel_fn)
|
||||
end
|
||||
end
|
||||
end
|
||||
else
|
||||
|
@ -65,7 +69,7 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def is_tool?(response)
|
||||
Nokogiri::HTML5.fragment(response).at("function_calls").present?
|
||||
response.is_a?(DiscourseAi::Completions::ToolCall)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -0,0 +1,47 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
# will work for anthropic and open ai compatible
|
||||
class JsonStreamDecoder
|
||||
attr_reader :buffer
|
||||
|
||||
LINE_REGEX = /data: ({.*})\s*$/
|
||||
|
||||
def initialize(symbolize_keys: true)
|
||||
@symbolize_keys = symbolize_keys
|
||||
@buffer = +""
|
||||
end
|
||||
|
||||
def <<(raw)
|
||||
@buffer << raw.to_s
|
||||
rval = []
|
||||
|
||||
split = @buffer.scan(/.*\n?/)
|
||||
split.pop if split.last.blank?
|
||||
|
||||
@buffer = +(split.pop.to_s)
|
||||
|
||||
split.each do |line|
|
||||
matches = line.match(LINE_REGEX)
|
||||
next if !matches
|
||||
rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
|
||||
end
|
||||
|
||||
if @buffer.present?
|
||||
matches = @buffer.match(LINE_REGEX)
|
||||
if matches
|
||||
begin
|
||||
rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
|
||||
@buffer = +""
|
||||
rescue JSON::ParserError
|
||||
# maybe it is a partial line
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
rval
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -0,0 +1,24 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Completions
|
||||
class ToolCall
|
||||
attr_reader :id, :name, :parameters
|
||||
|
||||
def initialize(id:, name:, parameters: nil)
|
||||
@id = id
|
||||
@name = name
|
||||
@parameters = parameters || {}
|
||||
end
|
||||
|
||||
def ==(other)
|
||||
id == other.id && name == other.name && parameters == other.parameters
|
||||
end
|
||||
|
||||
def to_s
|
||||
"#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)"
|
||||
end
|
||||
|
||||
end
|
||||
end
|
||||
end
|
|
@ -104,7 +104,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
data: {"type":"message_stop"}
|
||||
STRING
|
||||
|
||||
result = +""
|
||||
result = []
|
||||
body = body.scan(/.*\n/)
|
||||
EndpointMock.with_chunk_array_support do
|
||||
stub_request(:post, url).to_return(status: 200, body: body)
|
||||
|
@ -114,18 +114,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
end
|
||||
end
|
||||
|
||||
expected = (<<~TEXT).strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>search</tool_name>
|
||||
<parameters><search_query>s<a>m sam</search_query>
|
||||
<category>general</category></parameters>
|
||||
<tool_id>toolu_01DjrShFRRHp9SnHYRFRc53F</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "search",
|
||||
id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
|
||||
parameters: {
|
||||
search_query: "s<a>m sam",
|
||||
category: "general",
|
||||
},
|
||||
)
|
||||
|
||||
expect(result.strip).to eq(expected)
|
||||
expect(result).to eq([tool_call])
|
||||
end
|
||||
|
||||
it "can stream a response" do
|
||||
|
@ -242,17 +241,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
|||
|
||||
result = llm.generate(prompt, user: Discourse.system_user)
|
||||
|
||||
expected = <<~TEXT.strip
|
||||
<function_calls>
|
||||
<invoke>
|
||||
<tool_name>calculate</tool_name>
|
||||
<parameters><expression>2758975 + 21.11</expression></parameters>
|
||||
<tool_id>toolu_012kBdhG4eHaV68W56p4N94h</tool_id>
|
||||
</invoke>
|
||||
</function_calls>
|
||||
TEXT
|
||||
tool_call =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "calculate",
|
||||
id: "toolu_012kBdhG4eHaV68W56p4N94h",
|
||||
parameters: {
|
||||
expression: "2758975 + 21.11",
|
||||
},
|
||||
)
|
||||
|
||||
expect(result.strip).to eq(expected)
|
||||
expect(result).to eq(["Here is the calculation:", tool_call])
|
||||
|
||||
log = AiApiAuditLog.order(:id).last
|
||||
expect(log.request_tokens).to eq(345)
|
||||
expect(log.response_tokens).to eq(65)
|
||||
end
|
||||
|
||||
it "can send images via a completion prompt" do
|
||||
|
|
|
@ -0,0 +1,51 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
describe DiscourseAi::Completions::JsonStreamDecoder do
|
||||
let(:decoder) { DiscourseAi::Completions::JsonStreamDecoder.new }
|
||||
|
||||
it "should be able to parse simple messages" do
|
||||
result = decoder << "data: #{{ hello: "world" }.to_json}"
|
||||
expect(result).to eq([{ hello: "world" }])
|
||||
end
|
||||
|
||||
it "should handle anthropic mixed stlye streams" do
|
||||
stream = (<<~TEXT).split("|")
|
||||
event: |message_start|
|
||||
data: |{"hel|lo": "world"}|
|
||||
|
||||
event: |message_start
|
||||
data: {"foo": "bar"}
|
||||
|
||||
event: |message_start
|
||||
data: {"ba|z": "qux"|}
|
||||
|
||||
[DONE]
|
||||
TEXT
|
||||
|
||||
results = []
|
||||
stream.each do |chunk|
|
||||
results << (decoder << chunk)
|
||||
end
|
||||
|
||||
expect(results.flatten.compact).to eq([{ "hello": "world" }, { "foo": "bar" }, { "baz": "qux" }])
|
||||
end
|
||||
|
||||
it "should be able to handle complex overlaps" do
|
||||
stream = (<<~TEXT).split("|")
|
||||
data: |{"hel|lo": "world"}
|
||||
|
||||
data: {"foo": "bar"}
|
||||
|
||||
data: {"ba|z": "qux"|}
|
||||
|
||||
[DONE]
|
||||
TEXT
|
||||
|
||||
results = []
|
||||
stream.each do |chunk|
|
||||
results << (decoder << chunk)
|
||||
end
|
||||
|
||||
expect(results.flatten.compact).to eq([{ "hello": "world" }, { "foo": "bar" }, { "baz": "qux" }])
|
||||
end
|
||||
end
|
|
@ -220,6 +220,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
before do
|
||||
Jobs.run_immediately!
|
||||
SiteSetting.ai_bot_allowed_groups = "#{Group::AUTO_GROUPS[:trust_level_0]}"
|
||||
SiteSetting.ai_embeddings_enabled = false
|
||||
end
|
||||
|
||||
fab!(:persona) do
|
||||
|
@ -452,10 +453,25 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
it "can run tools" do
|
||||
persona.update!(tools: ["Time"])
|
||||
|
||||
responses = [
|
||||
"<function_calls><invoke><tool_name>time</tool_name><tool_id>time</tool_id><parameters><timezone>Buenos Aires</timezone></parameters></invoke></function_calls>",
|
||||
"The time is 2023-12-14 17:24:00 -0300",
|
||||
]
|
||||
tool_call1 =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "time",
|
||||
id: "time",
|
||||
parameters: {
|
||||
timezone: "Buenos Aires",
|
||||
},
|
||||
)
|
||||
|
||||
tool_call2 =
|
||||
DiscourseAi::Completions::ToolCall.new(
|
||||
name: "time",
|
||||
id: "time",
|
||||
parameters: {
|
||||
timezone: "Sydney",
|
||||
},
|
||||
)
|
||||
|
||||
responses = [[tool_call1, tool_call2], "The time is 2023-12-14 17:24:00 -0300"]
|
||||
|
||||
message =
|
||||
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do
|
||||
|
@ -470,7 +486,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
|||
|
||||
# it also needs to have tool details now set on message
|
||||
prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id)
|
||||
expect(prompt.custom_prompt.length).to eq(3)
|
||||
|
||||
expect(prompt.custom_prompt.length).to eq(5)
|
||||
|
||||
# TODO in chat I am mixed on including this in the context, but I guess maybe?
|
||||
# thinking about this
|
||||
|
|
|
@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiBot::BotController do
|
|||
fab!(:user)
|
||||
fab!(:pm_topic) { Fabricate(:private_message_topic) }
|
||||
fab!(:pm_post) { Fabricate(:post, topic: pm_topic) }
|
||||
fab!(:pm_post2) { Fabricate(:post, topic: pm_topic) }
|
||||
fab!(:pm_post3) { Fabricate(:post, topic: pm_topic) }
|
||||
|
||||
before { sign_in(user) }
|
||||
|
||||
|
@ -22,7 +24,17 @@ RSpec.describe DiscourseAi::AiBot::BotController do
|
|||
user = pm_topic.topic_allowed_users.first.user
|
||||
sign_in(user)
|
||||
|
||||
AiApiAuditLog.create!(
|
||||
|
||||
log1 = AiApiAuditLog.create!(
|
||||
provider_id: 1,
|
||||
topic_id: pm_topic.id,
|
||||
raw_request_payload: "request",
|
||||
raw_response_payload: "response",
|
||||
request_tokens: 1,
|
||||
response_tokens: 2,
|
||||
)
|
||||
|
||||
log2 = AiApiAuditLog.create!(
|
||||
post_id: pm_post.id,
|
||||
provider_id: 1,
|
||||
topic_id: pm_topic.id,
|
||||
|
@ -32,24 +44,43 @@ RSpec.describe DiscourseAi::AiBot::BotController do
|
|||
response_tokens: 2,
|
||||
)
|
||||
|
||||
log3 = AiApiAuditLog.create!(
|
||||
post_id: pm_post2.id,
|
||||
provider_id: 1,
|
||||
topic_id: pm_topic.id,
|
||||
raw_request_payload: "request",
|
||||
raw_response_payload: "response",
|
||||
request_tokens: 1,
|
||||
response_tokens: 2,
|
||||
)
|
||||
|
||||
Group.refresh_automatic_groups!
|
||||
SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s
|
||||
|
||||
get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info"
|
||||
expect(response.status).to eq(200)
|
||||
|
||||
expect(response.parsed_body["id"]).to eq(log2.id)
|
||||
expect(response.parsed_body["next_log_id"]).to eq(log3.id)
|
||||
expect(response.parsed_body["prev_log_id"]).to eq(log1.id)
|
||||
expect(response.parsed_body["topic_id"]).to eq(pm_topic.id)
|
||||
|
||||
expect(response.parsed_body["request_tokens"]).to eq(1)
|
||||
expect(response.parsed_body["response_tokens"]).to eq(2)
|
||||
expect(response.parsed_body["raw_request_payload"]).to eq("request")
|
||||
expect(response.parsed_body["raw_response_payload"]).to eq("response")
|
||||
|
||||
post2 = Fabricate(:post, topic: pm_topic)
|
||||
|
||||
# return previous post if current has no debug info
|
||||
get "/discourse-ai/ai-bot/post/#{post2.id}/show-debug-info"
|
||||
get "/discourse-ai/ai-bot/post/#{pm_post3.id}/show-debug-info"
|
||||
expect(response.status).to eq(200)
|
||||
expect(response.parsed_body["request_tokens"]).to eq(1)
|
||||
expect(response.parsed_body["response_tokens"]).to eq(2)
|
||||
|
||||
# can return debug info by id as well
|
||||
get "/discourse-ai/ai-bot/show-debug-info/#{log1.id}"
|
||||
expect(response.status).to eq(200)
|
||||
expect(response.parsed_body["id"]).to eq(log1.id)
|
||||
end
|
||||
end
|
||||
|
||||
|
|
Loading…
Reference in New Issue