FEATURE: Store provider-specific parameters. (#686)

Previously, we stored request parameters like the OpenAI organization and Bedrock's access key and region as site settings. This change stores them in the `llm_models` table instead, letting us drop more settings while also becoming more flexible.
This commit is contained in:
Roman Rizzi 2024-06-24 19:26:30 -03:00 committed by GitHub
parent 1d5fa0ce6c
commit f622e2644f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 255 additions and 57 deletions

View File

@ -16,6 +16,7 @@ module DiscourseAi
root: false, root: false,
).as_json, ).as_json,
meta: { meta: {
provider_params: LlmModel.provider_params,
presets: DiscourseAi::Completions::Llm.presets, presets: DiscourseAi::Completions::Llm.presets,
providers: DiscourseAi::Completions::Llm.provider_names, providers: DiscourseAi::Completions::Llm.provider_names,
tokenizers: tokenizers:
@ -43,7 +44,7 @@ module DiscourseAi
def update def update
llm_model = LlmModel.find(params[:id]) llm_model = LlmModel.find(params[:id])
if llm_model.update(ai_llm_params) if llm_model.update(ai_llm_params(updating: llm_model))
llm_model.toggle_companion_user llm_model.toggle_companion_user
render json: llm_model render json: llm_model
else else
@ -90,17 +91,32 @@ module DiscourseAi
private private
def ai_llm_params def ai_llm_params(updating: nil)
permitted =
params.require(:ai_llm).permit( params.require(:ai_llm).permit(
:display_name, :display_name,
:name, :name,
:provider, :provider,
:tokenizer, :tokenizer,
:max_prompt_tokens, :max_prompt_tokens,
:url,
:api_key, :api_key,
:enabled_chat_bot, :enabled_chat_bot,
) )
provider = updating ? updating.provider : permitted[:provider]
permit_url =
(updating && updating.url != LlmModel::RESERVED_VLLM_SRV_URL) ||
provider != LlmModel::BEDROCK_PROVIDER_NAME
permitted[:url] = params.dig(:ai_llm, :url) if permit_url
extra_field_names = LlmModel.provider_params.dig(provider&.to_sym, :fields).to_a
received_prov_params = params.dig(:ai_llm, :provider_params)
permitted[:provider_params] = received_prov_params.slice(
*extra_field_names,
).permit! if !extra_field_names.empty? && received_prov_params.present?
permitted
end end
end end
end end

View File

@ -3,6 +3,7 @@
class LlmModel < ActiveRecord::Base class LlmModel < ActiveRecord::Base
FIRST_BOT_USER_ID = -1200 FIRST_BOT_USER_ID = -1200
RESERVED_VLLM_SRV_URL = "https://vllm.shadowed-by-srv.invalid" RESERVED_VLLM_SRV_URL = "https://vllm.shadowed-by-srv.invalid"
BEDROCK_PROVIDER_NAME = "aws_bedrock"
belongs_to :user belongs_to :user
@ -28,6 +29,19 @@ class LlmModel < ActiveRecord::Base
end end
end end
def self.provider_params
{
aws_bedrock: {
url_editable: false,
fields: %i[access_key_id region],
},
open_ai: {
url_editable: true,
fields: %i[organization],
},
}
end
def toggle_companion_user_before_save def toggle_companion_user_before_save
toggle_companion_user if enabled_chat_bot_changed? || new_record? toggle_companion_user if enabled_chat_bot_changed? || new_record?
end end
@ -77,6 +91,10 @@ class LlmModel < ActiveRecord::Base
def tokenizer_class def tokenizer_class
tokenizer.constantize tokenizer.constantize
end end
def lookup_custom_param(key)
provider_params&.dig(key)
end
end end
# == Schema Information # == Schema Information

View File

@ -12,11 +12,12 @@ class LlmModelSerializer < ApplicationSerializer
:api_key, :api_key,
:url, :url,
:enabled_chat_bot, :enabled_chat_bot,
:url_editable :shadowed_by_srv,
:provider_params
has_one :user, serializer: BasicUserSerializer, embed: :object has_one :user, serializer: BasicUserSerializer, embed: :object
def url_editable def shadowed_by_srv
object.url != LlmModel::RESERVED_VLLM_SRV_URL object.url == LlmModel::RESERVED_VLLM_SRV_URL
end end
end end

View File

@ -12,7 +12,8 @@ export default class AiLlm extends RestModel {
"max_prompt_tokens", "max_prompt_tokens",
"url", "url",
"api_key", "api_key",
"enabled_chat_bot" "enabled_chat_bot",
"provider_params"
); );
} }

