FEATURE: improve tool support (#904)
This re-implements tool support in DiscourseAi::Completions::Llm #generate Previously tool support was always returned via XML and it would be the responsibility of the caller to parse XML New implementation has the endpoints return ToolCall objects. Additionally this simplifies the Llm endpoint interface and gives it more clarity. Llms must implement decode, decode_chunk (for streaming) It is the implementers responsibility to figure out how to decode chunks, base no longer implements. To make this easy we ship a flexible json decoder which is easy to wire up. Also (new) Better debugging for PMs, we now have a next / previous button to see all the Llm messages associated with a PM Token accounting is fixed for vllm (we were not correctly counting tokens)
This commit is contained in:
parent
644141ff08
commit
e817b7dc11
|
@ -6,6 +6,14 @@ module DiscourseAi
|
||||||
requires_plugin ::DiscourseAi::PLUGIN_NAME
|
requires_plugin ::DiscourseAi::PLUGIN_NAME
|
||||||
requires_login
|
requires_login
|
||||||
|
|
||||||
|
def show_debug_info_by_id
|
||||||
|
log = AiApiAuditLog.find(params[:id])
|
||||||
|
raise Discourse::NotFound if !log.topic
|
||||||
|
|
||||||
|
guardian.ensure_can_debug_ai_bot_conversation!(log.topic)
|
||||||
|
render json: AiApiAuditLogSerializer.new(log, root: false), status: 200
|
||||||
|
end
|
||||||
|
|
||||||
def show_debug_info
|
def show_debug_info
|
||||||
post = Post.find(params[:post_id])
|
post = Post.find(params[:post_id])
|
||||||
guardian.ensure_can_debug_ai_bot_conversation!(post)
|
guardian.ensure_can_debug_ai_bot_conversation!(post)
|
||||||
|
|
|
@ -14,6 +14,14 @@ class AiApiAuditLog < ActiveRecord::Base
|
||||||
Ollama = 7
|
Ollama = 7
|
||||||
SambaNova = 8
|
SambaNova = 8
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def next_log_id
|
||||||
|
self.class.where("id > ?", id).where(topic_id: topic_id).order(id: :asc).pluck(:id).first
|
||||||
|
end
|
||||||
|
|
||||||
|
def prev_log_id
|
||||||
|
self.class.where("id < ?", id).where(topic_id: topic_id).order(id: :desc).pluck(:id).first
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
# == Schema Information
|
# == Schema Information
|
||||||
|
|
|
@ -12,5 +12,7 @@ class AiApiAuditLogSerializer < ApplicationSerializer
|
||||||
:post_id,
|
:post_id,
|
||||||
:feature_name,
|
:feature_name,
|
||||||
:language_model,
|
:language_model,
|
||||||
:created_at
|
:created_at,
|
||||||
|
:prev_log_id,
|
||||||
|
:next_log_id
|
||||||
end
|
end
|
||||||
|
|
|
@ -7,6 +7,7 @@ import { htmlSafe } from "@ember/template";
|
||||||
import DButton from "discourse/components/d-button";
|
import DButton from "discourse/components/d-button";
|
||||||
import DModal from "discourse/components/d-modal";
|
import DModal from "discourse/components/d-modal";
|
||||||
import { ajax } from "discourse/lib/ajax";
|
import { ajax } from "discourse/lib/ajax";
|
||||||
|
import { popupAjaxError } from "discourse/lib/ajax-error";
|
||||||
import { clipboardCopy, escapeExpression } from "discourse/lib/utilities";
|
import { clipboardCopy, escapeExpression } from "discourse/lib/utilities";
|
||||||
import i18n from "discourse-common/helpers/i18n";
|
import i18n from "discourse-common/helpers/i18n";
|
||||||
import discourseLater from "discourse-common/lib/later";
|
import discourseLater from "discourse-common/lib/later";
|
||||||
|
@ -63,6 +64,28 @@ export default class DebugAiModal extends Component {
|
||||||
this.copy(this.info.raw_response_payload);
|
this.copy(this.info.raw_response_payload);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
async loadLog(logId) {
|
||||||
|
try {
|
||||||
|
await ajax(`/discourse-ai/ai-bot/show-debug-info/${logId}.json`).then(
|
||||||
|
(result) => {
|
||||||
|
this.info = result;
|
||||||
|
}
|
||||||
|
);
|
||||||
|
} catch (e) {
|
||||||
|
popupAjaxError(e);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
prevLog() {
|
||||||
|
this.loadLog(this.info.prev_log_id);
|
||||||
|
}
|
||||||
|
|
||||||
|
@action
|
||||||
|
nextLog() {
|
||||||
|
this.loadLog(this.info.next_log_id);
|
||||||
|
}
|
||||||
|
|
||||||
copy(text) {
|
copy(text) {
|
||||||
clipboardCopy(text);
|
clipboardCopy(text);
|
||||||
this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared");
|
this.justCopiedText = I18n.t("discourse_ai.ai_bot.conversation_shared");
|
||||||
|
@ -73,10 +96,12 @@ export default class DebugAiModal extends Component {
|
||||||
}
|
}
|
||||||
|
|
||||||
loadApiRequestInfo() {
|
loadApiRequestInfo() {
|
||||||
ajax(
|
ajax(`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`)
|
||||||
`/discourse-ai/ai-bot/post/${this.args.model.id}/show-debug-info.json`
|
.then((result) => {
|
||||||
).then((result) => {
|
|
||||||
this.info = result;
|
this.info = result;
|
||||||
|
})
|
||||||
|
.catch((e) => {
|
||||||
|
popupAjaxError(e);
|
||||||
});
|
});
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -147,6 +172,22 @@ export default class DebugAiModal extends Component {
|
||||||
@action={{this.copyResponse}}
|
@action={{this.copyResponse}}
|
||||||
@label="discourse_ai.ai_bot.debug_ai_modal.copy_response"
|
@label="discourse_ai.ai_bot.debug_ai_modal.copy_response"
|
||||||
/>
|
/>
|
||||||
|
{{#if this.info.prev_log_id}}
|
||||||
|
<DButton
|
||||||
|
class="btn"
|
||||||
|
@icon="angles-left"
|
||||||
|
@action={{this.prevLog}}
|
||||||
|
@label="discourse_ai.ai_bot.debug_ai_modal.previous_log"
|
||||||
|
/>
|
||||||
|
{{/if}}
|
||||||
|
{{#if this.info.next_log_id}}
|
||||||
|
<DButton
|
||||||
|
class="btn"
|
||||||
|
@icon="angles-right"
|
||||||
|
@action={{this.nextLog}}
|
||||||
|
@label="discourse_ai.ai_bot.debug_ai_modal.next_log"
|
||||||
|
/>
|
||||||
|
{{/if}}
|
||||||
<span class="ai-debut-modal__just-copied">{{this.justCopiedText}}</span>
|
<span class="ai-debut-modal__just-copied">{{this.justCopiedText}}</span>
|
||||||
</:footer>
|
</:footer>
|
||||||
</DModal>
|
</DModal>
|
||||||
|
|
|
@ -415,6 +415,8 @@ en:
|
||||||
response_tokens: "Response tokens:"
|
response_tokens: "Response tokens:"
|
||||||
request: "Request"
|
request: "Request"
|
||||||
response: "Response"
|
response: "Response"
|
||||||
|
next_log: "Next"
|
||||||
|
previous_log: "Previous"
|
||||||
|
|
||||||
share_full_topic_modal:
|
share_full_topic_modal:
|
||||||
title: "Share Conversation Publicly"
|
title: "Share Conversation Publicly"
|
||||||
|
|
|
@ -22,6 +22,7 @@ DiscourseAi::Engine.routes.draw do
|
||||||
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
|
scope module: :ai_bot, path: "/ai-bot", defaults: { format: :json } do
|
||||||
get "bot-username" => "bot#show_bot_username"
|
get "bot-username" => "bot#show_bot_username"
|
||||||
get "post/:post_id/show-debug-info" => "bot#show_debug_info"
|
get "post/:post_id/show-debug-info" => "bot#show_debug_info"
|
||||||
|
get "show-debug-info/:id" => "bot#show_debug_info_by_id"
|
||||||
post "post/:post_id/stop-streaming" => "bot#stop_streaming_response"
|
post "post/:post_id/stop-streaming" => "bot#stop_streaming_response"
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -100,6 +100,7 @@ module DiscourseAi
|
||||||
llm_kwargs[:top_p] = persona.top_p if persona.top_p
|
llm_kwargs[:top_p] = persona.top_p if persona.top_p
|
||||||
|
|
||||||
needs_newlines = false
|
needs_newlines = false
|
||||||
|
tools_ran = 0
|
||||||
|
|
||||||
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
while total_completions <= MAX_COMPLETIONS && ongoing_chain
|
||||||
tool_found = false
|
tool_found = false
|
||||||
|
@ -107,9 +108,10 @@ module DiscourseAi
|
||||||
|
|
||||||
result =
|
result =
|
||||||
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
|
llm.generate(prompt, feature_name: "bot", **llm_kwargs) do |partial, cancel|
|
||||||
tools = persona.find_tools(partial, bot_user: user, llm: llm, context: context)
|
tool = persona.find_tool(partial, bot_user: user, llm: llm, context: context)
|
||||||
|
tool = nil if tools_ran >= MAX_TOOLS
|
||||||
|
|
||||||
if (tools.present?)
|
if tool.present?
|
||||||
tool_found = true
|
tool_found = true
|
||||||
# a bit hacky, but extra newlines do no harm
|
# a bit hacky, but extra newlines do no harm
|
||||||
if needs_newlines
|
if needs_newlines
|
||||||
|
@ -117,15 +119,18 @@ module DiscourseAi
|
||||||
needs_newlines = false
|
needs_newlines = false
|
||||||
end
|
end
|
||||||
|
|
||||||
tools[0..MAX_TOOLS].each do |tool|
|
|
||||||
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
|
process_tool(tool, raw_context, llm, cancel, update_blk, prompt, context)
|
||||||
|
tools_ran += 1
|
||||||
ongoing_chain &&= tool.chain_next_response?
|
ongoing_chain &&= tool.chain_next_response?
|
||||||
end
|
|
||||||
else
|
else
|
||||||
needs_newlines = true
|
needs_newlines = true
|
||||||
|
if partial.is_a?(DiscourseAi::Completions::ToolCall)
|
||||||
|
Rails.logger.warn("DiscourseAi: Tool not found: #{partial.name}")
|
||||||
|
else
|
||||||
update_blk.call(partial, cancel)
|
update_blk.call(partial, cancel)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
if !tool_found
|
if !tool_found
|
||||||
ongoing_chain = false
|
ongoing_chain = false
|
||||||
|
|
|
@ -199,23 +199,16 @@ module DiscourseAi
|
||||||
prompt
|
prompt
|
||||||
end
|
end
|
||||||
|
|
||||||
def find_tools(partial, bot_user:, llm:, context:)
|
def find_tool(partial, bot_user:, llm:, context:)
|
||||||
return [] if !partial.include?("</invoke>")
|
return nil if !partial.is_a?(DiscourseAi::Completions::ToolCall)
|
||||||
|
tool_instance(partial, bot_user: bot_user, llm: llm, context: context)
|
||||||
parsed_function = Nokogiri::HTML5.fragment(partial)
|
|
||||||
parsed_function
|
|
||||||
.css("invoke")
|
|
||||||
.map do |fragment|
|
|
||||||
tool_instance(fragment, bot_user: bot_user, llm: llm, context: context)
|
|
||||||
end
|
|
||||||
.compact
|
|
||||||
end
|
end
|
||||||
|
|
||||||
protected
|
protected
|
||||||
|
|
||||||
def tool_instance(parsed_function, bot_user:, llm:, context:)
|
def tool_instance(tool_call, bot_user:, llm:, context:)
|
||||||
function_id = parsed_function.at("tool_id")&.text
|
function_id = tool_call.id
|
||||||
function_name = parsed_function.at("tool_name")&.text
|
function_name = tool_call.name
|
||||||
return nil if function_name.nil?
|
return nil if function_name.nil?
|
||||||
|
|
||||||
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
|
tool_klass = available_tools.find { |c| c.signature.dig(:name) == function_name }
|
||||||
|
@ -224,7 +217,7 @@ module DiscourseAi
|
||||||
arguments = {}
|
arguments = {}
|
||||||
tool_klass.signature[:parameters].to_a.each do |param|
|
tool_klass.signature[:parameters].to_a.each do |param|
|
||||||
name = param[:name]
|
name = param[:name]
|
||||||
value = parsed_function.at(name)&.text
|
value = tool_call.parameters[name.to_sym]
|
||||||
|
|
||||||
if param[:type] == "array" && value
|
if param[:type] == "array" && value
|
||||||
value =
|
value =
|
||||||
|
|
|
@ -13,6 +13,11 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
||||||
def append(json)
|
def append(json)
|
||||||
@raw_json << json
|
@raw_json << json
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def to_tool_call
|
||||||
|
parameters = JSON.parse(raw_json, symbolize_names: true)
|
||||||
|
DiscourseAi::Completions::ToolCall.new(id: id, name: name, parameters: parameters)
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_reader :tool_calls, :input_tokens, :output_tokens
|
attr_reader :tool_calls, :input_tokens, :output_tokens
|
||||||
|
@ -20,52 +25,32 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
||||||
def initialize(streaming_mode:)
|
def initialize(streaming_mode:)
|
||||||
@streaming_mode = streaming_mode
|
@streaming_mode = streaming_mode
|
||||||
@tool_calls = []
|
@tool_calls = []
|
||||||
|
@current_tool_call = nil
|
||||||
end
|
end
|
||||||
|
|
||||||
def to_xml_tool_calls(function_buffer)
|
def to_tool_calls
|
||||||
return function_buffer if @tool_calls.blank?
|
@tool_calls.map { |tool_call| tool_call.to_tool_call }
|
||||||
|
|
||||||
function_buffer = Nokogiri::HTML5.fragment(<<~TEXT)
|
|
||||||
<function_calls>
|
|
||||||
</function_calls>
|
|
||||||
TEXT
|
|
||||||
|
|
||||||
@tool_calls.each do |tool_call|
|
|
||||||
node =
|
|
||||||
function_buffer.at("function_calls").add_child(
|
|
||||||
Nokogiri::HTML5::DocumentFragment.parse(
|
|
||||||
DiscourseAi::Completions::Endpoints::Base.noop_function_call_text + "\n",
|
|
||||||
),
|
|
||||||
)
|
|
||||||
|
|
||||||
params = JSON.parse(tool_call.raw_json, symbolize_names: true)
|
|
||||||
xml =
|
|
||||||
params.map { |name, value| "<#{name}>#{CGI.escapeHTML(value.to_s)}</#{name}>" }.join("\n")
|
|
||||||
|
|
||||||
node.at("tool_name").content = tool_call.name
|
|
||||||
node.at("tool_id").content = tool_call.id
|
|
||||||
node.at("parameters").children = Nokogiri::HTML5::DocumentFragment.parse(xml) if xml.present?
|
|
||||||
end
|
end
|
||||||
|
|
||||||
function_buffer
|
def process_streamed_message(parsed)
|
||||||
end
|
result = nil
|
||||||
|
|
||||||
def process_message(payload)
|
|
||||||
result = ""
|
|
||||||
parsed = JSON.parse(payload, symbolize_names: true)
|
|
||||||
|
|
||||||
if @streaming_mode
|
|
||||||
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
|
if parsed[:type] == "content_block_start" && parsed.dig(:content_block, :type) == "tool_use"
|
||||||
tool_name = parsed.dig(:content_block, :name)
|
tool_name = parsed.dig(:content_block, :name)
|
||||||
tool_id = parsed.dig(:content_block, :id)
|
tool_id = parsed.dig(:content_block, :id)
|
||||||
@tool_calls << AnthropicToolCall.new(tool_name, tool_id) if tool_name
|
result = @current_tool_call.to_tool_call if @current_tool_call
|
||||||
|
@current_tool_call = AnthropicToolCall.new(tool_name, tool_id) if tool_name
|
||||||
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
elsif parsed[:type] == "content_block_start" || parsed[:type] == "content_block_delta"
|
||||||
if @tool_calls.present?
|
if @current_tool_call
|
||||||
result = parsed.dig(:delta, :partial_json).to_s
|
tool_delta = parsed.dig(:delta, :partial_json).to_s
|
||||||
@tool_calls.last.append(result)
|
@current_tool_call.append(tool_delta)
|
||||||
else
|
else
|
||||||
result = parsed.dig(:delta, :text).to_s
|
result = parsed.dig(:delta, :text).to_s
|
||||||
end
|
end
|
||||||
|
elsif parsed[:type] == "content_block_stop"
|
||||||
|
if @current_tool_call
|
||||||
|
result = @current_tool_call.to_tool_call
|
||||||
|
@current_tool_call = nil
|
||||||
|
end
|
||||||
elsif parsed[:type] == "message_start"
|
elsif parsed[:type] == "message_start"
|
||||||
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
@input_tokens = parsed.dig(:message, :usage, :input_tokens)
|
||||||
elsif parsed[:type] == "message_delta"
|
elsif parsed[:type] == "message_delta"
|
||||||
|
@ -78,21 +63,30 @@ class DiscourseAi::Completions::AnthropicMessageProcessor
|
||||||
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
|
@output_tokens = bedrock_stats[:outputTokenCount] || @output_tokens
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
else
|
result
|
||||||
|
end
|
||||||
|
|
||||||
|
def process_message(payload)
|
||||||
|
result = ""
|
||||||
|
parsed = payload
|
||||||
|
parsed = JSON.parse(payload, symbolize_names: true) if payload.is_a?(String)
|
||||||
|
|
||||||
content = parsed.dig(:content)
|
content = parsed.dig(:content)
|
||||||
if content.is_a?(Array)
|
if content.is_a?(Array)
|
||||||
tool_call = content.find { |c| c[:type] == "tool_use" }
|
result =
|
||||||
if tool_call
|
content.map do |data|
|
||||||
@tool_calls << AnthropicToolCall.new(tool_call[:name], tool_call[:id])
|
if data[:type] == "tool_use"
|
||||||
@tool_calls.last.append(tool_call[:input].to_json)
|
call = AnthropicToolCall.new(data[:name], data[:id])
|
||||||
|
call.append(data[:input].to_json)
|
||||||
|
call.to_tool_call
|
||||||
else
|
else
|
||||||
result = parsed.dig(:content, 0, :text).to_s
|
data[:text]
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@input_tokens = parsed.dig(:usage, :input_tokens)
|
@input_tokens = parsed.dig(:usage, :input_tokens)
|
||||||
@output_tokens = parsed.dig(:usage, :output_tokens)
|
@output_tokens = parsed.dig(:usage, :output_tokens)
|
||||||
end
|
|
||||||
|
|
||||||
result
|
result
|
||||||
end
|
end
|
||||||
|
|
|
@ -63,8 +63,23 @@ module DiscourseAi
|
||||||
def user_msg(msg)
|
def user_msg(msg)
|
||||||
user_message = { role: "user", content: msg[:content] }
|
user_message = { role: "user", content: msg[:content] }
|
||||||
|
|
||||||
# TODO: Add support for user messages with empbeded user ids
|
encoded_uploads = prompt.encoded_uploads(msg)
|
||||||
# TODO: Add support for user messages with attachments
|
if encoded_uploads.present?
|
||||||
|
images =
|
||||||
|
encoded_uploads
|
||||||
|
.map do |upload|
|
||||||
|
if upload[:mime_type].start_with?("image/")
|
||||||
|
upload[:base64]
|
||||||
|
else
|
||||||
|
nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
.compact
|
||||||
|
|
||||||
|
user_message[:images] = images if images.present?
|
||||||
|
end
|
||||||
|
|
||||||
|
# TODO: Add support for user messages with embedded user ids
|
||||||
|
|
||||||
user_message
|
user_message
|
||||||
end
|
end
|
||||||
|
|
|
@ -63,6 +63,10 @@ module DiscourseAi
|
||||||
URI(llm_model.url)
|
URI(llm_model.url)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def xml_tools_enabled?
|
||||||
|
!@native_tool_support
|
||||||
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, dialect)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
@native_tool_support = dialect.native_tool_support?
|
@native_tool_support = dialect.native_tool_support?
|
||||||
|
|
||||||
|
@ -90,35 +94,34 @@ module DiscourseAi
|
||||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def decode_chunk(partial_data)
|
||||||
|
@decoder ||= JsonStreamDecoder.new
|
||||||
|
(@decoder << partial_data)
|
||||||
|
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
|
||||||
|
.compact
|
||||||
|
end
|
||||||
|
|
||||||
|
def decode(response_data)
|
||||||
|
processor.process_message(response_data)
|
||||||
|
end
|
||||||
|
|
||||||
def processor
|
def processor
|
||||||
@processor ||=
|
@processor ||=
|
||||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
|
||||||
processor.to_xml_tool_calls(function_buffer) if !partial
|
|
||||||
end
|
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
|
||||||
processor.process_message(response_raw)
|
|
||||||
end
|
|
||||||
|
|
||||||
def has_tool?(_response_data)
|
def has_tool?(_response_data)
|
||||||
processor.tool_calls.present?
|
processor.tool_calls.present?
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def tool_calls
|
||||||
|
processor.to_tool_calls
|
||||||
|
end
|
||||||
|
|
||||||
def final_log_update(log)
|
def final_log_update(log)
|
||||||
log.request_tokens = processor.input_tokens if processor.input_tokens
|
log.request_tokens = processor.input_tokens if processor.input_tokens
|
||||||
log.response_tokens = processor.output_tokens if processor.output_tokens
|
log.response_tokens = processor.output_tokens if processor.output_tokens
|
||||||
end
|
end
|
||||||
|
|
||||||
def native_tool_support?
|
|
||||||
@native_tool_support
|
|
||||||
end
|
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
|
||||||
decoded_chunk.split("\n").map { |line| line.split("data: ", 2)[1] }.compact
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -117,7 +117,24 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def decode(chunk)
|
def decode_chunk(partial_data)
|
||||||
|
bedrock_decode(partial_data)
|
||||||
|
.map do |decoded_partial_data|
|
||||||
|
@raw_response ||= +""
|
||||||
|
@raw_response << decoded_partial_data
|
||||||
|
@raw_response << "\n"
|
||||||
|
|
||||||
|
parsed_json = JSON.parse(decoded_partial_data, symbolize_names: true)
|
||||||
|
processor.process_streamed_message(parsed_json)
|
||||||
|
end
|
||||||
|
.compact
|
||||||
|
end
|
||||||
|
|
||||||
|
def decode(response_data)
|
||||||
|
processor.process_message(response_data)
|
||||||
|
end
|
||||||
|
|
||||||
|
def bedrock_decode(chunk)
|
||||||
@decoder ||= Aws::EventStream::Decoder.new
|
@decoder ||= Aws::EventStream::Decoder.new
|
||||||
|
|
||||||
decoded, _done = @decoder.decode_chunk(chunk)
|
decoded, _done = @decoder.decode_chunk(chunk)
|
||||||
|
@ -147,12 +164,13 @@ module DiscourseAi
|
||||||
Aws::EventStream::Errors::MessageChecksumError,
|
Aws::EventStream::Errors::MessageChecksumError,
|
||||||
Aws::EventStream::Errors::PreludeChecksumError => e
|
Aws::EventStream::Errors::PreludeChecksumError => e
|
||||||
Rails.logger.error("#{self.class.name}: #{e.message}")
|
Rails.logger.error("#{self.class.name}: #{e.message}")
|
||||||
nil
|
[]
|
||||||
end
|
end
|
||||||
|
|
||||||
def final_log_update(log)
|
def final_log_update(log)
|
||||||
log.request_tokens = processor.input_tokens if processor.input_tokens
|
log.request_tokens = processor.input_tokens if processor.input_tokens
|
||||||
log.response_tokens = processor.output_tokens if processor.output_tokens
|
log.response_tokens = processor.output_tokens if processor.output_tokens
|
||||||
|
log.raw_response_payload = @raw_response
|
||||||
end
|
end
|
||||||
|
|
||||||
def processor
|
def processor
|
||||||
|
@ -160,30 +178,8 @@ module DiscourseAi
|
||||||
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
DiscourseAi::Completions::AnthropicMessageProcessor.new(streaming_mode: @streaming_mode)
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
def xml_tools_enabled?
|
||||||
processor.to_xml_tool_calls(function_buffer) if !partial
|
!@native_tool_support
|
||||||
end
|
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
|
||||||
processor.process_message(response_raw)
|
|
||||||
end
|
|
||||||
|
|
||||||
def has_tool?(_response_data)
|
|
||||||
processor.tool_calls.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
def partials_from(decoded_chunks)
|
|
||||||
decoded_chunks
|
|
||||||
end
|
|
||||||
|
|
||||||
def native_tool_support?
|
|
||||||
@native_tool_support
|
|
||||||
end
|
|
||||||
|
|
||||||
def chunk_to_string(chunk)
|
|
||||||
joined = +chunk.join("\n")
|
|
||||||
joined << "\n" if joined.length > 0
|
|
||||||
joined
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -40,10 +40,6 @@ module DiscourseAi
|
||||||
@llm_model = llm_model
|
@llm_model = llm_model
|
||||||
end
|
end
|
||||||
|
|
||||||
def native_tool_support?
|
|
||||||
false
|
|
||||||
end
|
|
||||||
|
|
||||||
def use_ssl?
|
def use_ssl?
|
||||||
if model_uri&.scheme.present?
|
if model_uri&.scheme.present?
|
||||||
model_uri.scheme == "https"
|
model_uri.scheme == "https"
|
||||||
|
@ -64,22 +60,10 @@ module DiscourseAi
|
||||||
feature_context: nil,
|
feature_context: nil,
|
||||||
&blk
|
&blk
|
||||||
)
|
)
|
||||||
allow_tools = dialect.prompt.has_tools?
|
|
||||||
model_params = normalize_model_params(model_params)
|
model_params = normalize_model_params(model_params)
|
||||||
orig_blk = blk
|
orig_blk = blk
|
||||||
|
|
||||||
@streaming_mode = block_given?
|
@streaming_mode = block_given?
|
||||||
to_strip = xml_tags_to_strip(dialect)
|
|
||||||
@xml_stripper =
|
|
||||||
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
|
|
||||||
|
|
||||||
if @streaming_mode && @xml_stripper
|
|
||||||
blk =
|
|
||||||
lambda do |partial, cancel|
|
|
||||||
partial = @xml_stripper << partial
|
|
||||||
orig_blk.call(partial, cancel) if partial
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
prompt = dialect.translate
|
prompt = dialect.translate
|
||||||
|
|
||||||
|
@ -108,177 +92,91 @@ module DiscourseAi
|
||||||
raise CompletionFailed, response.body
|
raise CompletionFailed, response.body
|
||||||
end
|
end
|
||||||
|
|
||||||
|
xml_tool_processor = XmlToolProcessor.new if xml_tools_enabled? &&
|
||||||
|
dialect.prompt.has_tools?
|
||||||
|
|
||||||
|
to_strip = xml_tags_to_strip(dialect)
|
||||||
|
xml_stripper =
|
||||||
|
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
|
||||||
|
|
||||||
|
if @streaming_mode && xml_stripper
|
||||||
|
blk =
|
||||||
|
lambda do |partial, cancel|
|
||||||
|
partial = xml_stripper << partial if partial.is_a?(String)
|
||||||
|
orig_blk.call(partial, cancel) if partial
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
log =
|
log =
|
||||||
AiApiAuditLog.new(
|
start_log(
|
||||||
provider_id: provider_id,
|
provider_id: provider_id,
|
||||||
user_id: user&.id,
|
request_body: request_body,
|
||||||
raw_request_payload: request_body,
|
dialect: dialect,
|
||||||
request_tokens: prompt_size(prompt),
|
prompt: prompt,
|
||||||
topic_id: dialect.prompt.topic_id,
|
user: user,
|
||||||
post_id: dialect.prompt.post_id,
|
|
||||||
feature_name: feature_name,
|
feature_name: feature_name,
|
||||||
language_model: llm_model.name,
|
feature_context: feature_context,
|
||||||
feature_context: feature_context.present? ? feature_context.as_json : nil,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if !@streaming_mode
|
if !@streaming_mode
|
||||||
response_raw = response.read_body
|
return(
|
||||||
response_data = extract_completion_from(response_raw)
|
non_streaming_response(
|
||||||
partials_raw = response_data.to_s
|
response: response,
|
||||||
|
xml_tool_processor: xml_tool_processor,
|
||||||
if native_tool_support?
|
xml_stripper: xml_stripper,
|
||||||
if allow_tools && has_tool?(response_data)
|
partials_raw: partials_raw,
|
||||||
function_buffer = build_buffer # Nokogiri document
|
response_raw: response_raw,
|
||||||
function_buffer =
|
)
|
||||||
add_to_function_buffer(function_buffer, payload: response_data)
|
)
|
||||||
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
|
|
||||||
|
|
||||||
response_data = +function_buffer.at("function_calls").to_s
|
|
||||||
response_data << "\n"
|
|
||||||
end
|
end
|
||||||
else
|
|
||||||
if allow_tools
|
|
||||||
response_data, function_calls = FunctionCallNormalizer.normalize(response_data)
|
|
||||||
response_data = function_calls if function_calls.present?
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
return response_data
|
|
||||||
end
|
|
||||||
|
|
||||||
has_tool = false
|
|
||||||
|
|
||||||
begin
|
begin
|
||||||
cancelled = false
|
cancelled = false
|
||||||
cancel = -> { cancelled = true }
|
cancel = -> { cancelled = true }
|
||||||
|
|
||||||
wrapped_blk = ->(partial, inner_cancel) do
|
|
||||||
response_data << partial
|
|
||||||
blk.call(partial, inner_cancel)
|
|
||||||
end
|
|
||||||
|
|
||||||
normalizer = FunctionCallNormalizer.new(wrapped_blk, cancel)
|
|
||||||
|
|
||||||
leftover = ""
|
|
||||||
function_buffer = build_buffer # Nokogiri document
|
|
||||||
prev_processed_partials = 0
|
|
||||||
|
|
||||||
response.read_body do |chunk|
|
|
||||||
if cancelled
|
if cancelled
|
||||||
http.finish
|
http.finish
|
||||||
break
|
break
|
||||||
end
|
end
|
||||||
|
|
||||||
decoded_chunk = decode(chunk)
|
response.read_body do |chunk|
|
||||||
if decoded_chunk.nil?
|
response_raw << chunk
|
||||||
raise CompletionFailed, "#{self.class.name}: Failed to decode LLM completion"
|
decode_chunk(chunk).each do |partial|
|
||||||
end
|
|
||||||
response_raw << chunk_to_string(decoded_chunk)
|
|
||||||
|
|
||||||
if decoded_chunk.is_a?(String)
|
|
||||||
redo_chunk = leftover + decoded_chunk
|
|
||||||
else
|
|
||||||
# custom implementation for endpoint
|
|
||||||
# no implicit leftover support
|
|
||||||
redo_chunk = decoded_chunk
|
|
||||||
end
|
|
||||||
|
|
||||||
raw_partials = partials_from(redo_chunk)
|
|
||||||
|
|
||||||
raw_partials =
|
|
||||||
raw_partials[prev_processed_partials..-1] if prev_processed_partials > 0
|
|
||||||
|
|
||||||
if raw_partials.blank? || (raw_partials.size == 1 && raw_partials.first.blank?)
|
|
||||||
leftover = redo_chunk
|
|
||||||
next
|
|
||||||
end
|
|
||||||
|
|
||||||
json_error = false
|
|
||||||
|
|
||||||
raw_partials.each do |raw_partial|
|
|
||||||
json_error = false
|
|
||||||
prev_processed_partials += 1
|
|
||||||
|
|
||||||
next if cancelled
|
|
||||||
next if raw_partial.blank?
|
|
||||||
|
|
||||||
begin
|
|
||||||
partial = extract_completion_from(raw_partial)
|
|
||||||
next if partial.nil?
|
|
||||||
# empty vs blank... we still accept " "
|
|
||||||
next if response_data.empty? && partial.empty?
|
|
||||||
partials_raw << partial.to_s
|
partials_raw << partial.to_s
|
||||||
|
response_data << partial if partial.is_a?(String)
|
||||||
if native_tool_support?
|
partials = [partial]
|
||||||
# Stop streaming the response as soon as you find a tool.
|
if xml_tool_processor && partial.is_a?(String)
|
||||||
# We'll buffer and yield it later.
|
partials = (xml_tool_processor << partial)
|
||||||
has_tool = true if allow_tools && has_tool?(partials_raw)
|
if xml_tool_processor.should_cancel?
|
||||||
|
cancel.call
|
||||||
if has_tool
|
break
|
||||||
function_buffer =
|
|
||||||
add_to_function_buffer(function_buffer, partial: partial)
|
|
||||||
else
|
|
||||||
response_data << partial
|
|
||||||
blk.call(partial, cancel) if partial
|
|
||||||
end
|
|
||||||
else
|
|
||||||
if allow_tools
|
|
||||||
normalizer << partial
|
|
||||||
else
|
|
||||||
response_data << partial
|
|
||||||
blk.call(partial, cancel) if partial
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
rescue JSON::ParserError
|
partials.each { |inner_partial| blk.call(inner_partial, cancel) }
|
||||||
leftover = redo_chunk
|
|
||||||
json_error = true
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
if json_error
|
|
||||||
prev_processed_partials -= 1
|
|
||||||
else
|
|
||||||
leftover = ""
|
|
||||||
end
|
|
||||||
|
|
||||||
prev_processed_partials = 0 if leftover.blank?
|
|
||||||
end
|
|
||||||
rescue IOError, StandardError
|
rescue IOError, StandardError
|
||||||
raise if !cancelled
|
raise if !cancelled
|
||||||
end
|
end
|
||||||
|
if xml_stripper
|
||||||
has_tool ||= has_tool?(partials_raw)
|
stripped = xml_stripper.finish
|
||||||
# Once we have the full response, try to return the tool as a XML doc.
|
if stripped.present?
|
||||||
if has_tool && native_tool_support?
|
response_data << stripped
|
||||||
function_buffer = add_to_function_buffer(function_buffer, payload: partials_raw)
|
result = []
|
||||||
|
result = (xml_tool_processor << stripped) if xml_tool_processor
|
||||||
if function_buffer.at("tool_name").text.present?
|
result.each { |partial| blk.call(partial, cancel) }
|
||||||
FunctionCallNormalizer.normalize_function_ids!(function_buffer)
|
|
||||||
|
|
||||||
invocation = +function_buffer.at("function_calls").to_s
|
|
||||||
invocation << "\n"
|
|
||||||
|
|
||||||
response_data << invocation
|
|
||||||
blk.call(invocation, cancel)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
if xml_tool_processor
|
||||||
if !native_tool_support? && function_calls = normalizer.function_calls
|
xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) }
|
||||||
response_data << function_calls
|
|
||||||
blk.call(function_calls, cancel)
|
|
||||||
end
|
end
|
||||||
|
decode_chunk_finish.each { |partial| blk.call(partial, cancel) }
|
||||||
if @xml_stripper
|
|
||||||
leftover = @xml_stripper.finish
|
|
||||||
orig_blk.call(leftover, cancel) if leftover.present?
|
|
||||||
end
|
|
||||||
|
|
||||||
return response_data
|
return response_data
|
||||||
ensure
|
ensure
|
||||||
if log
|
if log
|
||||||
log.raw_response_payload = response_raw
|
log.raw_response_payload = response_raw
|
||||||
log.response_tokens = tokenizer.size(partials_raw)
|
|
||||||
final_log_update(log)
|
final_log_update(log)
|
||||||
|
|
||||||
|
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
|
||||||
log.save!
|
log.save!
|
||||||
|
|
||||||
if Rails.env.development?
|
if Rails.env.development?
|
||||||
|
@ -330,15 +228,15 @@ module DiscourseAi
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_completion_from(_response_raw)
|
def decode(_response_raw)
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
|
||||||
def decode(chunk)
|
def decode_chunk_finish
|
||||||
chunk
|
[]
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(_decoded_chunk)
|
def decode_chunk(_chunk)
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -346,49 +244,73 @@ module DiscourseAi
|
||||||
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
||||||
end
|
end
|
||||||
|
|
||||||
def build_buffer
|
def xml_tools_enabled?
|
||||||
Nokogiri::HTML5.fragment(<<~TEXT)
|
raise NotImplementedError
|
||||||
<function_calls>
|
|
||||||
#{noop_function_call_text}
|
|
||||||
</function_calls>
|
|
||||||
TEXT
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def self.noop_function_call_text
|
private
|
||||||
(<<~TEXT).strip
|
|
||||||
<invoke>
|
def start_log(
|
||||||
<tool_name></tool_name>
|
provider_id:,
|
||||||
<parameters>
|
request_body:,
|
||||||
</parameters>
|
dialect:,
|
||||||
<tool_id></tool_id>
|
prompt:,
|
||||||
</invoke>
|
user:,
|
||||||
TEXT
|
feature_name:,
|
||||||
|
feature_context:
|
||||||
|
)
|
||||||
|
AiApiAuditLog.new(
|
||||||
|
provider_id: provider_id,
|
||||||
|
user_id: user&.id,
|
||||||
|
raw_request_payload: request_body,
|
||||||
|
request_tokens: prompt_size(prompt),
|
||||||
|
topic_id: dialect.prompt.topic_id,
|
||||||
|
post_id: dialect.prompt.post_id,
|
||||||
|
feature_name: feature_name,
|
||||||
|
language_model: llm_model.name,
|
||||||
|
feature_context: feature_context.present? ? feature_context.as_json : nil,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
def noop_function_call_text
|
def non_streaming_response(
|
||||||
self.class.noop_function_call_text
|
response:,
|
||||||
|
xml_tool_processor:,
|
||||||
|
xml_stripper:,
|
||||||
|
partials_raw:,
|
||||||
|
response_raw:
|
||||||
|
)
|
||||||
|
response_raw << response.read_body
|
||||||
|
response_data = decode(response_raw)
|
||||||
|
|
||||||
|
response_data.each { |partial| partials_raw << partial.to_s }
|
||||||
|
|
||||||
|
if xml_tool_processor
|
||||||
|
response_data.each do |partial|
|
||||||
|
processed = (xml_tool_processor << partial)
|
||||||
|
processed << xml_tool_processor.finish
|
||||||
|
response_data = []
|
||||||
|
processed.flatten.compact.each { |inner| response_data << inner }
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def has_tool?(response)
|
if xml_stripper
|
||||||
response.include?("<function_calls>")
|
response_data.map! do |partial|
|
||||||
end
|
stripped = (xml_stripper << partial) if partial.is_a?(String)
|
||||||
|
if stripped.present?
|
||||||
def chunk_to_string(chunk)
|
stripped
|
||||||
if chunk.is_a?(String)
|
|
||||||
chunk
|
|
||||||
else
|
else
|
||||||
chunk.to_s
|
partial
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
response_data << xml_stripper.finish
|
||||||
|
end
|
||||||
|
|
||||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
response_data.reject!(&:blank?)
|
||||||
if payload&.include?("</invoke>")
|
|
||||||
matches = payload.match(%r{<function_calls>.*</invoke>}m)
|
|
||||||
function_buffer =
|
|
||||||
Nokogiri::HTML5.fragment(matches[0] + "\n</function_calls>") if matches
|
|
||||||
end
|
|
||||||
|
|
||||||
function_buffer
|
# this is to keep stuff backwards compatible
|
||||||
|
response_data = response_data.first if response_data.length == 1
|
||||||
|
|
||||||
|
response_data
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -45,6 +45,8 @@ module DiscourseAi
|
||||||
cancel_fn = lambda { cancelled = true }
|
cancel_fn = lambda { cancelled = true }
|
||||||
|
|
||||||
# We buffer and return tool invocations in one go.
|
# We buffer and return tool invocations in one go.
|
||||||
|
as_array = response.is_a?(Array) ? response : [response]
|
||||||
|
as_array.each do |response|
|
||||||
if is_tool?(response)
|
if is_tool?(response)
|
||||||
yield(response, cancel_fn)
|
yield(response, cancel_fn)
|
||||||
else
|
else
|
||||||
|
@ -53,11 +55,13 @@ module DiscourseAi
|
||||||
yield(char, cancel_fn)
|
yield(char, cancel_fn)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
else
|
|
||||||
response
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
response = response.first if response.is_a?(Array) && response.length == 1
|
||||||
|
response
|
||||||
|
end
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
end
|
end
|
||||||
|
@ -65,7 +69,7 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def is_tool?(response)
|
def is_tool?(response)
|
||||||
Nokogiri::HTML5.fragment(response).at("function_calls").present?
|
response.is_a?(DiscourseAi::Completions::ToolCall)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -49,6 +49,47 @@ module DiscourseAi
|
||||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def decode(response_raw)
|
||||||
|
rval = []
|
||||||
|
|
||||||
|
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||||
|
|
||||||
|
text = parsed[:text]
|
||||||
|
rval << parsed[:text] if !text.to_s.empty? # also allow " "
|
||||||
|
|
||||||
|
# TODO tool calls
|
||||||
|
|
||||||
|
update_usage(parsed)
|
||||||
|
|
||||||
|
rval
|
||||||
|
end
|
||||||
|
|
||||||
|
def decode_chunk(chunk)
|
||||||
|
@tool_idx ||= -1
|
||||||
|
@json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/)
|
||||||
|
(@json_decoder << chunk)
|
||||||
|
.map do |parsed|
|
||||||
|
update_usage(parsed)
|
||||||
|
rval = []
|
||||||
|
|
||||||
|
rval << parsed[:text] if !parsed[:text].to_s.empty?
|
||||||
|
|
||||||
|
if tool_calls = parsed[:tool_calls]
|
||||||
|
tool_calls&.each do |tool_call|
|
||||||
|
@tool_idx += 1
|
||||||
|
tool_name = tool_call[:name]
|
||||||
|
tool_params = tool_call[:parameters]
|
||||||
|
tool_id = "tool_#{@tool_idx}"
|
||||||
|
rval << ToolCall.new(id: tool_id, name: tool_name, parameters: tool_params)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
rval
|
||||||
|
end
|
||||||
|
.flatten
|
||||||
|
.compact
|
||||||
|
end
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
def extract_completion_from(response_raw)
|
||||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||||
|
|
||||||
|
@ -77,36 +118,8 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def has_tool?(_ignored)
|
def xml_tools_enabled?
|
||||||
@has_tool
|
false
|
||||||
end
|
|
||||||
|
|
||||||
def native_tool_support?
|
|
||||||
true
|
|
||||||
end
|
|
||||||
|
|
||||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
|
||||||
if partial
|
|
||||||
tools = JSON.parse(partial)
|
|
||||||
tools.each do |tool|
|
|
||||||
name = tool["name"]
|
|
||||||
parameters = tool["parameters"]
|
|
||||||
xml_params = parameters.map { |k, v| "<#{k}>#{v}</#{k}>\n" }.join
|
|
||||||
|
|
||||||
current_function = function_buffer.at("invoke")
|
|
||||||
if current_function.nil? || current_function.at("tool_name").content.present?
|
|
||||||
current_function =
|
|
||||||
function_buffer.at("function_calls").add_child(
|
|
||||||
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
|
||||||
)
|
|
||||||
end
|
|
||||||
|
|
||||||
current_function.at("tool_name").content = name == "search_local" ? "search" : name
|
|
||||||
current_function.at("parameters").children =
|
|
||||||
Nokogiri::HTML5::DocumentFragment.parse(xml_params)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
function_buffer
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def final_log_update(log)
|
def final_log_update(log)
|
||||||
|
@ -114,10 +127,6 @@ module DiscourseAi
|
||||||
log.response_tokens = @output_tokens if @output_tokens
|
log.response_tokens = @output_tokens if @output_tokens
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
|
||||||
decoded_chunk.split("\n").compact
|
|
||||||
end
|
|
||||||
|
|
||||||
def extract_prompt_for_tokenizer(prompt)
|
def extract_prompt_for_tokenizer(prompt)
|
||||||
text = +""
|
text = +""
|
||||||
if prompt[:chat_history]
|
if prompt[:chat_history]
|
||||||
|
@ -131,6 +140,18 @@ module DiscourseAi
|
||||||
|
|
||||||
text
|
text
|
||||||
end
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def update_usage(parsed)
|
||||||
|
input_tokens = parsed.dig(:meta, :billed_units, :input_tokens)
|
||||||
|
input_tokens ||= parsed.dig(:response, :meta, :billed_units, :input_tokens)
|
||||||
|
@input_tokens = input_tokens if input_tokens.present?
|
||||||
|
|
||||||
|
output_tokens = parsed.dig(:meta, :billed_units, :output_tokens)
|
||||||
|
output_tokens ||= parsed.dig(:response, :meta, :billed_units, :output_tokens)
|
||||||
|
@output_tokens = output_tokens if output_tokens.present?
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -133,6 +133,9 @@ module DiscourseAi
|
||||||
content = content.shift if content.is_a?(Array)
|
content = content.shift if content.is_a?(Array)
|
||||||
|
|
||||||
if block_given?
|
if block_given?
|
||||||
|
if content.is_a?(DiscourseAi::Completions::ToolCall)
|
||||||
|
yield(content, -> {})
|
||||||
|
else
|
||||||
split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
|
split_indices = (1...content.length).to_a.sample(self.class.chunk_count - 1).sort
|
||||||
indexes = [0, *split_indices, content.length]
|
indexes = [0, *split_indices, content.length]
|
||||||
|
|
||||||
|
@ -159,6 +162,7 @@ module DiscourseAi
|
||||||
yield(chunk, cancel_proc)
|
yield(chunk, cancel_proc)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
content
|
content
|
||||||
end
|
end
|
||||||
|
|
|
@ -103,15 +103,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
class GeminiStreamingDecoder
|
||||||
decoded_chunk
|
|
||||||
end
|
|
||||||
|
|
||||||
def chunk_to_string(chunk)
|
|
||||||
chunk.to_s
|
|
||||||
end
|
|
||||||
|
|
||||||
class Decoder
|
|
||||||
def initialize
|
def initialize
|
||||||
@buffer = +""
|
@buffer = +""
|
||||||
end
|
end
|
||||||
|
@ -151,43 +143,87 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def decode(chunk)
|
def decode(chunk)
|
||||||
@decoder ||= Decoder.new
|
json = JSON.parse(chunk, symbolize_names: true)
|
||||||
@decoder.decode(chunk)
|
idx = -1
|
||||||
|
json
|
||||||
|
.dig(:candidates, 0, :content, :parts)
|
||||||
|
.map do |part|
|
||||||
|
if part[:functionCall]
|
||||||
|
idx += 1
|
||||||
|
ToolCall.new(
|
||||||
|
id: "tool_#{idx}",
|
||||||
|
name: part[:functionCall][:name],
|
||||||
|
parameters: part[:functionCall][:args],
|
||||||
|
)
|
||||||
|
else
|
||||||
|
part = part[:text]
|
||||||
|
if part != ""
|
||||||
|
part
|
||||||
|
else
|
||||||
|
nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def decode_chunk(chunk)
|
||||||
|
@tool_index ||= -1
|
||||||
|
|
||||||
|
streaming_decoder
|
||||||
|
.decode(chunk)
|
||||||
|
.map do |parsed|
|
||||||
|
update_usage(parsed)
|
||||||
|
parsed
|
||||||
|
.dig(:candidates, 0, :content, :parts)
|
||||||
|
.map do |part|
|
||||||
|
if part[:text]
|
||||||
|
part = part[:text]
|
||||||
|
if part != ""
|
||||||
|
part
|
||||||
|
else
|
||||||
|
nil
|
||||||
|
end
|
||||||
|
elsif part[:functionCall]
|
||||||
|
@tool_index += 1
|
||||||
|
ToolCall.new(
|
||||||
|
id: "tool_#{@tool_index}",
|
||||||
|
name: part[:functionCall][:name],
|
||||||
|
parameters: part[:functionCall][:args],
|
||||||
|
)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
.flatten
|
||||||
|
.compact
|
||||||
|
end
|
||||||
|
|
||||||
|
def update_usage(parsed)
|
||||||
|
usage = parsed.dig(:usageMetadata)
|
||||||
|
if usage
|
||||||
|
if prompt_token_count = usage[:promptTokenCount]
|
||||||
|
@prompt_token_count = prompt_token_count
|
||||||
|
end
|
||||||
|
if candidate_token_count = usage[:candidatesTokenCount]
|
||||||
|
@candidate_token_count = candidate_token_count
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def final_log_update(log)
|
||||||
|
log.request_tokens = @prompt_token_count if @prompt_token_count
|
||||||
|
log.response_tokens = @candidate_token_count if @candidate_token_count
|
||||||
|
end
|
||||||
|
|
||||||
|
def streaming_decoder
|
||||||
|
@decoder ||= GeminiStreamingDecoder.new
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_prompt_for_tokenizer(prompt)
|
def extract_prompt_for_tokenizer(prompt)
|
||||||
prompt.to_s
|
prompt.to_s
|
||||||
end
|
end
|
||||||
|
|
||||||
def has_tool?(_response_data)
|
def xml_tools_enabled?
|
||||||
@has_function_call
|
false
|
||||||
end
|
|
||||||
|
|
||||||
def native_tool_support?
|
|
||||||
true
|
|
||||||
end
|
|
||||||
|
|
||||||
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
|
|
||||||
if @streaming_mode
|
|
||||||
return function_buffer if !partial
|
|
||||||
else
|
|
||||||
partial = payload
|
|
||||||
end
|
|
||||||
|
|
||||||
function_buffer.at("tool_name").content = partial[:name] if partial[:name].present?
|
|
||||||
|
|
||||||
if partial[:args]
|
|
||||||
argument_fragments =
|
|
||||||
partial[:args].reduce(+"") do |memo, (arg_name, value)|
|
|
||||||
memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}</#{arg_name}>"
|
|
||||||
end
|
|
||||||
argument_fragments << "\n"
|
|
||||||
|
|
||||||
function_buffer.at("parameters").children =
|
|
||||||
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
|
||||||
end
|
|
||||||
|
|
||||||
function_buffer
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -59,22 +59,30 @@ module DiscourseAi
|
||||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
def xml_tools_enabled?
|
||||||
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
|
true
|
||||||
# half a line sent here
|
|
||||||
return if !parsed
|
|
||||||
|
|
||||||
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
|
||||||
|
|
||||||
response_h.dig(:content)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
def decode(response_raw)
|
||||||
decoded_chunk
|
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||||
.split("\n")
|
text = parsed.dig(:choices, 0, :message, :content)
|
||||||
.map do |line|
|
if text.to_s.empty?
|
||||||
data = line.split("data:", 2)[1]
|
[""]
|
||||||
data&.squish == "[DONE]" ? nil : data
|
else
|
||||||
|
[text]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def decode_chunk(chunk)
|
||||||
|
@json_decoder ||= JsonStreamDecoder.new
|
||||||
|
(@json_decoder << chunk)
|
||||||
|
.map do |parsed|
|
||||||
|
text = parsed.dig(:choices, 0, :delta, :content)
|
||||||
|
if text.to_s.empty?
|
||||||
|
nil
|
||||||
|
else
|
||||||
|
text
|
||||||
|
end
|
||||||
end
|
end
|
||||||
.compact
|
.compact
|
||||||
end
|
end
|
||||||
|
|
|
@ -37,12 +37,8 @@ module DiscourseAi
|
||||||
URI(llm_model.url)
|
URI(llm_model.url)
|
||||||
end
|
end
|
||||||
|
|
||||||
def native_tool_support?
|
def xml_tools_enabled?
|
||||||
@native_tool_support
|
!@native_tool_support
|
||||||
end
|
|
||||||
|
|
||||||
def has_tool?(_response_data)
|
|
||||||
@has_function_call
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, dialect)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
|
@ -67,74 +63,30 @@ module DiscourseAi
|
||||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
def decode_chunk(chunk)
|
||||||
decoded_chunk.split("\n").compact
|
# Native tool calls are not working right in streaming mode, use XML
|
||||||
|
@json_decoder ||= JsonStreamDecoder.new(line_regex: /^\s*({.*})$/)
|
||||||
|
(@json_decoder << chunk).map { |parsed| parsed.dig(:message, :content) }.compact
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
def decode(response_raw)
|
||||||
|
rval = []
|
||||||
parsed = JSON.parse(response_raw, symbolize_names: true)
|
parsed = JSON.parse(response_raw, symbolize_names: true)
|
||||||
return if !parsed
|
content = parsed.dig(:message, :content)
|
||||||
|
rval << content if !content.to_s.empty?
|
||||||
|
|
||||||
response_h = parsed.dig(:message)
|
idx = -1
|
||||||
|
parsed
|
||||||
@has_function_call ||= response_h.dig(:tool_calls).present?
|
.dig(:message, :tool_calls)
|
||||||
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
|
&.each do |tool_call|
|
||||||
|
idx += 1
|
||||||
|
id = "tool_#{idx}"
|
||||||
|
name = tool_call.dig(:function, :name)
|
||||||
|
args = tool_call.dig(:function, :arguments)
|
||||||
|
rval << ToolCall.new(id: id, name: name, parameters: args)
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_to_function_buffer(function_buffer, payload: nil, partial: nil)
|
rval
|
||||||
@args_buffer ||= +""
|
|
||||||
|
|
||||||
if @streaming_mode
|
|
||||||
return function_buffer if !partial
|
|
||||||
else
|
|
||||||
partial = payload
|
|
||||||
end
|
|
||||||
|
|
||||||
f_name = partial.dig(:function, :name)
|
|
||||||
|
|
||||||
@current_function ||= function_buffer.at("invoke")
|
|
||||||
|
|
||||||
if f_name
|
|
||||||
current_name = function_buffer.at("tool_name").content
|
|
||||||
|
|
||||||
if current_name.blank?
|
|
||||||
# first call
|
|
||||||
else
|
|
||||||
# we have a previous function, so we need to add a noop
|
|
||||||
@args_buffer = +""
|
|
||||||
@current_function =
|
|
||||||
function_buffer.at("function_calls").add_child(
|
|
||||||
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
|
||||||
)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@current_function.at("tool_name").content = f_name if f_name
|
|
||||||
@current_function.at("tool_id").content = partial[:id] if partial[:id]
|
|
||||||
|
|
||||||
args = partial.dig(:function, :arguments)
|
|
||||||
|
|
||||||
# allow for SPACE within arguments
|
|
||||||
if args && args != ""
|
|
||||||
@args_buffer << args.to_json
|
|
||||||
|
|
||||||
begin
|
|
||||||
json_args = JSON.parse(@args_buffer, symbolize_names: true)
|
|
||||||
|
|
||||||
argument_fragments =
|
|
||||||
json_args.reduce(+"") do |memo, (arg_name, value)|
|
|
||||||
memo << "\n<#{arg_name}>#{value}</#{arg_name}>"
|
|
||||||
end
|
|
||||||
argument_fragments << "\n"
|
|
||||||
|
|
||||||
@current_function.at("parameters").children =
|
|
||||||
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
|
||||||
rescue JSON::ParserError
|
|
||||||
return function_buffer
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function_buffer
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -93,98 +93,34 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def final_log_update(log)
|
def final_log_update(log)
|
||||||
log.request_tokens = @prompt_tokens if @prompt_tokens
|
log.request_tokens = processor.prompt_tokens if processor.prompt_tokens
|
||||||
log.response_tokens = @completion_tokens if @completion_tokens
|
log.response_tokens = processor.completion_tokens if processor.completion_tokens
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
def decode(response_raw)
|
||||||
json = JSON.parse(response_raw, symbolize_names: true)
|
processor.process_message(JSON.parse(response_raw, symbolize_names: true))
|
||||||
|
|
||||||
if @streaming_mode
|
|
||||||
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
|
||||||
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
parsed = json.dig(:choices, 0)
|
def decode_chunk(chunk)
|
||||||
return if !parsed
|
@decoder ||= JsonStreamDecoder.new
|
||||||
|
(@decoder << chunk)
|
||||||
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
|
||||||
@has_function_call ||= response_h.dig(:tool_calls).present?
|
.flatten
|
||||||
@has_function_call ? response_h.dig(:tool_calls, 0) : response_h.dig(:content)
|
|
||||||
end
|
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
|
||||||
decoded_chunk
|
|
||||||
.split("\n")
|
|
||||||
.map do |line|
|
|
||||||
data = line.split("data: ", 2)[1]
|
|
||||||
data == "[DONE]" ? nil : data
|
|
||||||
end
|
|
||||||
.compact
|
.compact
|
||||||
end
|
end
|
||||||
|
|
||||||
def has_tool?(_response_data)
|
def decode_chunk_finish
|
||||||
@has_function_call
|
@processor.finish
|
||||||
end
|
end
|
||||||
|
|
||||||
def native_tool_support?
|
def xml_tools_enabled?
|
||||||
true
|
false
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_to_function_buffer(function_buffer, partial: nil, payload: nil)
|
private
|
||||||
if @streaming_mode
|
|
||||||
return function_buffer if !partial
|
|
||||||
else
|
|
||||||
partial = payload
|
|
||||||
end
|
|
||||||
|
|
||||||
@args_buffer ||= +""
|
def processor
|
||||||
|
@processor ||= OpenAiMessageProcessor.new
|
||||||
f_name = partial.dig(:function, :name)
|
|
||||||
|
|
||||||
@current_function ||= function_buffer.at("invoke")
|
|
||||||
|
|
||||||
if f_name
|
|
||||||
current_name = function_buffer.at("tool_name").content
|
|
||||||
|
|
||||||
if current_name.blank?
|
|
||||||
# first call
|
|
||||||
else
|
|
||||||
# we have a previous function, so we need to add a noop
|
|
||||||
@args_buffer = +""
|
|
||||||
@current_function =
|
|
||||||
function_buffer.at("function_calls").add_child(
|
|
||||||
Nokogiri::HTML5::DocumentFragment.parse(noop_function_call_text + "\n"),
|
|
||||||
)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@current_function.at("tool_name").content = f_name if f_name
|
|
||||||
@current_function.at("tool_id").content = partial[:id] if partial[:id]
|
|
||||||
|
|
||||||
args = partial.dig(:function, :arguments)
|
|
||||||
|
|
||||||
# allow for SPACE within arguments
|
|
||||||
if args && args != ""
|
|
||||||
@args_buffer << args
|
|
||||||
|
|
||||||
begin
|
|
||||||
json_args = JSON.parse(@args_buffer, symbolize_names: true)
|
|
||||||
|
|
||||||
argument_fragments =
|
|
||||||
json_args.reduce(+"") do |memo, (arg_name, value)|
|
|
||||||
memo << "\n<#{arg_name}>#{CGI.escapeHTML(value.to_s)}</#{arg_name}>"
|
|
||||||
end
|
|
||||||
argument_fragments << "\n"
|
|
||||||
|
|
||||||
@current_function.at("parameters").children =
|
|
||||||
Nokogiri::HTML5::DocumentFragment.parse(argument_fragments)
|
|
||||||
rescue JSON::ParserError
|
|
||||||
return function_buffer
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
function_buffer
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -55,27 +55,31 @@ module DiscourseAi
|
||||||
log.response_tokens = @completion_tokens if @completion_tokens
|
log.response_tokens = @completion_tokens if @completion_tokens
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
def xml_tools_enabled?
|
||||||
json = JSON.parse(response_raw, symbolize_names: true)
|
true
|
||||||
|
end
|
||||||
|
|
||||||
|
def decode(response_raw)
|
||||||
|
json = JSON.parse(response_raw, symbolize_names: true)
|
||||||
|
[json.dig(:choices, 0, :message, :content)]
|
||||||
|
end
|
||||||
|
|
||||||
|
def decode_chunk(chunk)
|
||||||
|
@json_decoder ||= JsonStreamDecoder.new
|
||||||
|
(@json_decoder << chunk)
|
||||||
|
.map do |json|
|
||||||
|
text = json.dig(:choices, 0, :delta, :content)
|
||||||
|
|
||||||
if @streaming_mode
|
|
||||||
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
||||||
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
||||||
end
|
|
||||||
|
|
||||||
parsed = json.dig(:choices, 0)
|
if !text.to_s.empty?
|
||||||
return if !parsed
|
text
|
||||||
|
else
|
||||||
@streaming_mode ? parsed.dig(:delta, :content) : parsed.dig(:message, :content)
|
nil
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
|
||||||
decoded_chunk
|
|
||||||
.split("\n")
|
|
||||||
.map do |line|
|
|
||||||
data = line.split("data: ", 2)[1]
|
|
||||||
data == "[DONE]" ? nil : data
|
|
||||||
end
|
end
|
||||||
|
.flatten
|
||||||
.compact
|
.compact
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -42,7 +42,10 @@ module DiscourseAi
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, dialect)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
payload = default_options.merge(model_params).merge(messages: prompt)
|
payload = default_options.merge(model_params).merge(messages: prompt)
|
||||||
|
if @streaming_mode
|
||||||
payload[:stream] = true if @streaming_mode
|
payload[:stream] = true if @streaming_mode
|
||||||
|
payload[:stream_options] = { include_usage: true }
|
||||||
|
end
|
||||||
|
|
||||||
payload
|
payload
|
||||||
end
|
end
|
||||||
|
@ -56,25 +59,43 @@ module DiscourseAi
|
||||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
end
|
end
|
||||||
|
|
||||||
def partials_from(decoded_chunk)
|
def xml_tools_enabled?
|
||||||
decoded_chunk
|
true
|
||||||
.split("\n")
|
end
|
||||||
.map do |line|
|
|
||||||
data = line.split("data: ", 2)[1]
|
def final_log_update(log)
|
||||||
data == "[DONE]" ? nil : data
|
log.request_tokens = @prompt_tokens if @prompt_tokens
|
||||||
|
log.response_tokens = @completion_tokens if @completion_tokens
|
||||||
|
end
|
||||||
|
|
||||||
|
def decode(response_raw)
|
||||||
|
json = JSON.parse(response_raw, symbolize_names: true)
|
||||||
|
@prompt_tokens = json.dig(:usage, :prompt_tokens)
|
||||||
|
@completion_tokens = json.dig(:usage, :completion_tokens)
|
||||||
|
[json.dig(:choices, 0, :message, :content)]
|
||||||
|
end
|
||||||
|
|
||||||
|
def decode_chunk(chunk)
|
||||||
|
@json_decoder ||= JsonStreamDecoder.new
|
||||||
|
(@json_decoder << chunk)
|
||||||
|
.map do |parsed|
|
||||||
|
# vLLM keeps sending usage over and over again
|
||||||
|
prompt_tokens = parsed.dig(:usage, :prompt_tokens)
|
||||||
|
completion_tokens = parsed.dig(:usage, :completion_tokens)
|
||||||
|
|
||||||
|
@prompt_tokens = prompt_tokens if prompt_tokens
|
||||||
|
|
||||||
|
@completion_tokens = completion_tokens if completion_tokens
|
||||||
|
|
||||||
|
text = parsed.dig(:choices, 0, :delta, :content)
|
||||||
|
if text.to_s.empty?
|
||||||
|
nil
|
||||||
|
else
|
||||||
|
text
|
||||||
|
end
|
||||||
end
|
end
|
||||||
.compact
|
.compact
|
||||||
end
|
end
|
||||||
|
|
||||||
def extract_completion_from(response_raw)
|
|
||||||
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
|
|
||||||
# half a line sent here
|
|
||||||
return if !parsed
|
|
||||||
|
|
||||||
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
|
||||||
|
|
||||||
response_h.dig(:content)
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,113 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
class DiscourseAi::Completions::FunctionCallNormalizer
|
|
||||||
attr_reader :done
|
|
||||||
|
|
||||||
# blk is the block to call with filtered data
|
|
||||||
def initialize(blk, cancel)
|
|
||||||
@blk = blk
|
|
||||||
@cancel = cancel
|
|
||||||
@done = false
|
|
||||||
|
|
||||||
@in_tool = false
|
|
||||||
|
|
||||||
@buffer = +""
|
|
||||||
@function_buffer = +""
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.normalize(data)
|
|
||||||
text = +""
|
|
||||||
cancel = -> {}
|
|
||||||
blk = ->(partial, _) { text << partial }
|
|
||||||
|
|
||||||
normalizer = self.new(blk, cancel)
|
|
||||||
normalizer << data
|
|
||||||
|
|
||||||
[text, normalizer.function_calls]
|
|
||||||
end
|
|
||||||
|
|
||||||
def function_calls
|
|
||||||
return nil if @function_buffer.blank?
|
|
||||||
|
|
||||||
xml = Nokogiri::HTML5.fragment(@function_buffer)
|
|
||||||
self.class.normalize_function_ids!(xml)
|
|
||||||
last_invoke = xml.at("invoke:last")
|
|
||||||
if last_invoke
|
|
||||||
last_invoke.next_sibling.remove while last_invoke.next_sibling
|
|
||||||
xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
|
|
||||||
end
|
|
||||||
xml.at("function_calls").to_s.dup.force_encoding("UTF-8")
|
|
||||||
end
|
|
||||||
|
|
||||||
def <<(text)
|
|
||||||
@buffer << text
|
|
||||||
|
|
||||||
if !@in_tool
|
|
||||||
# double check if we are clearly in a tool
|
|
||||||
search_length = text.length + 20
|
|
||||||
search_string = @buffer[-search_length..-1] || @buffer
|
|
||||||
|
|
||||||
index = search_string.rindex("<function_calls>")
|
|
||||||
@in_tool = !!index
|
|
||||||
if @in_tool
|
|
||||||
@function_buffer = @buffer[index..-1]
|
|
||||||
text_index = text.rindex("<function_calls>")
|
|
||||||
@blk.call(text[0..text_index - 1].strip, @cancel) if text_index && text_index > 0
|
|
||||||
end
|
|
||||||
else
|
|
||||||
@function_buffer << text
|
|
||||||
end
|
|
||||||
|
|
||||||
if !@in_tool
|
|
||||||
if maybe_has_tool?(@buffer)
|
|
||||||
split_index = text.rindex("<").to_i - 1
|
|
||||||
if split_index >= 0
|
|
||||||
@function_buffer = text[split_index + 1..-1] || ""
|
|
||||||
text = text[0..split_index] || ""
|
|
||||||
else
|
|
||||||
@function_buffer << text
|
|
||||||
text = ""
|
|
||||||
end
|
|
||||||
else
|
|
||||||
if @function_buffer.length > 0
|
|
||||||
@blk.call(@function_buffer, @cancel)
|
|
||||||
@function_buffer = +""
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
@blk.call(text, @cancel) if text.length > 0
|
|
||||||
else
|
|
||||||
if text.include?("</function_calls>")
|
|
||||||
@done = true
|
|
||||||
@cancel.call
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
def self.normalize_function_ids!(function_buffer)
|
|
||||||
function_buffer
|
|
||||||
.css("invoke")
|
|
||||||
.each_with_index do |invoke, index|
|
|
||||||
if invoke.at("tool_id")
|
|
||||||
invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
|
|
||||||
else
|
|
||||||
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
private
|
|
||||||
|
|
||||||
def maybe_has_tool?(text)
|
|
||||||
# 16 is the length of function calls
|
|
||||||
substring = text[-16..-1] || text
|
|
||||||
split = substring.split("<")
|
|
||||||
|
|
||||||
if split.length > 1
|
|
||||||
match = "<" + split.last
|
|
||||||
"<function_calls>".start_with?(match)
|
|
||||||
else
|
|
||||||
substring.ends_with?("<")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -0,0 +1,48 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
# will work for anthropic and open ai compatible
|
||||||
|
class JsonStreamDecoder
|
||||||
|
attr_reader :buffer
|
||||||
|
|
||||||
|
LINE_REGEX = /data: ({.*})\s*$/
|
||||||
|
|
||||||
|
def initialize(symbolize_keys: true, line_regex: LINE_REGEX)
|
||||||
|
@symbolize_keys = symbolize_keys
|
||||||
|
@buffer = +""
|
||||||
|
@line_regex = line_regex
|
||||||
|
end
|
||||||
|
|
||||||
|
def <<(raw)
|
||||||
|
@buffer << raw.to_s
|
||||||
|
rval = []
|
||||||
|
|
||||||
|
split = @buffer.scan(/.*\n?/)
|
||||||
|
split.pop if split.last.blank?
|
||||||
|
|
||||||
|
@buffer = +(split.pop.to_s)
|
||||||
|
|
||||||
|
split.each do |line|
|
||||||
|
matches = line.match(@line_regex)
|
||||||
|
next if !matches
|
||||||
|
rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
|
||||||
|
end
|
||||||
|
|
||||||
|
if @buffer.present?
|
||||||
|
matches = @buffer.match(@line_regex)
|
||||||
|
if matches
|
||||||
|
begin
|
||||||
|
rval << JSON.parse(matches[1], symbolize_names: @symbolize_keys)
|
||||||
|
@buffer = +""
|
||||||
|
rescue JSON::ParserError
|
||||||
|
# maybe it is a partial line
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
rval
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,103 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
module DiscourseAi::Completions
|
||||||
|
class OpenAiMessageProcessor
|
||||||
|
attr_reader :prompt_tokens, :completion_tokens
|
||||||
|
|
||||||
|
def initialize
|
||||||
|
@tool = nil
|
||||||
|
@tool_arguments = +""
|
||||||
|
@prompt_tokens = nil
|
||||||
|
@completion_tokens = nil
|
||||||
|
end
|
||||||
|
|
||||||
|
def process_message(json)
|
||||||
|
result = []
|
||||||
|
tool_calls = json.dig(:choices, 0, :message, :tool_calls)
|
||||||
|
|
||||||
|
message = json.dig(:choices, 0, :message, :content)
|
||||||
|
result << message if message.present?
|
||||||
|
|
||||||
|
if tool_calls.present?
|
||||||
|
tool_calls.each do |tool_call|
|
||||||
|
id = tool_call.dig(:id)
|
||||||
|
name = tool_call.dig(:function, :name)
|
||||||
|
arguments = tool_call.dig(:function, :arguments)
|
||||||
|
parameters = arguments.present? ? JSON.parse(arguments, symbolize_names: true) : {}
|
||||||
|
result << ToolCall.new(id: id, name: name, parameters: parameters)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
update_usage(json)
|
||||||
|
|
||||||
|
result
|
||||||
|
end
|
||||||
|
|
||||||
|
def process_streamed_message(json)
|
||||||
|
rval = nil
|
||||||
|
|
||||||
|
tool_calls = json.dig(:choices, 0, :delta, :tool_calls)
|
||||||
|
content = json.dig(:choices, 0, :delta, :content)
|
||||||
|
|
||||||
|
finished_tools = json.dig(:choices, 0, :finish_reason) || tool_calls == []
|
||||||
|
|
||||||
|
if tool_calls.present?
|
||||||
|
id = tool_calls.dig(0, :id)
|
||||||
|
name = tool_calls.dig(0, :function, :name)
|
||||||
|
arguments = tool_calls.dig(0, :function, :arguments)
|
||||||
|
|
||||||
|
# TODO: multiple tool support may require index
|
||||||
|
#index = tool_calls[0].dig(:index)
|
||||||
|
|
||||||
|
if id.present? && @tool && @tool.id != id
|
||||||
|
process_arguments
|
||||||
|
rval = @tool
|
||||||
|
@tool = nil
|
||||||
|
end
|
||||||
|
|
||||||
|
if id.present? && name.present?
|
||||||
|
@tool_arguments = +""
|
||||||
|
@tool = ToolCall.new(id: id, name: name)
|
||||||
|
end
|
||||||
|
|
||||||
|
@tool_arguments << arguments.to_s
|
||||||
|
elsif finished_tools && @tool
|
||||||
|
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
|
||||||
|
@tool.parameters = parsed_args
|
||||||
|
rval = @tool
|
||||||
|
@tool = nil
|
||||||
|
elsif content.present?
|
||||||
|
rval = content
|
||||||
|
end
|
||||||
|
|
||||||
|
update_usage(json)
|
||||||
|
|
||||||
|
rval
|
||||||
|
end
|
||||||
|
|
||||||
|
def finish
|
||||||
|
rval = []
|
||||||
|
if @tool
|
||||||
|
process_arguments
|
||||||
|
rval << @tool
|
||||||
|
@tool = nil
|
||||||
|
end
|
||||||
|
|
||||||
|
rval
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def process_arguments
|
||||||
|
if @tool_arguments.present?
|
||||||
|
parsed_args = JSON.parse(@tool_arguments, symbolize_names: true)
|
||||||
|
@tool.parameters = parsed_args
|
||||||
|
@tool_arguments = nil
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def update_usage(json)
|
||||||
|
@prompt_tokens ||= json.dig(:usage, :prompt_tokens)
|
||||||
|
@completion_tokens ||= json.dig(:usage, :completion_tokens)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,29 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
class ToolCall
|
||||||
|
attr_reader :id, :name, :parameters
|
||||||
|
|
||||||
|
def initialize(id:, name:, parameters: nil)
|
||||||
|
@id = id
|
||||||
|
@name = name
|
||||||
|
self.parameters = parameters if parameters
|
||||||
|
@parameters ||= {}
|
||||||
|
end
|
||||||
|
|
||||||
|
def parameters=(parameters)
|
||||||
|
raise ArgumentError, "parameters must be a hash" unless parameters.is_a?(Hash)
|
||||||
|
@parameters = parameters.symbolize_keys
|
||||||
|
end
|
||||||
|
|
||||||
|
def ==(other)
|
||||||
|
id == other.id && name == other.name && parameters == other.parameters
|
||||||
|
end
|
||||||
|
|
||||||
|
def to_s
|
||||||
|
"#{name} - #{id} (\n#{parameters.map(&:to_s).join("\n")}\n)"
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,124 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
# This class can be used to process a stream of text that may contain XML tool
|
||||||
|
# calls.
|
||||||
|
# It will return either text or ToolCall objects.
|
||||||
|
|
||||||
|
module DiscourseAi
|
||||||
|
module Completions
|
||||||
|
class XmlToolProcessor
|
||||||
|
def initialize
|
||||||
|
@buffer = +""
|
||||||
|
@function_buffer = +""
|
||||||
|
@should_cancel = false
|
||||||
|
@in_tool = false
|
||||||
|
end
|
||||||
|
|
||||||
|
def <<(text)
|
||||||
|
@buffer << text
|
||||||
|
result = []
|
||||||
|
|
||||||
|
if !@in_tool
|
||||||
|
# double check if we are clearly in a tool
|
||||||
|
search_length = text.length + 20
|
||||||
|
search_string = @buffer[-search_length..-1] || @buffer
|
||||||
|
|
||||||
|
index = search_string.rindex("<function_calls>")
|
||||||
|
@in_tool = !!index
|
||||||
|
if @in_tool
|
||||||
|
@function_buffer = @buffer[index..-1]
|
||||||
|
text_index = text.rindex("<function_calls>")
|
||||||
|
result << text[0..text_index - 1].strip if text_index && text_index > 0
|
||||||
|
end
|
||||||
|
else
|
||||||
|
@function_buffer << text
|
||||||
|
end
|
||||||
|
|
||||||
|
if !@in_tool
|
||||||
|
if maybe_has_tool?(@buffer)
|
||||||
|
split_index = text.rindex("<").to_i - 1
|
||||||
|
if split_index >= 0
|
||||||
|
@function_buffer = text[split_index + 1..-1] || ""
|
||||||
|
text = text[0..split_index] || ""
|
||||||
|
else
|
||||||
|
@function_buffer << text
|
||||||
|
text = ""
|
||||||
|
end
|
||||||
|
else
|
||||||
|
if @function_buffer.length > 0
|
||||||
|
result << @function_buffer
|
||||||
|
@function_buffer = +""
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
result << text if text.length > 0
|
||||||
|
else
|
||||||
|
@should_cancel = true if text.include?("</function_calls>")
|
||||||
|
end
|
||||||
|
|
||||||
|
result
|
||||||
|
end
|
||||||
|
|
||||||
|
def finish
|
||||||
|
return [] if @function_buffer.blank?
|
||||||
|
|
||||||
|
xml = Nokogiri::HTML5.fragment(@function_buffer)
|
||||||
|
normalize_function_ids!(xml)
|
||||||
|
last_invoke = xml.at("invoke:last")
|
||||||
|
if last_invoke
|
||||||
|
last_invoke.next_sibling.remove while last_invoke.next_sibling
|
||||||
|
xml.at("invoke:last").add_next_sibling("\n") if !last_invoke.next_sibling
|
||||||
|
end
|
||||||
|
|
||||||
|
xml
|
||||||
|
.css("invoke")
|
||||||
|
.map do |invoke|
|
||||||
|
tool_name = invoke.at("tool_name").content.force_encoding("UTF-8")
|
||||||
|
tool_id = invoke.at("tool_id").content.force_encoding("UTF-8")
|
||||||
|
parameters = {}
|
||||||
|
invoke
|
||||||
|
.at("parameters")
|
||||||
|
&.children
|
||||||
|
&.each do |node|
|
||||||
|
next if node.text?
|
||||||
|
name = node.name
|
||||||
|
value = node.content.to_s
|
||||||
|
parameters[name.to_sym] = value.to_s.force_encoding("UTF-8")
|
||||||
|
end
|
||||||
|
ToolCall.new(id: tool_id, name: tool_name, parameters: parameters)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def should_cancel?
|
||||||
|
@should_cancel
|
||||||
|
end
|
||||||
|
|
||||||
|
private
|
||||||
|
|
||||||
|
def normalize_function_ids!(function_buffer)
|
||||||
|
function_buffer
|
||||||
|
.css("invoke")
|
||||||
|
.each_with_index do |invoke, index|
|
||||||
|
if invoke.at("tool_id")
|
||||||
|
invoke.at("tool_id").content = "tool_#{index}" if invoke.at("tool_id").content.blank?
|
||||||
|
else
|
||||||
|
invoke.add_child("<tool_id>tool_#{index}</tool_id>\n") if !invoke.at("tool_id")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
def maybe_has_tool?(text)
|
||||||
|
# 16 is the length of function calls
|
||||||
|
substring = text[-16..-1] || text
|
||||||
|
split = substring.split("<")
|
||||||
|
|
||||||
|
if split.length > 1
|
||||||
|
match = "<" + split.last
|
||||||
|
"<function_calls>".start_with?(match)
|
||||||
|
else
|
||||||
|
substring.ends_with?("<")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
||||||
|
end
|
|
@ -104,7 +104,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
data: {"type":"message_stop"}
|
data: {"type":"message_stop"}
|
||||||
STRING
|
STRING
|
||||||
|
|
||||||
result = +""
|
result = []
|
||||||
body = body.scan(/.*\n/)
|
body = body.scan(/.*\n/)
|
||||||
EndpointMock.with_chunk_array_support do
|
EndpointMock.with_chunk_array_support do
|
||||||
stub_request(:post, url).to_return(status: 200, body: body)
|
stub_request(:post, url).to_return(status: 200, body: body)
|
||||||
|
@ -114,18 +114,17 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
expected = (<<~TEXT).strip
|
tool_call =
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
name: "search",
|
||||||
<tool_name>search</tool_name>
|
id: "toolu_01DjrShFRRHp9SnHYRFRc53F",
|
||||||
<parameters><search_query>s<a>m sam</search_query>
|
parameters: {
|
||||||
<category>general</category></parameters>
|
search_query: "s<a>m sam",
|
||||||
<tool_id>toolu_01DjrShFRRHp9SnHYRFRc53F</tool_id>
|
category: "general",
|
||||||
</invoke>
|
},
|
||||||
</function_calls>
|
)
|
||||||
TEXT
|
|
||||||
|
|
||||||
expect(result.strip).to eq(expected)
|
expect(result).to eq([tool_call])
|
||||||
end
|
end
|
||||||
|
|
||||||
it "can stream a response" do
|
it "can stream a response" do
|
||||||
|
@ -191,6 +190,8 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
expect(log.feature_name).to eq("testing")
|
expect(log.feature_name).to eq("testing")
|
||||||
expect(log.response_tokens).to eq(15)
|
expect(log.response_tokens).to eq(15)
|
||||||
expect(log.request_tokens).to eq(25)
|
expect(log.request_tokens).to eq(25)
|
||||||
|
expect(log.raw_request_payload).to eq(expected_body.to_json)
|
||||||
|
expect(log.raw_response_payload.strip).to eq(body.strip)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "supports non streaming tool calls" do
|
it "supports non streaming tool calls" do
|
||||||
|
@ -242,17 +243,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Anthropic do
|
||||||
|
|
||||||
result = llm.generate(prompt, user: Discourse.system_user)
|
result = llm.generate(prompt, user: Discourse.system_user)
|
||||||
|
|
||||||
expected = <<~TEXT.strip
|
tool_call =
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
name: "calculate",
|
||||||
<tool_name>calculate</tool_name>
|
id: "toolu_012kBdhG4eHaV68W56p4N94h",
|
||||||
<parameters><expression>2758975 + 21.11</expression></parameters>
|
parameters: {
|
||||||
<tool_id>toolu_012kBdhG4eHaV68W56p4N94h</tool_id>
|
expression: "2758975 + 21.11",
|
||||||
</invoke>
|
},
|
||||||
</function_calls>
|
)
|
||||||
TEXT
|
|
||||||
|
|
||||||
expect(result.strip).to eq(expected)
|
expect(result).to eq(["Here is the calculation:", tool_call])
|
||||||
|
|
||||||
|
log = AiApiAuditLog.order(:id).last
|
||||||
|
expect(log.request_tokens).to eq(345)
|
||||||
|
expect(log.response_tokens).to eq(65)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "can send images via a completion prompt" do
|
it "can send images via a completion prompt" do
|
||||||
|
|
|
@ -79,7 +79,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt.tools = [tool]
|
prompt.tools = [tool]
|
||||||
response = +""
|
response = []
|
||||||
proxy.generate(prompt, user: user) { |partial| response << partial }
|
proxy.generate(prompt, user: user) { |partial| response << partial }
|
||||||
|
|
||||||
expect(request.headers["Authorization"]).to be_present
|
expect(request.headers["Authorization"]).to be_present
|
||||||
|
@ -90,21 +90,18 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
expect(parsed_body["tools"]).to eq(nil)
|
expect(parsed_body["tools"]).to eq(nil)
|
||||||
expect(parsed_body["stop_sequences"]).to eq(["</function_calls>"])
|
expect(parsed_body["stop_sequences"]).to eq(["</function_calls>"])
|
||||||
|
|
||||||
# note we now have a tool_id cause we were normalized
|
expected = [
|
||||||
function_call = <<~XML.strip
|
"hello\n",
|
||||||
hello
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
|
id: "tool_0",
|
||||||
|
name: "google",
|
||||||
|
parameters: {
|
||||||
|
query: "sydney weather today",
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
expect(response).to eq(expected)
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>google</tool_name>
|
|
||||||
<parameters><query>sydney weather today</query></parameters>
|
|
||||||
<tool_id>tool_0</tool_id>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
XML
|
|
||||||
|
|
||||||
expect(response.strip).to eq(function_call)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -230,23 +227,23 @@ RSpec.describe DiscourseAi::Completions::Endpoints::AwsBedrock do
|
||||||
}
|
}
|
||||||
|
|
||||||
prompt.tools = [tool]
|
prompt.tools = [tool]
|
||||||
response = +""
|
response = []
|
||||||
proxy.generate(prompt, user: user) { |partial| response << partial }
|
proxy.generate(prompt, user: user) { |partial| response << partial }
|
||||||
|
|
||||||
expect(request.headers["Authorization"]).to be_present
|
expect(request.headers["Authorization"]).to be_present
|
||||||
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
expect(request.headers["X-Amz-Content-Sha256"]).to be_present
|
||||||
|
|
||||||
expected_response = (<<~RESPONSE).strip
|
expected_response = [
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
id: "toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7",
|
||||||
<tool_name>google</tool_name>
|
name: "google",
|
||||||
<parameters><query>sydney weather today</query></parameters>
|
parameters: {
|
||||||
<tool_id>toolu_bdrk_014CMjxtGmKUtGoEFPgc7PF7</tool_id>
|
query: "sydney weather today",
|
||||||
</invoke>
|
},
|
||||||
</function_calls>
|
),
|
||||||
RESPONSE
|
]
|
||||||
|
|
||||||
expect(response.strip).to eq(expected_response)
|
expect(response).to eq(expected_response)
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
"max_tokens" => 3000,
|
"max_tokens" => 3000,
|
||||||
|
|
|
@ -66,7 +66,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
parsed_body = nil
|
parsed_body = nil
|
||||||
result = +""
|
result = []
|
||||||
|
|
||||||
sig = {
|
sig = {
|
||||||
name: "google",
|
name: "google",
|
||||||
|
@ -91,21 +91,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Cohere do
|
||||||
},
|
},
|
||||||
).to_return(status: 200, body: body.split("|"))
|
).to_return(status: 200, body: body.split("|"))
|
||||||
|
|
||||||
result = llm.generate(prompt, user: user) { |partial, cancel| result << partial }
|
llm.generate(prompt, user: user) { |partial, cancel| result << partial }
|
||||||
end
|
end
|
||||||
|
|
||||||
expected = <<~TEXT
|
text = "I will search for 'who is sam saffron' and relay the information to the user."
|
||||||
<function_calls>
|
tool_call =
|
||||||
<invoke>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<tool_name>google</tool_name>
|
id: "tool_0",
|
||||||
<parameters><query>who is sam saffron</query>
|
name: "google",
|
||||||
</parameters>
|
parameters: {
|
||||||
<tool_id>tool_0</tool_id>
|
query: "who is sam saffron",
|
||||||
</invoke>
|
},
|
||||||
</function_calls>
|
)
|
||||||
TEXT
|
|
||||||
|
|
||||||
expect(result.strip).to eq(expected.strip)
|
expect(result).to eq([text, tool_call])
|
||||||
|
|
||||||
expected = {
|
expected = {
|
||||||
model: "command-r-plus",
|
model: "command-r-plus",
|
||||||
|
|
|
@ -62,18 +62,14 @@ class EndpointMock
|
||||||
end
|
end
|
||||||
|
|
||||||
def invocation_response
|
def invocation_response
|
||||||
<<~TEXT
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<function_calls>
|
id: "tool_0",
|
||||||
<invoke>
|
name: "get_weather",
|
||||||
<tool_name>get_weather</tool_name>
|
parameters: {
|
||||||
<parameters>
|
location: "Sydney",
|
||||||
<location>Sydney</location>
|
unit: "c",
|
||||||
<unit>c</unit>
|
},
|
||||||
</parameters>
|
)
|
||||||
<tool_id>tool_0</tool_id>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
TEXT
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def tool_id
|
def tool_id
|
||||||
|
@ -185,7 +181,7 @@ class EndpointsCompliance
|
||||||
mock.stub_tool_call(a_dialect.translate)
|
mock.stub_tool_call(a_dialect.translate)
|
||||||
|
|
||||||
completion_response = endpoint.perform_completion!(a_dialect, user)
|
completion_response = endpoint.perform_completion!(a_dialect, user)
|
||||||
expect(completion_response.strip).to eq(mock.invocation_response.strip)
|
expect(completion_response).to eq(mock.invocation_response)
|
||||||
end
|
end
|
||||||
|
|
||||||
def streaming_mode_simple_prompt(mock)
|
def streaming_mode_simple_prompt(mock)
|
||||||
|
@ -205,6 +201,7 @@ class EndpointsCompliance
|
||||||
expect(log.raw_request_payload).to be_present
|
expect(log.raw_request_payload).to be_present
|
||||||
expect(log.raw_response_payload).to be_present
|
expect(log.raw_response_payload).to be_present
|
||||||
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
expect(log.request_tokens).to eq(endpoint.prompt_size(dialect.translate))
|
||||||
|
|
||||||
expect(log.response_tokens).to eq(
|
expect(log.response_tokens).to eq(
|
||||||
endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join),
|
endpoint.llm_model.tokenizer_class.size(mock.streamed_simple_deltas[0...-1].join),
|
||||||
)
|
)
|
||||||
|
@ -216,14 +213,14 @@ class EndpointsCompliance
|
||||||
a_dialect = dialect(prompt: prompt)
|
a_dialect = dialect(prompt: prompt)
|
||||||
|
|
||||||
mock.stub_streamed_tool_call(a_dialect.translate) do
|
mock.stub_streamed_tool_call(a_dialect.translate) do
|
||||||
buffered_partial = +""
|
buffered_partial = []
|
||||||
|
|
||||||
endpoint.perform_completion!(a_dialect, user) do |partial, cancel|
|
endpoint.perform_completion!(a_dialect, user) do |partial, cancel|
|
||||||
buffered_partial << partial
|
buffered_partial << partial
|
||||||
cancel.call if buffered_partial.include?("<function_calls>")
|
cancel.call if partial.is_a?(DiscourseAi::Completions::ToolCall)
|
||||||
end
|
end
|
||||||
|
|
||||||
expect(buffered_partial.strip).to eq(mock.invocation_response.strip)
|
expect(buffered_partial).to eq([mock.invocation_response])
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -195,19 +195,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
|
|
||||||
response = llm.generate(prompt, user: user)
|
response = llm.generate(prompt, user: user)
|
||||||
|
|
||||||
expected = (<<~XML).strip
|
tool =
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
id: "tool_0",
|
||||||
<tool_name>echo</tool_name>
|
name: "echo",
|
||||||
<parameters>
|
parameters: {
|
||||||
<text><S>ydney</text>
|
text: "<S>ydney",
|
||||||
</parameters>
|
},
|
||||||
<tool_id>tool_0</tool_id>
|
)
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
XML
|
|
||||||
|
|
||||||
expect(response.strip).to eq(expected)
|
expect(response).to eq(tool)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "Supports Vision API" do
|
it "Supports Vision API" do
|
||||||
|
@ -265,6 +262,68 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
expect(JSON.parse(req_body)).to eq(expected_prompt)
|
expect(JSON.parse(req_body)).to eq(expected_prompt)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "Can stream tool calls correctly" do
|
||||||
|
rows = [
|
||||||
|
{
|
||||||
|
candidates: [
|
||||||
|
{
|
||||||
|
content: {
|
||||||
|
parts: [{ functionCall: { name: "echo", args: { text: "sam<>wh!s" } } }],
|
||||||
|
role: "model",
|
||||||
|
},
|
||||||
|
safetyRatings: [
|
||||||
|
{ category: "HARM_CATEGORY_HATE_SPEECH", probability: "NEGLIGIBLE" },
|
||||||
|
{ category: "HARM_CATEGORY_DANGEROUS_CONTENT", probability: "NEGLIGIBLE" },
|
||||||
|
{ category: "HARM_CATEGORY_HARASSMENT", probability: "NEGLIGIBLE" },
|
||||||
|
{ category: "HARM_CATEGORY_SEXUALLY_EXPLICIT", probability: "NEGLIGIBLE" },
|
||||||
|
],
|
||||||
|
},
|
||||||
|
],
|
||||||
|
usageMetadata: {
|
||||||
|
promptTokenCount: 625,
|
||||||
|
totalTokenCount: 625,
|
||||||
|
},
|
||||||
|
modelVersion: "gemini-1.5-pro-002",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
candidates: [{ content: { parts: [{ text: "" }], role: "model" }, finishReason: "STOP" }],
|
||||||
|
usageMetadata: {
|
||||||
|
promptTokenCount: 625,
|
||||||
|
candidatesTokenCount: 4,
|
||||||
|
totalTokenCount: 629,
|
||||||
|
},
|
||||||
|
modelVersion: "gemini-1.5-pro-002",
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
payload = rows.map { |r| "data: #{r.to_json}\n\n" }.join
|
||||||
|
|
||||||
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||||
|
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
|
||||||
|
|
||||||
|
prompt = DiscourseAi::Completions::Prompt.new("Hello", tools: [echo_tool])
|
||||||
|
|
||||||
|
output = []
|
||||||
|
|
||||||
|
stub_request(:post, url).to_return(status: 200, body: payload)
|
||||||
|
llm.generate(prompt, user: user) { |partial| output << partial }
|
||||||
|
|
||||||
|
tool_call =
|
||||||
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
|
id: "tool_0",
|
||||||
|
name: "echo",
|
||||||
|
parameters: {
|
||||||
|
text: "sam<>wh!s",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(output).to eq([tool_call])
|
||||||
|
|
||||||
|
log = AiApiAuditLog.order(:id).last
|
||||||
|
expect(log.request_tokens).to eq(625)
|
||||||
|
expect(log.response_tokens).to eq(4)
|
||||||
|
end
|
||||||
|
|
||||||
it "Can correctly handle streamed responses even if they are chunked badly" do
|
it "Can correctly handle streamed responses even if they are chunked badly" do
|
||||||
data = +""
|
data = +""
|
||||||
data << "da|ta: |"
|
data << "da|ta: |"
|
||||||
|
@ -279,12 +338,12 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Gemini do
|
||||||
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
llm = DiscourseAi::Completions::Llm.proxy("custom:#{model.id}")
|
||||||
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
|
url = "#{model.url}:streamGenerateContent?alt=sse&key=123"
|
||||||
|
|
||||||
output = +""
|
output = []
|
||||||
gemini_mock.with_chunk_array_support do
|
gemini_mock.with_chunk_array_support do
|
||||||
stub_request(:post, url).to_return(status: 200, body: split)
|
stub_request(:post, url).to_return(status: 200, body: split)
|
||||||
llm.generate("Hello", user: user) { |partial| output << partial }
|
llm.generate("Hello", user: user) { |partial| output << partial }
|
||||||
end
|
end
|
||||||
|
|
||||||
expect(output).to eq("Hello World Sam")
|
expect(output.join).to eq("Hello World Sam")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -150,7 +150,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Ollama do
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "when using streaming mode" do
|
describe "when using streaming mode" do
|
||||||
context "with simpel prompts" do
|
context "with simple prompts" do
|
||||||
it "completes a trivial prompt and logs the response" do
|
it "completes a trivial prompt and logs the response" do
|
||||||
compliance.streaming_mode_simple_prompt(ollama_mock)
|
compliance.streaming_mode_simple_prompt(ollama_mock)
|
||||||
end
|
end
|
||||||
|
|
|
@ -17,8 +17,8 @@ class OpenAiMock < EndpointMock
|
||||||
created: 1_678_464_820,
|
created: 1_678_464_820,
|
||||||
model: "gpt-3.5-turbo-0301",
|
model: "gpt-3.5-turbo-0301",
|
||||||
usage: {
|
usage: {
|
||||||
prompt_tokens: 337,
|
prompt_tokens: 8,
|
||||||
completion_tokens: 162,
|
completion_tokens: 13,
|
||||||
total_tokens: 499,
|
total_tokens: 499,
|
||||||
},
|
},
|
||||||
choices: [
|
choices: [
|
||||||
|
@ -231,19 +231,16 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
|
|
||||||
result = llm.generate(prompt, user: user)
|
result = llm.generate(prompt, user: user)
|
||||||
|
|
||||||
expected = (<<~TXT).strip
|
tool_call =
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
id: "call_I8LKnoijVuhKOM85nnEQgWwd",
|
||||||
<tool_name>echo</tool_name>
|
name: "echo",
|
||||||
<parameters>
|
parameters: {
|
||||||
<text>hello</text>
|
text: "hello",
|
||||||
</parameters>
|
},
|
||||||
<tool_id>call_I8LKnoijVuhKOM85nnEQgWwd</tool_id>
|
)
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
TXT
|
|
||||||
|
|
||||||
expect(result.strip).to eq(expected)
|
expect(result).to eq(tool_call)
|
||||||
|
|
||||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||||
body: { choices: [message: { content: "OK" }] }.to_json,
|
body: { choices: [message: { content: "OK" }] }.to_json,
|
||||||
|
@ -320,19 +317,20 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
|
|
||||||
expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } })
|
expect(body_json[:tool_choice]).to eq({ type: "function", function: { name: "echo" } })
|
||||||
|
|
||||||
expected = (<<~TXT).strip
|
log = AiApiAuditLog.order(:id).last
|
||||||
<function_calls>
|
expect(log.request_tokens).to eq(55)
|
||||||
<invoke>
|
expect(log.response_tokens).to eq(13)
|
||||||
<tool_name>echo</tool_name>
|
|
||||||
<parameters>
|
|
||||||
<text>h<e>llo</text>
|
|
||||||
</parameters>
|
|
||||||
<tool_id>call_I8LKnoijVuhKOM85nnEQgWwd</tool_id>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
TXT
|
|
||||||
|
|
||||||
expect(result.strip).to eq(expected)
|
expected =
|
||||||
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
|
id: "call_I8LKnoijVuhKOM85nnEQgWwd",
|
||||||
|
name: "echo",
|
||||||
|
parameters: {
|
||||||
|
text: "h<e>llo",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(result).to eq(expected)
|
||||||
|
|
||||||
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
stub_request(:post, "https://api.openai.com/v1/chat/completions").to_return(
|
||||||
body: { choices: [message: { content: "OK" }] }.to_json,
|
body: { choices: [message: { content: "OK" }] }.to_json,
|
||||||
|
@ -487,7 +485,7 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
|
|
||||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]}
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"e AI "}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot\\"}"}}]},"logprobs":null,"finish_reason":null}]}
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{"tool_calls":[{"index":1,"function":{"arguments":"bot2\\"}"}}]},"logprobs":null,"finish_reason":null}]}
|
||||||
|
|
||||||
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
|
data: {"id":"chatcmpl-8xjcr5ZOGZ9v8BDYCx0iwe57lJAGk","object":"chat.completion.chunk","created":1709247429,"model":"gpt-4-0125-preview","system_fingerprint":"fp_91aa3742b1","choices":[{"index":0,"delta":{},"logprobs":null,"finish_reason":"tool_calls"}]}
|
||||||
|
|
||||||
|
@ -495,32 +493,30 @@ RSpec.describe DiscourseAi::Completions::Endpoints::OpenAi do
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
open_ai_mock.stub_raw(raw_data)
|
open_ai_mock.stub_raw(raw_data)
|
||||||
content = +""
|
response = []
|
||||||
|
|
||||||
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
|
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
|
||||||
|
|
||||||
endpoint.perform_completion!(dialect, user) { |partial| content << partial }
|
endpoint.perform_completion!(dialect, user) { |partial| response << partial }
|
||||||
|
|
||||||
expected = <<~TEXT
|
tool_calls = [
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
name: "search",
|
||||||
<tool_name>search</tool_name>
|
id: "call_3Gyr3HylFJwfrtKrL6NaIit1",
|
||||||
<parameters>
|
parameters: {
|
||||||
<search_query>Discourse AI bot</search_query>
|
search_query: "Discourse AI bot",
|
||||||
</parameters>
|
},
|
||||||
<tool_id>call_3Gyr3HylFJwfrtKrL6NaIit1</tool_id>
|
),
|
||||||
</invoke>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
name: "search",
|
||||||
<tool_name>search</tool_name>
|
id: "call_H7YkbgYurHpyJqzwUN4bghwN",
|
||||||
<parameters>
|
parameters: {
|
||||||
<query>Discourse AI bot</query>
|
query: "Discourse AI bot2",
|
||||||
</parameters>
|
},
|
||||||
<tool_id>call_H7YkbgYurHpyJqzwUN4bghwN</tool_id>
|
),
|
||||||
</invoke>
|
]
|
||||||
</function_calls>
|
|
||||||
TEXT
|
|
||||||
|
|
||||||
expect(content).to eq(expected)
|
expect(response).to eq(tool_calls)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "uses proper token accounting" do
|
it "uses proper token accounting" do
|
||||||
|
@ -593,21 +589,16 @@ TEXT
|
||||||
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
|
dialect = compliance.dialect(prompt: compliance.generic_prompt(tools: tools))
|
||||||
endpoint.perform_completion!(dialect, user) { |partial| partials << partial }
|
endpoint.perform_completion!(dialect, user) { |partial| partials << partial }
|
||||||
|
|
||||||
expect(partials.length).to eq(1)
|
tool_call =
|
||||||
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
|
id: "func_id",
|
||||||
|
name: "google",
|
||||||
|
parameters: {
|
||||||
|
query: "Adabas 9.1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
function_call = (<<~TXT).strip
|
expect(partials).to eq([tool_call])
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>google</tool_name>
|
|
||||||
<parameters>
|
|
||||||
<query>Adabas 9.1</query>
|
|
||||||
</parameters>
|
|
||||||
<tool_id>func_id</tool_id>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
TXT
|
|
||||||
|
|
||||||
expect(partials[0].strip).to eq(function_call)
|
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -22,10 +22,15 @@ data: [DONE]
|
||||||
},
|
},
|
||||||
).to_return(status: 200, body: body, headers: {})
|
).to_return(status: 200, body: body, headers: {})
|
||||||
|
|
||||||
response = +""
|
response = []
|
||||||
llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial }
|
llm.generate("who are you?", user: Discourse.system_user) { |partial| response << partial }
|
||||||
|
|
||||||
expect(response).to eq("I am a bot")
|
expect(response).to eq(["I am a bot"])
|
||||||
|
|
||||||
|
log = AiApiAuditLog.order(:id).last
|
||||||
|
|
||||||
|
expect(log.request_tokens).to eq(21)
|
||||||
|
expect(log.response_tokens).to eq(41)
|
||||||
end
|
end
|
||||||
|
|
||||||
it "can perform regular completions" do
|
it "can perform regular completions" do
|
||||||
|
|
|
@ -51,7 +51,13 @@ class VllmMock < EndpointMock
|
||||||
|
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
.stub_request(:post, "https://test.dev/v1/chat/completions")
|
||||||
.with(body: model.default_options.merge(messages: prompt, stream: true).to_json)
|
.with(
|
||||||
|
body:
|
||||||
|
model
|
||||||
|
.default_options
|
||||||
|
.merge(messages: prompt, stream: true, stream_options: { include_usage: true })
|
||||||
|
.to_json,
|
||||||
|
)
|
||||||
.to_return(status: 200, body: chunks)
|
.to_return(status: 200, body: chunks)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -136,29 +142,115 @@ RSpec.describe DiscourseAi::Completions::Endpoints::Vllm do
|
||||||
|
|
||||||
result = llm.generate(prompt, user: Discourse.system_user)
|
result = llm.generate(prompt, user: Discourse.system_user)
|
||||||
|
|
||||||
expected = <<~TEXT
|
expected =
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
name: "calculate",
|
||||||
<tool_name>calculate</tool_name>
|
id: "tool_0",
|
||||||
<parameters>
|
parameters: {
|
||||||
<expression>1+1</expression></parameters>
|
expression: "1+1",
|
||||||
<tool_id>tool_0</tool_id>
|
},
|
||||||
</invoke>
|
)
|
||||||
</function_calls>
|
|
||||||
|
expect(result).to eq(expected)
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
it "correctly accounts for tokens in non streaming mode" do
|
||||||
|
body = (<<~TEXT).strip
|
||||||
|
{"id":"chat-c580e4a9ebaa44a0becc802ed5dc213a","object":"chat.completion","created":1731294404,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"message":{"role":"assistant","content":"Random Number Generator Produces Smallest Possible Result","tool_calls":[]},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":146,"total_tokens":156,"completion_tokens":10},"prompt_logprobs":null}
|
||||||
TEXT
|
TEXT
|
||||||
|
|
||||||
expect(result.strip).to eq(expected.strip)
|
stub_request(:post, "https://test.dev/v1/chat/completions").to_return(status: 200, body: body)
|
||||||
|
|
||||||
|
result = llm.generate("generate a title", user: Discourse.system_user)
|
||||||
|
|
||||||
|
expect(result).to eq("Random Number Generator Produces Smallest Possible Result")
|
||||||
|
|
||||||
|
log = AiApiAuditLog.order(:id).last
|
||||||
|
expect(log.request_tokens).to eq(146)
|
||||||
|
expect(log.response_tokens).to eq(10)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
it "can properly include usage in streaming mode" do
|
||||||
|
payload = <<~TEXT.strip
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"role":"assistant","content":""},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":46,"completion_tokens":0}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"Hello"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":47,"completion_tokens":1}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Sam"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":48,"completion_tokens":2}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":49,"completion_tokens":3}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" It"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":50,"completion_tokens":4}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"'s"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":51,"completion_tokens":5}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" nice"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":52,"completion_tokens":6}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":53,"completion_tokens":7}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" meet"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":54,"completion_tokens":8}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":55,"completion_tokens":9}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"."},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":56,"completion_tokens":10}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" Is"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":57,"completion_tokens":11}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" there"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":58,"completion_tokens":12}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" something"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":59,"completion_tokens":13}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" I"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":60,"completion_tokens":14}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" can"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":61,"completion_tokens":15}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" help"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":62,"completion_tokens":16}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":63,"completion_tokens":17}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" with"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":64,"completion_tokens":18}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" or"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":65,"completion_tokens":19}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" would"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":66,"completion_tokens":20}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" you"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":67,"completion_tokens":21}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" like"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":68,"completion_tokens":22}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" to"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":69,"completion_tokens":23}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":" chat"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":70,"completion_tokens":24}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":"?"},"logprobs":null,"finish_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":71,"completion_tokens":25}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[{"index":0,"delta":{"content":""},"logprobs":null,"finish_reason":"stop","stop_reason":null}],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}}
|
||||||
|
|
||||||
|
data: {"id":"chat-b183bb5829194e8891cacceabfdb5274","object":"chat.completion.chunk","created":1731295402,"model":"meta-llama/Meta-Llama-3.1-70B-Instruct","choices":[],"usage":{"prompt_tokens":46,"total_tokens":72,"completion_tokens":26}}
|
||||||
|
|
||||||
|
data: [DONE]
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
stub_request(:post, "https://test.dev/v1/chat/completions").to_return(
|
||||||
|
status: 200,
|
||||||
|
body: payload,
|
||||||
|
)
|
||||||
|
|
||||||
|
response = []
|
||||||
|
llm.generate("say hello", user: Discourse.system_user) { |partial| response << partial }
|
||||||
|
|
||||||
|
expect(response.join).to eq(
|
||||||
|
"Hello Sam. It's nice to meet you. Is there something I can help you with or would you like to chat?",
|
||||||
|
)
|
||||||
|
|
||||||
|
log = AiApiAuditLog.order(:id).last
|
||||||
|
expect(log.request_tokens).to eq(46)
|
||||||
|
expect(log.response_tokens).to eq(26)
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "#perform_completion!" do
|
describe "#perform_completion!" do
|
||||||
context "when using regular mode" do
|
context "when using regular mode" do
|
||||||
context "with simple prompts" do
|
|
||||||
it "completes a trivial prompt and logs the response" do
|
|
||||||
compliance.regular_mode_simple_prompt(vllm_mock)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "with tools" do
|
context "with tools" do
|
||||||
it "returns a function invocation" do
|
it "returns a function invocation" do
|
||||||
compliance.regular_mode_tools(vllm_mock)
|
compliance.regular_mode_tools(vllm_mock)
|
||||||
|
|
|
@ -1,182 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::Completions::FunctionCallNormalizer do
|
|
||||||
let(:buffer) { +"" }
|
|
||||||
|
|
||||||
let(:normalizer) do
|
|
||||||
blk = ->(data, cancel) { buffer << data }
|
|
||||||
cancel = -> { @done = true }
|
|
||||||
DiscourseAi::Completions::FunctionCallNormalizer.new(blk, cancel)
|
|
||||||
end
|
|
||||||
|
|
||||||
def pass_through!(data)
|
|
||||||
normalizer << data
|
|
||||||
expect(buffer[-data.length..-1]).to eq(data)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "is usable in non streaming mode" do
|
|
||||||
xml = (<<~XML).strip
|
|
||||||
hello
|
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>hello</tool_name>
|
|
||||||
</invoke>
|
|
||||||
XML
|
|
||||||
|
|
||||||
text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
|
|
||||||
|
|
||||||
expect(text).to eq("hello")
|
|
||||||
|
|
||||||
expected_function_calls = (<<~XML).strip
|
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>hello</tool_name>
|
|
||||||
<tool_id>tool_0</tool_id>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
XML
|
|
||||||
|
|
||||||
expect(function_calls).to eq(expected_function_calls)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "strips junk from end of function calls" do
|
|
||||||
xml = (<<~XML).strip
|
|
||||||
hello
|
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>hello</tool_name>
|
|
||||||
</invoke>
|
|
||||||
junk
|
|
||||||
XML
|
|
||||||
|
|
||||||
_text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(xml)
|
|
||||||
|
|
||||||
expected_function_calls = (<<~XML).strip
|
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>hello</tool_name>
|
|
||||||
<tool_id>tool_0</tool_id>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
XML
|
|
||||||
|
|
||||||
expect(function_calls).to eq(expected_function_calls)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "returns nil for function calls if there are none" do
|
|
||||||
input = "hello world\n"
|
|
||||||
text, function_calls = DiscourseAi::Completions::FunctionCallNormalizer.normalize(input)
|
|
||||||
|
|
||||||
expect(text).to eq(input)
|
|
||||||
expect(function_calls).to eq(nil)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "passes through data if there are no function calls detected" do
|
|
||||||
pass_through!("hello")
|
|
||||||
pass_through!("<tool_name>hello</tool_name>")
|
|
||||||
pass_through!("<parameters><hello>world</hello></parameters>")
|
|
||||||
pass_through!("<function_call>")
|
|
||||||
end
|
|
||||||
|
|
||||||
it "properly handles non English tools" do
|
|
||||||
normalizer << "hello<function"
|
|
||||||
expect(buffer).to eq("hello")
|
|
||||||
|
|
||||||
normalizer << "_calls>\n"
|
|
||||||
|
|
||||||
normalizer << (<<~XML).strip
|
|
||||||
<invoke>
|
|
||||||
<tool_name>hello</tool_name>
|
|
||||||
<parameters>
|
|
||||||
<hello>世界</hello>
|
|
||||||
</parameters>
|
|
||||||
</invoke>
|
|
||||||
XML
|
|
||||||
|
|
||||||
expected = (<<~XML).strip
|
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>hello</tool_name>
|
|
||||||
<parameters>
|
|
||||||
<hello>世界</hello>
|
|
||||||
</parameters>
|
|
||||||
<tool_id>tool_0</tool_id>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
XML
|
|
||||||
|
|
||||||
function_calls = normalizer.function_calls
|
|
||||||
expect(function_calls).to eq(expected)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "works correctly even if you only give it 1 letter at a time" do
|
|
||||||
xml = (<<~XML).strip
|
|
||||||
abc
|
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>hello</tool_name>
|
|
||||||
<parameters>
|
|
||||||
<hello>world</hello>
|
|
||||||
</parameters>
|
|
||||||
<tool_id>abc</tool_id>
|
|
||||||
</invoke>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>hello2</tool_name>
|
|
||||||
<parameters>
|
|
||||||
<hello>world</hello>
|
|
||||||
</parameters>
|
|
||||||
<tool_id>aba</tool_id>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
XML
|
|
||||||
|
|
||||||
xml.each_char { |char| normalizer << char }
|
|
||||||
|
|
||||||
expect(buffer + normalizer.function_calls).to eq(xml)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "supports multiple invokes" do
|
|
||||||
xml = (<<~XML).strip
|
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>hello</tool_name>
|
|
||||||
<parameters>
|
|
||||||
<hello>world</hello>
|
|
||||||
</parameters>
|
|
||||||
<tool_id>abc</tool_id>
|
|
||||||
</invoke>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>hello2</tool_name>
|
|
||||||
<parameters>
|
|
||||||
<hello>world</hello>
|
|
||||||
</parameters>
|
|
||||||
<tool_id>aba</tool_id>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
XML
|
|
||||||
|
|
||||||
normalizer << xml
|
|
||||||
|
|
||||||
expect(normalizer.function_calls).to eq(xml)
|
|
||||||
end
|
|
||||||
|
|
||||||
it "can will cancel if it encounteres </function_calls>" do
|
|
||||||
normalizer << "<function_calls>"
|
|
||||||
expect(normalizer.done).to eq(false)
|
|
||||||
normalizer << "</function_calls>"
|
|
||||||
expect(normalizer.done).to eq(true)
|
|
||||||
expect(@done).to eq(true)
|
|
||||||
|
|
||||||
expect(normalizer.function_calls).to eq("<function_calls></function_calls>")
|
|
||||||
end
|
|
||||||
|
|
||||||
it "pauses on function call and starts buffering" do
|
|
||||||
normalizer << "hello<function_call"
|
|
||||||
expect(buffer).to eq("hello")
|
|
||||||
expect(normalizer.done).to eq(false)
|
|
||||||
|
|
||||||
normalizer << ">"
|
|
||||||
expect(buffer).to eq("hello<function_call>")
|
|
||||||
expect(normalizer.done).to eq(false)
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -0,0 +1,47 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
describe DiscourseAi::Completions::JsonStreamDecoder do
|
||||||
|
let(:decoder) { DiscourseAi::Completions::JsonStreamDecoder.new }
|
||||||
|
|
||||||
|
it "should be able to parse simple messages" do
|
||||||
|
result = decoder << "data: #{{ hello: "world" }.to_json}"
|
||||||
|
expect(result).to eq([{ hello: "world" }])
|
||||||
|
end
|
||||||
|
|
||||||
|
it "should handle anthropic mixed stlye streams" do
|
||||||
|
stream = (<<~TEXT).split("|")
|
||||||
|
event: |message_start|
|
||||||
|
data: |{"hel|lo": "world"}|
|
||||||
|
|
||||||
|
event: |message_start
|
||||||
|
data: {"foo": "bar"}
|
||||||
|
|
||||||
|
event: |message_start
|
||||||
|
data: {"ba|z": "qux"|}
|
||||||
|
|
||||||
|
[DONE]
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
results = []
|
||||||
|
stream.each { |chunk| results << (decoder << chunk) }
|
||||||
|
|
||||||
|
expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }])
|
||||||
|
end
|
||||||
|
|
||||||
|
it "should be able to handle complex overlaps" do
|
||||||
|
stream = (<<~TEXT).split("|")
|
||||||
|
data: |{"hel|lo": "world"}
|
||||||
|
|
||||||
|
data: {"foo": "bar"}
|
||||||
|
|
||||||
|
data: {"ba|z": "qux"|}
|
||||||
|
|
||||||
|
[DONE]
|
||||||
|
TEXT
|
||||||
|
|
||||||
|
results = []
|
||||||
|
stream.each { |chunk| results << (decoder << chunk) }
|
||||||
|
|
||||||
|
expect(results.flatten.compact).to eq([{ hello: "world" }, { foo: "bar" }, { baz: "qux" }])
|
||||||
|
end
|
||||||
|
end
|
|
@ -0,0 +1,188 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
RSpec.describe DiscourseAi::Completions::XmlToolProcessor do
|
||||||
|
let(:processor) { DiscourseAi::Completions::XmlToolProcessor.new }
|
||||||
|
|
||||||
|
it "can process simple text" do
|
||||||
|
result = []
|
||||||
|
result << (processor << "hello")
|
||||||
|
result << (processor << " world ")
|
||||||
|
expect(result).to eq([["hello"], [" world "]])
|
||||||
|
expect(processor.finish).to eq([])
|
||||||
|
expect(processor.should_cancel?).to eq(false)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "is usable for simple single message mode" do
|
||||||
|
xml = (<<~XML).strip
|
||||||
|
hello
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>hello</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<hello>world</hello>
|
||||||
|
<test>value</test>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
XML
|
||||||
|
|
||||||
|
result = []
|
||||||
|
result << (processor << xml)
|
||||||
|
result << (processor.finish)
|
||||||
|
|
||||||
|
tool_call =
|
||||||
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
|
id: "tool_0",
|
||||||
|
name: "hello",
|
||||||
|
parameters: {
|
||||||
|
hello: "world",
|
||||||
|
test: "value",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
expect(result).to eq([["hello"], [tool_call]])
|
||||||
|
expect(processor.should_cancel?).to eq(false)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "handles multiple tool calls in sequence" do
|
||||||
|
xml = (<<~XML).strip
|
||||||
|
start
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>first_tool</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<param1>value1</param1>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>second_tool</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<param2>value2</param2>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
</function_calls>
|
||||||
|
end
|
||||||
|
XML
|
||||||
|
|
||||||
|
result = []
|
||||||
|
result << (processor << xml)
|
||||||
|
result << (processor.finish)
|
||||||
|
|
||||||
|
first_tool =
|
||||||
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
|
id: "tool_0",
|
||||||
|
name: "first_tool",
|
||||||
|
parameters: {
|
||||||
|
param1: "value1",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
second_tool =
|
||||||
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
|
id: "tool_1",
|
||||||
|
name: "second_tool",
|
||||||
|
parameters: {
|
||||||
|
param2: "value2",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(result).to eq([["start"], [first_tool, second_tool]])
|
||||||
|
expect(processor.should_cancel?).to eq(true)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "handles non-English parameters correctly" do
|
||||||
|
xml = (<<~XML).strip
|
||||||
|
こんにちは
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>translator</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<text>世界</text>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
XML
|
||||||
|
|
||||||
|
result = []
|
||||||
|
result << (processor << xml)
|
||||||
|
result << (processor.finish)
|
||||||
|
|
||||||
|
tool_call =
|
||||||
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
|
id: "tool_0",
|
||||||
|
name: "translator",
|
||||||
|
parameters: {
|
||||||
|
text: "世界",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(result).to eq([["こんにちは"], [tool_call]])
|
||||||
|
end
|
||||||
|
|
||||||
|
it "processes input character by character" do
|
||||||
|
xml =
|
||||||
|
"hi<function_calls><invoke><tool_name>test</tool_name><parameters><p>v</p></parameters></invoke>"
|
||||||
|
|
||||||
|
result = []
|
||||||
|
xml.each_char { |char| result << (processor << char) }
|
||||||
|
result << processor.finish
|
||||||
|
|
||||||
|
tool_call =
|
||||||
|
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { p: "v" })
|
||||||
|
|
||||||
|
filtered_result = result.reject(&:empty?)
|
||||||
|
expect(filtered_result).to eq([["h"], ["i"], [tool_call]])
|
||||||
|
end
|
||||||
|
|
||||||
|
it "handles malformed XML gracefully" do
|
||||||
|
xml = (<<~XML).strip
|
||||||
|
text
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>test</tool_name>
|
||||||
|
<parameters>
|
||||||
|
<param>value
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
malformed
|
||||||
|
XML
|
||||||
|
|
||||||
|
result = []
|
||||||
|
result << (processor << xml)
|
||||||
|
result << (processor.finish)
|
||||||
|
|
||||||
|
# Should just do its best to parse the XML
|
||||||
|
tool_call =
|
||||||
|
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "test", parameters: { param: "" })
|
||||||
|
expect(result).to eq([["text"], [tool_call]])
|
||||||
|
end
|
||||||
|
|
||||||
|
it "correctly processes empty parameter sets" do
|
||||||
|
xml = (<<~XML).strip
|
||||||
|
hello
|
||||||
|
<function_calls>
|
||||||
|
<invoke>
|
||||||
|
<tool_name>no_params</tool_name>
|
||||||
|
<parameters>
|
||||||
|
</parameters>
|
||||||
|
</invoke>
|
||||||
|
XML
|
||||||
|
|
||||||
|
result = []
|
||||||
|
result << (processor << xml)
|
||||||
|
result << (processor.finish)
|
||||||
|
|
||||||
|
tool_call =
|
||||||
|
DiscourseAi::Completions::ToolCall.new(id: "tool_0", name: "no_params", parameters: {})
|
||||||
|
|
||||||
|
expect(result).to eq([["hello"], [tool_call]])
|
||||||
|
end
|
||||||
|
|
||||||
|
it "properly handles cancelled processing" do
|
||||||
|
xml = "start<function_calls></function_calls>"
|
||||||
|
result = []
|
||||||
|
result << (processor << xml)
|
||||||
|
result << (processor << "more text")
|
||||||
|
result << processor.finish
|
||||||
|
|
||||||
|
expect(result).to eq([["start"], [], []])
|
||||||
|
expect(processor.should_cancel?).to eq(true)
|
||||||
|
end
|
||||||
|
end
|
|
@ -72,40 +72,27 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
|
|
||||||
it "can parse string that are wrapped in quotes" do
|
it "can parse string that are wrapped in quotes" do
|
||||||
SiteSetting.ai_stability_api_key = "123"
|
SiteSetting.ai_stability_api_key = "123"
|
||||||
xml = <<~XML
|
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>image</tool_name>
|
|
||||||
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
|
|
||||||
<parameters>
|
|
||||||
<prompts>["cat oil painting", "big car"]</prompts>
|
|
||||||
<aspect_ratio>"16:9"</aspect_ratio>
|
|
||||||
</parameters>
|
|
||||||
</invoke>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>image</tool_name>
|
|
||||||
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
|
|
||||||
<parameters>
|
|
||||||
<prompts>["cat oil painting", "big car"]</prompts>
|
|
||||||
<aspect_ratio>'16:9'</aspect_ratio>
|
|
||||||
</parameters>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
XML
|
|
||||||
|
|
||||||
image1, image2 =
|
tool_call =
|
||||||
tools =
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
DiscourseAi::AiBot::Personas::Artist.new.find_tools(
|
name: "image",
|
||||||
xml,
|
id: "call_JtYQMful5QKqw97XFsHzPweB",
|
||||||
|
parameters: {
|
||||||
|
prompts: ["cat oil painting", "big car"],
|
||||||
|
aspect_ratio: "16:9",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_instance =
|
||||||
|
DiscourseAi::AiBot::Personas::Artist.new.find_tool(
|
||||||
|
tool_call,
|
||||||
bot_user: nil,
|
bot_user: nil,
|
||||||
llm: nil,
|
llm: nil,
|
||||||
context: nil,
|
context: nil,
|
||||||
)
|
)
|
||||||
expect(image1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
|
||||||
expect(image1.parameters[:aspect_ratio]).to eq("16:9")
|
|
||||||
expect(image2.parameters[:aspect_ratio]).to eq("16:9")
|
|
||||||
|
|
||||||
expect(tools.length).to eq(2)
|
expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
||||||
|
expect(tool_instance.parameters[:aspect_ratio]).to eq("16:9")
|
||||||
end
|
end
|
||||||
|
|
||||||
it "enforces enums" do
|
it "enforces enums" do
|
||||||
|
@ -132,38 +119,64 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
</function_calls>
|
</function_calls>
|
||||||
XML
|
XML
|
||||||
|
|
||||||
search1, search2 =
|
tool_call =
|
||||||
tools =
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
DiscourseAi::AiBot::Personas::General.new.find_tools(
|
name: "search",
|
||||||
xml,
|
id: "call_JtYQMful5QKqw97XFsHzPweB",
|
||||||
|
parameters: {
|
||||||
|
max_posts: "3.2",
|
||||||
|
status: "cow",
|
||||||
|
foo: "bar",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_instance =
|
||||||
|
DiscourseAi::AiBot::Personas::General.new.find_tool(
|
||||||
|
tool_call,
|
||||||
bot_user: nil,
|
bot_user: nil,
|
||||||
llm: nil,
|
llm: nil,
|
||||||
context: nil,
|
context: nil,
|
||||||
)
|
)
|
||||||
|
|
||||||
expect(search1.parameters.key?(:status)).to eq(false)
|
expect(tool_instance.parameters.key?(:status)).to eq(false)
|
||||||
expect(search2.parameters[:status]).to eq("open")
|
|
||||||
|
tool_call =
|
||||||
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
|
name: "search",
|
||||||
|
id: "call_JtYQMful5QKqw97XFsHzPweB",
|
||||||
|
parameters: {
|
||||||
|
max_posts: "3.2",
|
||||||
|
status: "open",
|
||||||
|
foo: "bar",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_instance =
|
||||||
|
DiscourseAi::AiBot::Personas::General.new.find_tool(
|
||||||
|
tool_call,
|
||||||
|
bot_user: nil,
|
||||||
|
llm: nil,
|
||||||
|
context: nil,
|
||||||
|
)
|
||||||
|
|
||||||
|
expect(tool_instance.parameters[:status]).to eq("open")
|
||||||
end
|
end
|
||||||
|
|
||||||
it "can coerce integers" do
|
it "can coerce integers" do
|
||||||
xml = <<~XML
|
tool_call =
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
name: "search",
|
||||||
<tool_name>search</tool_name>
|
id: "call_JtYQMful5QKqw97XFsHzPweB",
|
||||||
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
|
parameters: {
|
||||||
<parameters>
|
max_posts: "3.2",
|
||||||
<max_posts>"3.2"</max_posts>
|
search_query: "hello world",
|
||||||
<search_query>hello world</search_query>
|
foo: "bar",
|
||||||
<foo>bar</foo>
|
},
|
||||||
</parameters>
|
)
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
XML
|
|
||||||
|
|
||||||
search, =
|
search =
|
||||||
tools =
|
DiscourseAi::AiBot::Personas::General.new.find_tool(
|
||||||
DiscourseAi::AiBot::Personas::General.new.find_tools(
|
tool_call,
|
||||||
xml,
|
|
||||||
bot_user: nil,
|
bot_user: nil,
|
||||||
llm: nil,
|
llm: nil,
|
||||||
context: nil,
|
context: nil,
|
||||||
|
@ -177,43 +190,23 @@ RSpec.describe DiscourseAi::AiBot::Personas::Persona do
|
||||||
it "can correctly parse arrays in tools" do
|
it "can correctly parse arrays in tools" do
|
||||||
SiteSetting.ai_openai_api_key = "123"
|
SiteSetting.ai_openai_api_key = "123"
|
||||||
|
|
||||||
# Dall E tool uses an array for params
|
tool_call =
|
||||||
xml = <<~XML
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<function_calls>
|
name: "dall_e",
|
||||||
<invoke>
|
id: "call_JtYQMful5QKqw97XFsHzPweB",
|
||||||
<tool_name>dall_e</tool_name>
|
parameters: {
|
||||||
<tool_id>call_JtYQMful5QKqw97XFsHzPweB</tool_id>
|
prompts: ["cat oil painting", "big car"],
|
||||||
<parameters>
|
},
|
||||||
<prompts>["cat oil painting", "big car"]</prompts>
|
)
|
||||||
</parameters>
|
|
||||||
</invoke>
|
tool_instance =
|
||||||
<invoke>
|
DiscourseAi::AiBot::Personas::DallE3.new.find_tool(
|
||||||
<tool_name>dall_e</tool_name>
|
tool_call,
|
||||||
<tool_id>abc</tool_id>
|
|
||||||
<parameters>
|
|
||||||
<prompts>["pic3"]</prompts>
|
|
||||||
</parameters>
|
|
||||||
</invoke>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>unknown</tool_name>
|
|
||||||
<tool_id>abc</tool_id>
|
|
||||||
<parameters>
|
|
||||||
<prompts>["pic3"]</prompts>
|
|
||||||
</parameters>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
XML
|
|
||||||
dall_e1, dall_e2 =
|
|
||||||
tools =
|
|
||||||
DiscourseAi::AiBot::Personas::DallE3.new.find_tools(
|
|
||||||
xml,
|
|
||||||
bot_user: nil,
|
bot_user: nil,
|
||||||
llm: nil,
|
llm: nil,
|
||||||
context: nil,
|
context: nil,
|
||||||
)
|
)
|
||||||
expect(dall_e1.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
expect(tool_instance.parameters[:prompts]).to eq(["cat oil painting", "big car"])
|
||||||
expect(dall_e2.parameters[:prompts]).to eq(["pic3"])
|
|
||||||
expect(tools.length).to eq(2)
|
|
||||||
end
|
end
|
||||||
|
|
||||||
describe "custom personas" do
|
describe "custom personas" do
|
||||||
|
|
|
@ -55,6 +55,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
before { SiteSetting.ai_embeddings_enabled = false }
|
||||||
|
|
||||||
after do
|
after do
|
||||||
# we must reset cache on persona cause data can be rolled back
|
# we must reset cache on persona cause data can be rolled back
|
||||||
AiPersona.persona_cache.flush!
|
AiPersona.persona_cache.flush!
|
||||||
|
@ -83,17 +85,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
end
|
end
|
||||||
|
|
||||||
let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) }
|
let!(:ai_persona) { Fabricate(:ai_persona, tools: ["custom-#{custom_tool.id}"]) }
|
||||||
let(:function_call) { (<<~XML).strip }
|
let(:tool_call) do
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
name: "search",
|
||||||
<tool_name>search</tool_name>
|
id: "666",
|
||||||
<tool_id>666</tool_id>
|
parameters: {
|
||||||
<parameters>
|
query: "Can you use the custom tool",
|
||||||
<query>Can you use the custom tool</query>
|
},
|
||||||
</parameters>
|
)
|
||||||
</invoke>
|
end
|
||||||
</function_calls>",
|
|
||||||
XML
|
|
||||||
|
|
||||||
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) }
|
let(:bot) { DiscourseAi::AiBot::Bot.as(bot_user, persona: ai_persona.class_instance.new) }
|
||||||
|
|
||||||
|
@ -115,7 +115,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
reply_post = nil
|
reply_post = nil
|
||||||
prompts = nil
|
prompts = nil
|
||||||
|
|
||||||
responses = [function_call]
|
responses = [tool_call]
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
|
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompts|
|
||||||
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
|
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
|
||||||
reply_post = playground.reply_to(new_post)
|
reply_post = playground.reply_to(new_post)
|
||||||
|
@ -133,7 +133,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
it "can force usage of a tool" do
|
it "can force usage of a tool" do
|
||||||
tool_name = "custom-#{custom_tool.id}"
|
tool_name = "custom-#{custom_tool.id}"
|
||||||
ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1)
|
ai_persona.update!(tools: [[tool_name, nil, true]], forced_tool_count: 1)
|
||||||
responses = [function_call, "custom tool did stuff (maybe)"]
|
responses = [tool_call, "custom tool did stuff (maybe)"]
|
||||||
|
|
||||||
prompts = nil
|
prompts = nil
|
||||||
reply_post = nil
|
reply_post = nil
|
||||||
|
@ -166,7 +166,7 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
|
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
|
||||||
playground = DiscourseAi::AiBot::Playground.new(bot)
|
playground = DiscourseAi::AiBot::Playground.new(bot)
|
||||||
|
|
||||||
responses = [function_call, "custom tool did stuff (maybe)"]
|
responses = [tool_call, "custom tool did stuff (maybe)"]
|
||||||
|
|
||||||
reply_post = nil
|
reply_post = nil
|
||||||
|
|
||||||
|
@ -206,13 +206,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
|
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona_klass.new)
|
||||||
playground = DiscourseAi::AiBot::Playground.new(bot)
|
playground = DiscourseAi::AiBot::Playground.new(bot)
|
||||||
|
|
||||||
|
responses = ["custom tool did stuff (maybe)", tool_call]
|
||||||
|
|
||||||
# lets ensure tool does not run...
|
# lets ensure tool does not run...
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt|
|
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do |_, _, _prompt|
|
||||||
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
|
new_post = Fabricate(:post, raw: "Can you use the custom tool?")
|
||||||
reply_post = playground.reply_to(new_post)
|
reply_post = playground.reply_to(new_post)
|
||||||
end
|
end
|
||||||
|
|
||||||
expect(reply_post.raw.strip).to eq(function_call)
|
expect(reply_post.raw.strip).to eq("custom tool did stuff (maybe)")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -452,10 +454,25 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
it "can run tools" do
|
it "can run tools" do
|
||||||
persona.update!(tools: ["Time"])
|
persona.update!(tools: ["Time"])
|
||||||
|
|
||||||
responses = [
|
tool_call1 =
|
||||||
"<function_calls><invoke><tool_name>time</tool_name><tool_id>time</tool_id><parameters><timezone>Buenos Aires</timezone></parameters></invoke></function_calls>",
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
"The time is 2023-12-14 17:24:00 -0300",
|
name: "time",
|
||||||
]
|
id: "time",
|
||||||
|
parameters: {
|
||||||
|
timezone: "Buenos Aires",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
tool_call2 =
|
||||||
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
|
name: "time",
|
||||||
|
id: "time",
|
||||||
|
parameters: {
|
||||||
|
timezone: "Sydney",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
responses = [[tool_call1, tool_call2], "The time is 2023-12-14 17:24:00 -0300"]
|
||||||
|
|
||||||
message =
|
message =
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(responses) do
|
||||||
|
@ -470,7 +487,8 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
|
|
||||||
# it also needs to have tool details now set on message
|
# it also needs to have tool details now set on message
|
||||||
prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id)
|
prompt = ChatMessageCustomPrompt.find_by(message_id: reply.id)
|
||||||
expect(prompt.custom_prompt.length).to eq(3)
|
|
||||||
|
expect(prompt.custom_prompt.length).to eq(5)
|
||||||
|
|
||||||
# TODO in chat I am mixed on including this in the context, but I guess maybe?
|
# TODO in chat I am mixed on including this in the context, but I guess maybe?
|
||||||
# thinking about this
|
# thinking about this
|
||||||
|
@ -782,30 +800,29 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
end
|
end
|
||||||
|
|
||||||
it "supports multiple function calls" do
|
it "supports multiple function calls" do
|
||||||
response1 = (<<~TXT).strip
|
tool_call1 =
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
name: "search",
|
||||||
<tool_name>search</tool_name>
|
id: "search",
|
||||||
<tool_id>search</tool_id>
|
parameters: {
|
||||||
<parameters>
|
search_query: "testing various things",
|
||||||
<search_query>testing various things</search_query>
|
},
|
||||||
</parameters>
|
)
|
||||||
</invoke>
|
|
||||||
<invoke>
|
tool_call2 =
|
||||||
<tool_name>search</tool_name>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<tool_id>search</tool_id>
|
name: "search",
|
||||||
<parameters>
|
id: "search",
|
||||||
<search_query>another search</search_query>
|
parameters: {
|
||||||
</parameters>
|
search_query: "another search",
|
||||||
</invoke>
|
},
|
||||||
</function_calls>
|
)
|
||||||
TXT
|
|
||||||
|
|
||||||
response2 = "I found stuff"
|
response2 = "I found stuff"
|
||||||
|
|
||||||
DiscourseAi::Completions::Llm.with_prepared_responses([response1, response2]) do
|
DiscourseAi::Completions::Llm.with_prepared_responses(
|
||||||
playground.reply_to(third_post)
|
[[tool_call1, tool_call2], response2],
|
||||||
end
|
) { playground.reply_to(third_post) }
|
||||||
|
|
||||||
last_post = third_post.topic.reload.posts.order(:post_number).last
|
last_post = third_post.topic.reload.posts.order(:post_number).last
|
||||||
|
|
||||||
|
@ -819,17 +836,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona.class_instance.new)
|
bot = DiscourseAi::AiBot::Bot.as(bot_user, persona: persona.class_instance.new)
|
||||||
playground = described_class.new(bot)
|
playground = described_class.new(bot)
|
||||||
|
|
||||||
response1 = (<<~TXT).strip
|
response1 =
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
name: "search",
|
||||||
<tool_name>search</tool_name>
|
id: "search",
|
||||||
<tool_id>search</tool_id>
|
parameters: {
|
||||||
<parameters>
|
search_query: "testing various things",
|
||||||
<search_query>testing various things</search_query>
|
},
|
||||||
</parameters>
|
)
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
TXT
|
|
||||||
|
|
||||||
response2 = "I found stuff"
|
response2 = "I found stuff"
|
||||||
|
|
||||||
|
@ -843,17 +857,14 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
end
|
end
|
||||||
|
|
||||||
it "does not include placeholders in conversation context but includes all completions" do
|
it "does not include placeholders in conversation context but includes all completions" do
|
||||||
response1 = (<<~TXT).strip
|
response1 =
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
name: "search",
|
||||||
<tool_name>search</tool_name>
|
id: "search",
|
||||||
<tool_id>search</tool_id>
|
parameters: {
|
||||||
<parameters>
|
search_query: "testing various things",
|
||||||
<search_query>testing various things</search_query>
|
},
|
||||||
</parameters>
|
)
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
TXT
|
|
||||||
|
|
||||||
response2 = "I found some really amazing stuff!"
|
response2 = "I found some really amazing stuff!"
|
||||||
|
|
||||||
|
@ -889,17 +900,15 @@ RSpec.describe DiscourseAi::AiBot::Playground do
|
||||||
[{ b64_json: image, revised_prompt: "a pink cow 1" }]
|
[{ b64_json: image, revised_prompt: "a pink cow 1" }]
|
||||||
end
|
end
|
||||||
|
|
||||||
let(:response) { (<<~TXT).strip }
|
let(:response) do
|
||||||
<function_calls>
|
DiscourseAi::Completions::ToolCall.new(
|
||||||
<invoke>
|
name: "dall_e",
|
||||||
<tool_name>dall_e</tool_name>
|
id: "dall_e",
|
||||||
<tool_id>dall_e</tool_id>
|
parameters: {
|
||||||
<parameters>
|
prompts: ["a pink cow"],
|
||||||
<prompts>["a pink cow"]</prompts>
|
},
|
||||||
</parameters>
|
)
|
||||||
</invoke>
|
end
|
||||||
</function_calls>
|
|
||||||
TXT
|
|
||||||
|
|
||||||
it "properly returns an image when skipping tool details" do
|
it "properly returns an image when skipping tool details" do
|
||||||
persona.update!(tool_details: false)
|
persona.update!(tool_details: false)
|
||||||
|
|
|
@ -541,16 +541,10 @@ RSpec.describe DiscourseAi::Admin::AiPersonasController do
|
||||||
expect(topic.title).to eq("An amazing title")
|
expect(topic.title).to eq("An amazing title")
|
||||||
expect(topic.posts.count).to eq(2)
|
expect(topic.posts.count).to eq(2)
|
||||||
|
|
||||||
# now let's try to make a reply with a tool call
|
tool_call =
|
||||||
function_call = <<~XML
|
DiscourseAi::Completions::ToolCall.new(name: "categories", parameters: {}, id: "tool_1")
|
||||||
<function_calls>
|
|
||||||
<invoke>
|
|
||||||
<tool_name>categories</tool_name>
|
|
||||||
</invoke>
|
|
||||||
</function_calls>
|
|
||||||
XML
|
|
||||||
|
|
||||||
fake_endpoint.fake_content = [function_call, "this is the response after the tool"]
|
fake_endpoint.fake_content = [tool_call, "this is the response after the tool"]
|
||||||
# this simplifies function calls
|
# this simplifies function calls
|
||||||
fake_endpoint.chunk_count = 1
|
fake_endpoint.chunk_count = 1
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,8 @@ RSpec.describe DiscourseAi::AiBot::BotController do
|
||||||
fab!(:user)
|
fab!(:user)
|
||||||
fab!(:pm_topic) { Fabricate(:private_message_topic) }
|
fab!(:pm_topic) { Fabricate(:private_message_topic) }
|
||||||
fab!(:pm_post) { Fabricate(:post, topic: pm_topic) }
|
fab!(:pm_post) { Fabricate(:post, topic: pm_topic) }
|
||||||
|
fab!(:pm_post2) { Fabricate(:post, topic: pm_topic) }
|
||||||
|
fab!(:pm_post3) { Fabricate(:post, topic: pm_topic) }
|
||||||
|
|
||||||
before { sign_in(user) }
|
before { sign_in(user) }
|
||||||
|
|
||||||
|
@ -22,6 +24,17 @@ RSpec.describe DiscourseAi::AiBot::BotController do
|
||||||
user = pm_topic.topic_allowed_users.first.user
|
user = pm_topic.topic_allowed_users.first.user
|
||||||
sign_in(user)
|
sign_in(user)
|
||||||
|
|
||||||
|
log1 =
|
||||||
|
AiApiAuditLog.create!(
|
||||||
|
provider_id: 1,
|
||||||
|
topic_id: pm_topic.id,
|
||||||
|
raw_request_payload: "request",
|
||||||
|
raw_response_payload: "response",
|
||||||
|
request_tokens: 1,
|
||||||
|
response_tokens: 2,
|
||||||
|
)
|
||||||
|
|
||||||
|
log2 =
|
||||||
AiApiAuditLog.create!(
|
AiApiAuditLog.create!(
|
||||||
post_id: pm_post.id,
|
post_id: pm_post.id,
|
||||||
provider_id: 1,
|
provider_id: 1,
|
||||||
|
@ -32,24 +45,43 @@ RSpec.describe DiscourseAi::AiBot::BotController do
|
||||||
response_tokens: 2,
|
response_tokens: 2,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
log3 =
|
||||||
|
AiApiAuditLog.create!(
|
||||||
|
post_id: pm_post2.id,
|
||||||
|
provider_id: 1,
|
||||||
|
topic_id: pm_topic.id,
|
||||||
|
raw_request_payload: "request",
|
||||||
|
raw_response_payload: "response",
|
||||||
|
request_tokens: 1,
|
||||||
|
response_tokens: 2,
|
||||||
|
)
|
||||||
|
|
||||||
Group.refresh_automatic_groups!
|
Group.refresh_automatic_groups!
|
||||||
SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s
|
SiteSetting.ai_bot_debugging_allowed_groups = user.groups.first.id.to_s
|
||||||
|
|
||||||
get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info"
|
get "/discourse-ai/ai-bot/post/#{pm_post.id}/show-debug-info"
|
||||||
expect(response.status).to eq(200)
|
expect(response.status).to eq(200)
|
||||||
|
|
||||||
|
expect(response.parsed_body["id"]).to eq(log2.id)
|
||||||
|
expect(response.parsed_body["next_log_id"]).to eq(log3.id)
|
||||||
|
expect(response.parsed_body["prev_log_id"]).to eq(log1.id)
|
||||||
|
expect(response.parsed_body["topic_id"]).to eq(pm_topic.id)
|
||||||
|
|
||||||
expect(response.parsed_body["request_tokens"]).to eq(1)
|
expect(response.parsed_body["request_tokens"]).to eq(1)
|
||||||
expect(response.parsed_body["response_tokens"]).to eq(2)
|
expect(response.parsed_body["response_tokens"]).to eq(2)
|
||||||
expect(response.parsed_body["raw_request_payload"]).to eq("request")
|
expect(response.parsed_body["raw_request_payload"]).to eq("request")
|
||||||
expect(response.parsed_body["raw_response_payload"]).to eq("response")
|
expect(response.parsed_body["raw_response_payload"]).to eq("response")
|
||||||
|
|
||||||
post2 = Fabricate(:post, topic: pm_topic)
|
|
||||||
|
|
||||||
# return previous post if current has no debug info
|
# return previous post if current has no debug info
|
||||||
get "/discourse-ai/ai-bot/post/#{post2.id}/show-debug-info"
|
get "/discourse-ai/ai-bot/post/#{pm_post3.id}/show-debug-info"
|
||||||
expect(response.status).to eq(200)
|
expect(response.status).to eq(200)
|
||||||
expect(response.parsed_body["request_tokens"]).to eq(1)
|
expect(response.parsed_body["request_tokens"]).to eq(1)
|
||||||
expect(response.parsed_body["response_tokens"]).to eq(2)
|
expect(response.parsed_body["response_tokens"]).to eq(2)
|
||||||
|
|
||||||
|
# can return debug info by id as well
|
||||||
|
get "/discourse-ai/ai-bot/show-debug-info/#{log1.id}"
|
||||||
|
expect(response.status).to eq(200)
|
||||||
|
expect(response.parsed_body["id"]).to eq(log1.id)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue