FEATURE: Set endpoint credentials directly from LlmModel. (#625)

* FEATURE: Set endpoint credentials directly from LlmModel.

Drop Llama2Tokenizer since we no longer use it.

* Allow http for custom LLMs

---------

Co-authored-by: Rafael Silva <xfalcox@gmail.com>
This commit is contained in:
Roman Rizzi 2024-05-16 09:50:22 -03:00 committed by GitHub
parent 255139056d
commit 1d786fbaaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 184 additions and 93535 deletions

View File

@ -58,6 +58,8 @@ module DiscourseAi
:provider, :provider,
:tokenizer, :tokenizer,
:max_prompt_tokens, :max_prompt_tokens,
:url,
:api_key,
) )
end end
end end

View File

@ -3,5 +3,5 @@
class LlmModelSerializer < ApplicationSerializer class LlmModelSerializer < ApplicationSerializer
root "llm" root "llm"
attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer, :api_key, :url
end end

View File

@ -3,11 +3,14 @@ import RestModel from "discourse/models/rest";
export default class AiLlm extends RestModel { export default class AiLlm extends RestModel {
createProperties() { createProperties() {
return this.getProperties( return this.getProperties(
"id",
"display_name", "display_name",
"name", "name",
"provider", "provider",
"tokenizer", "tokenizer",
"max_prompt_tokens" "max_prompt_tokens",
"url",
"api_key"
); );
} }

View File

@ -33,7 +33,9 @@ export default class AiLlmEditor extends Component {
const isNew = this.args.model.isNew; const isNew = this.args.model.isNew;
try { try {
await this.args.model.save(); const result = await this.args.model.save();
this.args.model.setProperties(result.responseJson.ai_persona);
if (isNew) { if (isNew) {
this.args.llms.addObject(this.args.model); this.args.llms.addObject(this.args.model);
@ -61,7 +63,7 @@ export default class AiLlmEditor extends Component {
<div class="control-group"> <div class="control-group">
<label>{{i18n "discourse_ai.llms.display_name"}}</label> <label>{{i18n "discourse_ai.llms.display_name"}}</label>
<Input <Input
class="ai-llm-editor__display-name" class="ai-llm-editor-input ai-llm-editor__display-name"
@type="text" @type="text"
@value={{@model.display_name}} @value={{@model.display_name}}
/> />
@ -69,7 +71,7 @@ export default class AiLlmEditor extends Component {
<div class="control-group"> <div class="control-group">
<label>{{i18n "discourse_ai.llms.name"}}</label> <label>{{i18n "discourse_ai.llms.name"}}</label>
<Input <Input
class="ai-llm-editor__name" class="ai-llm-editor-input ai-llm-editor__name"
@type="text" @type="text"
@value={{@model.name}} @value={{@model.name}}
/> />
@ -85,6 +87,22 @@ export default class AiLlmEditor extends Component {
@content={{this.selectedProviders}} @content={{this.selectedProviders}}
/> />
</div> </div>
<div class="control-group">
<label>{{I18n.t "discourse_ai.llms.url"}}</label>
<Input
class="ai-llm-editor-input ai-llm-editor__url"
@type="text"
@value={{@model.url}}
/>
</div>
<div class="control-group">
<label>{{I18n.t "discourse_ai.llms.api_key"}}</label>
<Input
class="ai-llm-editor-input ai-llm-editor__api-key"
@type="text"
@value={{@model.api_key}}
/>
</div>
<div class="control-group"> <div class="control-group">
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label> <label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
<ComboBox <ComboBox
@ -96,7 +114,7 @@ export default class AiLlmEditor extends Component {
<label>{{i18n "discourse_ai.llms.max_prompt_tokens"}}</label> <label>{{i18n "discourse_ai.llms.max_prompt_tokens"}}</label>
<Input <Input
@type="number" @type="number"
class="ai-llm-editor__max-prompt-tokens" class="ai-llm-editor-input ai-llm-editor__max-prompt-tokens"
step="any" step="any"
min="0" min="0"
lang="en" lang="en"

View File

@ -29,4 +29,8 @@
text-align: center; text-align: center;
font-size: var(--font-up-1); font-size: var(--font-up-1);
} }
.ai-llm-editor-input {
width: 350px;
}
} }

View File

@ -204,6 +204,8 @@ en:
provider: "Service hosting the model:" provider: "Service hosting the model:"
tokenizer: "Tokenizer:" tokenizer: "Tokenizer:"
max_prompt_tokens: "Number of tokens for the prompt:" max_prompt_tokens: "Number of tokens for the prompt:"
url: "URL of the service hosting the model:"
api_key: "API Key of the service hosting the model:"
save: "Save" save: "Save"
saved: "LLM Model Saved" saved: "LLM Model Saved"

View File

@ -0,0 +1,8 @@
# frozen_string_literal: true
class AddEndpointDataToLlmModel < ActiveRecord::Migration[7.0]
def change
add_column :llm_models, :url, :string
add_column :llm_models, :api_key, :string
end
end

View File

@ -8,14 +8,14 @@ module DiscourseAi
def can_translate?(model_name) def can_translate?(model_name)
model_name.starts_with?("gpt-") model_name.starts_with?("gpt-")
end end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
end
def native_tool_support? def native_tool_support?
true true
end end
@ -30,7 +30,7 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
# provide a buffer of 120 tokens - our function counting is not # provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard # 100% accurate and getting numbers to align exactly is very hard
@ -38,7 +38,7 @@ module DiscourseAi
if tools.present? if tools.present?
# note this is about 100 tokens over, OpenAI have a more optimal representation # note this is about 100 tokens over, OpenAI have a more optimal representation
@function_size ||= self.class.tokenizer.size(tools.to_json.to_s) @function_size ||= self.tokenizer.size(tools.to_json.to_s)
buffer += @function_size buffer += @function_size
end end
@ -113,7 +113,7 @@ module DiscourseAi
end end
def calculate_message_token(context) def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) self.tokenizer.size(context[:content].to_s + context[:name].to_s)
end end
def model_max_tokens def model_max_tokens

View File

@ -10,10 +10,6 @@ module DiscourseAi
model_name, model_name,
) )
end end
def tokenizer
DiscourseAi::Tokenizer::AnthropicTokenizer
end
end end
class ClaudePrompt class ClaudePrompt
@ -26,6 +22,10 @@ module DiscourseAi
end end
end end
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::AnthropicTokenizer
end
def translate def translate
messages = super messages = super
@ -50,7 +50,8 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
# Longer term it will have over 1 million # Longer term it will have over 1 million
200_000 # Claude-3 has a 200k context window for now 200_000 # Claude-3 has a 200k context window for now
end end

View File

@ -10,14 +10,14 @@ module DiscourseAi
def can_translate?(model_name) def can_translate?(model_name)
%w[command-light command command-r command-r-plus].include?(model_name) %w[command-light command command-r command-r-plus].include?(model_name)
end end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
end
def translate def translate
messages = super messages = super
@ -38,7 +38,7 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
case model_name case model_name
when "command-light" when "command-light"
@ -61,7 +61,7 @@ module DiscourseAi
end end
def calculate_message_token(context) def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) self.tokenizer.size(context[:content].to_s + context[:name].to_s)
end end
def system_msg(msg) def system_msg(msg)

View File

@ -20,10 +20,6 @@ module DiscourseAi
] ]
end end
def available_tokenizers
all_dialects.map(&:tokenizer)
end
def dialect_for(model_name) def dialect_for(model_name)
dialects = [] dialects = []
@ -38,20 +34,21 @@ module DiscourseAi
dialect dialect
end end
def tokenizer
raise NotImplemented
end
end end
def initialize(generic_prompt, model_name, opts: {}) def initialize(generic_prompt, model_name, opts: {}, llm_model: nil)
@prompt = generic_prompt @prompt = generic_prompt
@model_name = model_name @model_name = model_name
@opts = opts @opts = opts
@llm_model = llm_model
end end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/ VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def tokenizer
raise NotImplemented
end
def can_end_with_assistant_msg? def can_end_with_assistant_msg?
false false
end end
@ -88,7 +85,7 @@ module DiscourseAi
private private
attr_reader :model_name, :opts attr_reader :model_name, :opts, :llm_model
def trim_messages(messages) def trim_messages(messages)
prompt_limit = max_prompt_tokens prompt_limit = max_prompt_tokens
@ -147,7 +144,7 @@ module DiscourseAi
end end
def calculate_message_token(msg) def calculate_message_token(msg)
self.class.tokenizer.size(msg[:content].to_s) self.tokenizer.size(msg[:content].to_s)
end end
def tools_dialect def tools_dialect

View File

@ -8,10 +8,10 @@ module DiscourseAi
def can_translate?(model_name) def can_translate?(model_name)
model_name == "fake" model_name == "fake"
end end
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer DiscourseAi::Tokenizer::OpenAiTokenizer
end
end end
def translate def translate

View File

@ -8,16 +8,16 @@ module DiscourseAi
def can_translate?(model_name) def can_translate?(model_name)
%w[gemini-pro gemini-1.5-pro].include?(model_name) %w[gemini-pro gemini-1.5-pro].include?(model_name)
end end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
end
end end
def native_tool_support? def native_tool_support?
true true
end end
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
end
def translate def translate
# Gemini complains if we don't alternate model/user roles. # Gemini complains if we don't alternate model/user roles.
noop_model_response = { role: "model", parts: { text: "Ok." } } noop_model_response = { role: "model", parts: { text: "Ok." } }
@ -68,7 +68,7 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
if model_name == "gemini-1.5-pro" if model_name == "gemini-1.5-pro"
# technically we support 1 million tokens, but we're being conservative # technically we support 1 million tokens, but we're being conservative
@ -81,7 +81,7 @@ module DiscourseAi
protected protected
def calculate_message_token(context) def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) self.tokenizer.size(context[:content].to_s + context[:name].to_s)
end end
def system_msg(msg) def system_msg(msg)

