diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-new.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-new.js new file mode 100644 index 00000000..aafc69f2 --- /dev/null +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-new.js @@ -0,0 +1,16 @@ +import DiscourseRoute from "discourse/routes/discourse"; + +export default DiscourseRoute.extend({ + async model() { + const record = this.store.createRecord("ai-llm"); + return record; + }, + + setupController(controller, model) { + this._super(controller, model); + controller.set( + "allLlms", + this.modelFor("adminPlugins.show.discourse-ai-llms") + ); + }, +}); diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-show.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-show.js new file mode 100644 index 00000000..7a9fa379 --- /dev/null +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms-show.js @@ -0,0 +1,17 @@ +import DiscourseRoute from "discourse/routes/discourse"; + +export default DiscourseRoute.extend({ + async model(params) { + const allLlms = this.modelFor("adminPlugins.show.discourse-ai-llms"); + const id = parseInt(params.id, 10); + return allLlms.findBy("id", id); + }, + + setupController(controller, model) { + this._super(controller, model); + controller.set( + "allLlms", + this.modelFor("adminPlugins.show.discourse-ai-llms") + ); + }, +}); diff --git a/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms.js b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms.js new file mode 100644 index 00000000..21f8f44b --- /dev/null +++ b/admin/assets/javascripts/discourse/routes/admin-plugins-show-discourse-ai-llms.js @@ -0,0 +1,7 @@ +import DiscourseRoute from "discourse/routes/discourse"; + +export default class DiscourseAiAiLlmsRoute extends DiscourseRoute { + model() { + return this.store.findAll("ai-llm"); + } +} diff --git a/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/index.hbs b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/index.hbs new file mode 100644 index 00000000..e1ab7f35 --- /dev/null +++ b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/index.hbs @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/new.hbs b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/new.hbs new file mode 100644 index 00000000..77f3b0f3 --- /dev/null +++ b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/new.hbs @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/show.hbs b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/show.hbs new file mode 100644 index 00000000..77f3b0f3 --- /dev/null +++ b/admin/assets/javascripts/discourse/templates/admin-plugins/show/discourse-ai-llms/show.hbs @@ -0,0 +1 @@ + \ No newline at end of file diff --git a/app/controllers/discourse_ai/admin/ai_llms_controller.rb b/app/controllers/discourse_ai/admin/ai_llms_controller.rb new file mode 100644 index 00000000..6e57f841 --- /dev/null +++ b/app/controllers/discourse_ai/admin/ai_llms_controller.rb @@ -0,0 +1,64 @@ +# frozen_string_literal: true + +module DiscourseAi + module Admin + class AiLlmsController < ::Admin::AdminController + requires_plugin ::DiscourseAi::PLUGIN_NAME + + def index + llms = LlmModel.all + + render json: { + ai_llms: + ActiveModel::ArraySerializer.new( + llms, + each_serializer: LlmModelSerializer, + root: false, + ).as_json, + meta: { + providers: DiscourseAi::Completions::Llm.provider_names, + tokenizers: + DiscourseAi::Completions::Llm.tokenizer_names.map { |tn| + { id: tn, name: tn.split("::").last } + }, + }, + } + end + + def show + llm_model = LlmModel.find(params[:id]) + render json: LlmModelSerializer.new(llm_model) + end + + def create + if llm_model = LlmModel.new(ai_llm_params).save + render json: { ai_persona: llm_model }, status: :created + else + render_json_error llm_model + end + end + + def update + llm_model = LlmModel.find(params[:id]) + + if llm_model.update(ai_llm_params) + render json: llm_model + else + render_json_error llm_model + end + end + + private + + def ai_llm_params + params.require(:ai_llm).permit( + :display_name, + :name, + :provider, + :tokenizer, + :max_prompt_tokens, + ) + end + end + end +end diff --git a/app/models/llm_model.rb b/app/models/llm_model.rb new file mode 100644 index 00000000..aefb9202 --- /dev/null +++ b/app/models/llm_model.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +class LlmModel < ActiveRecord::Base + def tokenizer_class + tokenizer.constantize + end +end diff --git a/app/serializers/llm_model_serializer.rb b/app/serializers/llm_model_serializer.rb new file mode 100644 index 00000000..77f264b8 --- /dev/null +++ b/app/serializers/llm_model_serializer.rb @@ -0,0 +1,7 @@ +# frozen_string_literal: true + +class LlmModelSerializer < ApplicationSerializer + root "llm" + + attributes :id, :display_name, :name, :provider, :max_prompt_tokens, :tokenizer +end diff --git a/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js b/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js index 54411f90..97ed05e5 100644 --- a/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js +++ b/assets/javascripts/discourse/admin-discourse-ai-plugin-route-map.js @@ -8,5 +8,10 @@ export default { this.route("new"); this.route("show", { path: "/:id" }); }); + + this.route("discourse-ai-llms", { path: "ai-llms" }, function () { + this.route("new"); + this.route("show", { path: "/:id" }); + }); }, }; diff --git a/assets/javascripts/discourse/admin/adapters/ai-llm.js b/assets/javascripts/discourse/admin/adapters/ai-llm.js new file mode 100644 index 00000000..fe82163d --- /dev/null +++ b/assets/javascripts/discourse/admin/adapters/ai-llm.js @@ -0,0 +1,21 @@ +import RestAdapter from "discourse/adapters/rest"; + +export default class Adapter extends RestAdapter { + jsonMode = true; + + basePath() { + return "/admin/plugins/discourse-ai/"; + } + + pathFor(store, type, findArgs) { + // removes underscores which are implemented in base + let path = + this.basePath(store, type, findArgs) + + store.pluralize(this.apiNameFor(type)); + return this.appendQueryParams(path, findArgs); + } + + apiNameFor() { + return "ai-llm"; + } +} diff --git a/assets/javascripts/discourse/admin/models/ai-llm.js b/assets/javascripts/discourse/admin/models/ai-llm.js new file mode 100644 index 00000000..69f9fbfe --- /dev/null +++ b/assets/javascripts/discourse/admin/models/ai-llm.js @@ -0,0 +1,20 @@ +import RestModel from "discourse/models/rest"; + +export default class AiLlm extends RestModel { + createProperties() { + return this.getProperties( + "display_name", + "name", + "provider", + "tokenizer", + "max_prompt_tokens" + ); + } + + updateProperties() { + const attrs = this.createProperties(); + attrs.id = this.id; + + return attrs; + } +} diff --git a/assets/javascripts/discourse/components/ai-llm-editor.gjs b/assets/javascripts/discourse/components/ai-llm-editor.gjs new file mode 100644 index 00000000..bac97ee5 --- /dev/null +++ b/assets/javascripts/discourse/components/ai-llm-editor.gjs @@ -0,0 +1,122 @@ +import Component from "@glimmer/component"; +import { tracked } from "@glimmer/tracking"; +import { Input } from "@ember/component"; +import { action } from "@ember/object"; +import { later } from "@ember/runloop"; +import { inject as service } from "@ember/service"; +import DButton from "discourse/components/d-button"; +import { popupAjaxError } from "discourse/lib/ajax-error"; +import i18n from "discourse-common/helpers/i18n"; +import I18n from "discourse-i18n"; +import ComboBox from "select-kit/components/combo-box"; +import DTooltip from "float-kit/components/d-tooltip"; + +export default class AiLlmEditor extends Component { + @service toasts; + @service router; + + @tracked isSaving = false; + + get selectedProviders() { + const t = (provName) => { + return I18n.t(`discourse_ai.llms.providers.${provName}`); + }; + + return this.args.llms.resultSetMeta.providers.map((prov) => { + return { id: prov, name: t(prov) }; + }); + } + + @action + async save() { + this.isSaving = true; + const isNew = this.args.model.isNew; + + try { + await this.args.model.save(); + + if (isNew) { + this.args.llms.addObject(this.args.model); + this.router.transitionTo( + "adminPlugins.show.discourse-ai-llms.show", + this.args.model + ); + } else { + this.toasts.success({ + data: { message: I18n.t("discourse_ai.llms.saved") }, + duration: 2000, + }); + } + } catch (e) { + popupAjaxError(e); + } finally { + later(() => { + this.isSaving = false; + }, 1000); + } + } + + +} diff --git a/assets/javascripts/discourse/components/ai-llms-list-editor.gjs b/assets/javascripts/discourse/components/ai-llms-list-editor.gjs new file mode 100644 index 00000000..73ba310e --- /dev/null +++ b/assets/javascripts/discourse/components/ai-llms-list-editor.gjs @@ -0,0 +1,61 @@ +import Component from "@glimmer/component"; +import { LinkTo } from "@ember/routing"; +import icon from "discourse-common/helpers/d-icon"; +import i18n from "discourse-common/helpers/i18n"; +import I18n from "discourse-i18n"; +import AiLlmEditor from "./ai-llm-editor"; + +export default class AiLlmsListEditor extends Component { + get hasNoLLMElements() { + this.args.llms.length !== 0; + } + + +} diff --git a/assets/javascripts/initializers/admin-plugin-configuration-nav.js b/assets/javascripts/initializers/admin-plugin-configuration-nav.js index 2c2bac59..51d88222 100644 --- a/assets/javascripts/initializers/admin-plugin-configuration-nav.js +++ b/assets/javascripts/initializers/admin-plugin-configuration-nav.js @@ -16,6 +16,10 @@ export default { label: "discourse_ai.ai_persona.short_title", route: "adminPlugins.show.discourse-ai-personas", }, + { + label: "discourse_ai.llms.short_title", + route: "adminPlugins.show.discourse-ai-llms", + }, ]); }); }, diff --git a/assets/stylesheets/modules/llms/common/ai-llms-editor.scss b/assets/stylesheets/modules/llms/common/ai-llms-editor.scss new file mode 100644 index 00000000..351deeef --- /dev/null +++ b/assets/stylesheets/modules/llms/common/ai-llms-editor.scss @@ -0,0 +1,32 @@ +.ai-llms-list-editor { + &__header { + display: flex; + justify-content: space-between; + align-items: center; + margin: 0 0 1em 0; + + h3 { + margin: 0; + } + } + + &__container { + display: flex; + flex-direction: row; + align-items: center; + gap: 20px; + width: 100%; + align-items: stretch; + } + + &__empty_list, + &__content_list { + min-width: 300px; + } + + &__empty_list { + align-content: center; + text-align: center; + font-size: var(--font-up-1); + } +} diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index 1c3d9c0c..044d00f3 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -194,6 +194,32 @@ en: uploading: "Uploading..." remove: "Remove upload" + llms: + short_title: "LLMs" + no_llms: "No LLMs yet" + new: "New Model" + display_name: "Name to display:" + name: "Model name:" + provider: "Service hosting the model:" + tokenizer: "Tokenizer:" + max_prompt_tokens: "Number of tokens for the prompt:" + save: "Save" + saved: "LLM Model Saved" + + hints: + max_prompt_tokens: "Max numbers of tokens for the prompt. As a rule of thumb, this should be 50% of the model's context window." + name: "We include this in the API call to specify which model we'll use." + + providers: + aws_bedrock: "AWS Bedrock" + anthropic: "Anthropic" + vllm: "vLLM" + hugging_face: "Hugging Face" + cohere: "Cohere" + open_ai: "OpenAI" + google: "Google" + azure: "Azure" + related_topics: title: "Related Topics" pill: "Related" diff --git a/config/routes.rb b/config/routes.rb index eb20cf98..f33dfd68 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -45,5 +45,10 @@ Discourse::Application.routes.draw do post "/ai-personas/files/upload", to: "discourse_ai/admin/ai_personas#upload_file" put "/ai-personas/:id/files/remove", to: "discourse_ai/admin/ai_personas#remove_file" get "/ai-personas/:id/files/status", to: "discourse_ai/admin/ai_personas#indexing_status_check" + + resources :ai_llms, + only: %i[index create show update], + path: "ai-llms", + controller: "discourse_ai/admin/ai_llms" end end diff --git a/db/migrate/20240504222307_create_llm_model_table.rb b/db/migrate/20240504222307_create_llm_model_table.rb new file mode 100644 index 00000000..96bcc3dd --- /dev/null +++ b/db/migrate/20240504222307_create_llm_model_table.rb @@ -0,0 +1,14 @@ +# frozen_string_literal: true + +class CreateLlmModelTable < ActiveRecord::Migration[7.0] + def change + create_table :llm_models do |t| + t.string :display_name + t.string :name, null: false + t.string :provider, null: false + t.string :tokenizer, null: false + t.integer :max_prompt_tokens, null: false + t.timestamps + end + end +end diff --git a/lib/ai_bot/bot.rb b/lib/ai_bot/bot.rb index 3e3c932e..d2ca71f5 100644 --- a/lib/ai_bot/bot.rb +++ b/lib/ai_bot/bot.rb @@ -150,59 +150,73 @@ module DiscourseAi def self.guess_model(bot_user) # HACK(roman): We'll do this until we define how we represent different providers in the bot settings - case bot_user.id - when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") - "aws_bedrock:claude-2" + guess = + case bot_user.id + when DiscourseAi::AiBot::EntryPoint::CLAUDE_V2_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-2") + "aws_bedrock:claude-2" + else + "anthropic:claude-2" + end + when DiscourseAi::AiBot::EntryPoint::GPT4_ID + "open_ai:gpt-4" + when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID + "open_ai:gpt-4-turbo" + when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID + "open_ai:gpt-3.5-turbo-16k" + when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID + mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" + if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model) + "vllm:#{mixtral_model}" + elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?( + mixtral_model, + ) + "hugging_face:#{mixtral_model}" + else + "ollama:mistral" + end + when DiscourseAi::AiBot::EntryPoint::GEMINI_ID + "google:gemini-pro" + when DiscourseAi::AiBot::EntryPoint::FAKE_ID + "fake:fake" + when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( + "claude-3-opus", + ) + "aws_bedrock:claude-3-opus" + else + "anthropic:claude-3-opus" + end + when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS + "cohere:command-r-plus" + when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( + "claude-3-sonnet", + ) + "aws_bedrock:claude-3-sonnet" + else + "anthropic:claude-3-sonnet" + end + when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_HAIKU_ID + if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( + "claude-3-haiku", + ) + "aws_bedrock:claude-3-haiku" + else + "anthropic:claude-3-haiku" + end else - "anthropic:claude-2" + nil end - when DiscourseAi::AiBot::EntryPoint::GPT4_ID - "open_ai:gpt-4" - when DiscourseAi::AiBot::EntryPoint::GPT4_TURBO_ID - "open_ai:gpt-4-turbo" - when DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID - "open_ai:gpt-3.5-turbo-16k" - when DiscourseAi::AiBot::EntryPoint::MIXTRAL_ID - mixtral_model = "mistralai/Mixtral-8x7B-Instruct-v0.1" - if DiscourseAi::Completions::Endpoints::Vllm.correctly_configured?(mixtral_model) - "vllm:#{mixtral_model}" - elsif DiscourseAi::Completions::Endpoints::HuggingFace.correctly_configured?( - mixtral_model, - ) - "hugging_face:#{mixtral_model}" - else - "ollama:mistral" - end - when DiscourseAi::AiBot::EntryPoint::GEMINI_ID - "google:gemini-pro" - when DiscourseAi::AiBot::EntryPoint::FAKE_ID - "fake:fake" - when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_OPUS_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-3-opus") - "aws_bedrock:claude-3-opus" - else - "anthropic:claude-3-opus" - end - when DiscourseAi::AiBot::EntryPoint::COHERE_COMMAND_R_PLUS - "cohere:command-r-plus" - when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_SONNET_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?( - "claude-3-sonnet", - ) - "aws_bedrock:claude-3-sonnet" - else - "anthropic:claude-3-sonnet" - end - when DiscourseAi::AiBot::EntryPoint::CLAUDE_3_HAIKU_ID - if DiscourseAi::Completions::Endpoints::AwsBedrock.correctly_configured?("claude-3-haiku") - "aws_bedrock:claude-3-haiku" - else - "anthropic:claude-3-haiku" - end - else - nil + + if guess + provider, model_name = guess.split(":") + llm_model = LlmModel.find_by(provider: provider, name: model_name) + + return "custom:#{llm_model.id}" if llm_model end + + guess end def build_placeholder(summary, details, custom_raw: nil) diff --git a/lib/automation.rb b/lib/automation.rb index b755f1db..71cde97b 100644 --- a/lib/automation.rb +++ b/lib/automation.rb @@ -25,6 +25,9 @@ module DiscourseAi ] def self.translate_model(model) + llm_model = LlmModel.find_by(name: model) + return "custom:#{llm_model.id}" if llm_model + return "google:#{model}" if model.start_with? "gemini" return "open_ai:#{model}" if model.start_with? "gpt" return "cohere:#{model}" if model.start_with? "command" diff --git a/lib/completions/dialects/chat_gpt.rb b/lib/completions/dialects/chat_gpt.rb index 9383019c..f6142d09 100644 --- a/lib/completions/dialects/chat_gpt.rb +++ b/lib/completions/dialects/chat_gpt.rb @@ -30,6 +30,8 @@ module DiscourseAi end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + # provide a buffer of 120 tokens - our function counting is not # 100% accurate and getting numbers to align exactly is very hard buffer = (opts[:max_tokens] || 2500) + 50 diff --git a/lib/completions/dialects/claude.rb b/lib/completions/dialects/claude.rb index 9a15b293..8be67e54 100644 --- a/lib/completions/dialects/claude.rb +++ b/lib/completions/dialects/claude.rb @@ -50,6 +50,7 @@ module DiscourseAi end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? # Longer term it will have over 1 million 200_000 # Claude-3 has a 200k context window for now end diff --git a/lib/completions/dialects/command.rb b/lib/completions/dialects/command.rb index f119aba8..62240372 100644 --- a/lib/completions/dialects/command.rb +++ b/lib/completions/dialects/command.rb @@ -38,6 +38,8 @@ module DiscourseAi end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + case model_name when "command-light" 4096 @@ -62,10 +64,6 @@ module DiscourseAi self.class.tokenizer.size(context[:content].to_s + context[:name].to_s) end - def tools_dialect - @tools_dialect ||= DiscourseAi::Completions::Dialects::XmlTools.new(prompt.tools) - end - def system_msg(msg) cmd_msg = { role: "SYSTEM", message: msg[:content] } diff --git a/lib/completions/dialects/dialect.rb b/lib/completions/dialects/dialect.rb index 865b5509..13275ad4 100644 --- a/lib/completions/dialects/dialect.rb +++ b/lib/completions/dialects/dialect.rb @@ -9,14 +9,22 @@ module DiscourseAi raise NotImplemented end - def dialect_for(model_name) - dialects = [ + def all_dialects + [ DiscourseAi::Completions::Dialects::ChatGpt, DiscourseAi::Completions::Dialects::Gemini, DiscourseAi::Completions::Dialects::Mistral, DiscourseAi::Completions::Dialects::Claude, DiscourseAi::Completions::Dialects::Command, ] + end + + def available_tokenizers + all_dialects.map(&:tokenizer) + end + + def dialect_for(model_name) + dialects = all_dialects if Rails.env.test? || Rails.env.development? dialects << DiscourseAi::Completions::Dialects::Fake diff --git a/lib/completions/dialects/gemini.rb b/lib/completions/dialects/gemini.rb index 678dc0cd..fde9cdda 100644 --- a/lib/completions/dialects/gemini.rb +++ b/lib/completions/dialects/gemini.rb @@ -68,6 +68,8 @@ module DiscourseAi end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + if model_name == "gemini-1.5-pro" # technically we support 1 million tokens, but we're being conservative 800_000 diff --git a/lib/completions/dialects/mistral.rb b/lib/completions/dialects/mistral.rb index 7752a876..d34130f9 100644 --- a/lib/completions/dialects/mistral.rb +++ b/lib/completions/dialects/mistral.rb @@ -23,6 +23,8 @@ module DiscourseAi end def max_prompt_tokens + return opts[:max_prompt_tokens] if opts.dig(:max_prompt_tokens).present? + 32_000 end diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 47190644..a2f70172 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -18,6 +18,14 @@ module DiscourseAi UNKNOWN_MODEL = Class.new(StandardError) class << self + def provider_names + %w[aws_bedrock anthropic vllm hugging_face cohere open_ai google azure] + end + + def tokenizer_names + DiscourseAi::Completions::Dialects::Dialect.available_tokenizers.map(&:name).uniq + end + def models_by_provider # ChatGPT models are listed under open_ai but they are actually available through OpenAI and Azure. # However, since they use the same URL/key settings, there's no reason to duplicate them. @@ -80,36 +88,54 @@ module DiscourseAi end def proxy(model_name) + # We are in the process of transitioning to always use objects here. + # We'll live with this hack for a while. provider_and_model_name = model_name.split(":") - provider_name = provider_and_model_name.first model_name_without_prov = provider_and_model_name[1..].join + is_custom_model = provider_name == "custom" + + if is_custom_model + llm_model = LlmModel.find(model_name_without_prov) + provider_name = llm_model.provider + model_name_without_prov = llm_model.name + end dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name_without_prov) + if is_custom_model + tokenizer = llm_model.tokenizer_class + else + tokenizer = dialect_klass.tokenizer + end + if @canned_response if @canned_llm && @canned_llm != model_name raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}" end - return new(dialect_klass, nil, model_name, gateway: @canned_response) + return new(dialect_klass, nil, model_name, opts: { gateway: @canned_response }) end + opts = {} + opts[:max_prompt_tokens] = llm_model.max_prompt_tokens if is_custom_model + gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for( provider_name, model_name_without_prov, ) - new(dialect_klass, gateway_klass, model_name_without_prov) + new(dialect_klass, gateway_klass, model_name_without_prov, opts: opts) end end - def initialize(dialect_klass, gateway_klass, model_name, gateway: nil) + def initialize(dialect_klass, gateway_klass, model_name, opts: {}) @dialect_klass = dialect_klass @gateway_klass = gateway_klass @model_name = model_name - @gateway = gateway + @gateway = opts[:gateway] + @max_prompt_tokens = opts[:max_prompt_tokens] end # @param generic_prompt { DiscourseAi::Completions::Prompt } - Our generic prompt object @@ -166,11 +192,18 @@ module DiscourseAi model_params.keys.each { |key| model_params.delete(key) if model_params[key].nil? } gateway = @gateway || gateway_klass.new(model_name, dialect_klass.tokenizer) - dialect = dialect_klass.new(prompt, model_name, opts: model_params) + dialect = + dialect_klass.new( + prompt, + model_name, + opts: model_params.merge(max_prompt_tokens: @max_prompt_tokens), + ) gateway.perform_completion!(dialect, user, model_params, &partial_read_blk) end def max_prompt_tokens + return @max_prompt_tokens if @max_prompt_tokens.present? + dialect_klass.new(DiscourseAi::Completions::Prompt.new(""), model_name).max_prompt_tokens end diff --git a/lib/configuration/llm_enumerator.rb b/lib/configuration/llm_enumerator.rb index d7618795..fd870b57 100644 --- a/lib/configuration/llm_enumerator.rb +++ b/lib/configuration/llm_enumerator.rb @@ -10,14 +10,22 @@ module DiscourseAi end def self.values - # do not cache cause settings can change this - DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models| - endpoint = - DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first) + begin + llm_models = + DiscourseAi::Completions::Llm.models_by_provider.flat_map do |provider, models| + endpoint = + DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider.to_s, models.first) - models.map do |model_name| - { name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" } + models.map do |model_name| + { name: endpoint.display_name(model_name), value: "#{provider}:#{model_name}" } + end + end + + LlmModel.all.each do |model| + llm_models << { name: model.display_name, value: "custom:#{model.id}" } end + + llm_models end end end diff --git a/plugin.rb b/plugin.rb index ad255b70..c5a571ba 100644 --- a/plugin.rb +++ b/plugin.rb @@ -26,6 +26,8 @@ register_asset "stylesheets/modules/sentiment/common/dashboard.scss" register_asset "stylesheets/modules/sentiment/desktop/dashboard.scss", :desktop register_asset "stylesheets/modules/sentiment/mobile/dashboard.scss", :mobile +register_asset "stylesheets/modules/llms/common/ai-llms-editor.scss" + module ::DiscourseAi PLUGIN_NAME = "discourse-ai" end diff --git a/spec/fabricators/llm_model_fabricator.rb b/spec/fabricators/llm_model_fabricator.rb new file mode 100644 index 00000000..c419341e --- /dev/null +++ b/spec/fabricators/llm_model_fabricator.rb @@ -0,0 +1,9 @@ +# frozen_string_literal: true + +Fabricator(:llm_model) do + display_name "A good model" + name "gpt-4-turbo" + provider "open_ai" + tokenizer "DiscourseAi::Tokenizers::OpenAi" + max_prompt_tokens 32_000 +end diff --git a/spec/lib/completions/endpoints/hugging_face_spec.rb b/spec/lib/completions/endpoints/hugging_face_spec.rb index 5b9bd9f5..f84fc43a 100644 --- a/spec/lib/completions/endpoints/hugging_face_spec.rb +++ b/spec/lib/completions/endpoints/hugging_face_spec.rb @@ -81,7 +81,10 @@ end RSpec.describe DiscourseAi::Completions::Endpoints::HuggingFace do subject(:endpoint) do - described_class.new("Llama2-*-chat-hf", DiscourseAi::Tokenizer::Llama2Tokenizer) + described_class.new( + "mistralai/Mistral-7B-Instruct-v0.2", + DiscourseAi::Tokenizer::MixtralTokenizer, + ) end before { SiteSetting.ai_hugging_face_api_url = "https://test.dev" } diff --git a/spec/lib/completions/llm_spec.rb b/spec/lib/completions/llm_spec.rb index e91a4a60..c3107dc0 100644 --- a/spec/lib/completions/llm_spec.rb +++ b/spec/lib/completions/llm_spec.rb @@ -6,7 +6,9 @@ RSpec.describe DiscourseAi::Completions::Llm do DiscourseAi::Completions::Dialects::Mistral, canned_response, "hugging_face:Upstage-Llama-2-*-instruct-v2", - gateway: canned_response, + opts: { + gateway: canned_response, + }, ) end diff --git a/spec/requests/admin/ai_llms_controller_spec.rb b/spec/requests/admin/ai_llms_controller_spec.rb new file mode 100644 index 00000000..e4747834 --- /dev/null +++ b/spec/requests/admin/ai_llms_controller_spec.rb @@ -0,0 +1,68 @@ +# frozen_string_literal: true + +RSpec.describe DiscourseAi::Admin::AiLlmsController do + fab!(:admin) + + before { sign_in(admin) } + + describe "GET #index" do + it "includes all available providers metadata" do + get "/admin/plugins/discourse-ai/ai-llms.json" + expect(response).to be_successful + + expect(response.parsed_body["meta"]["providers"]).to contain_exactly( + *DiscourseAi::Completions::Llm.provider_names, + ) + end + end + + describe "POST #create" do + context "with valid attributes" do + let(:valid_attrs) do + { + display_name: "My cool LLM", + name: "gpt-3.5", + provider: "open_ai", + tokenizer: "DiscourseAi::Tokenizers::OpenAiTokenizer", + max_prompt_tokens: 16_000, + } + end + + it "creates a new LLM model" do + post "/admin/plugins/discourse-ai/ai-llms.json", params: { ai_llm: valid_attrs } + + created_model = LlmModel.last + + expect(created_model.display_name).to eq(valid_attrs[:display_name]) + expect(created_model.name).to eq(valid_attrs[:name]) + expect(created_model.provider).to eq(valid_attrs[:provider]) + expect(created_model.tokenizer).to eq(valid_attrs[:tokenizer]) + expect(created_model.max_prompt_tokens).to eq(valid_attrs[:max_prompt_tokens]) + end + end + end + + describe "PUT #update" do + fab!(:llm_model) + + context "with valid update params" do + let(:update_attrs) { { provider: "anthropic" } } + + it "updates the model" do + put "/admin/plugins/discourse-ai/ai-llms/#{llm_model.id}.json", + params: { + ai_llm: update_attrs, + } + + expect(response.status).to eq(200) + expect(llm_model.reload.provider).to eq(update_attrs[:provider]) + end + + it "returns a 404 if there is no model with the given Id" do + put "/admin/plugins/discourse-ai/ai-llms/9999999.json" + + expect(response.status).to eq(404) + end + end + end +end