mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-02-15 16:04:48 +00:00
* FIX: Add tool support to open ai compatible dialect and vllm Automatic tools are in progress in vllm see: https://github.com/vllm-project/vllm/pull/5649 Even when they are supported, initial support will be uneven, only some models have native tool support notably mistral which has some special tokens for tool support. After the above PR lands in vllm we will still need to swap to XML based tools on models without native tool support. * fix specs
82 lines
2.2 KiB
Ruby
82 lines
2.2 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module DiscourseAi
|
|
module Completions
|
|
module Endpoints
|
|
class Vllm < Base
|
|
def self.can_contact?(model_provider)
|
|
model_provider == "vllm"
|
|
end
|
|
|
|
def normalize_model_params(model_params)
|
|
model_params = model_params.dup
|
|
|
|
# max_tokens, temperature are already supported
|
|
if model_params[:stop_sequences]
|
|
model_params[:stop] = model_params.delete(:stop_sequences)
|
|
end
|
|
|
|
model_params
|
|
end
|
|
|
|
def default_options
|
|
{ max_tokens: 2000, model: llm_model.name }
|
|
end
|
|
|
|
def provider_id
|
|
AiApiAuditLog::Provider::Vllm
|
|
end
|
|
|
|
private
|
|
|
|
def model_uri
|
|
if llm_model.url.to_s.starts_with?("srv://")
|
|
service = DiscourseAi::Utils::DnsSrv.lookup(llm_model.url.sub("srv://", ""))
|
|
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
|
|
else
|
|
api_endpoint = llm_model.url
|
|
end
|
|
|
|
@uri ||= URI(api_endpoint)
|
|
end
|
|
|
|
def prepare_payload(prompt, model_params, dialect)
|
|
payload = default_options.merge(model_params).merge(messages: prompt)
|
|
payload[:stream] = true if @streaming_mode
|
|
|
|
payload
|
|
end
|
|
|
|
def prepare_request(payload)
|
|
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
|
|
|
api_key = llm_model&.api_key || SiteSetting.ai_vllm_api_key
|
|
headers["X-API-KEY"] = api_key if api_key.present?
|
|
|
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
|
end
|
|
|
|
def partials_from(decoded_chunk)
|
|
decoded_chunk
|
|
.split("\n")
|
|
.map do |line|
|
|
data = line.split("data: ", 2)[1]
|
|
data == "[DONE]" ? nil : data
|
|
end
|
|
.compact
|
|
end
|
|
|
|
def extract_completion_from(response_raw)
|
|
parsed = JSON.parse(response_raw, symbolize_names: true).dig(:choices, 0)
|
|
# half a line sent here
|
|
return if !parsed
|
|
|
|
response_h = @streaming_mode ? parsed.dig(:delta) : parsed.dig(:message)
|
|
|
|
response_h.dig(:content)
|
|
end
|
|
end
|
|
end
|
|
end
|
|
end
|