View File

@ -12,10 +12,10 @@ module DiscourseAi
mistral mistral
].include?(model_name) ].include?(model_name)
end end
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::MixtralTokenizer llm_model&.tokenizer_class || DiscourseAi::Tokenizer::MixtralTokenizer
end
end end
def tools def tools
@ -23,7 +23,7 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
32_000 32_000
end end

View File

@ -8,10 +8,10 @@ module DiscourseAi
def can_translate?(_model_name) def can_translate?(_model_name)
true true
end end
end
def tokenizer def tokenizer
DiscourseAi::Tokenizer::Llama3Tokenizer llm_model&.tokenizer_class || DiscourseAi::Tokenizer::Llama3Tokenizer
end
end end
def tools def tools
@ -19,7 +19,7 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
32_000 32_000
end end

View File

@ -63,7 +63,9 @@ module DiscourseAi
end end
def model_uri def model_uri
@uri ||= URI("https://api.anthropic.com/v1/messages") url = llm_model&.url || "https://api.anthropic.com/v1/messages"
URI(url)
end end
def prepare_payload(prompt, model_params, dialect) def prepare_payload(prompt, model_params, dialect)
@ -78,7 +80,7 @@ module DiscourseAi
def prepare_request(payload) def prepare_request(payload)
headers = { headers = {
"anthropic-version" => "2023-06-01", "anthropic-version" => "2023-06-01",
"x-api-key" => SiteSetting.ai_anthropic_api_key, "x-api-key" => llm_model&.api_key || SiteSetting.ai_anthropic_api_key,
"content-type" => "application/json", "content-type" => "application/json",
} }

