From 62fc7d6ed0e174d5929121369fd0aa2fad850290 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Mon, 13 May 2024 12:46:42 -0300 Subject: [PATCH] FEATURE: Configurable LLMs. (#606) This PR introduces the concept of "LlmModel" as a new way to quickly add new LLM models without making any code changes. We are releasing this first version and will add incremental improvements, so expect changes. The AI Bot can't fully take advantage of this feature as users are hard-coded. We'll fix this in a separate PR.s --- ...dmin-plugins-show-discourse-ai-llms-new.js | 16 +++ ...min-plugins-show-discourse-ai-llms-show.js | 17 +++ .../admin-plugins-show-discourse-ai-llms.js | 7 + .../show/discourse-ai-llms/index.hbs | 1 + .../show/discourse-ai-llms/new.hbs | 1 + .../show/discourse-ai-llms/show.hbs | 1 + .../discourse_ai/admin/ai_llms_controller.rb | 64 +++++++++ app/models/llm_model.rb | 7 + app/serializers/llm_model_serializer.rb | 7 + .../admin-discourse-ai-plugin-route-map.js | 5 + .../discourse/admin/adapters/ai-llm.js | 21 +++ .../discourse/admin/models/ai-llm.js | 20 +++ .../discourse/components/ai-llm-editor.gjs | 122 ++++++++++++++++++ .../components/ai-llms-list-editor.gjs | 61 +++++++++ .../admin-plugin-configuration-nav.js | 4 + .../modules/llms/common/ai-llms-editor.scss | 32 +++++ config/locales/client.en.yml | 26 ++++ config/routes.rb | 5 + .../20240504222307_create_llm_model_table.rb | 14 ++ lib/ai_bot/bot.rb | 114 +++++++++------- lib/automation.rb | 3 + lib/completions/dialects/chat_gpt.rb | 2 + lib/completions/dialects/claude.rb | 1 + lib/completions/dialects/command.rb | 6 +- lib/completions/dialects/dialect.rb | 12 +- lib/completions/dialects/gemini.rb | 2 + lib/completions/dialects/mistral.rb | 2 + lib/completions/llm.rb | 45 ++++++- lib/configuration/llm_enumerator.rb | 20 ++- plugin.rb | 2 + spec/fabricators/llm_model_fabricator.rb | 9 ++ .../endpoints/hugging_face_spec.rb | 5 +- spec/lib/completions/llm_spec.rb | 4 +- .../requests/admin/ai_llms_controller_spec.rb | 68 ++++++++++ 34 files changed, 656 insertions(+), 70 deletions(-) create mode 100644 admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-new.js create mode 100644 admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-show.js create mode 100644 admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms.js create mode 100644 admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/index.hbs create mode 100644 admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/new.hbs create mode 100644 admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/show.hbs create mode 100644 app/controllers/discourse_ai/admin/ai_llms_controller.rb create mode 100644 app/models/llm_model.rb create mode 100644 app/serializers/llm_model_serializer.rb create mode 100644 assets/javascripts/discourse/admin/adapters/ai-llm.js create mode 100644 assets/javascripts/discourse/admin/models/ai-llm.js create mode 100644 assets/javascripts/discourse/components/ai-llm-editor.gjs create mode 100644 assets/javascripts/discourse/components/ai-llms-list-editor.gjs create mode 100644 assets/stylesheets/modules/llms/common/ai-llms-editor.scss create mode 100644 db/migrate/20240504222307_create_llm_model_table.rb create mode 100644 spec/fabricators/llm_model_fabricator.rb create mode 100644 spec/requests/admin/ai_llms_controller_spec.rb diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-new.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-new.js new file mode 100644 index 00000000..aafc69f2 --- /dev/null +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-new.js @@ -0,0 +1,16 @@ +import DiscourseRoute from "discourse/routes/discourse"; + +export default DiscourseRoute.extend({ + async model() { + const record = this.store.createRecord("ai-llm"); + return record; + }, + + setupController(controller, model) { + this._super(controller, model); + controller.set( + "allLlms", + this.modelFor("adminPlugins.show.discourse-ai-llms") + ); + }, +}); diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-show.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-show.js new file mode 100644 index 00000000..7a9fa379 --- /dev/null +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-show.js @@ -0,0 +1,17 @@ +import DiscourseRoute from "discourse/routes/discourse"; + +export default DiscourseRoute.extend({ + async model(params) { + const allLlms = this.modelFor("adminPlugins.show.discourse-ai-llms"); + const id = parseInt(params.id, 10); + return allLlms.findBy("id", id); + }, + + setupController(controller, model) { + this._super(controller, model); + controller.set( + "allLlms", + this.modelFor("adminPlugins.show.discourse-ai-llms") + ); + }, +}); diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms.js new file mode 100644 index 00000000..21f8f44b --- /dev/null +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms.js @@ -0,0 +1,7 @@ +import DiscourseRoute from "discourse/routes/discourse"; + +export default class DiscourseAiAiLlmsRoute extends DiscourseRoute { + model() { + return this.store.findAll("ai-llm"); + } +} diff --git a/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/index.hbs b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/index.hbs new file mode 100644 index 00000000..e1ab7f35 --- /dev/null +++ b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/index.hbs @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/new.hbs b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/new.hbs new file mode 100644 index 00000000..77f3b0f3 --- /dev/null +++ b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/new.hbs @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/show.hbs b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/show.hbs new file mode 100644 index 00000000..77f3b0f3 --- /dev/null +++ b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/show.hbs @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/controllers/discourse_ai/admin/ai_llms_controller.rb b/app/controllers/discourse_ai/admin/ai_llms_controller.rb new file mode 100644 index 00000000..6e57f841 --- /dev/null +++ b/app/controllers/discourse_ai/admin/ai_llms_controller.rb @@ -0,0 +1,64 @@ +# frozen_string_literal: true + +module DiscourseAi + module Admin + class AiLlmsController < ::Admin::AdminController + requires_plugin ::DiscourseAi::PLUGIN_NAME + + def index + llms = LlmModel.all + + render json: { + ai_llms: + ActiveModel::ArraySerializer.new( + llms, + each_serializer: LlmModelSerializer, + root: false, + ).as_json, + meta: { + providers: DiscourseAi::Completions::Llm.provider_names, + tokenizers: + DiscourseAi::Completions::Llm.tokenizer_names.map { |tn| + { id: tn, name: tn.split("::").last } + }, + }, + } + end + + def show + llm_model = LlmModel.find(params[:id]) + render json: LlmModelSerializer.new(llm_model) + end + + def create + if llm_model = LlmModel.new(ai_llm_params).save + render json: { ai_persona: llm_model }, status: :created + else + render_json_error llm_model + end + end + + def update + llm_model = LlmModel.find(params[:id]) + + if llm_model.update(ai_llm_params) + render json: llm_model + else + render_json_error llm_model + end + end + + private + + def ai_llm_params + params.require(:ai_llm).permit( + :display_name, + :name, + :provider, + :tokenizer, + :max_prompt_tokens, + ) + end + end + end +end diff --git a/app/models/llm_model.rb b/app/models/llm_model.rb new file mode 100644 index 00000000..aefb9202 --- /dev/null +++ b/app/models/llm_model.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +class LlmModel < ActiveRecord::Base + def tokenizer_class + tokenizer.constantize + end +end diff --git a/app/serializers/llm_model_serializer.rb b/app/serializers/llm_model_serializer.rb new file mode 100644 index 00000000..77f264b8 --- /dev/null +++ b/app/serializers/llm_model_serializer.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +class LlmModelSerializer < ApplicationSerializer + root "llm" + + attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer +end diff --git a/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js b/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js index 54411f90..97ed05e5 100644 --- a/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js +++ b/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js @@ -8,5 +8,10 @@ export default { this.route("new"); this.route("show", { path: "/:id" }); }); + + this.route("discourse-ai-llms", { path: "ai-llms" }, function () { + this.route("new"); + this.route("show", { path: "/:id" }); + }); }, }; diff --git a/assets/javascripts/discourse/admin/adapters/ai-llm.js b/assets/javascripts/discourse/admin/adapters/ai-llm.js new file mode 100644 index 00000000..fe82163d --- /dev/null +++ b/assets/javascripts/discourse/admin/adapters/ai-llm.js @@ -0,0 +1,21 @@ +import RestAdapter from "discourse/adapters/rest"; + +export default class Adapter extends RestAdapter { + jsonMode = true; + + basePath() { + return "/admin/plugins/discourse-ai/"; + } + + pathFor(store, type, findArgs) { + // removes underscores which are implemented in base + let path = + this.basePath(store, type, findArgs) + + store.pluralize(this.apiNameFor(type)); + return this.appendQueryParams(path, findArgs); + } + + apiNameFor() { + return "ai-llm"; + } +} diff --git a/assets/javascripts/discourse/admin/models/ai-llm.js b/assets/javascripts/discourse/admin/models/ai-llm.js new file mode 100644 index 00000000..69f9fbfe --- /dev/null +++ b/assets/javascripts/discourse/admin/models/ai-llm.js @@ -0,0 +1,20 @@ +import RestModel from "discourse/models/rest"; + +export default class AiLlm extends RestModel { + createProperties() { + return this.getProperties( + "display_name", + "name", + "provider", + "tokenizer", + "max_prompt_tokens" + ); + } + + updateProperties() { + const attrs = this.createProperties(); + attrs.id = this.id; + + return attrs; + } +} diff --git a/assets/javascripts/discourse/components/ai-llm-editor.gjs b/assets/javascripts/discourse/components/ai-llm-editor.gjs new file mode 100644 index 00000000..bac97ee5 --- /dev/null +++ b/assets/javascripts/discourse/components/ai-llm-editor.gjs @@ -0,0 +1,122 @@ +import Component from "@glimmer/component"; +import { tracked } from "@glimmer/tracking"; +import { Input } from "@ember/component"; +import { action } from "@ember/object"; +import { later } from "@ember/runloop"; +import { inject as service } from "@ember/service"; +import DButton from "discourse/components/d-button"; +import { popupAjaxError } from "discourse/lib/ajax-error"; +import i18n from "discourse-common/helpers/i18n"; +import I18n from "discourse-i18n"; +import ComboBox from "select-kit/components/combo-box"; +import DTooltip from "float-kit/components/d-tooltip"; + +export default class AiLlmEditor extends Component { + @service toasts; + @service router; + + @tracked isSaving = false; + + get selectedProviders() { + const t = (provName) => { + return I18n.t(`discourse_ai.llms.providers.${provName}`); + }; + + return this.args.llms.resultSetMeta.providers.map((prov) => { + return { id: prov, name: t(prov) }; + }); + } + + @action + async save() { + this.isSaving = true; + const isNew = this.args.model.isNew; + + try { + await this.args.model.save(); + + if (isNew) { + this.args.llms.addObject(this.args.model); + this.router.transitionTo( + "adminPlugins.show.discourse-ai-llms.show", + this.args.model + ); + } else { + this.toasts.success({ + data: { message: I18n.t("discourse_ai.llms.saved") }, + duration: 2000, + }); + } + } catch (e) { + popupAjaxError(e); + } finally { + later(() => { + this.isSaving = false; + }, 1000); + } + } + + +} diff --git a/assets/javascripts/discourse/components/ai-llms-list-editor.gjs b/assets/javascripts/discourse/components/ai-llms-list-editor.gjs new file mode 100644 index 00000000..73ba310e --- /dev/null +++ b/assets/javascripts/discourse/components/ai-llms-list-editor.gjs @@ -0,0 +1,61 @@ +import Component from "@glimmer/component"; +import { LinkTo } from "@ember/routing"; +import icon from "discourse-common/helpers/d-icon"; +import i18n from "discourse-common/helpers/i18n"; +import I18n from "discourse-i18n"; +import AiLlmEditor from "./ai-llm-editor"; + +export default class AiLlmsListEditor extends Component { + get hasNoLLMElements() { + this.args.llms.length !== 0; + } + + +} diff --git a/assets/javascripts/initializers/admin-plugin-configuration-nav.js b/assets/javascripts/initializers/admin-plugin-configuration-nav.js index 2c2bac59..51d88222 100644 --- a/assets/javascripts/initializers/admin-plugin-configuration-nav.js +++ b/assets/javascripts/initializers/admin-plugin-configuration-nav.js @@ -16,6 +16,10 @@ export default { label: "discourse_ai.ai_persona.short_title", route: "adminPlugins.show.discourse-ai-personas", }, + { + label: "discourse_ai.llms.short_title", + route: "adminPlugins.show.discourse-ai-llms", + }, ]); }); }, diff --git a/assets/stylesheets/modules/llms/common/ai-llms-editor.scss b/assets/stylesheets/modules/llms/common/ai-llms-editor.scss new file mode 100644 index 00000000..351deeef --- /dev/null +++ b/assets/stylesheets/modules/llms/common/ai-llms-editor.scss @@ -0,0 +1,32 @@ +.ai-llms-list-editor { + &__header { + display: flex; + justify-content: space-between; + align-items: center; + margin: 0 0 1em 0; + + h3 { + margin: 0; + } + } + + &__container { + display: flex; + flex-direction: row; + align-items: center; + gap: 20px; + width: 100%; + align-items: stretch; + } + + &__empty_list, + &__content_list { + min-width: 300px; + } + + &__empty_list { + align-content: center; + text-align: center; + font-size: var(--font-up-1); + } +} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 1c3d9c0c..044d00f3 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -194,6 +194,32 @@ en: uploading: "Uploading..." remove: "Remove upload" + llms: + short_title: "LLMs" + no_llms: "No LLMs yet" + new: "New Model" + display_name: "Name to display:" + name: "Model name:" + provider: "Service hosting the model:" + tokenizer: "Tokenizer:" + max_prompt_tokens: "Number of tokens for the prompt:" + save: "Save" + saved: "LLM Model Saved" + + hints: + max_prompt_tokens: "Max numbers of tokens for the prompt. As a rule of thumb, this should be 50% of the model's context window." + name: "We include this in the API call to specify which model we'll use." + + providers: + aws_bedrock: "AWS Bedrock" + anthropic: "Anthropic" + vllm: "vLLM" + hugging_face: "Hugging Face" + cohere: "Cohere" + open_ai: "OpenAI" + google: "Google" + azure: "Azure" + related_topics: title: "Related Topics" pill: "Related" diff --git a/config/routes.rb b/config/routes.rb index eb20cf98..f33dfd68 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -45,5 +45,10 @@ Discourse::Application.routes.draw do post "/ai-personas/files/upload", to: "discourse_ai/admin/ai_personas#upload_file" put "/ai-personas/:id/files/remove", to: "discourse_ai/admin/ai_personas#remove_file" get "/ai-personas/:id/files/status", to: "discourse_ai/admin/ai_personas#indexing_status_check" + + resources :ai_llms, + only: %i[index create show update], + path: "ai-llms", + controller: "discourse_ai/admin/ai_llms" end end diff --git a/db/migrate/20240504222307_create_llm_model_table.rb b/db/migrate/20240504222307_create_llm_model_table.rb new file mode 100644 index 00000000..96bcc3dd --- /dev/null +++ b/db/migrate/20240504222307_create_llm_model_table.rb @@ -0,0 +1,14 @@ +# frozen_string_literal: true + +class CreateLlmModelTable < ActiveRecord::Migration[7.0] + def change + create_table :llm_models do |t| + t.string :display_name + t.string :name, null: false + t.string :provider, null: false + t.string :tokenizer, null: false + t.integer :max_prompt_tokens, null: false + t.timestamps + end + end +end diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 3e3c932e..d2ca71f5 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -150,59 +150,73 @@ module DiscourseAi def self.guess_model(bot_user) # HACK(roman): We'll do this until we define how we represent different providers in the bot settings - case bot_user.id - when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") - "aws_bedrock:claude-2" + guess = + case bot_user.id + when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") + "aws_bedrock:claude-2" + else + "anthropic:claude-2" + end + when DiscourseAi::AiBot::EntryPoint::GPT4_ID + "open_ai:gpt-4" + when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID + "open_ai:gpt-4-turbo" + when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID + "open_ai:gpt-3.5-turbo-16k" + when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID + mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" + if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model) + "vllm:#{mixtral_model}" + elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?( + mixtral_model, + ) + "hugging_face:#{mixtral_model}" + else + "ollama:mistral" + end + when DiscourseAi::AiBot::EntryPoint::GEMINI_ID + "google:gemini-pro" + when DiscourseAi::AiBot::EntryPoint::FAKE_ID + "fake:fake" + when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( + "claude-3-opus", + ) + "aws_bedrock:claude-3-opus" + else + "anthropic:claude-3-opus" + end + when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS + "cohere:command-r-plus" + when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( + "claude-3-sonnet", + ) + "aws_bedrock:claude-3-sonnet" + else + "anthropic:claude-3-sonnet" + end + when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_HAIKU_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( + "claude-3-haiku", + ) + "aws_bedrock:claude-3-haiku" + else + "anthropic:claude-3-haiku" + end else - "anthropic:claude-2" + nil end - when DiscourseAi::AiBot::EntryPoint::GPT4_ID - "open_ai:gpt-4" - when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID - "open_ai:gpt-4-turbo" - when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID - "open_ai:gpt-3.5-turbo-16k" - when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID - mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" - if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model) - "vllm:#{mixtral_model}" - elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?( - mixtral_model, - ) - "hugging_face:#{mixtral_model}" - else - "ollama:mistral" - end - when DiscourseAi::AiBot::EntryPoint::GEMINI_ID - "google:gemini-pro" - when DiscourseAi::AiBot::EntryPoint::FAKE_ID - "fake:fake" - when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-3-opus") - "aws_bedrock:claude-3-opus" - else - "anthropic:claude-3-opus" - end - when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS - "cohere:command-r-plus" - when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( - "claude-3-sonnet", - ) - "aws_bedrock:claude-3-sonnet" - else - "anthropic:claude-3-sonnet" - end - when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_HAIKU_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-3-haiku") - "aws_bedrock:claude-3-haiku" - else - "anthropic:claude-3-haiku" - end - else - nil + + if guess + provider, model_name = guess.split(":") + llm_model = LlmModel.find_by(provider: provider, name: model_name) + + return "custom:#{llm_model.id}" if llm_model end + + guess end def build_placeholder(summary, details, custom_raw: nil) diff --git a/lib/automation.rb b/lib/automation.rb index b755f1db..71cde97b 100644 --- a/lib/automation.rb +++ b/lib/automation.rb @@ -25,6 +25,9 @@ module DiscourseAi ] def self.translate_model(model) + llm_model = LlmModel.find_by(name: model) + return "custom:#{llm_model.id}" if llm_model + return "google:#{model}" if model.start_with? "gemini" return "open_ai:#{model}" if model.start_with? "gpt" return "cohere:#{model}" if model.start_with? "command" diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 9383019c..f6142d09 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -30,6 +30,8 @@ module DiscourseAi end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + # provide a buffer of 120 tokens - our function counting is not # 100% accurate and getting numbers to align exactly is very hard buffer = (opts[:max_tokens] || 2500) + 50 diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 9a15b293..8be67e54 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -50,6 +50,7 @@ module DiscourseAi end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? # Longer term it will have over 1 million 200_000 # Claude-3 has a 200k context window for now end diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb index f119aba8..62240372 100644 --- a/lib/completions/dialects/command.rb +++ b/lib/completions/dialects/command.rb @@ -38,6 +38,8 @@ module DiscourseAi end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + case model_name when "command-light" 4096 @@ -62,10 +64,6 @@ module DiscourseAi self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end - def tools_dialect - @tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools) - end - def system_msg(msg) cmd_msg = { role: "SYSTEM", message: msg[:content] } diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 865b5509..13275ad4 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -9,14 +9,22 @@ module DiscourseAi raise NotImplemented end - def dialect_for(model_name) - dialects = [ + def all_dialects + [ DiscourseAi::Completions::Dialects::ChatGpt, DiscourseAi::Completions::Dialects::Gemini, DiscourseAi::Completions::Dialects::Mistral, DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Command, ] + end + + def available_tokenizers + all_dialects.map(&:tokenizer) + end + + def dialect_for(model_name) + dialects = all_dialects if Rails.env.test? || Rails.env.development? dialects << DiscourseAi::Completions::Dialects::Fake diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index 678dc0cd..fde9cdda 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -68,6 +68,8 @@ module DiscourseAi end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + if model_name == "gemini-1.5-pro" # technically we support 1 million tokens, but we're being conservative 800_000 diff --git a/lib/completions/dialects/mistral.rb b/lib/completions/dialects/mistral.rb index 7752a876..d34130f9 100644 --- a/lib/completions/dialects/mistral.rb +++ b/lib/completions/dialects/mistral.rb @@ -23,6 +23,8 @@ module DiscourseAi end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + 32_000 end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 47190644..a2f70172 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -18,6 +18,14 @@ module DiscourseAi UNKNOWN_MODEL = Class.new(StandardError) class << self + def provider_names + %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure] + end + + def tokenizer_names + DiscourseAi::Completions::Dialects::Dialect.available_tokenizers.map(&:name).uniq + end + def models_by_provider # ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure. # However, since they use the same URL/key settings, there's no reason to duplicate them. @@ -80,36 +88,54 @@ module DiscourseAi end def proxy(model_name) + # We are in the process of transitioning to always use objects here. + # We'll live with this hack for a while. provider_and_model_name = model_name.split(":") - provider_name = provider_and_model_name.first model_name_without_prov = provider_and_model_name[1..].join + is_custom_model = provider_name == "custom" + + if is_custom_model + llm_model = LlmModel.find(model_name_without_prov) + provider_name = llm_model.provider + model_name_without_prov = llm_model.name + end dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov) + if is_custom_model + tokenizer = llm_model.tokenizer_class + else + tokenizer = dialect_klass.tokenizer + end + if @canned_response if @canned_llm && @canned_llm != model_name raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}" end - return new(dialect_klass, nil, model_name, gateway: @canned_response) + return new(dialect_klass, nil, model_name, opts: { gateway: @canned_response }) end + opts = {} + opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model + gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for( provider_name, model_name_without_prov, ) - new(dialect_klass, gateway_klass, model_name_without_prov) + new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts) end end - def initialize(dialect_klass, gateway_klass, model_name, gateway: nil) + def initialize(dialect_klass, gateway_klass, model_name, opts: {}) @dialect_klass = dialect_klass @gateway_klass = gateway_klass @model_name = model_name - @gateway = gateway + @gateway = opts[:gateway] + @max_prompt_tokens = opts[:max_prompt_tokens] end # @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object @@ -166,11 +192,18 @@ module DiscourseAi model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? } gateway = @gateway || gateway_klass.new(model_name, dialect_klass.tokenizer) - dialect = dialect_klass.new(prompt, model_name, opts: model_params) + dialect = + dialect_klass.new( + prompt, + model_name, + opts: model_params.merge(max_prompt_tokens: @max_prompt_tokens), + ) gateway.perform_completion!(dialect, user, model_params, &partial_read_blk) end def max_prompt_tokens + return @max_prompt_tokens if @max_prompt_tokens.present? + dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens end diff --git a/lib/configuration/llm_enumerator.rb b/lib/configuration/llm_enumerator.rb index d7618795..fd870b57 100644 --- a/lib/configuration/llm_enumerator.rb +++ b/lib/configuration/llm_enumerator.rb @@ -10,14 +10,22 @@ module DiscourseAi end def self.values - # do not cache cause settings can change this - DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models| - endpoint = - DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first) + begin + llm_models = + DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models| + endpoint = + DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first) - models.map do |model_name| - { name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" } + models.map do |model_name| + { name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" } + end + end + + LlmModel.all.each do |model| + llm_models << { name: model.display_name, value: "custom:#{model.id}" } end + + llm_models end end end diff --git a/plugin.rb b/plugin.rb index ad255b70..c5a571ba 100644 --- a/plugin.rb +++ b/plugin.rb @@ -26,6 +26,8 @@ register_asset "stylesheets/modules/sentiment/common/dashboard.scss" register_asset "stylesheets/modules/sentiment/desktop/dashboard.scss", :desktop register_asset "stylesheets/modules/sentiment/mobile/dashboard.scss", :mobile +register_asset "stylesheets/modules/llms/common/ai-llms-editor.scss" + module ::DiscourseAi PLUGIN_NAME = "discourse-ai" end diff --git a/spec/fabricators/llm_model_fabricator.rb b/spec/fabricators/llm_model_fabricator.rb new file mode 100644 index 00000000..c419341e --- /dev/null +++ b/spec/fabricators/llm_model_fabricator.rb @@ -0,0 +1,9 @@ +# frozen_string_literal: true + +Fabricator(:llm_model) do + display_name "A good model" + name "gpt-4-turbo" + provider "open_ai" + tokenizer "DiscourseAi::Tokenizers::OpenAi" + max_prompt_tokens 32_000 +end diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb index 5b9bd9f5..f84fc43a 100644 --- a/spec/lib/completions/endpoints/hugging_face_spec.rb +++ b/spec/lib/completions/endpoints/hugging_face_spec.rb @@ -81,7 +81,10 @@ end RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do subject(:endpoint) do - described_class.new("Llama2-*-chat-hf", DiscourseAi::Tokenizer::Llama2Tokenizer) + described_class.new( + "mistralai/Mistral-7B-Instruct-v0.2", + DiscourseAi::Tokenizer::MixtralTokenizer, + ) end before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" } diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb index e91a4a60..c3107dc0 100644 --- a/spec/lib/completions/llm_spec.rb +++ b/spec/lib/completions/llm_spec.rb @@ -6,7 +6,9 @@ RSpec.describe DiscourseAi::Completions::Llm do DiscourseAi::Completions::Dialects::Mistral, canned_response, "hugging_face:Upstage-Llama-2-*-instruct-v2", - gateway: canned_response, + opts: { + gateway: canned_response, + }, ) end diff --git a/spec/requests/admin/ai_llms_controller_spec.rb b/spec/requests/admin/ai_llms_controller_spec.rb new file mode 100644 index 00000000..e4747834 --- /dev/null +++ b/spec/requests/admin/ai_llms_controller_spec.rb @@ -0,0 +1,68 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Admin::AiLlmsController do + fab!(:admin) + + before { sign_in(admin) } + + describe "GET #index" do + it "includes all available providers metadata" do + get "/admin/plugins/discourse-ai/ai-llms.json" + expect(response).to be_successful + + expect(response.parsed_body["meta"]["providers"]).to contain_exactly( + *DiscourseAi::Completions::Llm.provider_names, + ) + end + end + + describe "POST #create" do + context "with valid attributes" do + let(:valid_attrs) do + { + display_name: "My cool LLM", + name: "gpt-3.5", + provider: "open_ai", + tokenizer: "DiscourseAi::Tokenizers::OpenAiTokenizer", + max_prompt_tokens: 16_000, + } + end + + it "creates a new LLM model" do + post "/admin/plugins/discourse-ai/ai-llms.json", params: { ai_llm: valid_attrs } + + created_model = LlmModel.last + + expect(created_model.display_name).to eq(valid_attrs[:display_name]) + expect(created_model.name).to eq(valid_attrs[:name]) + expect(created_model.provider).to eq(valid_attrs[:provider]) + expect(created_model.tokenizer).to eq(valid_attrs[:tokenizer]) + expect(created_model.max_prompt_tokens).to eq(valid_attrs[:max_prompt_tokens]) + end + end + end + + describe "PUT #update" do + fab!(:llm_model) + + context "with valid update params" do + let(:update_attrs) { { provider: "anthropic" } } + + it "updates the model" do + put "/admin/plugins/discourse-ai/ai-llms/#{llm_model.id}.json", + params: { + ai_llm: update_attrs, + } + + expect(response.status).to eq(200) + expect(llm_model.reload.provider).to eq(update_attrs[:provider]) + end + + it "returns a 404 if there is no model with the given Id" do + put "/admin/plugins/discourse-ai/ai-llms/9999999.json" + + expect(response.status).to eq(404) + end + end + end +end