mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-03-09 11:48:47 +00:00
Implement streaming tool call implementation for Anthropic and Open AI. When calling: llm.generate(..., partial_tool_calls: true) do ... Partials may contain ToolCall instances with partial: true, These tool calls are partially populated with json partially parsed. So for example when performing a search you may get: ToolCall(..., {search: "hello" }) ToolCall(..., {search: "hello world" }) The library used to parse json is: https://github.com/dgraham/json-stream We use a fork cause we need access to the internal buffer. This prepares internals to perform partial tool calls, but does not implement it yet.
323 lines
9.6 KiB
Ruby
323 lines
9.6 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Completions
|
|
module Endpoints
|
|
class Base
|
|
attr_reader :partial_tool_calls
|
|
|
|
CompletionFailed = Class.new(StandardError)
|
|
TIMEOUT = 60
|
|
|
|
class << self
|
|
def endpoint_for(provider_name)
|
|
endpoints = [
|
|
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
|
DiscourseAi::Completions::Endpoints::OpenAi,
|
|
DiscourseAi::Completions::Endpoints::HuggingFace,
|
|
DiscourseAi::Completions::Endpoints::Gemini,
|
|
DiscourseAi::Completions::Endpoints::Vllm,
|
|
DiscourseAi::Completions::Endpoints::Anthropic,
|
|
DiscourseAi::Completions::Endpoints::Cohere,
|
|
DiscourseAi::Completions::Endpoints::SambaNova,
|
|
]
|
|
|
|
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
|
|
|
|
if Rails.env.test? || Rails.env.development?
|
|
endpoints << DiscourseAi::Completions::Endpoints::Fake
|
|
end
|
|
|
|
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
|
ek.can_contact?(provider_name)
|
|
end
|
|
end
|
|
|
|
def can_contact?(_model_provider)
|
|
raise NotImplementedError
|
|
end
|
|
end
|
|
|
|
def initialize(llm_model)
|
|
@llm_model = llm_model
|
|
end
|
|
|
|
def use_ssl?
|
|
if model_uri&.scheme.present?
|
|
model_uri.scheme == "https"
|
|
else
|
|
true
|
|
end
|
|
end
|
|
|
|
def xml_tags_to_strip(dialect)
|
|
[]
|
|
end
|
|
|
|
def perform_completion!(
|
|
dialect,
|
|
user,
|
|
model_params = {},
|
|
feature_name: nil,
|
|
feature_context: nil,
|
|
partial_tool_calls: false,
|
|
&blk
|
|
)
|
|
@partial_tool_calls = partial_tool_calls
|
|
model_params = normalize_model_params(model_params)
|
|
orig_blk = blk
|
|
|
|
@streaming_mode = block_given?
|
|
|
|
prompt = dialect.translate
|
|
|
|
FinalDestination::HTTP.start(
|
|
model_uri.host,
|
|
model_uri.port,
|
|
use_ssl: use_ssl?,
|
|
read_timeout: TIMEOUT,
|
|
open_timeout: TIMEOUT,
|
|
write_timeout: TIMEOUT,
|
|
) do |http|
|
|
response_data = +""
|
|
response_raw = +""
|
|
|
|
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
|
|
partials_raw = +""
|
|
request_body = prepare_payload(prompt, model_params, dialect).to_json
|
|
|
|
request = prepare_request(request_body)
|
|
|
|
http.request(request) do |response|
|
|
if response.code.to_i != 200
|
|
Rails.logger.error(
|
|
"#{self.class.name}: status: #{response.code.to_i} - body: #{response.body}",
|
|
)
|
|
raise CompletionFailed, response.body
|
|
end
|
|
|
|
xml_tool_processor = XmlToolProcessor.new if xml_tools_enabled? &&
|
|
dialect.prompt.has_tools?
|
|
|
|
to_strip = xml_tags_to_strip(dialect)
|
|
xml_stripper =
|
|
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
|
|
|
|
if @streaming_mode && xml_stripper
|
|
blk =
|
|
lambda do |partial, cancel|
|
|
partial = xml_stripper << partial if partial.is_a?(String)
|
|
orig_blk.call(partial, cancel) if partial
|
|
end
|
|
end
|
|
|
|
log =
|
|
start_log(
|
|
provider_id: provider_id,
|
|
request_body: request_body,
|
|
dialect: dialect,
|
|
prompt: prompt,
|
|
user: user,
|
|
feature_name: feature_name,
|
|
feature_context: feature_context,
|
|
)
|
|
|
|
if !@streaming_mode
|
|
return(
|
|
non_streaming_response(
|
|
response: response,
|
|
xml_tool_processor: xml_tool_processor,
|
|
xml_stripper: xml_stripper,
|
|
partials_raw: partials_raw,
|
|
response_raw: response_raw,
|
|
)
|
|
)
|
|
end
|
|
|
|
begin
|
|
cancelled = false
|
|
cancel = -> { cancelled = true }
|
|
if cancelled
|
|
http.finish
|
|
break
|
|
end
|
|
|
|
response.read_body do |chunk|
|
|
response_raw << chunk
|
|
decode_chunk(chunk).each do |partial|
|
|
partials_raw << partial.to_s
|
|
response_data << partial if partial.is_a?(String)
|
|
partials = [partial]
|
|
if xml_tool_processor && partial.is_a?(String)
|
|
partials = (xml_tool_processor << partial)
|
|
if xml_tool_processor.should_cancel?
|
|
cancel.call
|
|
break
|
|
end
|
|
end
|
|
partials.each { |inner_partial| blk.call(inner_partial, cancel) }
|
|
end
|
|
end
|
|
rescue IOError, StandardError
|
|
raise if !cancelled
|
|
end
|
|
if xml_stripper
|
|
stripped = xml_stripper.finish
|
|
if stripped.present?
|
|
response_data << stripped
|
|
result = []
|
|
result = (xml_tool_processor << stripped) if xml_tool_processor
|
|
result.each { |partial| blk.call(partial, cancel) }
|
|
end
|
|
end
|
|
if xml_tool_processor
|
|
xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) }
|
|
end
|
|
decode_chunk_finish.each { |partial| blk.call(partial, cancel) }
|
|
return response_data
|
|
ensure
|
|
if log
|
|
log.raw_response_payload = response_raw
|
|
final_log_update(log)
|
|
|
|
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
|
|
log.save!
|
|
|
|
if Rails.env.development?
|
|
puts "#{self.class.name}: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
def final_log_update(log)
|
|
# for people that need to override
|
|
end
|
|
|
|
def default_options
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def provider_id
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prompt_size(prompt)
|
|
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
|
end
|
|
|
|
attr_reader :llm_model
|
|
|
|
protected
|
|
|
|
def tokenizer
|
|
llm_model.tokenizer_class
|
|
end
|
|
|
|
# should normalize temperature, max_tokens, stop_words to endpoint specific values
|
|
def normalize_model_params(model_params)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def model_uri
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prepare_payload(_prompt, _model_params)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prepare_request(_payload)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def decode(_response_raw)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def decode_chunk_finish
|
|
[]
|
|
end
|
|
|
|
def decode_chunk(_chunk)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def extract_prompt_for_tokenizer(prompt)
|
|
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
|
end
|
|
|
|
def xml_tools_enabled?
|
|
raise NotImplementedError
|
|
end
|
|
|
|
private
|
|
|
|
def start_log(
|
|
provider_id:,
|
|
request_body:,
|
|
dialect:,
|
|
prompt:,
|
|
user:,
|
|
feature_name:,
|
|
feature_context:
|
|
)
|
|
AiApiAuditLog.new(
|
|
provider_id: provider_id,
|
|
user_id: user&.id,
|
|
raw_request_payload: request_body,
|
|
request_tokens: prompt_size(prompt),
|
|
topic_id: dialect.prompt.topic_id,
|
|
post_id: dialect.prompt.post_id,
|
|
feature_name: feature_name,
|
|
language_model: llm_model.name,
|
|
feature_context: feature_context.present? ? feature_context.as_json : nil,
|
|
)
|
|
end
|
|
|
|
def non_streaming_response(
|
|
response:,
|
|
xml_tool_processor:,
|
|
xml_stripper:,
|
|
partials_raw:,
|
|
response_raw:
|
|
)
|
|
response_raw << response.read_body
|
|
response_data = decode(response_raw)
|
|
|
|
response_data.each { |partial| partials_raw << partial.to_s }
|
|
|
|
if xml_tool_processor
|
|
response_data.each do |partial|
|
|
processed = (xml_tool_processor << partial)
|
|
processed << xml_tool_processor.finish
|
|
response_data = []
|
|
processed.flatten.compact.each { |inner| response_data << inner }
|
|
end
|
|
end
|
|
|
|
if xml_stripper
|
|
response_data.map! do |partial|
|
|
stripped = (xml_stripper << partial) if partial.is_a?(String)
|
|
if stripped.present?
|
|
stripped
|
|
else
|
|
partial
|
|
end
|
|
end
|
|
response_data << xml_stripper.finish
|
|
end
|
|
|
|
response_data.reject!(&:blank?)
|
|
|
|
# this is to keep stuff backwards compatible
|
|
response_data = response_data.first if response_data.length == 1
|
|
|
|
response_data
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|