mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-02-27 22:09:23 +00:00
* DEV: AI bot migration to the Llm pattern. We added tool and conversation context support to the Llm service in discourse-ai#366, meaning we met all the conditions to migrate this module. This PR migrates to the new pattern, meaning adding a new bot now requires minimal effort as long as the service supports it. On top of this, we introduce the concept of a "Playground" to separate the PM-specific bits from the completion, allowing us to use the bot in other contexts like chat in the future. Commands are called tools, and we simplified all the placeholder logic to perform updates in a single place, making the flow more one-wayish. * Followup fixes based on testing * Cleanup unused inference code * FIX: text-based tools could be in the middle of a sentence * GPT-4-turbo support * Use new LLM API
297 lines
9.0 KiB
Ruby
297 lines
9.0 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Completions
|
|
module Endpoints
|
|
class Base
|
|
CompletionFailed = Class.new(StandardError)
|
|
TIMEOUT = 60
|
|
|
|
def self.endpoint_for(model_name)
|
|
# Order is important.
|
|
# Bedrock has priority over Anthropic if creadentials are present.
|
|
[
|
|
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
|
DiscourseAi::Completions::Endpoints::Anthropic,
|
|
DiscourseAi::Completions::Endpoints::OpenAi,
|
|
DiscourseAi::Completions::Endpoints::HuggingFace,
|
|
DiscourseAi::Completions::Endpoints::Gemini,
|
|
DiscourseAi::Completions::Endpoints::Vllm,
|
|
].detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
|
ek.can_contact?(model_name)
|
|
end
|
|
end
|
|
|
|
def self.can_contact?(_model_name)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def initialize(model_name, tokenizer)
|
|
@model = model_name
|
|
@tokenizer = tokenizer
|
|
end
|
|
|
|
def perform_completion!(dialect, user, model_params = {})
|
|
model_params = normalize_model_params(model_params)
|
|
|
|
@streaming_mode = block_given?
|
|
|
|
prompt = dialect.translate
|
|
|
|
Net::HTTP.start(
|
|
model_uri.host,
|
|
model_uri.port,
|
|
use_ssl: true,
|
|
read_timeout: TIMEOUT,
|
|
open_timeout: TIMEOUT,
|
|
write_timeout: TIMEOUT,
|
|
) do |http|
|
|
response_data = +""
|
|
response_raw = +""
|
|
|
|
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
|
|
partials_raw = +""
|
|
request_body = prepare_payload(prompt, model_params, dialect).to_json
|
|
|
|
request = prepare_request(request_body)
|
|
|
|
http.request(request) do |response|
|
|
if response.code.to_i != 200
|
|
Rails.logger.error(
|
|
"#{self.class.name}: status: #{response.code.to_i} - body: #{response.body}",
|
|
)
|
|
raise CompletionFailed
|
|
end
|
|
|
|
log =
|
|
AiApiAuditLog.new(
|
|
provider_id: provider_id,
|
|
user_id: user&.id,
|
|
raw_request_payload: request_body,
|
|
request_tokens: prompt_size(prompt),
|
|
)
|
|
|
|
if !@streaming_mode
|
|
response_raw = response.read_body
|
|
response_data = extract_completion_from(response_raw)
|
|
partials_raw = response_data.to_s
|
|
|
|
if has_tool?(response_data)
|
|
function_buffer = build_buffer # Nokogiri document
|
|
function_buffer = add_to_buffer(function_buffer, "", response_data)
|
|
|
|
response_data = +function_buffer.at("function_calls").to_s
|
|
response_data << "\n"
|
|
end
|
|
|
|
return response_data
|
|
end
|
|
|
|
has_tool = false
|
|
|
|
begin
|
|
cancelled = false
|
|
cancel = lambda { cancelled = true }
|
|
|
|
leftover = ""
|
|
function_buffer = build_buffer # Nokogiri document
|
|
prev_processed_partials = 0
|
|
|
|
response.read_body do |chunk|
|
|
if cancelled
|
|
http.finish
|
|
break
|
|
end
|
|
|
|
decoded_chunk = decode(chunk)
|
|
response_raw << decoded_chunk
|
|
|
|
redo_chunk = leftover + decoded_chunk
|
|
|
|
raw_partials = partials_from(redo_chunk)
|
|
|
|
raw_partials =
|
|
raw_partials[prev_processed_partials..-1] if prev_processed_partials > 0
|
|
|
|
if raw_partials.blank? || (raw_partials.size == 1 && raw_partials.first.blank?)
|
|
leftover = redo_chunk
|
|
next
|
|
end
|
|
|
|
json_error = false
|
|
|
|
raw_partials.each do |raw_partial|
|
|
json_error = false
|
|
prev_processed_partials += 1
|
|
|
|
next if cancelled
|
|
next if raw_partial.blank?
|
|
|
|
begin
|
|
partial = extract_completion_from(raw_partial)
|
|
next if response_data.empty? && partial.blank?
|
|
next if partial.nil?
|
|
partials_raw << partial.to_s
|
|
|
|
# Stop streaming the response as soon as you find a tool.
|
|
# We'll buffer and yield it later.
|
|
has_tool = true if has_tool?(partials_raw)
|
|
|
|
if has_tool
|
|
function_buffer = add_to_buffer(function_buffer, partials_raw, partial)
|
|
else
|
|
response_data << partial
|
|
|
|
yield partial, cancel if partial
|
|
end
|
|
rescue JSON::ParserError
|
|
leftover = redo_chunk
|
|
json_error = true
|
|
end
|
|
end
|
|
|
|
if json_error
|
|
prev_processed_partials -= 1
|
|
else
|
|
leftover = ""
|
|
end
|
|
prev_processed_partials = 0 if leftover.blank?
|
|
end
|
|
rescue IOError, StandardError
|
|
raise if !cancelled
|
|
end
|
|
|
|
# Once we have the full response, try to return the tool as a XML doc.
|
|
if has_tool
|
|
if function_buffer.at("tool_name").text.present?
|
|
invocation = +function_buffer.at("function_calls").to_s
|
|
invocation << "\n"
|
|
|
|
response_data << invocation
|
|
yield invocation, cancel
|
|
end
|
|
end
|
|
|
|
return response_data
|
|
ensure
|
|
if log
|
|
log.raw_response_payload = response_raw
|
|
log.response_tokens = tokenizer.size(partials_raw)
|
|
log.save!
|
|
|
|
if Rails.env.development?
|
|
puts "#{self.class.name}: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
def default_options
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def provider_id
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prompt_size(prompt)
|
|
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
|
end
|
|
|
|
attr_reader :tokenizer
|
|
|
|
protected
|
|
|
|
attr_reader :model
|
|
|
|
# should normalize temperature, max_tokens, stop_words to endpoint specific values
|
|
def normalize_model_params(model_params)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def model_uri
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prepare_payload(_prompt, _model_params)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prepare_request(_payload)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def extract_completion_from(_response_raw)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def decode(chunk)
|
|
chunk
|
|
end
|
|
|
|
def partials_from(_decoded_chunk)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def extract_prompt_for_tokenizer(prompt)
|
|
prompt
|
|
end
|
|
|
|
def build_buffer
|
|
Nokogiri::HTML5.fragment(<<~TEXT)
|
|
<function_calls>
|
|
<invoke>
|
|
<tool_name></tool_name>
|
|
<tool_id></tool_id>
|
|
<parameters>
|
|
</parameters>
|
|
</invoke>
|
|
</function_calls>
|
|
TEXT
|
|
end
|
|
|
|
def has_tool?(response)
|
|
response.include?("<function")
|
|
end
|
|
|
|
def add_to_buffer(function_buffer, response_data, partial)
|
|
raw_data = (response_data + partial)
|
|
|
|
# recover stop word potentially
|
|
raw_data =
|
|
raw_data.split("</invoke>").first + "</invoke>\n</function_calls>" if raw_data.split(
|
|
"</invoke>",
|
|
).length > 1
|
|
|
|
return function_buffer unless raw_data.include?("</invoke>")
|
|
|
|
read_function = Nokogiri::HTML5.fragment(raw_data)
|
|
|
|
if tool_name = read_function.at("tool_name")&.text
|
|
function_buffer.at("tool_name").inner_html = tool_name
|
|
function_buffer.at("tool_id").inner_html = tool_name
|
|
end
|
|
|
|
_read_parameters =
|
|
read_function
|
|
.at("parameters")
|
|
&.elements
|
|
.to_a
|
|
.each do |elem|
|
|
if paramenter = function_buffer.at(elem.name)&.text
|
|
function_buffer.at(elem.name).inner_html = paramenter
|
|
else
|
|
param_node = read_function.at(elem.name)
|
|
function_buffer.at("parameters").add_child(param_node)
|
|
function_buffer.at("parameters").add_child("\n")
|
|
end
|
|
end
|
|
|
|
function_buffer
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|