Sam e817b7dc11
FEATURE: improve tool support (#904)
This re-implements tool support in DiscourseAi::Completions::Llm #generate

Previously tool support was always returned via XML and it would be the responsibility of the caller to parse XML

New implementation has the endpoints return ToolCall objects.

Additionally this simplifies the Llm endpoint interface and gives it more clarity. Llms must implement

decode, decode_chunk (for streaming)

It is the implementers responsibility to figure out how to decode chunks, base no longer implements. To make this easy we ship a flexible json decoder which is easy to wire up.

Also (new)

    Better debugging for PMs, we now have a next / previous button to see all the Llm messages associated with a PM
    Token accounting is fixed for vllm (we were not correctly counting tokens)
2024-11-12 08:14:30 +11:00

319 lines
9.5 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class Base
CompletionFailed = Class.new(StandardError)
TIMEOUT = 60
class << self
def endpoint_for(provider_name)
endpoints = [
DiscourseAi::Completions::Endpoints::AwsBedrock,
DiscourseAi::Completions::Endpoints::OpenAi,
DiscourseAi::Completions::Endpoints::HuggingFace,
DiscourseAi::Completions::Endpoints::Gemini,
DiscourseAi::Completions::Endpoints::Vllm,
DiscourseAi::Completions::Endpoints::Anthropic,
DiscourseAi::Completions::Endpoints::Cohere,
DiscourseAi::Completions::Endpoints::SambaNova,
]
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
if Rails.env.test? || Rails.env.development?
endpoints << DiscourseAi::Completions::Endpoints::Fake
end
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
ek.can_contact?(provider_name)
end
end
def can_contact?(_model_provider)
raise NotImplementedError
end
end
def initialize(llm_model)
@llm_model = llm_model
end
def use_ssl?
if model_uri&.scheme.present?
model_uri.scheme == "https"
else
true
end
end
def xml_tags_to_strip(dialect)
[]
end
def perform_completion!(
dialect,
user,
model_params = {},
feature_name: nil,
feature_context: nil,
&blk
)
model_params = normalize_model_params(model_params)
orig_blk = blk
@streaming_mode = block_given?
prompt = dialect.translate
FinalDestination::HTTP.start(
model_uri.host,
model_uri.port,
use_ssl: use_ssl?,
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) do |http|
response_data = +""
response_raw = +""
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
partials_raw = +""
request_body = prepare_payload(prompt, model_params, dialect).to_json
request = prepare_request(request_body)
http.request(request) do |response|
if response.code.to_i != 200
Rails.logger.error(
"#{self.class.name}: status: #{response.code.to_i} - body: #{response.body}",
)
raise CompletionFailed, response.body
end
xml_tool_processor = XmlToolProcessor.new if xml_tools_enabled? &&
dialect.prompt.has_tools?
to_strip = xml_tags_to_strip(dialect)
xml_stripper =
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
if @streaming_mode && xml_stripper
blk =
lambda do |partial, cancel|
partial = xml_stripper << partial if partial.is_a?(String)
orig_blk.call(partial, cancel) if partial
end
end
log =
start_log(
provider_id: provider_id,
request_body: request_body,
dialect: dialect,
prompt: prompt,
user: user,
feature_name: feature_name,
feature_context: feature_context,
)
if !@streaming_mode
return(
non_streaming_response(
response: response,
xml_tool_processor: xml_tool_processor,
xml_stripper: xml_stripper,
partials_raw: partials_raw,
response_raw: response_raw,
)
)
end
begin
cancelled = false
cancel = -> { cancelled = true }
if cancelled
http.finish
break
end
response.read_body do |chunk|
response_raw << chunk
decode_chunk(chunk).each do |partial|
partials_raw << partial.to_s
response_data << partial if partial.is_a?(String)
partials = [partial]
if xml_tool_processor && partial.is_a?(String)
partials = (xml_tool_processor << partial)
if xml_tool_processor.should_cancel?
cancel.call
break
end
end
partials.each { |inner_partial| blk.call(inner_partial, cancel) }
end
end
rescue IOError, StandardError
raise if !cancelled
end
if xml_stripper
stripped = xml_stripper.finish
if stripped.present?
response_data << stripped
result = []
result = (xml_tool_processor << stripped) if xml_tool_processor
result.each { |partial| blk.call(partial, cancel) }
end
end
if xml_tool_processor
xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) }
end
decode_chunk_finish.each { |partial| blk.call(partial, cancel) }
return response_data
ensure
if log
log.raw_response_payload = response_raw
final_log_update(log)
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
log.save!
if Rails.env.development?
puts "#{self.class.name}: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
end
end
end
end
end
def final_log_update(log)
# for people that need to override
end
def default_options
raise NotImplementedError
end
def provider_id
raise NotImplementedError
end
def prompt_size(prompt)
tokenizer.size(extract_prompt_for_tokenizer(prompt))
end
attr_reader :llm_model
protected
def tokenizer
llm_model.tokenizer_class
end
# should normalize temperature, max_tokens, stop_words to endpoint specific values
def normalize_model_params(model_params)
raise NotImplementedError
end
def model_uri
raise NotImplementedError
end
def prepare_payload(_prompt, _model_params)
raise NotImplementedError
end
def prepare_request(_payload)
raise NotImplementedError
end
def decode(_response_raw)
raise NotImplementedError
end
def decode_chunk_finish
[]
end
def decode_chunk(_chunk)
raise NotImplementedError
end
def extract_prompt_for_tokenizer(prompt)
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
end
def xml_tools_enabled?
raise NotImplementedError
end
private
def start_log(
provider_id:,
request_body:,
dialect:,
prompt:,
user:,
feature_name:,
feature_context:
)
AiApiAuditLog.new(
provider_id: provider_id,
user_id: user&.id,
raw_request_payload: request_body,
request_tokens: prompt_size(prompt),
topic_id: dialect.prompt.topic_id,
post_id: dialect.prompt.post_id,
feature_name: feature_name,
language_model: llm_model.name,
feature_context: feature_context.present? ? feature_context.as_json : nil,
)
end
def non_streaming_response(
response:,
xml_tool_processor:,
xml_stripper:,
partials_raw:,
response_raw:
)
response_raw << response.read_body
response_data = decode(response_raw)
response_data.each { |partial| partials_raw << partial.to_s }
if xml_tool_processor
response_data.each do |partial|
processed = (xml_tool_processor << partial)
processed << xml_tool_processor.finish
response_data = []
processed.flatten.compact.each { |inner| response_data << inner }
end
end
if xml_stripper
response_data.map! do |partial|
stripped = (xml_stripper << partial) if partial.is_a?(String)
if stripped.present?
stripped
else
partial
end
end
response_data << xml_stripper.finish
end
response_data.reject!(&:blank?)
# this is to keep stuff backwards compatible
response_data = response_data.first if response_data.length == 1
response_data
end
end
end
end
end