2023-11-23 10:58:54 -05:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module DiscourseAi
|
|
|
|
module Completions
|
|
|
|
module Endpoints
|
|
|
|
class Base
|
|
|
|
CompletionFailed = Class.new(StandardError)
|
|
|
|
TIMEOUT = 60
|
|
|
|
|
|
|
|
def self.endpoint_for(model_name)
|
|
|
|
# Order is important.
|
2024-01-12 18:28:06 -05:00
|
|
|
# Bedrock has priority over Anthropic if credentials are present.
|
2024-01-10 23:56:40 -05:00
|
|
|
endpoints = [
|
2023-11-23 10:58:54 -05:00
|
|
|
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
|
|
|
DiscourseAi::Completions::Endpoints::Anthropic,
|
2023-11-28 23:17:46 -05:00
|
|
|
DiscourseAi::Completions::Endpoints::OpenAi,
|
|
|
|
DiscourseAi::Completions::Endpoints::HuggingFace,
|
2023-12-15 12:32:01 -05:00
|
|
|
DiscourseAi::Completions::Endpoints::Gemini,
|
2023-12-26 12:49:55 -05:00
|
|
|
DiscourseAi::Completions::Endpoints::Vllm,
|
2024-01-10 23:56:40 -05:00
|
|
|
]
|
|
|
|
|
|
|
|
if Rails.env.test? || Rails.env.development?
|
|
|
|
endpoints << DiscourseAi::Completions::Endpoints::Fake
|
|
|
|
end
|
|
|
|
|
|
|
|
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
2023-11-23 10:58:54 -05:00
|
|
|
ek.can_contact?(model_name)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
def self.can_contact?(_model_name)
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def initialize(model_name, tokenizer)
|
|
|
|
@model = model_name
|
|
|
|
@tokenizer = tokenizer
|
|
|
|
end
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
def perform_completion!(dialect, user, model_params = {})
|
2024-01-04 07:53:47 -05:00
|
|
|
model_params = normalize_model_params(model_params)
|
|
|
|
|
2023-11-23 10:58:54 -05:00
|
|
|
@streaming_mode = block_given?
|
|
|
|
|
2023-12-18 16:06:01 -05:00
|
|
|
prompt = dialect.translate
|
|
|
|
|
2023-11-23 10:58:54 -05:00
|
|
|
Net::HTTP.start(
|
|
|
|
model_uri.host,
|
|
|
|
model_uri.port,
|
|
|
|
use_ssl: true,
|
|
|
|
read_timeout: TIMEOUT,
|
|
|
|
open_timeout: TIMEOUT,
|
|
|
|
write_timeout: TIMEOUT,
|
|
|
|
) do |http|
|
|
|
|
response_data = +""
|
|
|
|
response_raw = +""
|
2023-12-18 16:06:01 -05:00
|
|
|
|
|
|
|
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
|
|
|
|
partials_raw = +""
|
|
|
|
request_body = prepare_payload(prompt, model_params, dialect).to_json
|
2023-11-23 10:58:54 -05:00
|
|
|
|
|
|
|
request = prepare_request(request_body)
|
|
|
|
|
|
|
|
http.request(request) do |response|
|
|
|
|
if response.code.to_i != 200
|
|
|
|
Rails.logger.error(
|
|
|
|
"#{self.class.name}: status: #{response.code.to_i} - body: #{response.body}",
|
|
|
|
)
|
|
|
|
raise CompletionFailed
|
|
|
|
end
|
|
|
|
|
|
|
|
log =
|
|
|
|
AiApiAuditLog.new(
|
|
|
|
provider_id: provider_id,
|
2023-11-28 10:52:22 -05:00
|
|
|
user_id: user&.id,
|
2023-11-23 10:58:54 -05:00
|
|
|
raw_request_payload: request_body,
|
|
|
|
request_tokens: prompt_size(prompt),
|
|
|
|
)
|
|
|
|
|
|
|
|
if !@streaming_mode
|
|
|
|
response_raw = response.read_body
|
|
|
|
response_data = extract_completion_from(response_raw)
|
2023-12-18 16:06:01 -05:00
|
|
|
partials_raw = response_data.to_s
|
|
|
|
|
2024-01-02 09:21:13 -05:00
|
|
|
if has_tool?(response_data)
|
2023-12-18 16:06:01 -05:00
|
|
|
function_buffer = build_buffer # Nokogiri document
|
|
|
|
function_buffer = add_to_buffer(function_buffer, "", response_data)
|
|
|
|
|
|
|
|
response_data = +function_buffer.at("function_calls").to_s
|
|
|
|
response_data << "\n"
|
|
|
|
end
|
2023-11-23 10:58:54 -05:00
|
|
|
|
|
|
|
return response_data
|
|
|
|
end
|
|
|
|
|
2024-01-04 08:44:07 -05:00
|
|
|
has_tool = false
|
|
|
|
|
2023-11-23 10:58:54 -05:00
|
|
|
begin
|
|
|
|
cancelled = false
|
|
|
|
cancel = lambda { cancelled = true }
|
|
|
|
|
|
|
|
leftover = ""
|
2023-12-18 16:06:01 -05:00
|
|
|
function_buffer = build_buffer # Nokogiri document
|
2023-12-20 12:28:05 -05:00
|
|
|
prev_processed_partials = 0
|
2023-11-23 10:58:54 -05:00
|
|
|
|
|
|
|
response.read_body do |chunk|
|
|
|
|
if cancelled
|
|
|
|
http.finish
|
2023-12-20 12:28:05 -05:00
|
|
|
break
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
decoded_chunk = decode(chunk)
|
|
|
|
response_raw << decoded_chunk
|
|
|
|
|
2023-12-20 12:28:05 -05:00
|
|
|
redo_chunk = leftover + decoded_chunk
|
|
|
|
|
|
|
|
raw_partials = partials_from(redo_chunk)
|
|
|
|
|
|
|
|
raw_partials =
|
|
|
|
raw_partials[prev_processed_partials..-1] if prev_processed_partials > 0
|
|
|
|
|
|
|
|
if raw_partials.blank? || (raw_partials.size == 1 && raw_partials.first.blank?)
|
|
|
|
leftover = redo_chunk
|
2023-12-18 16:06:01 -05:00
|
|
|
next
|
|
|
|
end
|
|
|
|
|
2023-12-20 12:28:05 -05:00
|
|
|
json_error = false
|
|
|
|
|
|
|
|
raw_partials.each do |raw_partial|
|
|
|
|
json_error = false
|
|
|
|
prev_processed_partials += 1
|
|
|
|
|
2023-11-23 10:58:54 -05:00
|
|
|
next if cancelled
|
|
|
|
next if raw_partial.blank?
|
|
|
|
|
|
|
|
begin
|
|
|
|
partial = extract_completion_from(raw_partial)
|
2024-01-02 09:21:13 -05:00
|
|
|
next if response_data.empty? && partial.blank?
|
2023-12-06 09:22:42 -05:00
|
|
|
next if partial.nil?
|
2024-01-04 08:44:07 -05:00
|
|
|
partials_raw << partial.to_s
|
|
|
|
|
|
|
|
# Stop streaming the response as soon as you find a tool.
|
|
|
|
# We'll buffer and yield it later.
|
|
|
|
has_tool = true if has_tool?(partials_raw)
|
2023-11-23 10:58:54 -05:00
|
|
|
|
2024-01-04 08:44:07 -05:00
|
|
|
if has_tool
|
2024-01-02 09:21:13 -05:00
|
|
|
function_buffer = add_to_buffer(function_buffer, partials_raw, partial)
|
2023-12-18 16:06:01 -05:00
|
|
|
else
|
|
|
|
response_data << partial
|
|
|
|
|
|
|
|
yield partial, cancel if partial
|
|
|
|
end
|
2023-11-23 10:58:54 -05:00
|
|
|
rescue JSON::ParserError
|
2023-12-20 12:28:05 -05:00
|
|
|
leftover = redo_chunk
|
|
|
|
json_error = true
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
end
|
2023-12-20 12:28:05 -05:00
|
|
|
|
|
|
|
if json_error
|
|
|
|
prev_processed_partials -= 1
|
|
|
|
else
|
|
|
|
leftover = ""
|
|
|
|
end
|
|
|
|
prev_processed_partials = 0 if leftover.blank?
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
rescue IOError, StandardError
|
|
|
|
raise if !cancelled
|
|
|
|
end
|
|
|
|
|
2024-01-02 09:21:13 -05:00
|
|
|
# Once we have the full response, try to return the tool as a XML doc.
|
2024-01-04 08:44:07 -05:00
|
|
|
if has_tool
|
2024-01-02 09:21:13 -05:00
|
|
|
if function_buffer.at("tool_name").text.present?
|
|
|
|
invocation = +function_buffer.at("function_calls").to_s
|
|
|
|
invocation << "\n"
|
|
|
|
|
|
|
|
response_data << invocation
|
|
|
|
yield invocation, cancel
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2023-11-23 10:58:54 -05:00
|
|
|
return response_data
|
|
|
|
ensure
|
2023-11-28 09:15:12 -05:00
|
|
|
if log
|
|
|
|
log.raw_response_payload = response_raw
|
2023-12-18 16:06:01 -05:00
|
|
|
log.response_tokens = tokenizer.size(partials_raw)
|
2023-11-28 09:15:12 -05:00
|
|
|
log.save!
|
2023-11-23 10:58:54 -05:00
|
|
|
|
2023-11-28 09:15:12 -05:00
|
|
|
if Rails.env.development?
|
|
|
|
puts "#{self.class.name}: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
|
|
|
end
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
|
|
|
def default_options
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def provider_id
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def prompt_size(prompt)
|
|
|
|
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
|
|
|
end
|
|
|
|
|
2024-01-17 13:08:49 -05:00
|
|
|
attr_reader :tokenizer, :model
|
2023-11-23 10:58:54 -05:00
|
|
|
|
|
|
|
protected
|
|
|
|
|
2024-01-04 07:53:47 -05:00
|
|
|
# should normalize temperature, max_tokens, stop_words to endpoint specific values
|
|
|
|
def normalize_model_params(model_params)
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
2023-11-23 10:58:54 -05:00
|
|
|
def model_uri
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def prepare_payload(_prompt, _model_params)
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def prepare_request(_payload)
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def extract_completion_from(_response_raw)
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def decode(chunk)
|
|
|
|
chunk
|
|
|
|
end
|
|
|
|
|
|
|
|
def partials_from(_decoded_chunk)
|
|
|
|
raise NotImplementedError
|
|
|
|
end
|
|
|
|
|
|
|
|
def extract_prompt_for_tokenizer(prompt)
|
|
|
|
prompt
|
|
|
|
end
|
2023-12-18 16:06:01 -05:00
|
|
|
|
|
|
|
def build_buffer
|
|
|
|
Nokogiri::HTML5.fragment(<<~TEXT)
|
|
|
|
<function_calls>
|
|
|
|
<invoke>
|
|
|
|
<tool_name></tool_name>
|
|
|
|
<tool_id></tool_id>
|
2023-12-26 12:49:55 -05:00
|
|
|
<parameters>
|
|
|
|
</parameters>
|
2023-12-18 16:06:01 -05:00
|
|
|
</invoke>
|
|
|
|
</function_calls>
|
|
|
|
TEXT
|
|
|
|
end
|
|
|
|
|
2024-01-02 09:21:13 -05:00
|
|
|
def has_tool?(response)
|
|
|
|
response.include?("<function")
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
|
|
|
|
|
|
|
def add_to_buffer(function_buffer, response_data, partial)
|
2024-01-02 09:21:13 -05:00
|
|
|
raw_data = (response_data + partial)
|
|
|
|
|
|
|
|
# recover stop word potentially
|
|
|
|
raw_data =
|
|
|
|
raw_data.split("</invoke>").first + "</invoke>\n</function_calls>" if raw_data.split(
|
|
|
|
"</invoke>",
|
|
|
|
).length > 1
|
|
|
|
|
|
|
|
return function_buffer unless raw_data.include?("</invoke>")
|
|
|
|
|
|
|
|
read_function = Nokogiri::HTML5.fragment(raw_data)
|
2023-12-18 16:06:01 -05:00
|
|
|
|
2024-01-04 08:44:07 -05:00
|
|
|
if tool_name = read_function.at("tool_name")&.text
|
2023-12-26 12:49:55 -05:00
|
|
|
function_buffer.at("tool_name").inner_html = tool_name
|
|
|
|
function_buffer.at("tool_id").inner_html = tool_name
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
|
|
|
|
2024-01-04 16:15:34 -05:00
|
|
|
read_function
|
|
|
|
.at("parameters")
|
|
|
|
&.elements
|
|
|
|
.to_a
|
|
|
|
.each do |elem|
|
2024-01-12 18:28:06 -05:00
|
|
|
if parameter = function_buffer.at(elem.name)&.text
|
|
|
|
function_buffer.at(elem.name).inner_html = parameter
|
2024-01-04 16:15:34 -05:00
|
|
|
else
|
|
|
|
param_node = read_function.at(elem.name)
|
|
|
|
function_buffer.at("parameters").add_child(param_node)
|
|
|
|
function_buffer.at("parameters").add_child("\n")
|
2023-12-26 12:49:55 -05:00
|
|
|
end
|
2024-01-04 16:15:34 -05:00
|
|
|
end
|
2023-12-26 12:49:55 -05:00
|
|
|
|
|
|
|
function_buffer
|
2023-12-18 16:06:01 -05:00
|
|
|
end
|
2023-11-23 10:58:54 -05:00
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|