FEATURE: Store provider-specific parameters. (#686)

Previously, we stored request parameters like the OpenAI organization and Bedrock's access key and region as site settings. This change stores them in the `llm_models` table instead, letting us drop more settings while also becoming more flexible.
This commit is contained in:
Roman Rizzi 2024-06-24 19:26:30 -03:00 committed by GitHub
parent 1d5fa0ce6c
commit f622e2644f
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
11 changed files with 255 additions and 57 deletions

View File

@ -16,6 +16,7 @@ module DiscourseAi
root: false,
).as_json,
meta: {
provider_params: LlmModel.provider_params,
presets: DiscourseAi::Completions::Llm.presets,
providers: DiscourseAi::Completions::Llm.provider_names,
tokenizers:
@ -43,7 +44,7 @@ module DiscourseAi
def update
llm_model = LlmModel.find(params[:id])
if llm_model.update(ai_llm_params)
if llm_model.update(ai_llm_params(updating: llm_model))
llm_model.toggle_companion_user
render json: llm_model
else
@ -90,17 +91,32 @@ module DiscourseAi
private
def ai_llm_params
def ai_llm_params(updating: nil)
permitted =
params.require(:ai_llm).permit(
:display_name,
:name,
:provider,
:tokenizer,
:max_prompt_tokens,
:url,
:api_key,
:enabled_chat_bot,
)
provider = updating ? updating.provider : permitted[:provider]
permit_url =
(updating && updating.url != LlmModel::RESERVED_VLLM_SRV_URL) ||
provider != LlmModel::BEDROCK_PROVIDER_NAME
permitted[:url] = params.dig(:ai_llm, :url) if permit_url
extra_field_names = LlmModel.provider_params.dig(provider&.to_sym, :fields).to_a
received_prov_params = params.dig(:ai_llm, :provider_params)
permitted[:provider_params] = received_prov_params.slice(
*extra_field_names,
).permit! if !extra_field_names.empty? && received_prov_params.present?
permitted
end
end
end

View File

@ -3,6 +3,7 @@
class LlmModel < ActiveRecord::Base
FIRST_BOT_USER_ID = -1200
RESERVED_VLLM_SRV_URL = "https://vllm.shadowed-by-srv.invalid"
BEDROCK_PROVIDER_NAME = "aws_bedrock"
belongs_to :user
@ -28,6 +29,19 @@ class LlmModel < ActiveRecord::Base
end
end
def self.provider_params
{
aws_bedrock: {
url_editable: false,
fields: %i[access_key_id region],
},
open_ai: {
url_editable: true,
fields: %i[organization],
},
}
end
def toggle_companion_user_before_save
toggle_companion_user if enabled_chat_bot_changed? || new_record?
end
@ -77,6 +91,10 @@ class LlmModel < ActiveRecord::Base
def tokenizer_class
tokenizer.constantize
end
def lookup_custom_param(key)
provider_params&.dig(key)
end
end
# == Schema Information

View File

@ -12,11 +12,12 @@ class LlmModelSerializer < ApplicationSerializer
:api_key,
:url,
:enabled_chat_bot,
:url_editable
:shadowed_by_srv,
:provider_params
has_one :user, serializer: BasicUserSerializer, embed: :object
def url_editable
object.url != LlmModel::RESERVED_VLLM_SRV_URL
def shadowed_by_srv
object.url == LlmModel::RESERVED_VLLM_SRV_URL
end
end

View File

@ -12,7 +12,8 @@ export default class AiLlm extends RestModel {
"max_prompt_tokens",
"url",
"api_key",
"enabled_chat_bot"
"enabled_chat_bot",
"provider_params"
);
}

View File