View File

@ -1,12 +1,12 @@
import Component from "@glimmer/component"; import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking"; import { tracked } from "@glimmer/tracking";
import { Input } from "@ember/component"; import { Input } from "@ember/component";
import { concat, get } from "@ember/helper";
import { on } from "@ember/modifier"; import { on } from "@ember/modifier";
import { action } from "@ember/object"; import { action, computed } from "@ember/object";
import { LinkTo } from "@ember/routing"; import { LinkTo } from "@ember/routing";
import { later } from "@ember/runloop"; import { later } from "@ember/runloop";
import { inject as service } from "@ember/service"; import { inject as service } from "@ember/service";
import { or } from "truth-helpers";
import DButton from "discourse/components/d-button"; import DButton from "discourse/components/d-button";
import DToggleSwitch from "discourse/components/d-toggle-switch"; import DToggleSwitch from "discourse/components/d-toggle-switch";
import Avatar from "discourse/helpers/bound-avatar-template"; import Avatar from "discourse/helpers/bound-avatar-template";
@ -30,6 +30,14 @@ export default class AiLlmEditorForm extends Component {
@tracked testError = null; @tracked testError = null;
@tracked apiKeySecret = true; @tracked apiKeySecret = true;
didReceiveAttrs() {
super.didReceiveAttrs(...arguments);
if (!this.args.model.provider_params) {
this.populateProviderParams(this.args.model.provider);
}
}
get selectedProviders() { get selectedProviders() {
const t = (provName) => { const t = (provName) => {
return I18n.t(`discourse_ai.llms.providers.${provName}`); return I18n.t(`discourse_ai.llms.providers.${provName}`);
@ -44,6 +52,35 @@ export default class AiLlmEditorForm extends Component {
return AdminUser.create(this.args.model?.user); return AdminUser.create(this.args.model?.user);
} }
get testErrorMessage() {
return I18n.t("discourse_ai.llms.tests.failure", { error: this.testError });
}
get displayTestResult() {
return this.testRunning || this.testResult !== null;
}
get displaySRVWarning() {
return this.args.model.shadowed_by_srv && !this.args.model.isNew;
}
get canEditURL() {
// Explicitly false.
if (this.metaProviderParams.url_editable === false) {
return false;
}
return !this.args.model.shadowed_by_srv || this.args.model.isNew;
}
@computed("args.model.provider")
get metaProviderParams() {
return (
this.args.llms.resultSetMeta.provider_params[this.args.model.provider] ||
{}
);
}
@action @action
async save() { async save() {
this.isSaving = true; this.isSaving = true;
@ -94,14 +131,6 @@ export default class AiLlmEditorForm extends Component {
} }
} }
get testErrorMessage() {
return I18n.t("discourse_ai.llms.tests.failure", { error: this.testError });
}
get displayTestResult() {
return this.testRunning || this.testResult !== null;
}
@action @action
makeApiKeySecret() { makeApiKeySecret() {
this.apiKeySecret = true; this.apiKeySecret = true;
@ -145,12 +174,12 @@ export default class AiLlmEditorForm extends Component {
} }
<template> <template>
{{#unless (or @model.url_editable @model.isNew)}} {{#if this.displaySRVWarning}}
<div class="alert alert-info"> <div class="alert alert-info">
{{icon "exclamation-circle"}} {{icon "exclamation-circle"}}
{{I18n.t "discourse_ai.llms.srv_warning"}} {{I18n.t "discourse_ai.llms.srv_warning"}}
</div> </div>
{{/unless}} {{/if}}
<form class="form-horizontal ai-llm-editor"> <form class="form-horizontal ai-llm-editor">
<div class="control-group"> <div class="control-group">
<label>{{i18n "discourse_ai.llms.display_name"}}</label> <label>{{i18n "discourse_ai.llms.display_name"}}</label>
@ -179,13 +208,14 @@ export default class AiLlmEditorForm extends Component {
@content={{this.selectedProviders}} @content={{this.selectedProviders}}
/> />
</div> </div>
{{#if (or @model.url_editable @model.isNew)}} {{#if this.canEditURL}}
<div class="control-group"> <div class="control-group">
<label>{{I18n.t "discourse_ai.llms.url"}}</label> <label>{{I18n.t "discourse_ai.llms.url"}}</label>
<Input <Input
class="ai-llm-editor-input ai-llm-editor__url" class="ai-llm-editor-input ai-llm-editor__url"
@type="text" @type="text"
@value={{@model.url}} @value={{@model.url}}
required="true"
/> />
</div> </div>
{{/if}} {{/if}}
@ -196,11 +226,24 @@ export default class AiLlmEditorForm extends Component {
@value={{@model.api_key}} @value={{@model.api_key}}
class="ai-llm-editor-input ai-llm-editor__api-key" class="ai-llm-editor-input ai-llm-editor__api-key"
@type={{if this.apiKeySecret "password" "text"}} @type={{if this.apiKeySecret "password" "text"}}
required="true"
{{on "focusout" this.makeApiKeySecret}} {{on "focusout" this.makeApiKeySecret}}
/> />
<DButton @action={{this.toggleApiKeySecret}} @icon="far-eye-slash" /> <DButton @action={{this.toggleApiKeySecret}} @icon="far-eye-slash" />
</div> </div>
</div> </div>
{{#each this.metaProviderParams.fields as |field|}}
<div class="control-group">
<label>{{I18n.t
(concat "discourse_ai.llms.provider_fields." field)
}}</label>
<Input
@type="text"
@value={{mut (get @model.provider_params field)}}
class="ai-llm-editor-input ai-llm-editor__{{field}}"
/>
</div>
{{/each}}
<div class="control-group"> <div class="control-group">
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label> <label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
<ComboBox <ComboBox
@ -217,6 +260,7 @@ export default class AiLlmEditorForm extends Component {
min="0" min="0"
lang="en" lang="en"
@value={{@model.max_prompt_tokens}} @value={{@model.max_prompt_tokens}}
required="true"
/> />
<DTooltip <DTooltip
@icon="question-circle" @icon="question-circle"

View File

@ -231,6 +231,11 @@ en:
azure: "Azure" azure: "Azure"
ollama: "Ollama" ollama: "Ollama"
provider_fields:
access_key_id: "AWS Bedrock Access key ID"
region: "AWS Bedrock Region"
organization: "Optional OpenAI Organization ID"
related_topics: related_topics:
title: "Related Topics" title: "Related Topics"
pill: "Related" pill: "Related"

View File

@ -0,0 +1,6 @@
# frozen_string_literal: true
class LlmModelCustomParams < ActiveRecord::Migration[7.1]
def change
add_column :llm_models, :provider_params, :jsonb
end
end

View File

@ -0,0 +1,41 @@
# frozen_string_literal: true
class AddProviderSpecificParamsToLlmModels < ActiveRecord::Migration[7.1]
def up
open_ai_organization = fetch_setting("ai_openai_organization")
DB.exec(<<~SQL, organization: open_ai_organization) if open_ai_organization
UPDATE llm_models
SET provider_params = jsonb_build_object('organization', :organization)
WHERE provider = 'open_ai' AND provider_params IS NULL
SQL
bedrock_region = fetch_setting("ai_bedrock_region") || "us-east-1"
bedrock_access_key_id = fetch_setting("ai_bedrock_access_key_id")
DB.exec(<<~SQL, key_id: bedrock_access_key_id, region: bedrock_region) if bedrock_access_key_id
UPDATE llm_models
SET
provider_params = jsonb_build_object('access_key_id', :key_id, 'region', :region),
name = CASE name WHEN 'claude-2' THEN 'anthropic.claude-v2:1'
WHEN 'claude-3-haiku' THEN 'anthropic.claude-3-haiku-20240307-v1:0'
WHEN 'claude-3-sonnet' THEN 'anthropic.claude-3-sonnet-20240229-v1:0'
WHEN 'claude-instant-1' THEN 'anthropic.claude-instant-v1'
WHEN 'claude-3-opus' THEN 'anthropic.claude-3-opus-20240229-v1:0'
WHEN 'claude-3-5-sonnet' THEN 'anthropic.claude-3-5-sonnet-20240620-v1:0'
ELSE name
END
WHERE provider = 'aws_bedrock' AND provider_params IS NULL
SQL
end
def fetch_setting(name)
DB.query_single(
"SELECT value FROM site_settings WHERE name = :setting_name",
setting_name: name,
).first
end
def down
raise ActiveRecord::IrreversibleMigration
end
end

View File

@ -62,6 +62,12 @@ module DiscourseAi
end end
def model_uri def model_uri
if llm_model
region = llm_model.lookup_custom_param("region")
api_url =
"https://bedrock-runtime.#{region}.amazonaws.com/model/#{llm_model.name}/invoke"
else
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html # See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
# #
# FYI there is a 2.0 version of Claude, very little need to support it given # FYI there is a 2.0 version of Claude, very little need to support it given
@ -85,8 +91,8 @@ module DiscourseAi
end end
api_url = api_url =
llm_model&.url ||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke" "https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
end
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
@ -108,8 +114,10 @@ module DiscourseAi
signer = signer =
Aws::Sigv4::Signer.new( Aws::Sigv4::Signer.new(
access_key_id: SiteSetting.ai_bedrock_access_key_id, access_key_id:
region: SiteSetting.ai_bedrock_region, llm_model&.lookup_custom_param("access_key_id") ||
SiteSetting.ai_bedrock_access_key_id,
region: llm_model&.lookup_custom_param("region") || SiteSetting.ai_bedrock_region,
secret_access_key: llm_model&.api_key || SiteSetting.ai_bedrock_secret_access_key, secret_access_key: llm_model&.api_key || SiteSetting.ai_bedrock_secret_access_key,
service: "bedrock", service: "bedrock",
) )

View File

@ -128,9 +128,9 @@ module DiscourseAi
headers["Authorization"] = "Bearer #{api_key}" headers["Authorization"] = "Bearer #{api_key}"
end end
if SiteSetting.ai_openai_organization.present? org_id =
headers["OpenAI-Organization"] = SiteSetting.ai_openai_organization llm_model&.lookup_custom_param("organization") || SiteSetting.ai_openai_organization
end headers["OpenAI-Organization"] = org_id if org_id.present?
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload } Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end end

View File

@ -41,6 +41,34 @@ RSpec.describe DiscourseAi::Admin::AiLlmsController do
expect(created_model.tokenizer).to eq(valid_attrs[:tokenizer]) expect(created_model.tokenizer).to eq(valid_attrs[:tokenizer])
expect(created_model.max_prompt_tokens).to eq(valid_attrs[:max_prompt_tokens]) expect(created_model.max_prompt_tokens).to eq(valid_attrs[:max_prompt_tokens])
end end
it "stores provider-specific config params" do
provider_params = { organization: "Discourse" }
post "/admin/plugins/discourse-ai/ai-llms.json",
params: {
ai_llm: valid_attrs.merge(provider_params: provider_params),
}
created_model = LlmModel.last
expect(created_model.lookup_custom_param("organization")).to eq(
provider_params[:organization],
)
end
it "ignores parameters not associated with that provider" do
provider_params = { access_key_id: "random_key" }
post "/admin/plugins/discourse-ai/ai-llms.json",
params: {
ai_llm: valid_attrs.merge(provider_params: provider_params),
}
created_model = LlmModel.last
expect(created_model.lookup_custom_param("access_key_id")).to be_nil
end
end end
end end
@ -66,6 +94,36 @@ RSpec.describe DiscourseAi::Admin::AiLlmsController do
expect(response.status).to eq(404) expect(response.status).to eq(404)
end end
end end
context "with provider-specific params" do
it "updates provider-specific config params" do
provider_params = { organization: "Discourse" }
put "/admin/plugins/discourse-ai/ai-llms/#{llm_model.id}.json",
params: {
ai_llm: {
provider_params: provider_params,
},
}
expect(llm_model.reload.lookup_custom_param("organization")).to eq(
provider_params[:organization],
)
end
it "ignores parameters not associated with that provider" do
provider_params = { access_key_id: "random_key" }
put "/admin/plugins/discourse-ai/ai-llms/#{llm_model.id}.json",
params: {
ai_llm: {
provider_params: provider_params,
},
}
expect(llm_model.reload.lookup_custom_param("access_key_id")).to be_nil
end
end
end end
describe "GET #test" do describe "GET #test" do