FEATURE: Set endpoint credentials directly from LlmModel. (#625)

* FEATURE: Set endpoint credentials directly from LlmModel.

Drop Llama2Tokenizer since we no longer use it.

* Allow http for custom LLMs

---------

Co-authored-by: Rafael Silva <xfalcox@gmail.com>
This commit is contained in:
Roman Rizzi 2024-05-16 09:50:22 -03:00 committed by GitHub
parent 255139056d
commit 1d786fbaaf
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
32 changed files with 184 additions and 93535 deletions

View File

@ -58,6 +58,8 @@ module DiscourseAi
:provider,
:tokenizer,
:max_prompt_tokens,
:url,
:api_key,
)
end
end

View File

@ -3,5 +3,5 @@
class LlmModelSerializer < ApplicationSerializer
root "llm"
attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer
attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer, :api_key, :url
end

View File

@ -3,11 +3,14 @@ import RestModel from "discourse/models/rest";
export default class AiLlm extends RestModel {
createProperties() {
return this.getProperties(
"id",
"display_name",
"name",
"provider",
"tokenizer",
"max_prompt_tokens"
"max_prompt_tokens",
"url",
"api_key"
);
}

View File

@ -33,7 +33,9 @@ export default class AiLlmEditor extends Component {
const isNew = this.args.model.isNew;
try {
await this.args.model.save();
const result = await this.args.model.save();
this.args.model.setProperties(result.responseJson.ai_persona);
if (isNew) {
this.args.llms.addObject(this.args.model);
@ -61,7 +63,7 @@ export default class AiLlmEditor extends Component {
<div class="control-group">
<label>{{i18n "discourse_ai.llms.display_name"}}</label>
<Input
class="ai-llm-editor__display-name"
class="ai-llm-editor-input ai-llm-editor__display-name"
@type="text"
@value={{@model.display_name}}
/>
@ -69,7 +71,7 @@ export default class AiLlmEditor extends Component {
<div class="control-group">
<label>{{i18n "discourse_ai.llms.name"}}</label>
<Input
class="ai-llm-editor__name"
class="ai-llm-editor-input ai-llm-editor__name"
@type="text"
@value={{@model.name}}
/>
@ -85,6 +87,22 @@ export default class AiLlmEditor extends Component {
@content={{this.selectedProviders}}
/>
</div>
<div class="control-group">
<label>{{I18n.t "discourse_ai.llms.url"}}</label>
<Input
class="ai-llm-editor-input ai-llm-editor__url"
@type="text"
@value={{@model.url}}
/>
</div>
<div class="control-group">
<label>{{I18n.t "discourse_ai.llms.api_key"}}</label>
<Input
class="ai-llm-editor-input ai-llm-editor__api-key"
@type="text"
@value={{@model.api_key}}
/>
</div>
<div class="control-group">
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
<ComboBox
@ -96,7 +114,7 @@ export default class AiLlmEditor extends Component {
<label>{{i18n "discourse_ai.llms.max_prompt_tokens"}}</label>
<Input
@type="number"
class="ai-llm-editor__max-prompt-tokens"
class="ai-llm-editor-input ai-llm-editor__max-prompt-tokens"
step="any"
min="0"
lang="en"

View File

@ -29,4 +29,8 @@
text-align: center;
font-size: var(--font-up-1);
}
.ai-llm-editor-input {
width: 350px;
}
}

View File

@ -204,6 +204,8 @@ en:
provider: "Service hosting the model:"
tokenizer: "Tokenizer:"
max_prompt_tokens: "Number of tokens for the prompt:"
url: "URL of the service hosting the model:"
api_key: "API Key of the service hosting the model:"
save: "Save"
saved: "LLM Model Saved"

View File

@ -0,0 +1,8 @@
# frozen_string_literal: true
class AddEndpointDataToLlmModel < ActiveRecord::Migration[7.0]
def change
add_column :llm_models, :url, :string
add_column :llm_models, :api_key, :string
end
end

View File

@ -8,14 +8,14 @@ module DiscourseAi
def can_translate?(model_name)
model_name.starts_with?("gpt-")
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
end
def native_tool_support?
true
end
@ -30,7 +30,7 @@ module DiscourseAi
end
def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
# provide a buffer of 120 tokens - our function counting is not
# 100% accurate and getting numbers to align exactly is very hard
@ -38,7 +38,7 @@ module DiscourseAi
if tools.present?
# note this is about 100 tokens over, OpenAI have a more optimal representation
@function_size ||= self.class.tokenizer.size(tools.to_json.to_s)
@function_size ||= self.tokenizer.size(tools.to_json.to_s)
buffer += @function_size
end
@ -113,7 +113,7 @@ module DiscourseAi
end
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
def model_max_tokens

View File

@ -10,10 +10,6 @@ module DiscourseAi
model_name,
)
end
def tokenizer
DiscourseAi::Tokenizer::AnthropicTokenizer
end
end
class ClaudePrompt
@ -26,6 +22,10 @@ module DiscourseAi
end
end
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::AnthropicTokenizer
end
def translate
messages = super
@ -50,7 +50,8 @@ module DiscourseAi
end
def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
# Longer term it will have over 1 million
200_000 # Claude-3 has a 200k context window for now
end

View File

@ -10,14 +10,14 @@ module DiscourseAi
def can_translate?(model_name)
%w[command-light command command-r command-r-plus].include?(model_name)
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer
end
def translate
messages = super
@ -38,7 +38,7 @@ module DiscourseAi
end
def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
case model_name
when "command-light"
@ -61,7 +61,7 @@ module DiscourseAi
end
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
def system_msg(msg)

View File

@ -20,10 +20,6 @@ module DiscourseAi
]
end
def available_tokenizers
all_dialects.map(&:tokenizer)
end
def dialect_for(model_name)
dialects = []
@ -38,19 +34,20 @@ module DiscourseAi
dialect
end
end
def initialize(generic_prompt, model_name, opts: {}, llm_model: nil)
@prompt = generic_prompt
@model_name = model_name
@opts = opts
@llm_model = llm_model
end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def tokenizer
raise NotImplemented
end
end
def initialize(generic_prompt, model_name, opts: {})
@prompt = generic_prompt
@model_name = model_name
@opts = opts
end
VALID_ID_REGEX = /\A[a-zA-Z0-9_]+\z/
def can_end_with_assistant_msg?
false
@ -88,7 +85,7 @@ module DiscourseAi
private
attr_reader :model_name, :opts
attr_reader :model_name, :opts, :llm_model
def trim_messages(messages)
prompt_limit = max_prompt_tokens
@ -147,7 +144,7 @@ module DiscourseAi
end
def calculate_message_token(msg)
self.class.tokenizer.size(msg[:content].to_s)
self.tokenizer.size(msg[:content].to_s)
end
def tools_dialect

View File

@ -8,11 +8,11 @@ module DiscourseAi
def can_translate?(model_name)
model_name == "fake"
end
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer
end
end
def translate
""

View File

@ -8,16 +8,16 @@ module DiscourseAi
def can_translate?(model_name)
%w[gemini-pro gemini-1.5-pro].include?(model_name)
end
def tokenizer
DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
end
end
def native_tool_support?
true
end
def tokenizer
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::OpenAiTokenizer ## TODO Replace with GeminiTokenizer
end
def translate
# Gemini complains if we don't alternate model/user roles.
noop_model_response = { role: "model", parts: { text: "Ok." } }
@ -68,7 +68,7 @@ module DiscourseAi
end
def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
if model_name == "gemini-1.5-pro"
# technically we support 1 million tokens, but we're being conservative
@ -81,7 +81,7 @@ module DiscourseAi
protected
def calculate_message_token(context)
self.class.tokenizer.size(context[:content].to_s + context[:name].to_s)
self.tokenizer.size(context[:content].to_s + context[:name].to_s)
end
def system_msg(msg)

View File

@ -12,10 +12,10 @@ module DiscourseAi
mistral
].include?(model_name)
end
end
def tokenizer
DiscourseAi::Tokenizer::MixtralTokenizer
end
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::MixtralTokenizer
end
def tools
@ -23,7 +23,7 @@ module DiscourseAi
end
def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
32_000
end

View File

@ -8,10 +8,10 @@ module DiscourseAi
def can_translate?(_model_name)
true
end
end
def tokenizer
DiscourseAi::Tokenizer::Llama3Tokenizer
end
llm_model&.tokenizer_class || DiscourseAi::Tokenizer::Llama3Tokenizer
end
def tools
@ -19,7 +19,7 @@ module DiscourseAi
end
def max_prompt_tokens
return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present?
return llm_model.max_prompt_tokens if llm_model&.max_prompt_tokens
32_000
end

View File

@ -63,7 +63,9 @@ module DiscourseAi
end
def model_uri
@uri ||= URI("https://api.anthropic.com/v1/messages")
url = llm_model&.url || "https://api.anthropic.com/v1/messages"
URI(url)
end
def prepare_payload(prompt, model_params, dialect)
@ -78,7 +80,7 @@ module DiscourseAi
def prepare_request(payload)
headers = {
"anthropic-version" => "2023-06-01",
"x-api-key" => SiteSetting.ai_anthropic_api_key,
"x-api-key" => llm_model&.api_key || SiteSetting.ai_anthropic_api_key,
"content-type" => "application/json",
}

View File

@ -71,6 +71,7 @@ module DiscourseAi
end
api_url =
llm_model&.url ||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
@ -91,7 +92,7 @@ module DiscourseAi
Aws::Sigv4::Signer.new(
access_key_id: SiteSetting.ai_bedrock_access_key_id,
region: SiteSetting.ai_bedrock_region,
secret_access_key: SiteSetting.ai_bedrock_secret_access_key,
secret_access_key: llm_model&.api_key || SiteSetting.ai_bedrock_secret_access_key,
service: "bedrock",
)

View File

@ -60,9 +60,10 @@ module DiscourseAi
end
end
def initialize(model_name, tokenizer)
def initialize(model_name, tokenizer, llm_model: nil)
@model = model_name
@tokenizer = tokenizer
@llm_model = llm_model
end
def native_tool_support?
@ -70,8 +71,12 @@ module DiscourseAi
end
def use_ssl?
if model_uri&.scheme.present?
model_uri.scheme == "https"
else
true
end
end
def perform_completion!(dialect, user, model_params = {}, feature_name: nil, &blk)
allow_tools = dialect.prompt.has_tools?
@ -294,7 +299,7 @@ module DiscourseAi
tokenizer.size(extract_prompt_for_tokenizer(prompt))
end
attr_reader :tokenizer, :model
attr_reader :tokenizer, :model, :llm_model
protected

View File

@ -42,7 +42,9 @@ module DiscourseAi
private
def model_uri
URI("https://api.cohere.ai/v1/chat")
url = llm_model&.url || "https://api.cohere.ai/v1/chat"
URI(url)
end
def prepare_payload(prompt, model_params, dialect)
@ -56,7 +58,7 @@ module DiscourseAi
def prepare_request(payload)
headers = {
"Content-Type" => "application/json",
"Authorization" => "Bearer #{SiteSetting.ai_cohere_api_key}",
"Authorization" => "Bearer #{llm_model&.api_key || SiteSetting.ai_cohere_api_key}",
}
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }

View File

@ -51,9 +51,16 @@ module DiscourseAi
private
def model_uri
if llm_model
url = llm_model.url
else
mapped_model = model == "gemini-1.5-pro" ? "gemini-1.5-pro-latest" : model
url =
"https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{SiteSetting.ai_gemini_api_key}"
url = "https://generativelanguage.googleapis.com/v1beta/models/#{mapped_model}"
end
key = llm_model&.api_key || SiteSetting.ai_gemini_api_key
url = "#{url}:#{@streaming_mode ? "streamGenerateContent" : "generateContent"}?key=#{key}"
URI(url)
end

View File

@ -44,7 +44,7 @@ module DiscourseAi
private
def model_uri
URI(SiteSetting.ai_hugging_face_api_url)
URI(llm_model&.url || SiteSetting.ai_hugging_face_api_url)
end
def prepare_payload(prompt, model_params, _dialect)
@ -53,7 +53,8 @@ module DiscourseAi
.merge(messages: prompt)
.tap do |payload|
if !payload[:max_tokens]
token_limit = SiteSetting.ai_hugging_face_token_limit || 4_000
token_limit =
llm_model&.max_prompt_tokens || SiteSetting.ai_hugging_face_token_limit
payload[:max_tokens] = token_limit - prompt_size(prompt)
end
@ -63,11 +64,11 @@ module DiscourseAi
end
def prepare_request(payload)
api_key = llm_model&.api_key || SiteSetting.ai_hugging_face_api_key
headers =
{ "Content-Type" => "application/json" }.tap do |h|
if SiteSetting.ai_hugging_face_api_key.present?
h["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}"
end
h["Authorization"] = "Bearer #{api_key}" if api_key.present?
end
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }

View File

@ -48,7 +48,7 @@ module DiscourseAi
private
def model_uri
URI("#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
URI(llm_model&.url || "#{SiteSetting.ai_ollama_endpoint}/v1/chat/completions")
end
def prepare_payload(prompt, model_params, _dialect)

View File

@ -78,6 +78,8 @@ module DiscourseAi
private
def model_uri
return URI(llm_model.url) if llm_model&.url
url =
if model.include?("gpt-4")
if model.include?("32k")
@ -115,10 +117,12 @@ module DiscourseAi
def prepare_request(payload)
headers = { "Content-Type" => "application/json" }
api_key = llm_model&.api_key || SiteSetting.ai_openai_api_key
if model_uri.host.include?("azure")
headers["api-key"] = SiteSetting.ai_openai_api_key
headers["api-key"] = api_key
else
headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}"
headers["Authorization"] = "Bearer #{api_key}"
end
if SiteSetting.ai_openai_organization.present?

View File

@ -44,6 +44,8 @@ module DiscourseAi
private
def model_uri
return URI(llm_model.url) if llm_model&.url
service = DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_vllm_endpoint_srv)
if service.present?
api_endpoint = "https://#{service.target}:#{service.port}/v1/chat/completions"
@ -63,7 +65,8 @@ module DiscourseAi
def prepare_request(payload)
headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" }
headers["X-API-KEY"] = SiteSetting.ai_vllm_api_key if SiteSetting.ai_vllm_api_key.present?
api_key = llm_model&.api_key || SiteSetting.ai_vllm_api_key
headers["X-API-KEY"] = api_key if api_key.present?
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end

View File

@ -25,7 +25,7 @@ module DiscourseAi
end
def tokenizer_names
DiscourseAi::Completions::Dialects::Dialect.available_tokenizers.map(&:name).uniq
DiscourseAi::Tokenizer::BasicTokenizer.available_llm_tokenizers.map(&:name)
end
def models_by_provider
@ -91,17 +91,16 @@ module DiscourseAi
end
def proxy(model_name)
# We are in the process of transitioning to always use objects here.
# We'll live with this hack for a while.
provider_and_model_name = model_name.split(":")
provider_name = provider_and_model_name.first
model_name_without_prov = provider_and_model_name[1..].join
is_custom_model = provider_name == "custom"
if is_custom_model
# We are in the process of transitioning to always use objects here.
# We'll live with this hack for a while.
if provider_name == "custom"
llm_model = LlmModel.find(model_name_without_prov)
provider_name = llm_model.provider
model_name_without_prov = llm_model.name
raise UNKNOWN_MODEL if !llm_model
return proxy_from_obj(llm_model)
end
dialect_klass =
@ -111,24 +110,32 @@ module DiscourseAi
if @canned_llm && @canned_llm != model_name
raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}"
end
return new(dialect_klass, nil, model_name, opts: { gateway: @canned_response })
end
opts = {}
opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model
return new(dialect_klass, nil, model_name, gateway: @canned_response)
end
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts)
new(dialect_klass, gateway_klass, model_name_without_prov)
end
def proxy_from_obj(llm_model)
provider_name = llm_model.provider
model_name = llm_model.name
dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name)
gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
new(dialect_klass, gateway_klass, model_name, llm_model: llm_model)
end
end
def initialize(dialect_klass, gateway_klass, model_name, opts: {})
def initialize(dialect_klass, gateway_klass, model_name, gateway: nil, llm_model: nil)
@dialect_klass = dialect_klass
@gateway_klass = gateway_klass
@model_name = model_name
@gateway = opts[:gateway]
@max_prompt_tokens = opts[:max_prompt_tokens]
@gateway = gateway
@llm_model = llm_model
end
# @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object
@ -185,13 +192,9 @@ module DiscourseAi
model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? }
gateway = @gateway || gateway_klass.new(model_name, dialect_klass.tokenizer)
dialect =
dialect_klass.new(
prompt,
model_name,
opts: model_params.merge(max_prompt_tokens: @max_prompt_tokens),
)
dialect = dialect_klass.new(prompt, model_name, opts: model_params, llm_model: llm_model)
gateway = @gateway || gateway_klass.new(model_name, dialect.tokenizer, llm_model: llm_model)
gateway.perform_completion!(
dialect,
user,
@ -202,18 +205,20 @@ module DiscourseAi
end
def max_prompt_tokens
return @max_prompt_tokens if @max_prompt_tokens.present?
llm_model&.max_prompt_tokens ||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens
end
delegate :tokenizer, to: :dialect_klass
def tokenizer
llm_model&.tokenizer_class ||
dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).tokenizer
end
attr_reader :model_name
private
attr_reader :dialect_klass, :gateway_klass
attr_reader :dialect_klass, :gateway_klass, :llm_model
end
end
end