@ -1,12 +1,12 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { Input } from "@ember/component";
import { concat, get } from "@ember/helper";
import { on } from "@ember/modifier";
import { action } from "@ember/object";
import { action, computed } from "@ember/object";
import { LinkTo } from "@ember/routing";
import { later } from "@ember/runloop";
import { inject as service } from "@ember/service";
import { or } from "truth-helpers";
import DButton from "discourse/components/d-button";
import DToggleSwitch from "discourse/components/d-toggle-switch";
import Avatar from "discourse/helpers/bound-avatar-template";
@ -30,6 +30,14 @@ export default class AiLlmEditorForm extends Component {
@tracked testError = null;
@tracked apiKeySecret = true;
didReceiveAttrs() {
super.didReceiveAttrs(...arguments);
if (!this.args.model.provider_params) {
this.populateProviderParams(this.args.model.provider);
}
}
get selectedProviders() {
const t = (provName) => {
return I18n.t(`discourse_ai.llms.providers.${provName}`);
@ -44,6 +52,35 @@ export default class AiLlmEditorForm extends Component {
return AdminUser.create(this.args.model?.user);
}
get testErrorMessage() {
return I18n.t("discourse_ai.llms.tests.failure", { error: this.testError });
}
get displayTestResult() {
return this.testRunning || this.testResult !== null;
}
get displaySRVWarning() {
return this.args.model.shadowed_by_srv && !this.args.model.isNew;
}
get canEditURL() {
// Explicitly false.
if (this.metaProviderParams.url_editable === false) {
return false;
}
return !this.args.model.shadowed_by_srv || this.args.model.isNew;
}
@computed("args.model.provider")
get metaProviderParams() {
return (
this.args.llms.resultSetMeta.provider_params[this.args.model.provider] ||
{}
);
}
@action
async save() {
this.isSaving = true;
@ -94,14 +131,6 @@ export default class AiLlmEditorForm extends Component {
}
}
get testErrorMessage() {
return I18n.t("discourse_ai.llms.tests.failure", { error: this.testError });
}
get displayTestResult() {
return this.testRunning || this.testResult !== null;
}
@action
makeApiKeySecret() {
this.apiKeySecret = true;
@ -145,12 +174,12 @@ export default class AiLlmEditorForm extends Component {
}
<template>
{{#unless (or @model.url_editable @model.isNew)}}
{{#if this.displaySRVWarning}}
<div class="alert alert-info">
{{icon "exclamation-circle"}}
{{I18n.t "discourse_ai.llms.srv_warning"}}
</div>
{{/unless}}
{{/if}}
<form class="form-horizontal ai-llm-editor">
<div class="control-group">
<label>{{i18n "discourse_ai.llms.display_name"}}</label>
@ -179,13 +208,14 @@ export default class AiLlmEditorForm extends Component {
@content={{this.selectedProviders}}
/>
</div>
{{#if (or @model.url_editable @model.isNew)}}
{{#if this.canEditURL}}
<div class="control-group">
<label>{{I18n.t "discourse_ai.llms.url"}}</label>
<Input
class="ai-llm-editor-input ai-llm-editor__url"
@type="text"
@value={{@model.url}}
required="true"
/>
</div>
{{/if}}
@ -196,11 +226,24 @@ export default class AiLlmEditorForm extends Component {
@value={{@model.api_key}}
class="ai-llm-editor-input ai-llm-editor__api-key"
@type={{if this.apiKeySecret "password" "text"}}
required="true"
{{on "focusout" this.makeApiKeySecret}}
/>
<DButton @action={{this.toggleApiKeySecret}} @icon="far-eye-slash" />
</div>
</div>
{{#each this.metaProviderParams.fields as |field|}}
<div class="control-group">
<label>{{I18n.t
(concat "discourse_ai.llms.provider_fields." field)
}}</label>
<Input
@type="text"
@value={{mut (get @model.provider_params field)}}
class="ai-llm-editor-input ai-llm-editor__{{field}}"
/>
</div>
{{/each}}
<div class="control-group">
<label>{{I18n.t "discourse_ai.llms.tokenizer"}}</label>
<ComboBox
@ -217,6 +260,7 @@ export default class AiLlmEditorForm extends Component {
min="0"
lang="en"
@value={{@model.max_prompt_tokens}}
required="true"
/>
<DTooltip
@icon="question-circle"

View File

@ -231,6 +231,11 @@ en:
azure: "Azure"
ollama: "Ollama"
provider_fields:
access_key_id: "AWS Bedrock Access key ID"
region: "AWS Bedrock Region"
organization: "Optional OpenAI Organization ID"
related_topics:
title: "Related Topics"
pill: "Related"

View File

@ -0,0 +1,6 @@
# frozen_string_literal: true
class LlmModelCustomParams < ActiveRecord::Migration[7.1]
def change
add_column :llm_models, :provider_params, :jsonb
end
end

View File

@ -0,0 +1,41 @@
# frozen_string_literal: true
class AddProviderSpecificParamsToLlmModels < ActiveRecord::Migration[7.1]
def up
open_ai_organization = fetch_setting("ai_openai_organization")
DB.exec(<<~SQL, organization: open_ai_organization) if open_ai_organization
UPDATE llm_models
SET provider_params = jsonb_build_object('organization', :organization)
WHERE provider = 'open_ai' AND provider_params IS NULL
SQL
bedrock_region = fetch_setting("ai_bedrock_region") || "us-east-1"
bedrock_access_key_id = fetch_setting("ai_bedrock_access_key_id")
DB.exec(<<~SQL, key_id: bedrock_access_key_id, region: bedrock_region) if bedrock_access_key_id
UPDATE llm_models
SET
provider_params = jsonb_build_object('access_key_id', :key_id, 'region', :region),
name = CASE name WHEN 'claude-2' THEN 'anthropic.claude-v2:1'
WHEN 'claude-3-haiku' THEN 'anthropic.claude-3-haiku-20240307-v1:0'
WHEN 'claude-3-sonnet' THEN 'anthropic.claude-3-sonnet-20240229-v1:0'
WHEN 'claude-instant-1' THEN 'anthropic.claude-instant-v1'
WHEN 'claude-3-opus' THEN 'anthropic.claude-3-opus-20240229-v1:0'
WHEN 'claude-3-5-sonnet' THEN 'anthropic.claude-3-5-sonnet-20240620-v1:0'
ELSE name
END
WHERE provider = 'aws_bedrock' AND provider_params IS NULL
SQL
end
def fetch_setting(name)
DB.query_single(
"SELECT value FROM site_settings WHERE name = :setting_name",
setting_name: name,
).first
end
def down
raise ActiveRecord::IrreversibleMigration
end
end

View File

@ -62,6 +62,12 @@ module DiscourseAi
end
def model_uri
if llm_model
region = llm_model.lookup_custom_param("region")
api_url =
"https://bedrock-runtime.#{region}.amazonaws.com/model/#{llm_model.name}/invoke"
else
# See: https://docs.aws.amazon.com/bedrock/latest/userguide/model-ids-arns.html
#
# FYI there is a 2.0 version of Claude, very little need to support it given
@ -85,8 +91,8 @@ module DiscourseAi
end
api_url =
llm_model&.url ||
"https://bedrock-runtime.#{SiteSetting.ai_bedrock_region}.amazonaws.com/model/#{bedrock_model_id}/invoke"
end
api_url = @streaming_mode ? (api_url + "-with-response-stream") : api_url
@ -108,8 +114,10 @@ module DiscourseAi
signer =
Aws::Sigv4::Signer.new(
access_key_id: SiteSetting.ai_bedrock_access_key_id,
region: SiteSetting.ai_bedrock_region,
access_key_id:
llm_model&.lookup_custom_param("access_key_id") ||
SiteSetting.ai_bedrock_access_key_id,
region: llm_model&.lookup_custom_param("region") || SiteSetting.ai_bedrock_region,
secret_access_key: llm_model&.api_key || SiteSetting.ai_bedrock_secret_access_key,
service: "bedrock",
)

View File

@ -128,9 +128,9 @@ module DiscourseAi
headers["Authorization"] = "Bearer #{api_key}"
end
if SiteSetting.ai_openai_organization.present?
headers["OpenAI-Organization"] = SiteSetting.ai_openai_organization
end
org_id =
llm_model&.lookup_custom_param("organization") || SiteSetting.ai_openai_organization
headers["OpenAI-Organization"] = org_id if org_id.present?
Net::HTTP::Post.new(model_uri, headers).tap { |r| r.body = payload }
end

View File

@ -41,6 +41,34 @@ RSpec.describe DiscourseAi::Admin::AiLlmsController do
expect(created_model.tokenizer).to eq(valid_attrs[:tokenizer])
expect(created_model.max_prompt_tokens).to eq(valid_attrs[:max_prompt_tokens])
end
it "stores provider-specific config params" do
provider_params = { organization: "Discourse" }
post "/admin/plugins/discourse-ai/ai-llms.json",
params: {
ai_llm: valid_attrs.merge(provider_params: provider_params),
}
created_model = LlmModel.last
expect(created_model.lookup_custom_param("organization")).to eq(
provider_params[:organization],
)
end
it "ignores parameters not associated with that provider" do
provider_params = { access_key_id: "random_key" }
post "/admin/plugins/discourse-ai/ai-llms.json",
params: {
ai_llm: valid_attrs.merge(provider_params: provider_params),
}
created_model = LlmModel.last
expect(created_model.lookup_custom_param("access_key_id")).to be_nil
end
end
end
@ -66,6 +94,36 @@ RSpec.describe DiscourseAi::Admin::AiLlmsController do
expect(response.status).to eq(404)
end
end
context "with provider-specific params" do
it "updates provider-specific config params" do
provider_params = { organization: "Discourse" }
put "/admin/plugins/discourse-ai/ai-llms/#{llm_model.id}.json",
params: {
ai_llm: {
provider_params: provider_params,
},
}
expect(llm_model.reload.lookup_custom_param("organization")).to eq(
provider_params[:organization],
)
end
it "ignores parameters not associated with that provider" do
provider_params = { access_key_id: "random_key" }
put "/admin/plugins/discourse-ai/ai-llms/#{llm_model.id}.json",
params: {
ai_llm: {
provider_params: provider_params,
},
}
expect(llm_model.reload.lookup_custom_param("access_key_id")).to be_nil
end
end
end
describe "GET #test" do