Sam 20612fde52
FEATURE: add the ability to disable streaming on an Open AI LLM
Disabling streaming is required for models such o1 that do not have streaming
enabled yet

It is good to carry this feature around in case various apis decide not to support streaming endpoints and Discourse AI can continue to work just as it did before. 

Also: fixes issue where sharing artifacts would miss viewport leading to tiny artifacts on mobile
2025-01-13 17:01:01 +11:00

142 lines
4.1 KiB
Ruby

# frozen_string_literal: true
module DiscourseAi
module Completions
module Endpoints
class OpenAi < Base
def self.can_contact?(model_provider)
%w[open_ai azure].include?(model_provider)
end
def normalize_model_params(model_params)
model_params = model_params.dup
# max_tokens, temperature are already supported
if model_params[:stop_sequences]
model_params[:stop] = model_params.delete(:stop_sequences)
end
model_params
end
def default_options
{ model: llm_model.name }
end
def provider_id
AiApiAuditLog::Provider::OpenAI
end
def perform_completion!(
dialect,
user,
model_params = {},
feature_name: nil,
feature_context: nil,
partial_tool_calls: false,
&blk
)
@disable_native_tools = dialect.disable_native_tools?
super
end
private
def disable_streaming?
@disable_streaming = llm_model.lookup_custom_param("disable_streaming")
end
def model_uri
if llm_model.url.to_s.starts_with?("srv://")
service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
else
api_endpoint = llm_model.url
end
@uri ||= URI(api_endpoint)
end
def prepare_payload(prompt, model_params, dialect)
payload = default_options.merge(model_params).merge(messages: prompt)
if @streaming_mode
payload[:stream] = true
# Usage is not available in Azure yet.
# We'll fallback to guess this using the tokenizer.
payload[:stream_options] = { include_usage: true } if llm_model.provider == "open_ai"
end
if !xml_tools_enabled?
if dialect.tools.present?
payload[:tools] = dialect.tools
if dialect.tool_choice.present?
payload[:tool_choice] = {
type: "function",
function: {
name: dialect.tool_choice,
},
}
end
end
end
payload
end
def prepare_request(payload)
headers = { "Content-Type" => "application/json" }
api_key = llm_model.api_key
if llm_model.provider == "azure"
headers["api-key"] = api_key
else
headers["Authorization"] = "Bearer #{api_key}"
org_id = llm_model.lookup_custom_param("organization")
headers["OpenAI-Organization"] = org_id if org_id.present?
end
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end
def final_log_update(log)
log.request_tokens = processor.prompt_tokens if processor.prompt_tokens
log.response_tokens = processor.completion_tokens if processor.completion_tokens
log.cached_tokens = processor.cached_tokens if processor.cached_tokens
end
def decode(response_raw)
processor.process_message(JSON.parse(response_raw, symbolize_names: true))
end
def decode_chunk(chunk)
@decoder ||= JsonStreamDecoder.new
elements =
(@decoder << chunk)
.map { |parsed_json| processor.process_streamed_message(parsed_json) }
.flatten
.compact
# Remove duplicate partial tool calls
# sometimes we stream weird chunks
seen_tools = Set.new
elements.select { |item| !item.is_a?(ToolCall) || seen_tools.add?(item) }
end
def decode_chunk_finish
@processor.finish
end
def xml_tools_enabled?
!!@disable_native_tools
end
private
def processor
@processor ||= OpenAiMessageProcessor.new(partial_tool_calls: partial_tool_calls)
end
end
end
end
end