View File

@ -71,7 +71,8 @@ module DiscourseAi
end end
api_url = api_url =
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke" llm_model&.url ||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
@ -91,7 +92,7 @@ module DiscourseAi
Aws::Sigv4::Signer.new( Aws::Sigv4::Signer.new(
access_key_id: SiteSetting.ai_bedrock_access_key_id, access_key_id: SiteSetting.ai_bedrock_access_key_id,
region: SiteSetting.ai_bedrock_region, region: SiteSetting.ai_bedrock_region,
secret_access_key: SiteSetting.ai_bedrock_secret_access_key, secret_access_key: llm_model&.api_key || SiteSetting.ai_bedrock_secret_access_key,
service: "bedrock", service: "bedrock",
) )

View File

@ -60,9 +60,10 @@ module DiscourseAi
end end
end end
def initialize(model_name, tokenizer) def initialize(model_name, tokenizer, llm_model: nil)
@model = model_name @model = model_name
@tokenizer = tokenizer @tokenizer = tokenizer
@llm_model = llm_model
end end
def native_tool_support? def native_tool_support?
@ -70,7 +71,11 @@ module DiscourseAi
end end
def use_ssl? def use_ssl?
true if model_uri&.scheme.present?
model_uri.scheme == "https"
else
true
end
end end
def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk) def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk)
@ -294,7 +299,7 @@ module DiscourseAi
tokenizer.size(extract_prompt_for_tokenizer(prompt)) tokenizer.size(extract_prompt_for_tokenizer(prompt))
end end
attr_reader :tokenizer, :model attr_reader :tokenizer, :model, :llm_model
protected protected

