FEATURE: Store provider-specific parameters. (#686)
Previously, we stored request parameters like the OpenAI organization and Bedrock's access key and region as site settings. This change stores them in the `llm_models` table instead, letting us drop more settings while also becoming more flexible.
This commit is contained in:
parent
1d5fa0ce6c
commit
f622e2644f
|
@ -16,6 +16,7 @@ module DiscourseAi
|
|||
root: false,
|
||||
).as_json,
|
||||
meta: {
|
||||
provider_params: LlmModel.provider_params,
|
||||
presets: DiscourseAi::Completions::Llm.presets,
|
||||
providers: DiscourseAi::Completions::Llm.provider_names,
|
||||
tokenizers:
|
||||
|
@ -43,7 +44,7 @@ module DiscourseAi
|
|||
def update
|
||||
llm_model = LlmModel.find(params[:id])
|
||||
|
||||
if llm_model.update(ai_llm_params)
|
||||
if llm_model.update(ai_llm_params(updating: llm_model))
|
||||
llm_model.toggle_companion_user
|
||||
render json: llm_model
|
||||
else
|
||||
|
@ -90,17 +91,32 @@ module DiscourseAi
|
|||
|
||||
private
|
||||
|
||||
def ai_llm_params
|
||||
params.require(:ai_llm).permit(
|
||||
:display_name,
|
||||
:name,
|
||||
:provider,
|
||||
:tokenizer,
|
||||
:max_prompt_tokens,
|
||||
:url,
|
||||
:api_key,
|
||||
:enabled_chat_bot,
|
||||
)
|
||||
def ai_llm_params(updating: nil)
|
||||
permitted =
|
||||
params.require(:ai_llm).permit(
|
||||
:display_name,
|
||||
:name,
|
||||
:provider,
|
||||
:tokenizer,
|
||||
:max_prompt_tokens,
|
||||
:api_key,
|
||||
:enabled_chat_bot,
|
||||
)
|
||||
|
||||
provider = updating ? updating.provider : permitted[:provider]
|
||||
permit_url =
|
||||
(updating && updating.url != LlmModel::RESERVED_VLLM_SRV_URL) ||
|
||||
provider != LlmModel::BEDROCK_PROVIDER_NAME
|
||||
|
||||
permitted[:url] = params.dig(:ai_llm, :url) if permit_url
|
||||
|
||||
extra_field_names = LlmModel.provider_params.dig(provider&.to_sym, :fields).to_a
|
||||
received_prov_params = params.dig(:ai_llm, :provider_params)
|
||||
permitted[:provider_params] = received_prov_params.slice(
|
||||
*extra_field_names,
|
||||
).permit! if !extra_field_names.empty? && received_prov_params.present?
|
||||
|
||||
permitted
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -3,6 +3,7 @@
|
|||
class LlmModel < ActiveRecord::Base
|
||||
FIRST_BOT_USER_ID = -1200
|
||||
RESERVED_VLLM_SRV_URL = "https://vllm.shadowed-by-srv.invalid"
|
||||
BEDROCK_PROVIDER_NAME = "aws_bedrock"
|
||||
|
||||
belongs_to :user
|
||||
|
||||
|
@ -28,6 +29,19 @@ class LlmModel < ActiveRecord::Base
|
|||
end
|
||||
end
|
||||
|
||||
def self.provider_params
|
||||
{
|
||||
aws_bedrock: {
|
||||
url_editable: false,
|
||||
fields: %i[access_key_id region],
|
||||
},
|
||||
open_ai: {
|
||||
url_editable: true,
|
||||
fields: %i[organization],
|
||||
},
|
||||
}
|
||||
end
|
||||
|
||||
def toggle_companion_user_before_save
|
||||
toggle_companion_user if enabled_chat_bot_changed? || new_record?
|
||||
end
|
||||
|
@ -77,6 +91,10 @@ class LlmModel < ActiveRecord::Base
|
|||
def tokenizer_class
|
||||
tokenizer.constantize
|
||||
end
|
||||
|
||||
def lookup_custom_param(key)
|
||||
provider_params&.dig(key)
|
||||
end
|
||||
end
|
||||
|
||||
# == Schema Information
|
||||
|
|
|
@ -12,11 +12,12 @@ class LlmModelSerializer < ApplicationSerializer
|
|||
:api_key,
|
||||
:url,
|
||||
:enabled_chat_bot,
|
||||
:url_editable
|
||||
:shadowed_by_srv,
|
||||
:provider_params
|
||||
|
||||
has_one :user, serializer: BasicUserSerializer, embed: :object
|
||||
|
||||
def url_editable
|
||||
object.url != LlmModel::RESERVED_VLLM_SRV_URL
|
||||
def shadowed_by_srv
|
||||
object.url == LlmModel::RESERVED_VLLM_SRV_URL
|
||||
end
|
||||
end
|
||||
|
|
|
@ -12,7 +12,8 @@ export default class AiLlm extends RestModel {
|
|||
"max_prompt_tokens",
|
||||
"url",
|
||||
"api_key",
|
||||
"enabled_chat_bot"
|
||||
"enabled_chat_bot",
|
||||
"provider_params"
|
||||
);
|
||||
}
|
||||
|
||||
|
|
|
@ -1,12 +1,12 @@
|
|||
import Component from "@glimmer/component";
|
||||
import { tracked } from "@glimmer/tracking";
|
||||
import { Input } from "@ember/component";
|
||||
import { concat, get } from "@ember/helper";
|
||||
import { on } from "@ember/modifier";
|
||||
import { action } from "@ember/object";
|
||||
import { action, computed } from "@ember/object";
|
||||
import { LinkTo } from "@ember/routing";
|
||||
import { later } from "@ember/runloop";
|
||||
import { inject as service } from "@ember/service";
|
||||
import { or } from "truth-helpers";
|
||||
import DButton from "discourse/components/d-button";
|
||||
import DToggleSwitch from "discourse/components/d-toggle-switch";
|
||||
import Avatar from "discourse/helpers/bound-avatar-template";
|
||||
|
@ -30,6 +30,14 @@ export default class AiLlmEditorForm extends Component {
|
|||
@tracked testError = null;
|
||||
@tracked apiKeySecret = true;
|
||||
|
||||
didReceiveAttrs() {
|
||||
super.didReceiveAttrs(...arguments);
|
||||
|
||||
if (!this.args.model.provider_params) {
|
||||
this.populateProviderParams(this.args.model.provider);
|
||||
}
|
||||
}
|
||||
|
||||
get selectedProviders() {
|
||||
const t = (provName) => {
|
||||
return I18n.t(`discourse_ai.llms.providers.${provName}`);
|
||||
|
@ -44,6 +52,35 @@ export default class AiLlmEditorForm extends Component {
|
|||
return AdminUser.create(this.args.model?.user);
|
||||
}
|
||||
|
||||
get testErrorMessage() {
|
||||
return I18n.t("discourse_ai.llms.tests.failure", { error: this.testError });
|
||||
}
|
||||
|
||||
get displayTestResult() {
|
||||
return this.testRunning || this.testResult !== null;
|
||||
}
|
||||
|
||||
get displaySRVWarning() {
|
||||
return this.args.model.shadowed_by_srv && !this.args.model.isNew;
|
||||
}
|
||||
|
||||
get canEditURL() {
|
||||
// Explicitly false.
|
||||
if (this.metaProviderParams.url_editable === false) {
|
||||
return false;
|
||||
}
|
||||
|
||||
return !this.args.model.shadowed_by_srv || this.args.model.isNew;
|
||||
}
|
||||
|
||||
@computed("args.model.provider")
|
||||
get metaProviderParams() {
|
||||
return (
|
||||
this.args.llms.resultSetMeta.provider_params[this.args.model.provider] ||
|
||||
{}
|
||||
);
|
||||
}
|
||||
|
||||
@action
|
||||
async save() {
|
||||
this.isSaving = true;
|
||||
|
@ -94,14 +131,6 @@ export default class AiLlmEditorForm extends Component {
|
|||
}
|
||||
}
|
||||
|
||||
get testErrorMessage() {
|
||||
return I18n.t("discourse_ai.llms.tests.failure", { error: this.testError });
|
||||
}
|
||||
|
||||
get displayTestResult() {
|
||||
return this.testRunning || this.testResult !== null;
|
||||
}
|
||||
|
||||
@action
|
||||
makeApiKeySecret() {
|
||||
this.apiKeySecret = true;
|
||||
|
@ -145,12 +174,12 @@ export default class AiLlmEditorForm extends Component {
|
|||
}
|
||||
|
||||
<template>
|
||||
{{#unless (or @model.url_editable @model.isNew)}}
|
||||
{{#if this.displaySRVWarning}}
|
||||
<div class="alert alert-info">
|
||||
{{icon "exclamation-circle"}}
|
||||
{{I18n.t "discourse_ai.llms.srv_warning"}}
|
||||
</div>
|
||||
{{/unless}}
|
||||
{{/if}}
|
||||
<form class="form-horizontal ai-llm-editor">
|
||||
<div class="control-group">
|
||||
<label>{{i18n "discourse_ai.llms.display_name"}}</label>
|
||||
|
@ -179,13 +208,14 @@ export default class AiLlmEditorForm extends Component {
|
|||
@content={{this.selectedProviders}}
|
||||
/>
|
||||
</div>
|
||||
{{#if (or @model.url_editable @model.isNew)}}
|
||||
{{#if this.canEditURL}}
|
||||
<div class="control-group">
|
||||
<label>{{I18n.t "discourse_ai.llms.url"}}</label>
|
||||
<Input
|
||||
class="ai-llm-editor-input ai-llm-editor__url"
|
||||
@type="text"
|
||||
@value={{@model.url}}
|
||||
required="true"
|
||||
/>
|
||||
</div>
|
||||
{{/if}}
|
||||
|
@ -196,11 +226,24 @@ export default class AiLlmEditorForm extends Component {
|
|||
@value={{@model.api_key}}
|
||||
class="ai-llm-editor-input ai-llm-editor__api-key"
|
||||
@type={{if this.apiKeySecret "password" "text"}}
|
||||
required="true"
|
||||
{{on "focusout" this.makeApiKeySecret}}
|
||||
/>
|
||||
<DButton @action={{this.toggleApiKeySecret}} @icon="far-eye-slash" />
|
||||
</div>
|
||||
</div>
|
||||
{{#each this.metaProviderParams.fields as |field|}}
|
||||
<div class="control-group">
|
||||
<label>{{I18n.t
|
||||
(concat "discourse_ai.llms.provider_fields." field)
|
||||
}}</label>
|
||||
<Input
|
||||
@type="text"
|
||||
@value={{mut (get @model.provider_params field)}}
|
||||
class="ai-llm-editor-input ai-llm-editor__{{field}}"
|
||||
/>
|
||||
</div>
|
||||
{{/each}}
|
||||
<div class="control-group">
|
||||
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
|
||||
<ComboBox
|
||||
|
@ -217,6 +260,7 @@ export default class AiLlmEditorForm extends Component {
|
|||
min="0"
|
||||
lang="en"
|
||||
@value={{@model.max_prompt_tokens}}
|
||||
required="true"
|
||||
/>
|
||||
<DTooltip
|
||||
@icon="question-circle"
|
||||
|
|
|
@ -231,6 +231,11 @@ en:
|
|||
azure: "Azure"
|
||||
ollama: "Ollama"
|
||||
|
||||
provider_fields:
|
||||
access_key_id: "AWS Bedrock Access key ID"
|
||||
region: "AWS Bedrock Region"
|
||||
organization: "Optional OpenAI Organization ID"
|
||||
|
||||
related_topics:
|
||||
title: "Related Topics"
|
||||
pill: "Related"
|
||||
|
|
|
@ -0,0 +1,6 @@
|
|||
# frozen_string_literal: true
|
||||
class LlmModelCustomParams < ActiveRecord::Migration[7.1]
|
||||
def change
|
||||
add_column :llm_models, :provider_params, :jsonb
|
||||
end
|
||||
end
|
|
@ -0,0 +1,41 @@
|
|||
# frozen_string_literal: true
|
||||
class AddProviderSpecificParamsToLlmModels < ActiveRecord::Migration[7.1]
|
||||
def up
|
||||
open_ai_organization = fetch_setting("ai_openai_organization")
|
||||
|
||||
DB.exec(<<~SQL, organization: open_ai_organization) if open_ai_organization
|
||||
UPDATE llm_models
|
||||
SET provider_params = jsonb_build_object('organization', :organization)
|
||||
WHERE provider = 'open_ai' AND provider_params IS NULL
|
||||
SQL
|
||||
|
||||
bedrock_region = fetch_setting("ai_bedrock_region") || "us-east-1"
|
||||
bedrock_access_key_id = fetch_setting("ai_bedrock_access_key_id")
|
||||
|
||||
DB.exec(<<~SQL, key_id: bedrock_access_key_id, region: bedrock_region) if bedrock_access_key_id
|
||||
UPDATE llm_models
|
||||
SET
|
||||
provider_params = jsonb_build_object('access_key_id', :key_id, 'region', :region),
|
||||
name = CASE name WHEN 'claude-2' THEN 'anthropic.claude-v2:1'
|
||||
WHEN 'claude-3-haiku' THEN 'anthropic.claude-3-haiku-20240307-v1:0'
|
||||
WHEN 'claude-3-sonnet' THEN 'anthropic.claude-3-sonnet-20240229-v1:0'
|
||||
WHEN 'claude-instant-1' THEN 'anthropic.claude-instant-v1'
|
||||
WHEN 'claude-3-opus' THEN 'anthropic.claude-3-opus-20240229-v1:0'
|
||||
WHEN 'claude-3-5-sonnet' THEN 'anthropic.claude-3-5-sonnet-20240620-v1:0'
|
||||
ELSE name
|
||||
END
|
||||
WHERE provider = 'aws_bedrock' AND provider_params IS NULL
|
||||
SQL
|
||||
end
|
||||
|
||||
def fetch_setting(name)
|
||||
DB.query_single(
|
||||
"SELECT value FROM site_settings WHERE name = :setting_name",
|
||||
setting_name: name,
|
||||
).first
|
||||
end
|
||||
|
||||
def down
|
||||
raise ActiveRecord::IrreversibleMigration
|
||||
end
|
||||
end
|
|
@ -62,31 +62,37 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def model_uri
|
||||
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
#
|
||||
# FYI there is a 2.0 version of Claude, very little need to support it given
|
||||
# haiku/sonnet are better fits anyway, we map to claude-2.1
|
||||
bedrock_model_id =
|
||||
case model
|
||||
when "claude-2"
|
||||
"anthropic.claude-v2:1"
|
||||
when "claude-3-haiku"
|
||||
"anthropic.claude-3-haiku-20240307-v1:0"
|
||||
when "claude-3-sonnet"
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
when "claude-instant-1"
|
||||
"anthropic.claude-instant-v1"
|
||||
when "claude-3-opus"
|
||||
"anthropic.claude-3-opus-20240229-v1:0"
|
||||
when "claude-3-5-sonnet"
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
else
|
||||
model
|
||||
end
|
||||
if llm_model
|
||||
region = llm_model.lookup_custom_param("region")
|
||||
|
||||
api_url =
|
||||
llm_model&.url ||
|
||||
api_url =
|
||||
"https://bedrock-runtime.#{region}.amazonaws.com/model/#{llm_model.name}/invoke"
|
||||
else
|
||||
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
|
||||
#
|
||||
# FYI there is a 2.0 version of Claude, very little need to support it given
|
||||
# haiku/sonnet are better fits anyway, we map to claude-2.1
|
||||
bedrock_model_id =
|
||||
case model
|
||||
when "claude-2"
|
||||
"anthropic.claude-v2:1"
|
||||
when "claude-3-haiku"
|
||||
"anthropic.claude-3-haiku-20240307-v1:0"
|
||||
when "claude-3-sonnet"
|
||||
"anthropic.claude-3-sonnet-20240229-v1:0"
|
||||
when "claude-instant-1"
|
||||
"anthropic.claude-instant-v1"
|
||||
when "claude-3-opus"
|
||||
"anthropic.claude-3-opus-20240229-v1:0"
|
||||
when "claude-3-5-sonnet"
|
||||
"anthropic.claude-3-5-sonnet-20240620-v1:0"
|
||||
else
|
||||
model
|
||||
end
|
||||
|
||||
api_url =
|
||||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
|
||||
end
|
||||
|
||||
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
|
||||
|
||||
|
@ -108,8 +114,10 @@ module DiscourseAi
|
|||
|
||||
signer =
|
||||
Aws::Sigv4::Signer.new(
|
||||
access_key_id: SiteSetting.ai_bedrock_access_key_id,
|
||||
region: SiteSetting.ai_bedrock_region,
|
||||
access_key_id:
|
||||
llm_model&.lookup_custom_param("access_key_id") ||
|
||||
SiteSetting.ai_bedrock_access_key_id,
|
||||
region: llm_model&.lookup_custom_param("region") || SiteSetting.ai_bedrock_region,
|
||||
secret_access_key: llm_model&.api_key || SiteSetting.ai_bedrock_secret_access_key,
|
||||
service: "bedrock",
|
||||
)
|
||||
|
|
|
@ -128,9 +128,9 @@ module DiscourseAi
|
|||
headers["Authorization"] = "Bearer #{api_key}"
|
||||
end
|
||||
|
||||
if SiteSetting.ai_openai_organization.present?
|
||||
headers["OpenAI-Organization"] = SiteSetting.ai_openai_organization
|
||||
end
|
||||
org_id =
|
||||
llm_model&.lookup_custom_param("organization") || SiteSetting.ai_openai_organization
|
||||
headers["OpenAI-Organization"] = org_id if org_id.present?
|
||||
|
||||
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
|
||||
end
|
||||
|
|
|
@ -41,6 +41,34 @@ RSpec.describe DiscourseAi::Admin::AiLlmsController do
|
|||
expect(created_model.tokenizer).to eq(valid_attrs[:tokenizer])
|
||||
expect(created_model.max_prompt_tokens).to eq(valid_attrs[:max_prompt_tokens])
|
||||
end
|
||||
|
||||
it "stores provider-specific config params" do
|
||||
provider_params = { organization: "Discourse" }
|
||||
|
||||
post "/admin/plugins/discourse-ai/ai-llms.json",
|
||||
params: {
|
||||
ai_llm: valid_attrs.merge(provider_params: provider_params),
|
||||
}
|
||||
|
||||
created_model = LlmModel.last
|
||||
|
||||
expect(created_model.lookup_custom_param("organization")).to eq(
|
||||
provider_params[:organization],
|
||||
)
|
||||
end
|
||||
|
||||
it "ignores parameters not associated with that provider" do
|
||||
provider_params = { access_key_id: "random_key" }
|
||||
|
||||
post "/admin/plugins/discourse-ai/ai-llms.json",
|
||||
params: {
|
||||
ai_llm: valid_attrs.merge(provider_params: provider_params),
|
||||
}
|
||||
|
||||
created_model = LlmModel.last
|
||||
|
||||
expect(created_model.lookup_custom_param("access_key_id")).to be_nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
|
@ -66,6 +94,36 @@ RSpec.describe DiscourseAi::Admin::AiLlmsController do
|
|||
expect(response.status).to eq(404)
|
||||
end
|
||||
end
|
||||
|
||||
context "with provider-specific params" do
|
||||
it "updates provider-specific config params" do
|
||||
provider_params = { organization: "Discourse" }
|
||||
|
||||
put "/admin/plugins/discourse-ai/ai-llms/#{llm_model.id}.json",
|
||||
params: {
|
||||
ai_llm: {
|
||||
provider_params: provider_params,
|
||||
},
|
||||
}
|
||||
|
||||
expect(llm_model.reload.lookup_custom_param("organization")).to eq(
|
||||
provider_params[:organization],
|
||||
)
|
||||
end
|
||||
|
||||
it "ignores parameters not associated with that provider" do
|
||||
provider_params = { access_key_id: "random_key" }
|
||||
|
||||
put "/admin/plugins/discourse-ai/ai-llms/#{llm_model.id}.json",
|
||||
params: {
|
||||
ai_llm: {
|
||||
provider_params: provider_params,
|
||||
},
|
||||
}
|
||||
|
||||
expect(llm_model.reload.lookup_custom_param("access_key_id")).to be_nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
describe "GET #test" do
|
||||
|
|
Loading…
Reference in New Issue