View File

@ -18,12 +18,8 @@ module DiscourseAi
model_name_without_prov = provider_and_model_name[1..].join
is_custom_model = provider_name == "custom"
if is_custom_model
llm_model = LlmModel.find(model_name_without_prov)
provider_name = llm_model.provider
model_name_without_prov = llm_model.name
end
# Bypass setting validations for custom models. They don't rely on site settings.
if !is_custom_model
endpoint = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name)
return false if endpoint.nil?
@ -32,6 +28,7 @@ module DiscourseAi
@endpoint = endpoint
return false
end
end
if !can_talk_to_model?(val)
@unreachable = true

View File

@ -4,6 +4,15 @@ module DiscourseAi
module Tokenizer
class BasicTokenizer
class << self
def available_llm_tokenizers
[
DiscourseAi::Tokenizer::AnthropicTokenizer,
DiscourseAi::Tokenizer::Llama3Tokenizer,
DiscourseAi::Tokenizer::MixtralTokenizer,
DiscourseAi::Tokenizer::OpenAiTokenizer,
]
end
def tokenizer
raise NotImplementedError
end

View File

@ -1,12 +0,0 @@
# frozen_string_literal: true
module DiscourseAi
module Tokenizer
class Llama2Tokenizer < BasicTokenizer
def self.tokenizer
@@tokenizer ||=
Tokenizers.from_file("./plugins/discourse-ai/tokenizers/llama-2-70b-chat-hf.json")
end
end
end
end

View File

@ -7,7 +7,7 @@ class TestDialect < DiscourseAi::Completions::Dialects::Dialect
trim_messages(messages)
end
def self.tokenizer
def tokenizer
Class.new do
def self.size(str)
str.length

View File

@ -6,9 +6,7 @@ RSpec.describe DiscourseAi::Completions::Llm do
DiscourseAi::Completions::Dialects::Mistral,
canned_response,
"hugging_face:Upstage-Llama-2-*-instruct-v2",
opts: {
gateway: canned_response,
},
)
end

View File

@ -126,23 +126,6 @@ describe DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer do
end
end
describe DiscourseAi::Tokenizer::Llama2Tokenizer do
describe "#size" do
describe "returns a token count" do
it "for a sentence with punctuation and capitalization and numbers" do
expect(described_class.size("Hello, World! 123")).to eq(9)
end
end
end
describe "#truncate" do
it "truncates a sentence" do
sentence = "foo bar baz qux quux corge grault garply waldo fred plugh xyzzy thud"
expect(described_class.truncate(sentence, 3)).to eq("foo bar")
end
end
end
describe DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer do
describe "#size" do
describe "returns a token count" do

File diff suppressed because it is too large Load Diff