View File

@ -42,7 +42,9 @@ module DiscourseAi
private private
def model_uri def model_uri
URI("https://api.cohere.ai/v1/chat") url = llm_model&.url || "https://api.cohere.ai/v1/chat"
URI(url)
end end
def prepare_payload(prompt, model_params, dialect) def prepare_payload(prompt, model_params, dialect)
@ -56,7 +58,7 @@ module DiscourseAi
def prepare_request(payload) def prepare_request(payload)
headers = { headers = {
"Content-Type" => "application/json", "Content-Type" => "application/json",
"Authorization" => "Bearer #{SiteSetting.ai_cohere_api_key}", "Authorization" => "Bearer #{llm_model&.api_key || SiteSetting.ai_cohere_api_key}",
} }
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }

View File

@ -51,9 +51,16 @@ module DiscourseAi
private private
def model_uri def model_uri
mapped_model = model == "gemini-1.5-pro" ? "gemini-1.5-pro-latest" : model if llm_model
url = url = llm_model.url
"https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{SiteSetting.ai_gemini_api_key}" else
mapped_model = model == "gemini-1.5-pro" ? "gemini-1.5-pro-latest" : model
url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}"
end
key = llm_model&.api_key || SiteSetting.ai_gemini_api_key
url = "#{url}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{key}"
URI(url) URI(url)
end end

View File

@ -44,7 +44,7 @@ module DiscourseAi
private private
def model_uri def model_uri
URI(SiteSetting.ai_hugging_face_api_url) URI(llm_model&.url || SiteSetting.ai_hugging_face_api_url)
end end
def prepare_payload(prompt, model_params, _dialect) def prepare_payload(prompt, model_params, _dialect)
@ -53,7 +53,8 @@ module DiscourseAi
.merge(messages: prompt) .merge(messages: prompt)
.tap do |payload| .tap do |payload|
if !payload[:max_tokens] if !payload[:max_tokens]
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000 token_limit =
llm_model&.max_prompt_tokens || SiteSetting.ai_hugging_face_token_limit
payload[:max_tokens] = token_limit - prompt_size(prompt) payload[:max_tokens] = token_limit - prompt_size(prompt)
end end
@ -63,11 +64,11 @@ module DiscourseAi
end end
def prepare_request(payload) def prepare_request(payload)
api_key = llm_model&.api_key || SiteSetting.ai_hugging_face_api_key
headers = headers =
{ "Content-Type" => "application/json" }.tap do |h| { "Content-Type" => "application/json" }.tap do |h|
if SiteSetting.ai_hugging_face_api_key.present? h["Authorization"] = "Bearer #{api_key}" if api_key.present?
h["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}"
end
end end
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }

View File

@ -48,7 +48,7 @@ module DiscourseAi
private private
def model_uri def model_uri
URI("#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions") URI(llm_model&.url || "#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
end end
def prepare_payload(prompt, model_params, _dialect) def prepare_payload(prompt, model_params, _dialect)

View File

@ -78,6 +78,8 @@ module DiscourseAi
private private
def model_uri def model_uri
return URI(llm_model.url) if llm_model&.url
url = url =
if model.include?("gpt-4") if model.include?("gpt-4")
if model.include?("32k") if model.include?("32k")
@ -115,10 +117,12 @@ module DiscourseAi
def prepare_request(payload) def prepare_request(payload)
headers = { "Content-Type" => "application/json" } headers = { "Content-Type" => "application/json" }
api_key = llm_model&.api_key || SiteSetting.ai_openai_api_key
if model_uri.host.include?("azure") if model_uri.host.include?("azure")
headers["api-key"] = SiteSetting.ai_openai_api_key headers["api-key"] = api_key
else else
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}" headers["Authorization"] = "Bearer #{api_key}"
end end
if SiteSetting.ai_openai_organization.present? if SiteSetting.ai_openai_organization.present?

View File

@ -44,6 +44,8 @@ module DiscourseAi
private private
def model_uri def model_uri
return URI(llm_model.url) if llm_model&.url
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv) service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
if service.present? if service.present?
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions" api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
@ -63,7 +65,8 @@ module DiscourseAi
def prepare_request(payload) def prepare_request(payload)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = SiteSetting.ai_vllm_api_key if SiteSetting.ai_vllm_api_key.present? api_key = llm_model&.api_key || SiteSetting.ai_vllm_api_key
headers["X-API-KEY"] = api_key if api_key.present?
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end

View File

@ -25,7 +25,7 @@ module DiscourseAi
end end
def tokenizer_names def tokenizer_names
DiscourseAi::Completions::Dialects::Dialect.available_tokenizers.map(&:name).uniq DiscourseAi::Tokenizer::BasicTokenizer.available_llm_tokenizers.map(&:name)
end end
def models_by_provider def models_by_provider
@ -91,17 +91,16 @@ module DiscourseAi
end end
def proxy(model_name) def proxy(model_name)
# We are in the process of transitioning to always use objects here.
# We'll live with this hack for a while.
provider_and_model_name = model_name.split(":") provider_and_model_name = model_name.split(":")
provider_name = provider_and_model_name.first provider_name = provider_and_model_name.first
model_name_without_prov = provider_and_model_name[1..].join model_name_without_prov = provider_and_model_name[1..].join
is_custom_model = provider_name == "custom"
if is_custom_model # We are in the process of transitioning to always use objects here.
# We'll live with this hack for a while.
if provider_name == "custom"
llm_model = LlmModel.find(model_name_without_prov) llm_model = LlmModel.find(model_name_without_prov)
provider_name = llm_model.provider raise UNKNOWN_MODEL if !llm_model
model_name_without_prov = llm_model.name return proxy_from_obj(llm_model)
end end
dialect_klass = dialect_klass =
@ -111,24 +110,32 @@ module DiscourseAi
if @canned_llm && @canned_llm != model_name if @canned_llm && @canned_llm != model_name
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}" raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
end end
return new(dialect_klass, nil, model_name, opts: { gateway: @canned_response })
end
opts = {} return new(dialect_klass, nil, model_name, gateway: @canned_response)
opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model end
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name) gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts) new(dialect_klass, gateway_klass, model_name_without_prov)
end
def proxy_from_obj(llm_model)
provider_name = llm_model.provider
model_name = llm_model.name
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
new(dialect_klass, gateway_klass, model_name, llm_model: llm_model)
end end
end end
def initialize(dialect_klass, gateway_klass, model_name, opts: {}) def initialize(dialect_klass, gateway_klass, model_name, gateway: nil, llm_model: nil)
@dialect_klass = dialect_klass @dialect_klass = dialect_klass
@gateway_klass = gateway_klass @gateway_klass = gateway_klass
@model_name = model_name @model_name = model_name
@gateway = opts[:gateway] @gateway = gateway
@max_prompt_tokens = opts[:max_prompt_tokens] @llm_model = llm_model
end end
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object # @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
@ -185,13 +192,9 @@ module DiscourseAi
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? } model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
gateway = @gateway || gateway_klass.new(model_name, dialect_klass.tokenizer) dialect = dialect_klass.new(prompt, model_name, opts: model_params, llm_model: llm_model)
dialect =
dialect_klass.new( gateway = @gateway || gateway_klass.new(model_name, dialect.tokenizer, llm_model: llm_model)
prompt,
model_name,
opts: model_params.merge(max_prompt_tokens: @max_prompt_tokens),
)
gateway.perform_completion!( gateway.perform_completion!(
dialect, dialect,
user, user,
@ -202,18 +205,20 @@ module DiscourseAi
end end
def max_prompt_tokens def max_prompt_tokens
return @max_prompt_tokens if @max_prompt_tokens.present? llm_model&.max_prompt_tokens ||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
end end
delegate :tokenizer, to: :dialect_klass def tokenizer
llm_model&.tokenizer_class ||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).tokenizer
end
attr_reader :model_name attr_reader :model_name
private private
attr_reader :dialect_klass, :gateway_klass attr_reader :dialect_klass, :gateway_klass, :llm_model
end end
end end
end end

View File

@ -18,19 +18,16 @@ module DiscourseAi
model_name_without_prov = provider_and_model_name[1..].join model_name_without_prov = provider_and_model_name[1..].join
is_custom_model = provider_name == "custom" is_custom_model = provider_name == "custom"
if is_custom_model # Bypass setting validations for custom models. They don't rely on site settings.
llm_model = LlmModel.find(model_name_without_prov) if !is_custom_model
provider_name = llm_model.provider endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
model_name_without_prov = llm_model.name
end
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name) return false if endpoint.nil?
return false if endpoint.nil? if !endpoint.correctly_configured?(model_name_without_prov)
@endpoint = endpoint
if !endpoint.correctly_configured?(model_name_without_prov) return false
@endpoint = endpoint end
return false
end end
if !can_talk_to_model?(val) if !can_talk_to_model?(val)

