FEATURE: Set endpoint credentials directly from LlmModel. (#625)
* FEATURE: Set endpoint credentials directly from LlmModel. Drop Llama2Tokenizer since we no longer use it. * Allow http for custom LLMs --------- Co-authored-by: Rafael Silva <xfalcox@gmail.com>
This commit is contained in:
parent
255139056d
commit
1d786fbaaf
|
@ -58,6 +58,8 @@ module DiscourseAi
|
||||||
:provider,
|
:provider,
|
||||||
:tokenizer,
|
:tokenizer,
|
||||||
:max_prompt_tokens,
|
:max_prompt_tokens,
|
||||||
|
:url,
|
||||||
|
:api_key,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,5 +3,5 @@
|
||||||
class LlmModelSerializer < ApplicationSerializer
|
class LlmModelSerializer < ApplicationSerializer
|
||||||
root "llm"
|
root "llm"
|
||||||
|
|
||||||
attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer
|
attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer, :api_key, :url
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,11 +3,14 @@ import RestModel from "discourse/models/rest";
|
||||||
export default class AiLlm extends RestModel {
|
export default class AiLlm extends RestModel {
|
||||||
createProperties() {
|
createProperties() {
|
||||||
return this.getProperties(
|
return this.getProperties(
|
||||||
|
"id",
|
||||||
"display_name",
|
"display_name",
|
||||||
"name",
|
"name",
|
||||||
"provider",
|
"provider",
|
||||||
"tokenizer",
|
"tokenizer",
|
||||||
"max_prompt_tokens"
|
"max_prompt_tokens",
|
||||||
|
"url",
|
||||||
|
"api_key"
|
||||||
);
|
);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -33,7 +33,9 @@ export default class AiLlmEditor extends Component {
|
||||||
const isNew = this.args.model.isNew;
|
const isNew = this.args.model.isNew;
|
||||||
|
|
||||||
try {
|
try {
|
||||||
await this.args.model.save();
|
const result = await this.args.model.save();
|
||||||
|
|
||||||
|
this.args.model.setProperties(result.responseJson.ai_persona);
|
||||||
|
|
||||||
if (isNew) {
|
if (isNew) {
|
||||||
this.args.llms.addObject(this.args.model);
|
this.args.llms.addObject(this.args.model);
|
||||||
|
@ -61,7 +63,7 @@ export default class AiLlmEditor extends Component {
|
||||||
<div class="control-group">
|
<div class="control-group">
|
||||||
<label>{{i18n "discourse_ai.llms.display_name"}}</label>
|
<label>{{i18n "discourse_ai.llms.display_name"}}</label>
|
||||||
<Input
|
<Input
|
||||||
class="ai-llm-editor__display-name"
|
class="ai-llm-editor-input ai-llm-editor__display-name"
|
||||||
@type="text"
|
@type="text"
|
||||||
@value={{@model.display_name}}
|
@value={{@model.display_name}}
|
||||||
/>
|
/>
|
||||||
|
@ -69,7 +71,7 @@ export default class AiLlmEditor extends Component {
|
||||||
<div class="control-group">
|
<div class="control-group">
|
||||||
<label>{{i18n "discourse_ai.llms.name"}}</label>
|
<label>{{i18n "discourse_ai.llms.name"}}</label>
|
||||||
<Input
|
<Input
|
||||||
class="ai-llm-editor__name"
|
class="ai-llm-editor-input ai-llm-editor__name"
|
||||||
@type="text"
|
@type="text"
|
||||||
@value={{@model.name}}
|
@value={{@model.name}}
|
||||||
/>
|
/>
|
||||||
|
@ -85,6 +87,22 @@ export default class AiLlmEditor extends Component {
|
||||||
@content={{this.selectedProviders}}
|
@content={{this.selectedProviders}}
|
||||||
/>
|
/>
|
||||||
</div>
|
</div>
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{I18n.t "discourse_ai.llms.url"}}</label>
|
||||||
|
<Input
|
||||||
|
class="ai-llm-editor-input ai-llm-editor__url"
|
||||||
|
@type="text"
|
||||||
|
@value={{@model.url}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
|
<div class="control-group">
|
||||||
|
<label>{{I18n.t "discourse_ai.llms.api_key"}}</label>
|
||||||
|
<Input
|
||||||
|
class="ai-llm-editor-input ai-llm-editor__api-key"
|
||||||
|
@type="text"
|
||||||
|
@value={{@model.api_key}}
|
||||||
|
/>
|
||||||
|
</div>
|
||||||
<div class="control-group">
|
<div class="control-group">
|
||||||
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
|
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
|
||||||
<ComboBox
|
<ComboBox
|
||||||
|
@ -96,7 +114,7 @@ export default class AiLlmEditor extends Component {
|
||||||
<label>{{i18n "discourse_ai.llms.max_prompt_tokens"}}</label>
|
<label>{{i18n "discourse_ai.llms.max_prompt_tokens"}}</label>
|
||||||
<Input
|
<Input
|
||||||
@type="number"
|
@type="number"
|
||||||
class="ai-llm-editor__max-prompt-tokens"
|
class="ai-llm-editor-input ai-llm-editor__max-prompt-tokens"
|
||||||
step="any"
|
step="any"
|
||||||
min="0"
|
min="0"
|
||||||
lang="en"
|
lang="en"
|
||||||
|
|
|
@ -29,4 +29,8 @@
|
||||||
text-align: center;
|
text-align: center;
|
||||||
font-size: var(--font-up-1);
|
font-size: var(--font-up-1);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
.ai-llm-editor-input {
|
||||||
|
width: 350px;
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
|
@ -204,6 +204,8 @@ en:
|
||||||
provider: "Service hosting the model:"
|
provider: "Service hosting the model:"
|
||||||
tokenizer: "Tokenizer:"
|
tokenizer: "Tokenizer:"
|
||||||
max_prompt_tokens: "Number of tokens for the prompt:"
|
max_prompt_tokens: "Number of tokens for the prompt:"
|
||||||
|
url: "URL of the service hosting the model:"
|
||||||
|
api_key: "API Key of the service hosting the model:"
|
||||||
save: "Save"
|
save: "Save"
|
||||||
saved: "LLM Model Saved"
|
saved: "LLM Model Saved"
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,8 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
|
||||||
|
class AddEndpointDataToLlmModel < ActiveRecord::Migration[7.0]
|
||||||
|
def change
|
||||||
|
add_column :llm_models, :url, :string
|
||||||
|
add_column :llm_models, :api_key, :string
|
||||||
|
end
|
||||||
|
end
|
|
@ -8,14 +8,14 @@ module DiscourseAi
|
||||||
def can_translate?(model_name)
|
def can_translate?(model_name)
|
||||||
model_name.starts_with?("gpt-")
|
model_name.starts_with?("gpt-")
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
|
end
|
||||||
|
|
||||||
def native_tool_support?
|
def native_tool_support?
|
||||||
true
|
true
|
||||||
end
|
end
|
||||||
|
@ -30,7 +30,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||||
|
|
||||||
# provide a buffer of 120 tokens - our function counting is not
|
# provide a buffer of 120 tokens - our function counting is not
|
||||||
# 100% accurate and getting numbers to align exactly is very hard
|
# 100% accurate and getting numbers to align exactly is very hard
|
||||||
|
@ -38,7 +38,7 @@ module DiscourseAi
|
||||||
|
|
||||||
if tools.present?
|
if tools.present?
|
||||||
# note this is about 100 tokens over, OpenAI have a more optimal representation
|
# note this is about 100 tokens over, OpenAI have a more optimal representation
|
||||||
@function_size ||= self.class.tokenizer.size(tools.to_json.to_s)
|
@function_size ||= self.tokenizer.size(tools.to_json.to_s)
|
||||||
buffer += @function_size
|
buffer += @function_size
|
||||||
end
|
end
|
||||||
|
|
||||||
|
@ -113,7 +113,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def calculate_message_token(context)
|
def calculate_message_token(context)
|
||||||
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||||
end
|
end
|
||||||
|
|
||||||
def model_max_tokens
|
def model_max_tokens
|
||||||
|
|
|
@ -10,10 +10,6 @@ module DiscourseAi
|
||||||
model_name,
|
model_name,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::AnthropicTokenizer
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
class ClaudePrompt
|
class ClaudePrompt
|
||||||
|
@ -26,6 +22,10 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::AnthropicTokenizer
|
||||||
|
end
|
||||||
|
|
||||||
def translate
|
def translate
|
||||||
messages = super
|
messages = super
|
||||||
|
|
||||||
|
@ -50,7 +50,8 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||||
|
|
||||||
# Longer term it will have over 1 million
|
# Longer term it will have over 1 million
|
||||||
200_000 # Claude-3 has a 200k context window for now
|
200_000 # Claude-3 has a 200k context window for now
|
||||||
end
|
end
|
||||||
|
|
|
@ -10,14 +10,14 @@ module DiscourseAi
|
||||||
def can_translate?(model_name)
|
def can_translate?(model_name)
|
||||||
%w[command-light command command-r command-r-plus].include?(model_name)
|
%w[command-light command command-r command-r-plus].include?(model_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
|
end
|
||||||
|
|
||||||
def translate
|
def translate
|
||||||
messages = super
|
messages = super
|
||||||
|
|
||||||
|
@ -38,7 +38,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||||
|
|
||||||
case model_name
|
case model_name
|
||||||
when "command-light"
|
when "command-light"
|
||||||
|
@ -61,7 +61,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def calculate_message_token(context)
|
def calculate_message_token(context)
|
||||||
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||||
end
|
end
|
||||||
|
|
||||||
def system_msg(msg)
|
def system_msg(msg)
|
||||||
|
|
|
@ -20,10 +20,6 @@ module DiscourseAi
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
|
|
||||||
def available_tokenizers
|
|
||||||
all_dialects.map(&:tokenizer)
|
|
||||||
end
|
|
||||||
|
|
||||||
def dialect_for(model_name)
|
def dialect_for(model_name)
|
||||||
dialects = []
|
dialects = []
|
||||||
|
|
||||||
|
@ -38,20 +34,21 @@ module DiscourseAi
|
||||||
|
|
||||||
dialect
|
dialect
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
raise NotImplemented
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(generic_prompt, model_name, opts: {})
|
def initialize(generic_prompt, model_name, opts: {}, llm_model: nil)
|
||||||
@prompt = generic_prompt
|
@prompt = generic_prompt
|
||||||
@model_name = model_name
|
@model_name = model_name
|
||||||
@opts = opts
|
@opts = opts
|
||||||
|
@llm_model = llm_model
|
||||||
end
|
end
|
||||||
|
|
||||||
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
raise NotImplemented
|
||||||
|
end
|
||||||
|
|
||||||
def can_end_with_assistant_msg?
|
def can_end_with_assistant_msg?
|
||||||
false
|
false
|
||||||
end
|
end
|
||||||
|
@ -88,7 +85,7 @@ module DiscourseAi
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
attr_reader :model_name, :opts
|
attr_reader :model_name, :opts, :llm_model
|
||||||
|
|
||||||
def trim_messages(messages)
|
def trim_messages(messages)
|
||||||
prompt_limit = max_prompt_tokens
|
prompt_limit = max_prompt_tokens
|
||||||
|
@ -147,7 +144,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def calculate_message_token(msg)
|
def calculate_message_token(msg)
|
||||||
self.class.tokenizer.size(msg[:content].to_s)
|
self.tokenizer.size(msg[:content].to_s)
|
||||||
end
|
end
|
||||||
|
|
||||||
def tools_dialect
|
def tools_dialect
|
||||||
|
|
|
@ -8,10 +8,10 @@ module DiscourseAi
|
||||||
def can_translate?(model_name)
|
def can_translate?(model_name)
|
||||||
model_name == "fake"
|
model_name == "fake"
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer
|
DiscourseAi::Tokenizer::OpenAiTokenizer
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def translate
|
def translate
|
||||||
|
|
|
@ -8,16 +8,16 @@ module DiscourseAi
|
||||||
def can_translate?(model_name)
|
def can_translate?(model_name)
|
||||||
%w[gemini-pro gemini-1.5-pro].include?(model_name)
|
%w[gemini-pro gemini-1.5-pro].include?(model_name)
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer
|
|
||||||
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def native_tool_support?
|
def native_tool_support?
|
||||||
true
|
true
|
||||||
end
|
end
|
||||||
|
|
||||||
|
def tokenizer
|
||||||
|
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
|
||||||
|
end
|
||||||
|
|
||||||
def translate
|
def translate
|
||||||
# Gemini complains if we don't alternate model/user roles.
|
# Gemini complains if we don't alternate model/user roles.
|
||||||
noop_model_response = { role: "model", parts: { text: "Ok." } }
|
noop_model_response = { role: "model", parts: { text: "Ok." } }
|
||||||
|
@ -68,7 +68,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||||
|
|
||||||
if model_name == "gemini-1.5-pro"
|
if model_name == "gemini-1.5-pro"
|
||||||
# technically we support 1 million tokens, but we're being conservative
|
# technically we support 1 million tokens, but we're being conservative
|
||||||
|
@ -81,7 +81,7 @@ module DiscourseAi
|
||||||
protected
|
protected
|
||||||
|
|
||||||
def calculate_message_token(context)
|
def calculate_message_token(context)
|
||||||
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
|
||||||
end
|
end
|
||||||
|
|
||||||
def system_msg(msg)
|
def system_msg(msg)
|
||||||
|
|
|
@ -12,10 +12,10 @@ module DiscourseAi
|
||||||
mistral
|
mistral
|
||||||
].include?(model_name)
|
].include?(model_name)
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::MixtralTokenizer
|
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::MixtralTokenizer
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def tools
|
def tools
|
||||||
|
@ -23,7 +23,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||||
|
|
||||||
32_000
|
32_000
|
||||||
end
|
end
|
||||||
|
|
|
@ -8,10 +8,10 @@ module DiscourseAi
|
||||||
def can_translate?(_model_name)
|
def can_translate?(_model_name)
|
||||||
true
|
true
|
||||||
end
|
end
|
||||||
|
end
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
DiscourseAi::Tokenizer::Llama3Tokenizer
|
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::Llama3Tokenizer
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
def tools
|
def tools
|
||||||
|
@ -19,7 +19,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
|
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
|
||||||
|
|
||||||
32_000
|
32_000
|
||||||
end
|
end
|
||||||
|
|
|
@ -63,7 +63,9 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
@uri ||= URI("https://api.anthropic.com/v1/messages")
|
url = llm_model&.url || "https://api.anthropic.com/v1/messages"
|
||||||
|
|
||||||
|
URI(url)
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, dialect)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
|
@ -78,7 +80,7 @@ module DiscourseAi
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
headers = {
|
headers = {
|
||||||
"anthropic-version" => "2023-06-01",
|
"anthropic-version" => "2023-06-01",
|
||||||
"x-api-key" => SiteSetting.ai_anthropic_api_key,
|
"x-api-key" => llm_model&.api_key || SiteSetting.ai_anthropic_api_key,
|
||||||
"content-type" => "application/json",
|
"content-type" => "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -71,7 +71,8 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
api_url =
|
api_url =
|
||||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
|
llm_model&.url ||
|
||||||
|
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
|
||||||
|
|
||||||
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
|
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
|
||||||
|
|
||||||
|
@ -91,7 +92,7 @@ module DiscourseAi
|
||||||
Aws::Sigv4::Signer.new(
|
Aws::Sigv4::Signer.new(
|
||||||
access_key_id: SiteSetting.ai_bedrock_access_key_id,
|
access_key_id: SiteSetting.ai_bedrock_access_key_id,
|
||||||
region: SiteSetting.ai_bedrock_region,
|
region: SiteSetting.ai_bedrock_region,
|
||||||
secret_access_key: SiteSetting.ai_bedrock_secret_access_key,
|
secret_access_key: llm_model&.api_key || SiteSetting.ai_bedrock_secret_access_key,
|
||||||
service: "bedrock",
|
service: "bedrock",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -60,9 +60,10 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(model_name, tokenizer)
|
def initialize(model_name, tokenizer, llm_model: nil)
|
||||||
@model = model_name
|
@model = model_name
|
||||||
@tokenizer = tokenizer
|
@tokenizer = tokenizer
|
||||||
|
@llm_model = llm_model
|
||||||
end
|
end
|
||||||
|
|
||||||
def native_tool_support?
|
def native_tool_support?
|
||||||
|
@ -70,7 +71,11 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def use_ssl?
|
def use_ssl?
|
||||||
true
|
if model_uri&.scheme.present?
|
||||||
|
model_uri.scheme == "https"
|
||||||
|
else
|
||||||
|
true
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk)
|
def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk)
|
||||||
|
@ -294,7 +299,7 @@ module DiscourseAi
|
||||||
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
tokenizer.size(extract_prompt_for_tokenizer(prompt))
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_reader :tokenizer, :model
|
attr_reader :tokenizer, :model, :llm_model
|
||||||
|
|
||||||
protected
|
protected
|
||||||
|
|
||||||
|
|
|
@ -42,7 +42,9 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
URI("https://api.cohere.ai/v1/chat")
|
url = llm_model&.url || "https://api.cohere.ai/v1/chat"
|
||||||
|
|
||||||
|
URI(url)
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, dialect)
|
def prepare_payload(prompt, model_params, dialect)
|
||||||
|
@ -56,7 +58,7 @@ module DiscourseAi
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
headers = {
|
headers = {
|
||||||
"Content-Type" => "application/json",
|
"Content-Type" => "application/json",
|
||||||
"Authorization" => "Bearer #{SiteSetting.ai_cohere_api_key}",
|
"Authorization" => "Bearer #{llm_model&.api_key || SiteSetting.ai_cohere_api_key}",
|
||||||
}
|
}
|
||||||
|
|
||||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
|
|
|
@ -51,9 +51,16 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
mapped_model = model == "gemini-1.5-pro" ? "gemini-1.5-pro-latest" : model
|
if llm_model
|
||||||
url =
|
url = llm_model.url
|
||||||
"https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{SiteSetting.ai_gemini_api_key}"
|
else
|
||||||
|
mapped_model = model == "gemini-1.5-pro" ? "gemini-1.5-pro-latest" : model
|
||||||
|
url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}"
|
||||||
|
end
|
||||||
|
|
||||||
|
key = llm_model&.api_key || SiteSetting.ai_gemini_api_key
|
||||||
|
|
||||||
|
url = "#{url}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{key}"
|
||||||
|
|
||||||
URI(url)
|
URI(url)
|
||||||
end
|
end
|
||||||
|
|
|
@ -44,7 +44,7 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
URI(SiteSetting.ai_hugging_face_api_url)
|
URI(llm_model&.url || SiteSetting.ai_hugging_face_api_url)
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, _dialect)
|
def prepare_payload(prompt, model_params, _dialect)
|
||||||
|
@ -53,7 +53,8 @@ module DiscourseAi
|
||||||
.merge(messages: prompt)
|
.merge(messages: prompt)
|
||||||
.tap do |payload|
|
.tap do |payload|
|
||||||
if !payload[:max_tokens]
|
if !payload[:max_tokens]
|
||||||
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
|
token_limit =
|
||||||
|
llm_model&.max_prompt_tokens || SiteSetting.ai_hugging_face_token_limit
|
||||||
|
|
||||||
payload[:max_tokens] = token_limit - prompt_size(prompt)
|
payload[:max_tokens] = token_limit - prompt_size(prompt)
|
||||||
end
|
end
|
||||||
|
@ -63,11 +64,11 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
|
api_key = llm_model&.api_key || SiteSetting.ai_hugging_face_api_key
|
||||||
|
|
||||||
headers =
|
headers =
|
||||||
{ "Content-Type" => "application/json" }.tap do |h|
|
{ "Content-Type" => "application/json" }.tap do |h|
|
||||||
if SiteSetting.ai_hugging_face_api_key.present?
|
h["Authorization"] = "Bearer #{api_key}" if api_key.present?
|
||||||
h["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}"
|
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
|
|
|
@ -48,7 +48,7 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
URI("#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
|
URI(llm_model&.url || "#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
|
||||||
end
|
end
|
||||||
|
|
||||||
def prepare_payload(prompt, model_params, _dialect)
|
def prepare_payload(prompt, model_params, _dialect)
|
||||||
|
|
|
@ -78,6 +78,8 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
|
return URI(llm_model.url) if llm_model&.url
|
||||||
|
|
||||||
url =
|
url =
|
||||||
if model.include?("gpt-4")
|
if model.include?("gpt-4")
|
||||||
if model.include?("32k")
|
if model.include?("32k")
|
||||||
|
@ -115,10 +117,12 @@ module DiscourseAi
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
headers = { "Content-Type" => "application/json" }
|
headers = { "Content-Type" => "application/json" }
|
||||||
|
|
||||||
|
api_key = llm_model&.api_key || SiteSetting.ai_openai_api_key
|
||||||
|
|
||||||
if model_uri.host.include?("azure")
|
if model_uri.host.include?("azure")
|
||||||
headers["api-key"] = SiteSetting.ai_openai_api_key
|
headers["api-key"] = api_key
|
||||||
else
|
else
|
||||||
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
|
headers["Authorization"] = "Bearer #{api_key}"
|
||||||
end
|
end
|
||||||
|
|
||||||
if SiteSetting.ai_openai_organization.present?
|
if SiteSetting.ai_openai_organization.present?
|
||||||
|
|
|
@ -44,6 +44,8 @@ module DiscourseAi
|
||||||
private
|
private
|
||||||
|
|
||||||
def model_uri
|
def model_uri
|
||||||
|
return URI(llm_model.url) if llm_model&.url
|
||||||
|
|
||||||
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
|
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
|
||||||
if service.present?
|
if service.present?
|
||||||
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
|
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
|
||||||
|
@ -63,7 +65,8 @@ module DiscourseAi
|
||||||
def prepare_request(payload)
|
def prepare_request(payload)
|
||||||
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
|
||||||
|
|
||||||
headers["X-API-KEY"] = SiteSetting.ai_vllm_api_key if SiteSetting.ai_vllm_api_key.present?
|
api_key = llm_model&.api_key || SiteSetting.ai_vllm_api_key
|
||||||
|
headers["X-API-KEY"] = api_key if api_key.present?
|
||||||
|
|
||||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||||
end
|
end
|
||||||
|
|
|
@ -25,7 +25,7 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def tokenizer_names
|
def tokenizer_names
|
||||||
DiscourseAi::Completions::Dialects::Dialect.available_tokenizers.map(&:name).uniq
|
DiscourseAi::Tokenizer::BasicTokenizer.available_llm_tokenizers.map(&:name)
|
||||||
end
|
end
|
||||||
|
|
||||||
def models_by_provider
|
def models_by_provider
|
||||||
|
@ -91,17 +91,16 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def proxy(model_name)
|
def proxy(model_name)
|
||||||
# We are in the process of transitioning to always use objects here.
|
|
||||||
# We'll live with this hack for a while.
|
|
||||||
provider_and_model_name = model_name.split(":")
|
provider_and_model_name = model_name.split(":")
|
||||||
provider_name = provider_and_model_name.first
|
provider_name = provider_and_model_name.first
|
||||||
model_name_without_prov = provider_and_model_name[1..].join
|
model_name_without_prov = provider_and_model_name[1..].join
|
||||||
is_custom_model = provider_name == "custom"
|
|
||||||
|
|
||||||
if is_custom_model
|
# We are in the process of transitioning to always use objects here.
|
||||||
|
# We'll live with this hack for a while.
|
||||||
|
if provider_name == "custom"
|
||||||
llm_model = LlmModel.find(model_name_without_prov)
|
llm_model = LlmModel.find(model_name_without_prov)
|
||||||
provider_name = llm_model.provider
|
raise UNKNOWN_MODEL if !llm_model
|
||||||
model_name_without_prov = llm_model.name
|
return proxy_from_obj(llm_model)
|
||||||
end
|
end
|
||||||
|
|
||||||
dialect_klass =
|
dialect_klass =
|
||||||
|
@ -111,24 +110,32 @@ module DiscourseAi
|
||||||
if @canned_llm && @canned_llm != model_name
|
if @canned_llm && @canned_llm != model_name
|
||||||
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
|
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
|
||||||
end
|
end
|
||||||
return new(dialect_klass, nil, model_name, opts: { gateway: @canned_response })
|
|
||||||
end
|
|
||||||
|
|
||||||
opts = {}
|
return new(dialect_klass, nil, model_name, gateway: @canned_response)
|
||||||
opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model
|
end
|
||||||
|
|
||||||
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||||
|
|
||||||
new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts)
|
new(dialect_klass, gateway_klass, model_name_without_prov)
|
||||||
|
end
|
||||||
|
|
||||||
|
def proxy_from_obj(llm_model)
|
||||||
|
provider_name = llm_model.provider
|
||||||
|
model_name = llm_model.name
|
||||||
|
|
||||||
|
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
|
||||||
|
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||||
|
|
||||||
|
new(dialect_klass, gateway_klass, model_name, llm_model: llm_model)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def initialize(dialect_klass, gateway_klass, model_name, opts: {})
|
def initialize(dialect_klass, gateway_klass, model_name, gateway: nil, llm_model: nil)
|
||||||
@dialect_klass = dialect_klass
|
@dialect_klass = dialect_klass
|
||||||
@gateway_klass = gateway_klass
|
@gateway_klass = gateway_klass
|
||||||
@model_name = model_name
|
@model_name = model_name
|
||||||
@gateway = opts[:gateway]
|
@gateway = gateway
|
||||||
@max_prompt_tokens = opts[:max_prompt_tokens]
|
@llm_model = llm_model
|
||||||
end
|
end
|
||||||
|
|
||||||
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
|
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
|
||||||
|
@ -185,13 +192,9 @@ module DiscourseAi
|
||||||
|
|
||||||
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
|
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
|
||||||
|
|
||||||
gateway = @gateway || gateway_klass.new(model_name, dialect_klass.tokenizer)
|
dialect = dialect_klass.new(prompt, model_name, opts: model_params, llm_model: llm_model)
|
||||||
dialect =
|
|
||||||
dialect_klass.new(
|
gateway = @gateway || gateway_klass.new(model_name, dialect.tokenizer, llm_model: llm_model)
|
||||||
prompt,
|
|
||||||
model_name,
|
|
||||||
opts: model_params.merge(max_prompt_tokens: @max_prompt_tokens),
|
|
||||||
)
|
|
||||||
gateway.perform_completion!(
|
gateway.perform_completion!(
|
||||||
dialect,
|
dialect,
|
||||||
user,
|
user,
|
||||||
|
@ -202,18 +205,20 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
def max_prompt_tokens
|
def max_prompt_tokens
|
||||||
return @max_prompt_tokens if @max_prompt_tokens.present?
|
llm_model&.max_prompt_tokens ||
|
||||||
|
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
|
||||||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
|
|
||||||
end
|
end
|
||||||
|
|
||||||
delegate :tokenizer, to: :dialect_klass
|
def tokenizer
|
||||||
|
llm_model&.tokenizer_class ||
|
||||||
|
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).tokenizer
|
||||||
|
end
|
||||||
|
|
||||||
attr_reader :model_name
|
attr_reader :model_name
|
||||||
|
|
||||||
private
|
private
|
||||||
|
|
||||||
attr_reader :dialect_klass, :gateway_klass
|
attr_reader :dialect_klass, :gateway_klass, :llm_model
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -18,19 +18,16 @@ module DiscourseAi
|
||||||
model_name_without_prov = provider_and_model_name[1..].join
|
model_name_without_prov = provider_and_model_name[1..].join
|
||||||
is_custom_model = provider_name == "custom"
|
is_custom_model = provider_name == "custom"
|
||||||
|
|
||||||
if is_custom_model
|
# Bypass setting validations for custom models. They don't rely on site settings.
|
||||||
llm_model = LlmModel.find(model_name_without_prov)
|
if !is_custom_model
|
||||||
provider_name = llm_model.provider
|
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
||||||
model_name_without_prov = llm_model.name
|
|
||||||
end
|
|
||||||
|
|
||||||
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
|
return false if endpoint.nil?
|
||||||
|
|
||||||
return false if endpoint.nil?
|
if !endpoint.correctly_configured?(model_name_without_prov)
|
||||||
|
@endpoint = endpoint
|
||||||
if !endpoint.correctly_configured?(model_name_without_prov)
|
return false
|
||||||
@endpoint = endpoint
|
end
|
||||||
return false
|
|
||||||
end
|
end
|
||||||
|
|
||||||
if !can_talk_to_model?(val)
|
if !can_talk_to_model?(val)
|
||||||
|
|
|
@ -4,6 +4,15 @@ module DiscourseAi
|
||||||
module Tokenizer
|
module Tokenizer
|
||||||
class BasicTokenizer
|
class BasicTokenizer
|
||||||
class << self
|
class << self
|
||||||
|
def available_llm_tokenizers
|
||||||
|
[
|
||||||
|
DiscourseAi::Tokenizer::AnthropicTokenizer,
|
||||||
|
DiscourseAi::Tokenizer::Llama3Tokenizer,
|
||||||
|
DiscourseAi::Tokenizer::MixtralTokenizer,
|
||||||
|
DiscourseAi::Tokenizer::OpenAiTokenizer,
|
||||||
|
]
|
||||||
|
end
|
||||||
|
|
||||||
def tokenizer
|
def tokenizer
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
end
|
end
|
||||||
|
|
|
@ -1,12 +0,0 @@
|
||||||
# frozen_string_literal: true
|
|
||||||
|
|
||||||
module DiscourseAi
|
|
||||||
module Tokenizer
|
|
||||||
class Llama2Tokenizer < BasicTokenizer
|
|
||||||
def self.tokenizer
|
|
||||||
@@tokenizer ||=
|
|
||||||
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/llama-2-70b-chat-hf.json")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
|
@ -7,7 +7,7 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect
|
||||||
trim_messages(messages)
|
trim_messages(messages)
|
||||||
end
|
end
|
||||||
|
|
||||||
def self.tokenizer
|
def tokenizer
|
||||||
Class.new do
|
Class.new do
|
||||||
def self.size(str)
|
def self.size(str)
|
||||||
str.length
|
str.length
|
||||||
|
|
|
@ -6,9 +6,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
|
||||||
DiscourseAi::Completions::Dialects::Mistral,
|
DiscourseAi::Completions::Dialects::Mistral,
|
||||||
canned_response,
|
canned_response,
|
||||||
"hugging_face:Upstage-Llama-2-*-instruct-v2",
|
"hugging_face:Upstage-Llama-2-*-instruct-v2",
|
||||||
opts: {
|
gateway: canned_response,
|
||||||
gateway: canned_response,
|
|
||||||
},
|
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -126,23 +126,6 @@ describe DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer do
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
describe DiscourseAi::Tokenizer::Llama2Tokenizer do
|
|
||||||
describe "#size" do
|
|
||||||
describe "returns a token count" do
|
|
||||||
it "for a sentence with punctuation and capitalization and numbers" do
|
|
||||||
expect(described_class.size("Hello, World! 123")).to eq(9)
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
describe "#truncate" do
|
|
||||||
it "truncates a sentence" do
|
|
||||||
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
|
|
||||||
expect(described_class.truncate(sentence, 3)).to eq("foo bar")
|
|
||||||
end
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
describe DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer do
|
describe DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer do
|
||||||
describe "#size" do
|
describe "#size" do
|
||||||
describe "returns a token count" do
|
describe "returns a token count" do
|
||||||
|
|
File diff suppressed because it is too large
Load Diff
Loading…
Reference in New Issue