mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-07-08 23:32:45 +00:00
Adds a comprehensive quota management system for LLM models that allows: - Setting per-group (applied per user in the group) token and usage limits with configurable durations - Tracking and enforcing token/usage limits across user groups - Quota reset periods (hourly, daily, weekly, or custom) - Admin UI for managing quotas with real-time updates This system provides granular control over LLM API usage by allowing admins to define limits on both total tokens and number of requests per group. Supports multiple concurrent quotas per model and automatically handles quota resets. Co-authored-by: Keegan George <kgeorge13@gmail.com>
358 lines
11 KiB
Ruby
358 lines
11 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Completions
|
|
module Endpoints
|
|
class Base
|
|
attr_reader :partial_tool_calls
|
|
|
|
CompletionFailed = Class.new(StandardError)
|
|
TIMEOUT = 60
|
|
|
|
class << self
|
|
def endpoint_for(provider_name)
|
|
endpoints = [
|
|
DiscourseAi::Completions::Endpoints::AwsBedrock,
|
|
DiscourseAi::Completions::Endpoints::OpenAi,
|
|
DiscourseAi::Completions::Endpoints::HuggingFace,
|
|
DiscourseAi::Completions::Endpoints::Gemini,
|
|
DiscourseAi::Completions::Endpoints::Vllm,
|
|
DiscourseAi::Completions::Endpoints::Anthropic,
|
|
DiscourseAi::Completions::Endpoints::Cohere,
|
|
DiscourseAi::Completions::Endpoints::SambaNova,
|
|
DiscourseAi::Completions::Endpoints::Mistral,
|
|
DiscourseAi::Completions::Endpoints::OpenRouter,
|
|
]
|
|
|
|
endpoints << DiscourseAi::Completions::Endpoints::Ollama if Rails.env.development?
|
|
|
|
if Rails.env.test? || Rails.env.development?
|
|
endpoints << DiscourseAi::Completions::Endpoints::Fake
|
|
end
|
|
|
|
endpoints.detect(-> { raise DiscourseAi::Completions::Llm::UNKNOWN_MODEL }) do |ek|
|
|
ek.can_contact?(provider_name)
|
|
end
|
|
end
|
|
|
|
def can_contact?(_model_provider)
|
|
raise NotImplementedError
|
|
end
|
|
end
|
|
|
|
def initialize(llm_model)
|
|
@llm_model = llm_model
|
|
end
|
|
|
|
def use_ssl?
|
|
if model_uri&.scheme.present?
|
|
model_uri.scheme == "https"
|
|
else
|
|
true
|
|
end
|
|
end
|
|
|
|
def xml_tags_to_strip(dialect)
|
|
[]
|
|
end
|
|
|
|
def perform_completion!(
|
|
dialect,
|
|
user,
|
|
model_params = {},
|
|
feature_name: nil,
|
|
feature_context: nil,
|
|
partial_tool_calls: false,
|
|
&blk
|
|
)
|
|
LlmQuota.check_quotas!(@llm_model, user)
|
|
|
|
@partial_tool_calls = partial_tool_calls
|
|
model_params = normalize_model_params(model_params)
|
|
orig_blk = blk
|
|
|
|
if block_given? && disable_streaming?
|
|
result =
|
|
perform_completion!(
|
|
dialect,
|
|
user,
|
|
model_params,
|
|
feature_name: feature_name,
|
|
feature_context: feature_context,
|
|
partial_tool_calls: partial_tool_calls,
|
|
)
|
|
|
|
result = [result] if !result.is_a?(Array)
|
|
cancelled_by_caller = false
|
|
cancel_proc = -> { cancelled_by_caller = true }
|
|
result.each do |partial|
|
|
blk.call(partial, cancel_proc)
|
|
break if cancelled_by_caller
|
|
end
|
|
return result
|
|
end
|
|
|
|
@streaming_mode = block_given?
|
|
|
|
prompt = dialect.translate
|
|
|
|
FinalDestination::HTTP.start(
|
|
model_uri.host,
|
|
model_uri.port,
|
|
use_ssl: use_ssl?,
|
|
read_timeout: TIMEOUT,
|
|
open_timeout: TIMEOUT,
|
|
write_timeout: TIMEOUT,
|
|
) do |http|
|
|
response_data = +""
|
|
response_raw = +""
|
|
|
|
# Needed to response token calculations. Cannot rely on response_data due to function buffering.
|
|
partials_raw = +""
|
|
request_body = prepare_payload(prompt, model_params, dialect).to_json
|
|
|
|
request = prepare_request(request_body)
|
|
|
|
http.request(request) do |response|
|
|
if response.code.to_i != 200
|
|
Rails.logger.error(
|
|
"#{self.class.name}: status: #{response.code.to_i} - body: #{response.body}",
|
|
)
|
|
raise CompletionFailed, response.body
|
|
end
|
|
|
|
xml_tool_processor =
|
|
XmlToolProcessor.new(
|
|
partial_tool_calls: partial_tool_calls,
|
|
) if xml_tools_enabled? && dialect.prompt.has_tools?
|
|
|
|
to_strip = xml_tags_to_strip(dialect)
|
|
xml_stripper =
|
|
DiscourseAi::Completions::XmlTagStripper.new(to_strip) if to_strip.present?
|
|
|
|
if @streaming_mode && xml_stripper
|
|
blk =
|
|
lambda do |partial, cancel|
|
|
partial = xml_stripper << partial if partial.is_a?(String)
|
|
orig_blk.call(partial, cancel) if partial
|
|
end
|
|
end
|
|
|
|
log =
|
|
start_log(
|
|
provider_id: provider_id,
|
|
request_body: request_body,
|
|
dialect: dialect,
|
|
prompt: prompt,
|
|
user: user,
|
|
feature_name: feature_name,
|
|
feature_context: feature_context,
|
|
)
|
|
|
|
if !@streaming_mode
|
|
return(
|
|
non_streaming_response(
|
|
response: response,
|
|
xml_tool_processor: xml_tool_processor,
|
|
xml_stripper: xml_stripper,
|
|
partials_raw: partials_raw,
|
|
response_raw: response_raw,
|
|
)
|
|
)
|
|
end
|
|
|
|
begin
|
|
cancelled = false
|
|
cancel = -> do
|
|
cancelled = true
|
|
http.finish
|
|
end
|
|
|
|
break if cancelled
|
|
|
|
response.read_body do |chunk|
|
|
break if cancelled
|
|
|
|
response_raw << chunk
|
|
|
|
decode_chunk(chunk).each do |partial|
|
|
break if cancelled
|
|
partials_raw << partial.to_s
|
|
response_data << partial if partial.is_a?(String)
|
|
partials = [partial]
|
|
if xml_tool_processor && partial.is_a?(String)
|
|
partials = (xml_tool_processor << partial)
|
|
if xml_tool_processor.should_cancel?
|
|
cancel.call
|
|
break
|
|
end
|
|
end
|
|
partials.each { |inner_partial| blk.call(inner_partial, cancel) }
|
|
end
|
|
end
|
|
rescue IOError, StandardError
|
|
raise if !cancelled
|
|
end
|
|
if xml_stripper
|
|
stripped = xml_stripper.finish
|
|
if stripped.present?
|
|
response_data << stripped
|
|
result = []
|
|
result = (xml_tool_processor << stripped) if xml_tool_processor
|
|
result.each { |partial| blk.call(partial, cancel) }
|
|
end
|
|
end
|
|
if xml_tool_processor
|
|
xml_tool_processor.finish.each { |partial| blk.call(partial, cancel) }
|
|
end
|
|
decode_chunk_finish.each { |partial| blk.call(partial, cancel) }
|
|
return response_data
|
|
ensure
|
|
if log
|
|
log.raw_response_payload = response_raw
|
|
final_log_update(log)
|
|
log.response_tokens = tokenizer.size(partials_raw) if log.response_tokens.blank?
|
|
log.save!
|
|
LlmQuota.log_usage(@llm_model, user, log.request_tokens, log.response_tokens)
|
|
if Rails.env.development?
|
|
puts "#{self.class.name}: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}"
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|
|
|
|
def final_log_update(log)
|
|
# for people that need to override
|
|
end
|
|
|
|
def default_options
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def provider_id
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prompt_size(prompt)
|
|
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
|
end
|
|
|
|
attr_reader :llm_model
|
|
|
|
protected
|
|
|
|
def tokenizer
|
|
llm_model.tokenizer_class
|
|
end
|
|
|
|
# should normalize temperature, max_tokens, stop_words to endpoint specific values
|
|
def normalize_model_params(model_params)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def model_uri
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prepare_payload(_prompt, _model_params)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def prepare_request(_payload)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def decode(_response_raw)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def decode_chunk_finish
|
|
[]
|
|
end
|
|
|
|
def decode_chunk(_chunk)
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def extract_prompt_for_tokenizer(prompt)
|
|
prompt.map { |message| message[:content] || message["content"] || "" }.join("\n")
|
|
end
|
|
|
|
def xml_tools_enabled?
|
|
raise NotImplementedError
|
|
end
|
|
|
|
def disable_streaming?
|
|
@disable_streaming = !!llm_model.lookup_custom_param("disable_streaming")
|
|
end
|
|
|
|
private
|
|
|
|
def start_log(
|
|
provider_id:,
|
|
request_body:,
|
|
dialect:,
|
|
prompt:,
|
|
user:,
|
|
feature_name:,
|
|
feature_context:
|
|
)
|
|
AiApiAuditLog.new(
|
|
provider_id: provider_id,
|
|
user_id: user&.id,
|
|
raw_request_payload: request_body,
|
|
request_tokens: prompt_size(prompt),
|
|
topic_id: dialect.prompt.topic_id,
|
|
post_id: dialect.prompt.post_id,
|
|
feature_name: feature_name,
|
|
language_model: llm_model.name,
|
|
feature_context: feature_context.present? ? feature_context.as_json : nil,
|
|
)
|
|
end
|
|
|
|
def non_streaming_response(
|
|
response:,
|
|
xml_tool_processor:,
|
|
xml_stripper:,
|
|
partials_raw:,
|
|
response_raw:
|
|
)
|
|
response_raw << response.read_body
|
|
response_data = decode(response_raw)
|
|
|
|
response_data.each { |partial| partials_raw << partial.to_s }
|
|
|
|
if xml_tool_processor
|
|
response_data.each do |partial|
|
|
processed = (xml_tool_processor << partial)
|
|
processed << xml_tool_processor.finish
|
|
response_data = []
|
|
processed.flatten.compact.each { |inner| response_data << inner }
|
|
end
|
|
end
|
|
|
|
if xml_stripper
|
|
response_data.map! do |partial|
|
|
stripped = (xml_stripper << partial) if partial.is_a?(String)
|
|
if stripped.present?
|
|
stripped
|
|
else
|
|
partial
|
|
end
|
|
end
|
|
response_data << xml_stripper.finish
|
|
end
|
|
|
|
response_data.reject!(&:blank?)
|
|
|
|
# this is to keep stuff backwards compatible
|
|
response_data = response_data.first if response_data.length == 1
|
|
|
|
response_data
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|