View File

@ -4,6 +4,15 @@ module DiscourseAi
module Tokenizer module Tokenizer
class BasicTokenizer class BasicTokenizer
class << self class << self
def available_llm_tokenizers
[
DiscourseAi::Tokenizer::AnthropicTokenizer,
DiscourseAi::Tokenizer::Llama3Tokenizer,
DiscourseAi::Tokenizer::MixtralTokenizer,
DiscourseAi::Tokenizer::OpenAiTokenizer,
]
end
def tokenizer def tokenizer
raise NotImplementedError raise NotImplementedError
end end

View File

@ -1,12 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Tokenizer
class Llama2Tokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||=
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/llama-2-70b-chat-hf.json")
end
end
end
end

View File

@ -7,7 +7,7 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect
trim_messages(messages) trim_messages(messages)
end end
def self.tokenizer def tokenizer
Class.new do Class.new do
def self.size(str) def self.size(str)
str.length str.length

View File

@ -6,9 +6,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
DiscourseAi::Completions::Dialects::Mistral, DiscourseAi::Completions::Dialects::Mistral,
canned_response, canned_response,
"hugging_face:Upstage-Llama-2-*-instruct-v2", "hugging_face:Upstage-Llama-2-*-instruct-v2",
opts: { gateway: canned_response,
gateway: canned_response,
},
) )
end end

View File

@ -126,23 +126,6 @@ describe DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer do
end end
end end
describe DiscourseAi::Tokenizer::Llama2Tokenizer do
describe "#size" do
describe "returns a token count" do
it "for a sentence with punctuation and capitalization and numbers" do
expect(described_class.size("Hello, World! 123")).to eq(9)
end
end
end
describe "#truncate" do
it "truncates a sentence" do
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
expect(described_class.truncate(sentence, 3)).to eq("foo bar")
end
end
end
describe DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer do describe DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer do
describe "#size" do describe "#size" do
describe "returns a token count" do describe "returns a token count" do

File diff suppressed because it is too large Load Diff