FEATURE: Set endpoint credentials directly from LlmModel. (#625)
* FEATURE: Set endpoint credentials directly from LlmModel. Drop Llama2Tokenizer since we no longer use it. * Allow http for custom LLMs --------- Co-authored-by: Rafael Silva <xfalcox@gmail.com>
This commit is contained in:
parent
255139056d
commit
1d786fbaaf
|
@ -58,6 +58,8 @@ module DiscourseAi
|
|||
:provider,
|
||||
:tokenizer,
|
||||
:max_prompt_tokens,
|
||||
:url,
|
||||
:api_key,
|
||||
)
|
||||
end
|
||||
end
|
||||
|
|
|
@ -3,5 +3,5 @@
|
|||
class LlmModelSerializer < ApplicationSerializer
|
||||
root "llm"
|
||||
|
||||
attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer
|
||||
attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer, :api_key, :url
|
||||
end
|
||||
|
|
|
@ -3,11 +3,14 @@ import RestModel from "discourse/models/rest";
|
|||
export default class AiLlm extends RestModel {
|
||||
createProperties() {
|
||||
return this.getProperties(
|
||||
"id",
|
||||
"display_name",
|
||||
"name",
|
||||
"provider",
|
||||
"tokenizer",
|
||||
"max_prompt_tokens"
|
||||
"max_prompt_tokens",
|
||||
"url",
|
||||
"api_key"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -33,7 +33,9 @@ export default class AiLlmEditor extends Component {
|
|||
const isNew = this.args.model.isNew;
|
||||
|
||||
try {
|
||||
await this.args.model.save();
|
||||
const result = await this.args.model.save();
|
||||
|
||||
this.args.model.setProperties(result.responseJson.ai_persona);
|
||||
|
||||
if (isNew) {
|
||||
this.args.llms.addObject(this.args.model);
|
||||
|
@ -61,7 +63,7 @@ export default class AiLlmEditor extends Component {
|
|||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.llms.display_name"}}</label>
|
||||
<Input
|
||||
class="ai-llm-editor__display-name"
|
||||
class="ai-llm-editor-input ai-llm-editor__display-name"
|
||||
@type="text"
|
||||
@value={{@model.display_name}}
|
||||
/>
|
||||
|
@ -69,7 +71,7 @@ export default class AiLlmEditor extends Component {
|
|||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.llms.name"}}</label>
|
||||
<Input
|
||||
class="ai-llm-editor__name"
|
||||
class="ai-llm-editor-input ai-llm-editor__name"
|
||||
@type="text"
|
||||
@value={{@model.name}}
|
||||
/>
|
||||
|
@ -85,6 +87,22 @@ export default class AiLlmEditor extends Component {
|
|||
@content={{this.selectedProviders}}
|
||||
/>
|
||||
</div>
|
||||
<div class="control-group">
|
||||
<label>{{I18n.t "discourse_ai.llms.url"}}</label>
|
||||
<Input
|
||||
class="ai-llm-editor-input ai-llm-editor__url"
|
||||
@type="text"
|
||||
@value={{@model.url}}
|
||||
/>
|
||||
</div>
|
||||
<div class="control-group">
|
||||
<label>{{I18n.t "discourse_ai.llms.api_key"}}</label>
|
||||
<Input
|
||||
class="ai-llm-editor-input ai-llm-editor__api-key"
|
||||
@type="text"
|
||||
@value={{@model.api_key}}
|
||||
/>
|
||||
</div>
|
||||
<div class="control-group">
|
||||
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
|
||||
<ComboBox
|
||||
|
@ -96,7 +114,7 @@ export default class AiLlmEditor extends Component {
|
|||
<label>{{i18n "discourse_ai.llms.max_prompt_tokens"}}</label>
|
||||
<Input
|
||||
@type="number"
|
||||
class="ai-llm-editor__max-prompt-tokens"
|
||||
class="ai-llm-editor-input ai-llm-editor__max-prompt-tokens"
|
||||
step="any"
|
||||
min="0"
|
||||
lang="en"
|
||||
|
|
|
@ -29,4 +29,8 @@
|
|||
text-align: center;
|
||||
font-size: var(--font-up-1);
|
||||
}
|
||||
|
||||
.ai-llm-editor-input {
|
||||
width: 350px;
|
||||
}
|
||||
}
|
||||
|
|
|
@ -204,6 +204,8 @@ en:
|
|||
provider: "Service hosting the model:"
|
||||
tokenizer: "Tokenizer:"
|
||||
max_prompt_tokens: "Number of tokens for the prompt:"
|
||||
url: "URL of the service hosting the model:"
|
||||
api_key: "API Key of the service hosting the model:"
|
||||
save: "Save"
|
||||
saved: "LLM Model Saved"
|
||||
|
||||
|
|
|
@ -0,0 +1,8 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
class AddEndpointDataToLlmModel < ActiveRecord::Migration[7.0]
|
||||
def change
|
||||
add_column :llm_models, :url, :string
|
||||
add_column :llm_models, :api_key, :string
|
||||
end
|
||||
end
|
|
@ -8,14 +8,14 @@ module DiscourseAi
|
|||
def can_translate?(model_name)
|
||||
model_name.starts_with?("gpt-")
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
end
|
||||
|
||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||
|
||||
def tokenizer
|
||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
true
|
||||
end
|
||||
|
@ -30,7 +30,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||
|
||||
# provide a buffer of 120 tokens - our function counting is not
|
||||
# 100% accurate and getting numbers to align exactly is very hard
|
||||
|
@ -38,7 +38,7 @@ module DiscourseAi
|
|||
|
||||
if tools.present?
|
||||
# note this is about 100 tokens over, OpenAI have a more optimal representation
|
||||
@function_size ||= self.class.tokenizer.size(tools.to_json.to_s)
|
||||
@function_size ||= self.tokenizer.size(tools.to_json.to_s)
|
||||
buffer += @function_size
|
||||
end
|
||||
|
||||
|
@ -113,7 +113,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def calculate_message_token(context)
|
||||
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
end
|
||||
|
||||
def model_max_tokens
|
||||
|
|
|
@ -10,10 +10,6 @@ module DiscourseAi
|
|||
model_name,
|
||||
)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::AnthropicTokenizer
|
||||
end
|
||||
end
|
||||
|
||||
class ClaudePrompt
|
||||
|
@ -26,6 +22,10 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::AnthropicTokenizer
|
||||
end
|
||||
|
||||
def translate
|
||||
messages = super
|
||||
|
||||
|
@ -50,7 +50,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||
|
||||
# Longer term it will have over 1 million
|
||||
200_000 # Claude-3 has a 200k context window for now
|
||||
end
|
||||
|
|
|
@ -10,14 +10,14 @@ module DiscourseAi
|
|||
def can_translate?(model_name)
|
||||
%w[command-light command command-r command-r-plus].include?(model_name)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
end
|
||||
|
||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||
|
||||
def tokenizer
|
||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def translate
|
||||
messages = super
|
||||
|
||||
|
@ -38,7 +38,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||
|
||||
case model_name
|
||||
when "command-light"
|
||||
|
@ -61,7 +61,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def calculate_message_token(context)
|
||||
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
end
|
||||
|
||||
def system_msg(msg)
|
||||
|
|
|
@ -20,10 +20,6 @@ module DiscourseAi
|
|||
]
|
||||
end
|
||||
|
||||
def available_tokenizers
|
||||
all_dialects.map(&:tokenizer)
|
||||
end
|
||||
|
||||
def dialect_for(model_name)
|
||||
dialects = []
|
||||
|
||||
|
@ -38,20 +34,21 @@ module DiscourseAi
|
|||
|
||||
dialect
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
raise NotImplemented
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(generic_prompt, model_name, opts: {})
|
||||
def initialize(generic_prompt, model_name, opts: {}, llm_model: nil)
|
||||
@prompt = generic_prompt
|
||||
@model_name = model_name
|
||||
@opts = opts
|
||||
@llm_model = llm_model
|
||||
end
|
||||
|
||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||
|
||||
def tokenizer
|
||||
raise NotImplemented
|
||||
end
|
||||
|
||||
def can_end_with_assistant_msg?
|
||||
false
|
||||
end
|
||||
|
@ -88,7 +85,7 @@ module DiscourseAi
|
|||
|
||||
private
|
||||
|
||||
attr_reader :model_name, :opts
|
||||
attr_reader :model_name, :opts, :llm_model
|
||||
|
||||
def trim_messages(messages)
|
||||
prompt_limit = max_prompt_tokens
|
||||
|
@ -147,7 +144,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def calculate_message_token(msg)
|
||||
self.class.tokenizer.size(msg[:content].to_s)
|
||||
self.tokenizer.size(msg[:content].to_s)
|
||||
end
|
||||
|
||||
def tools_dialect
|
||||
|
|
|
@ -8,10 +8,10 @@ module DiscourseAi
|
|||
def can_translate?(model_name)
|
||||
model_name == "fake"
|
||||
end
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||
end
|
||||
|
||||
def translate
|
||||
|
|
|
@ -8,16 +8,16 @@ module DiscourseAi
|
|||
def can_translate?(model_name)
|
||||
%w[gemini-pro gemini-1.5-pro].include?(model_name)
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
||||
end
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
true
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
||||
end
|
||||
|
||||
def translate
|
||||
# Gemini complains if we don't alternate model/user roles.
|
||||
noop_model_response = { role: "model", parts: { text: "Ok." } }
|
||||
|
@ -68,7 +68,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||
|
||||
if model_name == "gemini-1.5-pro"
|
||||
# technically we support 1 million tokens, but we're being conservative
|
||||
|
@ -81,7 +81,7 @@ module DiscourseAi
|
|||
protected
|
||||
|
||||
def calculate_message_token(context)
|
||||
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||
end
|
||||
|
||||
def system_msg(msg)
|
||||
|
|
|
@ -12,10 +12,10 @@ module DiscourseAi
|
|||
mistral
|
||||
].include?(model_name)
|
||||
end
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::MixtralTokenizer
|
||||
end
|
||||
def tokenizer
|
||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::MixtralTokenizer
|
||||
end
|
||||
|
||||
def tools
|
||||
|
@ -23,7 +23,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||
|
||||
32_000
|
||||
end
|
||||
|
|
|
@ -8,10 +8,10 @@ module DiscourseAi
|
|||
def can_translate?(_model_name)
|
||||
true
|
||||
end
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
DiscourseAi::Tokenizer::Llama3Tokenizer
|
||||
end
|
||||
def tokenizer
|
||||
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::Llama3Tokenizer
|
||||
end
|
||||
|
||||
def tools
|
||||
|
@ -19,7 +19,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
||||
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||
|
||||
32_000
|
||||
end
|
||||
|
|
|
@ -63,7 +63,9 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def model_uri
|
||||
@uri ||= URI("https://api.anthropic.com/v1/messages")
|
||||
url = llm_model&.url || "https://api.anthropic.com/v1/messages"
|
||||
|
||||
URI(url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
|
@ -78,7 +80,7 @@ module DiscourseAi
|
|||
def prepare_request(payload)
|
||||
headers = {
|
||||
"anthropic-version" => "2023-06-01",
|
||||
"x-api-key" => SiteSetting.ai_anthropic_api_key,
|
||||
"x-api-key" => llm_model&.api_key || SiteSetting.ai_anthropic_api_key,
|
||||
"content-type" => "application/json",
|
||||
}
|
||||
|
||||
|
|
|
@ -71,7 +71,8 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
api_url =
|
||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
|
||||
llm_model&.url ||
|
||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
|
||||
|
||||
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
|
||||
|
||||
|
@ -91,7 +92,7 @@ module DiscourseAi
|
|||
Aws::Sigv4::Signer.new(
|
||||
access_key_id: SiteSetting.ai_bedrock_access_key_id,
|
||||
region: SiteSetting.ai_bedrock_region,
|
||||
secret_access_key: SiteSetting.ai_bedrock_secret_access_key,
|
||||
secret_access_key: llm_model&.api_key || SiteSetting.ai_bedrock_secret_access_key,
|
||||
service: "bedrock",
|
||||
)
|
||||
|
||||
|
|
|
@ -60,9 +60,10 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def initialize(model_name, tokenizer)
|
||||
def initialize(model_name, tokenizer, llm_model: nil)
|
||||
@model = model_name
|
||||
@tokenizer = tokenizer
|
||||
@llm_model = llm_model
|
||||
end
|
||||
|
||||
def native_tool_support?
|
||||
|
@ -70,7 +71,11 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def use_ssl?
|
||||
true
|
||||
if model_uri&.scheme.present?
|
||||
model_uri.scheme == "https"
|
||||
else
|
||||
true
|
||||
end
|
||||
end
|
||||
|
||||
def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk)
|
||||
|
@ -294,7 +299,7 @@ module DiscourseAi
|
|||
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
||||
end
|
||||
|
||||
attr_reader :tokenizer, :model
|
||||
attr_reader :tokenizer, :model, :llm_model
|
||||
|
||||
protected
|
||||
|
||||
|
|
|
@ -42,7 +42,9 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def model_uri
|
||||
URI("https://api.cohere.ai/v1/chat")
|
||||
url = llm_model&.url || "https://api.cohere.ai/v1/chat"
|
||||
|
||||
URI(url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, dialect)
|
||||
|
@ -56,7 +58,7 @@ module DiscourseAi
|
|||
def prepare_request(payload)
|
||||
headers = {
|
||||
"Content-Type" => "application/json",
|
||||
"Authorization" => "Bearer #{SiteSetting.ai_cohere_api_key}",
|
||||
"Authorization" => "Bearer #{llm_model&.api_key || SiteSetting.ai_cohere_api_key}",
|
||||
}
|
||||
|
||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
|
|
|
@ -51,9 +51,16 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def model_uri
|
||||
mapped_model = model == "gemini-1.5-pro" ? "gemini-1.5-pro-latest" : model
|
||||
url =
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{SiteSetting.ai_gemini_api_key}"
|
||||
if llm_model
|
||||
url = llm_model.url
|
||||
else
|
||||
mapped_model = model == "gemini-1.5-pro" ? "gemini-1.5-pro-latest" : model
|
||||
url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}"
|
||||
end
|
||||
|
||||
key = llm_model&.api_key || SiteSetting.ai_gemini_api_key
|
||||
|
||||
url = "#{url}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{key}"
|
||||
|
||||
URI(url)
|
||||
end
|
||||
|
|
|
@ -44,7 +44,7 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def model_uri
|
||||
URI(SiteSetting.ai_hugging_face_api_url)
|
||||
URI(llm_model&.url || SiteSetting.ai_hugging_face_api_url)
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, _dialect)
|
||||
|
@ -53,7 +53,8 @@ module DiscourseAi
|
|||
.merge(messages: prompt)
|
||||
.tap do |payload|
|
||||
if !payload[:max_tokens]
|
||||
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
|
||||
token_limit =
|
||||
llm_model&.max_prompt_tokens || SiteSetting.ai_hugging_face_token_limit
|
||||
|
||||
payload[:max_tokens] = token_limit - prompt_size(prompt)
|
||||
end
|
||||
|
@ -63,11 +64,11 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def prepare_request(payload)
|
||||
api_key = llm_model&.api_key || SiteSetting.ai_hugging_face_api_key
|
||||
|
||||
headers =
|
||||
{ "Content-Type" => "application/json" }.tap do |h|
|
||||
if SiteSetting.ai_hugging_face_api_key.present?
|
||||
h["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}"
|
||||
end
|
||||
h["Authorization"] = "Bearer #{api_key}" if api_key.present?
|
||||
end
|
||||
|
||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
|
|
|
@ -48,7 +48,7 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def model_uri
|
||||
URI("#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
|
||||
URI(llm_model&.url || "#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
|
||||
end
|
||||
|
||||
def prepare_payload(prompt, model_params, _dialect)
|
||||
|
|
|
@ -78,6 +78,8 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def model_uri
|
||||
return URI(llm_model.url) if llm_model&.url
|
||||
|
||||
url =
|
||||
if model.include?("gpt-4")
|
||||
if model.include?("32k")
|
||||
|
@ -115,10 +117,12 @@ module DiscourseAi
|
|||
def prepare_request(payload)
|
||||
headers = { "Content-Type" => "application/json" }
|
||||
|
||||
api_key = llm_model&.api_key || SiteSetting.ai_openai_api_key
|
||||
|
||||
if model_uri.host.include?("azure")
|
||||
headers["api-key"] = SiteSetting.ai_openai_api_key
|
||||
headers["api-key"] = api_key
|
||||
else
|
||||
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
||||
headers["Authorization"] = "Bearer #{api_key}"
|
||||
end
|
||||
|
||||
if SiteSetting.ai_openai_organization.present?
|
||||
|
|
|
@ -44,6 +44,8 @@ module DiscourseAi
|
|||
private
|
||||
|
||||
def model_uri
|
||||
return URI(llm_model.url) if llm_model&.url
|
||||
|
||||
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
|
||||
if service.present?
|
||||
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
|
||||
|
@ -63,7 +65,8 @@ module DiscourseAi
|
|||
def prepare_request(payload)
|
||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||
|
||||
headers["X-API-KEY"] = SiteSetting.ai_vllm_api_key if SiteSetting.ai_vllm_api_key.present?
|
||||
api_key = llm_model&.api_key || SiteSetting.ai_vllm_api_key
|
||||
headers["X-API-KEY"] = api_key if api_key.present?
|
||||
|
||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
|
|
@ -25,7 +25,7 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def tokenizer_names
|
||||
DiscourseAi::Completions::Dialects::Dialect.available_tokenizers.map(&:name).uniq
|
||||
DiscourseAi::Tokenizer::BasicTokenizer.available_llm_tokenizers.map(&:name)
|
||||
end
|
||||
|
||||
def models_by_provider
|
||||
|
@ -91,17 +91,16 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def proxy(model_name)
|
||||
# We are in the process of transitioning to always use objects here.
|
||||
# We'll live with this hack for a while.
|
||||
provider_and_model_name = model_name.split(":")
|
||||
provider_name = provider_and_model_name.first
|
||||
model_name_without_prov = provider_and_model_name[1..].join
|
||||
is_custom_model = provider_name == "custom"
|
||||
|
||||
if is_custom_model
|
||||
# We are in the process of transitioning to always use objects here.
|
||||
# We'll live with this hack for a while.
|
||||
if provider_name == "custom"
|
||||
llm_model = LlmModel.find(model_name_without_prov)
|
||||
provider_name = llm_model.provider
|
||||
model_name_without_prov = llm_model.name
|
||||
raise UNKNOWN_MODEL if !llm_model
|
||||
return proxy_from_obj(llm_model)
|
||||
end
|
||||
|
||||
dialect_klass =
|
||||
|
@ -111,24 +110,32 @@ module DiscourseAi
|
|||
if @canned_llm && @canned_llm != model_name
|
||||
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
|
||||
end
|
||||
return new(dialect_klass, nil, model_name, opts: { gateway: @canned_response })
|
||||
end
|
||||
|
||||
opts = {}
|
||||
opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model
|
||||
return new(dialect_klass, nil, model_name, gateway: @canned_response)
|
||||
end
|
||||
|
||||
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||
|
||||
new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts)
|
||||
new(dialect_klass, gateway_klass, model_name_without_prov)
|
||||
end
|
||||
|
||||
def proxy_from_obj(llm_model)
|
||||
provider_name = llm_model.provider
|
||||
model_name = llm_model.name
|
||||
|
||||
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
|
||||
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||
|
||||
new(dialect_klass, gateway_klass, model_name, llm_model: llm_model)
|
||||
end
|
||||
end
|
||||
|
||||
def initialize(dialect_klass, gateway_klass, model_name, opts: {})
|
||||
def initialize(dialect_klass, gateway_klass, model_name, gateway: nil, llm_model: nil)
|
||||
@dialect_klass = dialect_klass
|
||||
@gateway_klass = gateway_klass
|
||||
@model_name = model_name
|
||||
@gateway = opts[:gateway]
|
||||
@max_prompt_tokens = opts[:max_prompt_tokens]
|
||||
@gateway = gateway
|
||||
@llm_model = llm_model
|
||||
end
|
||||
|
||||
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
|
||||
|
@ -185,13 +192,9 @@ module DiscourseAi
|
|||
|
||||
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
|
||||
|
||||
gateway = @gateway || gateway_klass.new(model_name, dialect_klass.tokenizer)
|
||||
dialect =
|
||||
dialect_klass.new(
|
||||
prompt,
|
||||
model_name,
|
||||
opts: model_params.merge(max_prompt_tokens: @max_prompt_tokens),
|
||||
)
|
||||
dialect = dialect_klass.new(prompt, model_name, opts: model_params, llm_model: llm_model)
|
||||
|
||||
gateway = @gateway || gateway_klass.new(model_name, dialect.tokenizer, llm_model: llm_model)
|
||||
gateway.perform_completion!(
|
||||
dialect,
|
||||
user,
|
||||
|
@ -202,18 +205,20 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def max_prompt_tokens
|
||||
return @max_prompt_tokens if @max_prompt_tokens.present?
|
||||
|
||||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
|
||||
llm_model&.max_prompt_tokens ||
|
||||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
|
||||
end
|
||||
|
||||
delegate :tokenizer, to: :dialect_klass
|
||||
def tokenizer
|
||||
llm_model&.tokenizer_class ||
|
||||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).tokenizer
|
||||
end
|
||||
|
||||
attr_reader :model_name
|
||||
|
||||
private
|
||||
|
||||
attr_reader :dialect_klass, :gateway_klass
|
||||
attr_reader :dialect_klass, :gateway_klass, :llm_model
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -18,19 +18,16 @@ module DiscourseAi
|
|||
model_name_without_prov = provider_and_model_name[1..].join
|
||||
is_custom_model = provider_name == "custom"
|
||||
|
||||
if is_custom_model
|
||||
llm_model = LlmModel.find(model_name_without_prov)
|
||||
provider_name = llm_model.provider
|
||||
model_name_without_prov = llm_model.name
|
||||
end
|
||||
# Bypass setting validations for custom models. They don't rely on site settings.
|
||||
if !is_custom_model
|
||||
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||
|
||||
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||
return false if endpoint.nil?
|
||||
|
||||
return false if endpoint.nil?
|
||||
|
||||
if !endpoint.correctly_configured?(model_name_without_prov)
|
||||
@endpoint = endpoint
|
||||
return false
|
||||
if !endpoint.correctly_configured?(model_name_without_prov)
|
||||
@endpoint = endpoint
|
||||
return false
|
||||
end
|
||||
end
|
||||
|
||||
if !can_talk_to_model?(val)
|
||||
|
|
|
@ -4,6 +4,15 @@ module DiscourseAi
|
|||
module Tokenizer
|
||||
class BasicTokenizer
|
||||
class << self
|
||||
def available_llm_tokenizers
|
||||
[
|
||||
DiscourseAi::Tokenizer::AnthropicTokenizer,
|
||||
DiscourseAi::Tokenizer::Llama3Tokenizer,
|
||||
DiscourseAi::Tokenizer::MixtralTokenizer,
|
||||
DiscourseAi::Tokenizer::OpenAiTokenizer,
|
||||
]
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
raise NotImplementedError
|
||||
end
|
||||
|
|
|
@ -1,12 +0,0 @@
|
|||
# frozen_string_literal: true
|
||||
|
||||
module DiscourseAi
|
||||
module Tokenizer
|
||||
class Llama2Tokenizer < BasicTokenizer
|
||||
def self.tokenizer
|
||||
@@tokenizer ||=
|
||||
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/llama-2-70b-chat-hf.json")
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
|
@ -7,7 +7,7 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect
|
|||
trim_messages(messages)
|
||||
end
|
||||
|
||||
def self.tokenizer
|
||||
def tokenizer
|
||||
Class.new do
|
||||
def self.size(str)
|
||||
str.length
|
||||
|
|
|
@ -6,9 +6,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
|||
DiscourseAi::Completions::Dialects::Mistral,
|
||||
canned_response,
|
||||
"hugging_face:Upstage-Llama-2-*-instruct-v2",
|
||||
opts: {
|
||||
gateway: canned_response,
|
||||
},
|
||||
gateway: canned_response,
|
||||
)
|
||||
end
|
||||
|
||||
|
|
|
@ -126,23 +126,6 @@ describe DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer do
|
|||
end
|
||||
end
|
||||
|
||||
describe DiscourseAi::Tokenizer::Llama2Tokenizer do
|
||||
describe "#size" do
|
||||
describe "returns a token count" do
|
||||
it "for a sentence with punctuation and capitalization and numbers" do
|
||||
expect(described_class.size("Hello, World! 123")).to eq(9)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "#truncate" do
|
||||
it "truncates a sentence" do
|
||||
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
|
||||
expect(described_class.truncate(sentence, 3)).to eq("foo bar")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer do
|
||||
describe "#size" do
|
||||
describe "returns a token count" do
|
||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue