From f622e2644f456bcf2212552311c8cc3405d76604 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Mon, 24 Jun 2024 19:26:30 -0300 Subject: [PATCH] FEATURE: Store provider-specific parameters. (#686) Previously, we stored request parameters like the OpenAI organization and Bedrock's access key and region as site settings. This change stores them in the `llm_models` table instead, letting us drop more settings while also becoming more flexible. --- .../discourse_ai/admin/ai_llms_controller.rb | 40 +++++++---- app/models/llm_model.rb | 18 +++++ app/serializers/llm_model_serializer.rb | 7 +- .../discourse/admin/models/ai-llm.js | 3 +- .../components/ai-llm-editor-form.gjs | 70 +++++++++++++++---- config/locales/client.en.yml | 5 ++ .../20240624135356_llm_model_custom_params.rb | 6 ++ ..._provider_specific_params_to_llm_models.rb | 41 +++++++++++ lib/completions/endpoints/aws_bedrock.rb | 58 ++++++++------- lib/completions/endpoints/open_ai.rb | 6 +- .../requests/admin/ai_llms_controller_spec.rb | 58 +++++++++++++++ 11 files changed, 255 insertions(+), 57 deletions(-) create mode 100644 db/migrate/20240624135356_llm_model_custom_params.rb create mode 100644 db/post_migrate/20240624202602_add_provider_specific_params_to_llm_models.rb diff --git a/app/controllers/discourse_ai/admin/ai_llms_controller.rb b/app/controllers/discourse_ai/admin/ai_llms_controller.rb index a98b6803..156b3297 100644 --- a/app/controllers/discourse_ai/admin/ai_llms_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_llms_controller.rb @@ -16,6 +16,7 @@ module DiscourseAi root: false, ).as_json, meta: { + provider_params: LlmModel.provider_params, presets: DiscourseAi::Completions::Llm.presets, providers: DiscourseAi::Completions::Llm.provider_names, tokenizers: @@ -43,7 +44,7 @@ module DiscourseAi def update llm_model = LlmModel.find(params[:id]) - if llm_model.update(ai_llm_params) + if llm_model.update(ai_llm_params(updating: llm_model)) llm_model.toggle_companion_user render json: llm_model else @@ -90,17 +91,32 @@ module DiscourseAi private - def ai_llm_params - params.require(:ai_llm).permit( - :display_name, - :name, - :provider, - :tokenizer, - :max_prompt_tokens, - :url, - :api_key, - :enabled_chat_bot, - ) + def ai_llm_params(updating: nil) + permitted = + params.require(:ai_llm).permit( + :display_name, + :name, + :provider, + :tokenizer, + :max_prompt_tokens, + :api_key, + :enabled_chat_bot, + ) + + provider = updating ? updating.provider : permitted[:provider] + permit_url = + (updating && updating.url != LlmModel::RESERVED_VLLM_SRV_URL) || + provider != LlmModel::BEDROCK_PROVIDER_NAME + + permitted[:url] = params.dig(:ai_llm, :url) if permit_url + + extra_field_names = LlmModel.provider_params.dig(provider&.to_sym, :fields).to_a + received_prov_params = params.dig(:ai_llm, :provider_params) + permitted[:provider_params] = received_prov_params.slice( + *extra_field_names, + ).permit! if !extra_field_names.empty? && received_prov_params.present? + + permitted end end end diff --git a/app/models/llm_model.rb b/app/models/llm_model.rb index ab4c85c9..28ed5b17 100644 --- a/app/models/llm_model.rb +++ b/app/models/llm_model.rb @@ -3,6 +3,7 @@ class LlmModel < ActiveRecord::Base FIRST_BOT_USER_ID = -1200 RESERVED_VLLM_SRV_URL = "https://vllm.shadowed-by-srv.invalid" + BEDROCK_PROVIDER_NAME = "aws_bedrock" belongs_to :user @@ -28,6 +29,19 @@ class LlmModel < ActiveRecord::Base end end + def self.provider_params + { + aws_bedrock: { + url_editable: false, + fields: %i[access_key_id region], + }, + open_ai: { + url_editable: true, + fields: %i[organization], + }, + } + end + def toggle_companion_user_before_save toggle_companion_user if enabled_chat_bot_changed? || new_record? end @@ -77,6 +91,10 @@ class LlmModel < ActiveRecord::Base def tokenizer_class tokenizer.constantize end + + def lookup_custom_param(key) + provider_params&.dig(key) + end end # == Schema Information diff --git a/app/serializers/llm_model_serializer.rb b/app/serializers/llm_model_serializer.rb index e93ca69a..268f41b2 100644 --- a/app/serializers/llm_model_serializer.rb +++ b/app/serializers/llm_model_serializer.rb @@ -12,11 +12,12 @@ class LlmModelSerializer < ApplicationSerializer :api_key, :url, :enabled_chat_bot, - :url_editable + :shadowed_by_srv, + :provider_params has_one :user, serializer: BasicUserSerializer, embed: :object - def url_editable - object.url != LlmModel::RESERVED_VLLM_SRV_URL + def shadowed_by_srv + object.url == LlmModel::RESERVED_VLLM_SRV_URL end end diff --git a/assets/javascripts/discourse/admin/models/ai-llm.js b/assets/javascripts/discourse/admin/models/ai-llm.js index 3c6cb8e9..e81d0d04 100644 --- a/assets/javascripts/discourse/admin/models/ai-llm.js +++ b/assets/javascripts/discourse/admin/models/ai-llm.js @@ -12,7 +12,8 @@ export default class AiLlm extends RestModel { "max_prompt_tokens", "url", "api_key", - "enabled_chat_bot" + "enabled_chat_bot", + "provider_params" ); } diff --git a/assets/javascripts/discourse/components/ai-llm-editor-form.gjs b/assets/javascripts/discourse/components/ai-llm-editor-form.gjs index 0fc46ec2..11a198fc 100644 --- a/assets/javascripts/discourse/components/ai-llm-editor-form.gjs +++ b/assets/javascripts/discourse/components/ai-llm-editor-form.gjs @@ -1,12 +1,12 @@ import Component from "@glimmer/component"; import { tracked } from "@glimmer/tracking"; import { Input } from "@ember/component"; +import { concat, get } from "@ember/helper"; import { on } from "@ember/modifier"; -import { action } from "@ember/object"; +import { action, computed } from "@ember/object"; import { LinkTo } from "@ember/routing"; import { later } from "@ember/runloop"; import { inject as service } from "@ember/service"; -import { or } from "truth-helpers"; import DButton from "discourse/components/d-button"; import DToggleSwitch from "discourse/components/d-toggle-switch"; import Avatar from "discourse/helpers/bound-avatar-template"; @@ -30,6 +30,14 @@ export default class AiLlmEditorForm extends Component { @tracked testError = null; @tracked apiKeySecret = true; + didReceiveAttrs() { + super.didReceiveAttrs(...arguments); + + if (!this.args.model.provider_params) { + this.populateProviderParams(this.args.model.provider); + } + } + get selectedProviders() { const t = (provName) => { return I18n.t(`discourse_ai.llms.providers.${provName}`); @@ -44,6 +52,35 @@ export default class AiLlmEditorForm extends Component { return AdminUser.create(this.args.model?.user); } + get testErrorMessage() { + return I18n.t("discourse_ai.llms.tests.failure", { error: this.testError }); + } + + get displayTestResult() { + return this.testRunning || this.testResult !== null; + } + + get displaySRVWarning() { + return this.args.model.shadowed_by_srv && !this.args.model.isNew; + } + + get canEditURL() { + // Explicitly false. + if (this.metaProviderParams.url_editable === false) { + return false; + } + + return !this.args.model.shadowed_by_srv || this.args.model.isNew; + } + + @computed("args.model.provider") + get metaProviderParams() { + return ( + this.args.llms.resultSetMeta.provider_params[this.args.model.provider] || + {} + ); + } + @action async save() { this.isSaving = true; @@ -94,14 +131,6 @@ export default class AiLlmEditorForm extends Component { } } - get testErrorMessage() { - return I18n.t("discourse_ai.llms.tests.failure", { error: this.testError }); - } - - get displayTestResult() { - return this.testRunning || this.testResult !== null; - } - @action makeApiKeySecret() { this.apiKeySecret = true; @@ -145,12 +174,12 @@ export default class AiLlmEditorForm extends Component { }