diff --git a/app/controllers/discourse_ai/admin/ai_llms_controller.rb b/app/controllers/discourse_ai/admin/ai_llms_controller.rb index c7e5d5cf..897fb152 100644 --- a/app/controllers/discourse_ai/admin/ai_llms_controller.rb +++ b/app/controllers/discourse_ai/admin/ai_llms_controller.rb @@ -49,6 +49,22 @@ module DiscourseAi end end + def test + RateLimiter.new(current_user, "llm_test_#{current_user.id}", 3, 1.minute).performed! + + llm_model = LlmModel.new(ai_llm_params) + + DiscourseAi::Completions::Llm.proxy_from_obj(llm_model).generate( + "How much is 1 + 1?", + user: current_user, + feature_name: "llm_validator", + ) + + render json: { success: true } + rescue DiscourseAi::Completions::Endpoints::Base::CompletionFailed => e + render json: { success: false, error: e.message } + end + private def ai_llm_params diff --git a/app/models/llm_model.rb b/app/models/llm_model.rb index aefb9202..08055802 100644 --- a/app/models/llm_model.rb +++ b/app/models/llm_model.rb @@ -5,3 +5,19 @@ class LlmModel < ActiveRecord::Base tokenizer.constantize end end + +# == Schema Information +# +# Table name: llm_models +# +# id :bigint not null, primary key +# display_name :string +# name :string not null +# provider :string not null +# tokenizer :string not null +# max_prompt_tokens :integer not null +# created_at :datetime not null +# updated_at :datetime not null +# url :string +# api_key :string +# diff --git a/assets/javascripts/discourse/admin/models/ai-llm.js b/assets/javascripts/discourse/admin/models/ai-llm.js index 59d76daa..e46bf8f2 100644 --- a/assets/javascripts/discourse/admin/models/ai-llm.js +++ b/assets/javascripts/discourse/admin/models/ai-llm.js @@ -1,3 +1,4 @@ +import { ajax } from "discourse/lib/ajax"; import RestModel from "discourse/models/rest"; export default class AiLlm extends RestModel { @@ -20,4 +21,10 @@ export default class AiLlm extends RestModel { return attrs; } + + async testConfig() { + return await ajax(`/admin/plugins/discourse-ai/ai-llms/test.json`, { + data: { ai_llm: this.createProperties() }, + }); + } } diff --git a/assets/javascripts/discourse/components/ai-llm-editor.gjs b/assets/javascripts/discourse/components/ai-llm-editor.gjs index 5fa79aef..dd92abc2 100644 --- a/assets/javascripts/discourse/components/ai-llm-editor.gjs +++ b/assets/javascripts/discourse/components/ai-llm-editor.gjs @@ -7,6 +7,7 @@ import { inject as service } from "@ember/service"; import BackButton from "discourse/components/back-button"; import DButton from "discourse/components/d-button"; import { popupAjaxError } from "discourse/lib/ajax-error"; +import icon from "discourse-common/helpers/d-icon"; import i18n from "discourse-common/helpers/i18n"; import I18n from "discourse-i18n"; import ComboBox from "select-kit/components/combo-box"; @@ -18,6 +19,10 @@ export default class AiLlmEditor extends Component { @tracked isSaving = false; + @tracked testRunning = false; + @tracked testResult = null; + @tracked testError = null; + get selectedProviders() { const t = (provName) => { return I18n.t(`discourse_ai.llms.providers.${provName}`); @@ -59,6 +64,36 @@ export default class AiLlmEditor extends Component { } } + @action + async test() { + this.testRunning = true; + + try { + const configTestResult = await this.args.model.testConfig(); + this.testResult = configTestResult.success; + + if (this.testResult) { + this.testError = null; + } else { + this.testError = configTestResult.error; + } + } catch (e) { + popupAjaxError(e); + } finally { + later(() => { + this.testRunning = false; + }, 1000); + } + } + + get testErrorMessage() { + return I18n.t("discourse_ai.llms.tests.failure", { error: this.testError }); + } + + get displayTestResult() { + return this.testRunning || this.testResult !== null; + } + } diff --git a/assets/stylesheets/modules/llms/common/ai-llms-editor.scss b/assets/stylesheets/modules/llms/common/ai-llms-editor.scss index f2c30a07..f6a3d495 100644 --- a/assets/stylesheets/modules/llms/common/ai-llms-editor.scss +++ b/assets/stylesheets/modules/llms/common/ai-llms-editor.scss @@ -22,4 +22,14 @@ padding-left: 0.25em; color: var(--primary-medium); } + + .ai-llm-editor-tests { + &__failure { + color: var(--danger); + } + + &__success { + color: var(--success); + } + } } diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index e29668d6..38832723 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -210,6 +210,11 @@ en: edit: "Edit" saved: "LLM Model Saved" back: "Back" + tests: + title: "Run Test" + running: "Running test..." + success: "Success!" + failure: "Trying to contact the model returned this error: %{error}" hints: max_prompt_tokens: "Max numbers of tokens for the prompt. As a rule of thumb, this should be 50% of the model's context window." diff --git a/config/routes.rb b/config/routes.rb index f33dfd68..d1cc9ca1 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -49,6 +49,8 @@ Discourse::Application.routes.draw do resources :ai_llms, only: %i[index create show update], path: "ai-llms", - controller: "discourse_ai/admin/ai_llms" + controller: "discourse_ai/admin/ai_llms" do + collection { get :test } + end end end diff --git a/lib/completions/endpoints/base.rb b/lib/completions/endpoints/base.rb index 8e3db103..0df0ff0b 100644 --- a/lib/completions/endpoints/base.rb +++ b/lib/completions/endpoints/base.rb @@ -108,7 +108,7 @@ module DiscourseAi Rails.logger.error( "#{self.class.name}: status: #{response.code.to_i} - body: #{response.body}", ) - raise CompletionFailed + raise CompletionFailed, response.body end log = diff --git a/lib/completions/llm.rb b/lib/completions/llm.rb index 570e5e66..531dacfa 100644 --- a/lib/completions/llm.rb +++ b/lib/completions/llm.rb @@ -124,6 +124,15 @@ module DiscourseAi model_name = llm_model.name dialect_klass = DiscourseAi::Completions::Dialects::Dialect.dialect_for(model_name) + + if @canned_response + if @canned_llm && @canned_llm != model_name + raise "Invalid call LLM call, expected #{@canned_llm} but got #{model_name}" + end + + return new(dialect_klass, nil, model_name, gateway: @canned_response) + end + gateway_klass = DiscourseAi::Completions::Endpoints::Base.endpoint_for(provider_name) new(dialect_klass, gateway_klass, model_name, llm_model: llm_model) diff --git a/spec/requests/admin/ai_llms_controller_spec.rb b/spec/requests/admin/ai_llms_controller_spec.rb index e4747834..8bad4b05 100644 --- a/spec/requests/admin/ai_llms_controller_spec.rb +++ b/spec/requests/admin/ai_llms_controller_spec.rb @@ -23,7 +23,9 @@ RSpec.describe DiscourseAi::Admin::AiLlmsController do display_name: "My cool LLM", name: "gpt-3.5", provider: "open_ai", - tokenizer: "DiscourseAi::Tokenizers::OpenAiTokenizer", + url: "https://test.test/v1/chat/completions", + api_key: "test", + tokenizer: "DiscourseAi::Tokenizer::OpenAiTokenizer", max_prompt_tokens: 16_000, } end @@ -65,4 +67,49 @@ RSpec.describe DiscourseAi::Admin::AiLlmsController do end end end + + describe "GET #test" do + let(:test_attrs) do + { + name: "llama3", + provider: "hugging_face", + url: "https://test.test/v1/chat/completions", + api_key: "test", + tokenizer: "DiscourseAi::Tokenizer::Llama3Tokenizer", + max_prompt_tokens: 2_000, + } + end + + context "when we can contact the model" do + it "returns a success true flag" do + DiscourseAi::Completions::Llm.with_prepared_responses(["a response"]) do + get "/admin/plugins/discourse-ai/ai-llms/test.json", params: { ai_llm: test_attrs } + + expect(response).to be_successful + expect(response.parsed_body["success"]).to eq(true) + end + end + end + + context "when we cannot contact the model" do + it "returns a success false flag and the error message" do + error_message = { + error: + "Input validation error: `inputs` tokens + `max_new_tokens` must be <= 1512. Given: 30 `inputs` tokens and 3984 `max_new_tokens`", + error_type: "validation", + } + + WebMock.stub_request(:post, test_attrs[:url]).to_return( + status: 422, + body: error_message.to_json, + ) + + get "/admin/plugins/discourse-ai/ai-llms/test.json", params: { ai_llm: test_attrs } + + expect(response).to be_successful + expect(response.parsed_body["success"]).to eq(false) + expect(response.parsed_body["error"]).to eq(error_message.to_json) + end + end + end end