From 5cb91217bd46f3b1b388d4618e179dc7cd20faa4 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Mon, 8 Jul 2024 18:47:10 -0300 Subject: [PATCH] FIX: Flaky SRV-backed model seeding. (#708) * Seeding the SRV-backed model should happen inside an initializer. * Keep the model up to date when the hidden setting changes. * Use the correct Mixtral model name and fix previous data migration. * URL validation should trigger only when we attempt to update it. --- app/models/llm_model.rb | 42 +++++++----- db/fixtures/602_srv_backed_llm_model.rb | 8 +++ .../20240708193243_fix_vllm_model_name.rb | 18 ++++++ plugin.rb | 6 +- spec/models/llm_model_spec.rb | 64 +++++++++++++++++++ 5 files changed, 121 insertions(+), 17 deletions(-) create mode 100644 db/fixtures/602_srv_backed_llm_model.rb create mode 100644 db/post_migrate/20240708193243_fix_vllm_model_name.rb create mode 100644 spec/models/llm_model_spec.rb diff --git a/app/models/llm_model.rb b/app/models/llm_model.rb index a818d6b2..180a7a84 100644 --- a/app/models/llm_model.rb +++ b/app/models/llm_model.rb @@ -7,24 +7,36 @@ class LlmModel < ActiveRecord::Base belongs_to :user - validates :url, exclusion: { in: [RESERVED_VLLM_SRV_URL] } + validates :url, exclusion: { in: [RESERVED_VLLM_SRV_URL] }, if: :url_changed? - def self.enable_or_disable_srv_llm! + def self.seed_srv_backed_model + srv = SiteSetting.ai_vllm_endpoint_srv srv_model = find_by(url: RESERVED_VLLM_SRV_URL) - if SiteSetting.ai_vllm_endpoint_srv.present? && srv_model.blank? - record = - new( - display_name: "vLLM SRV LLM", - name: "mistralai/Mixtral", - provider: "vllm", - tokenizer: "DiscourseAi::Tokenizer::MixtralTokenizer", - url: RESERVED_VLLM_SRV_URL, - max_prompt_tokens: 8000, - ) - record.save(validate: false) # Ignore reserved URL validation - elsif srv_model.present? - srv_model.destroy! + if srv.present? + if srv_model.present? + current_key = SiteSetting.ai_vllm_api_key + srv_model.update!(api_key: current_key) if current_key != srv_model.api_key + else + record = + new( + display_name: "vLLM SRV LLM", + name: "mistralai/Mixtral-8x7B-Instruct-v0.1", + provider: "vllm", + tokenizer: "DiscourseAi::Tokenizer::MixtralTokenizer", + url: RESERVED_VLLM_SRV_URL, + max_prompt_tokens: 8000, + api_key: SiteSetting.ai_vllm_api_key, + ) + + record.save(validate: false) # Ignore reserved URL validation + end + else + # Clean up companion users + srv_model&.enabled_chat_bot = false + srv_model&.toggle_companion_user + + srv_model&.destroy! end end diff --git a/db/fixtures/602_srv_backed_llm_model.rb b/db/fixtures/602_srv_backed_llm_model.rb new file mode 100644 index 00000000..9be19266 --- /dev/null +++ b/db/fixtures/602_srv_backed_llm_model.rb @@ -0,0 +1,8 @@ +# frozen_string_literal: true + +begin + LlmModel.seed_srv_backed_model +rescue PG::UndefinedColumn => e + # If this code runs before migrations, an attribute might be missing. + Rails.logger.warn("Failed to seed SRV-Backed LLM: #{e.meesage}") +end diff --git a/db/post_migrate/20240708193243_fix_vllm_model_name.rb b/db/post_migrate/20240708193243_fix_vllm_model_name.rb new file mode 100644 index 00000000..b5d715ce --- /dev/null +++ b/db/post_migrate/20240708193243_fix_vllm_model_name.rb @@ -0,0 +1,18 @@ +# frozen_string_literal: true +class FixVllmModelName < ActiveRecord::Migration[7.1] + def up + vllm_mixtral_model_id = DB.query_single(<<~SQL).first + SELECT id FROM llm_models WHERE name = 'mistralai/Mixtral' + SQL + + DB.exec(<<~SQL, target_id: vllm_mixtral_model_id) if vllm_mixtral_model_id + UPDATE llm_models + SET name = 'mistralai/Mixtral-8x7B-Instruct-v0.1' + WHERE id = :target_id + SQL + end + + def down + raise ActiveRecord::IrreversibleMigration + end +end diff --git a/plugin.rb b/plugin.rb index 5981d0e3..1b5ad2f9 100644 --- a/plugin.rb +++ b/plugin.rb @@ -47,8 +47,6 @@ after_initialize do add_admin_route("discourse_ai.title", "discourse-ai", { use_new_show_route: true }) - LlmModel.enable_or_disable_srv_llm! - [ DiscourseAi::Embeddings::EntryPoint.new, DiscourseAi::Nsfw::EntryPoint.new, @@ -82,4 +80,8 @@ after_initialize do nil end end + + on(:site_setting_changed) do |name, _old_value, _new_value| + LlmModel.seed_srv_backed_model if name == :ai_vllm_endpoint_srv || name == :ai_vllm_api_key + end end diff --git a/spec/models/llm_model_spec.rb b/spec/models/llm_model_spec.rb new file mode 100644 index 00000000..c828744b --- /dev/null +++ b/spec/models/llm_model_spec.rb @@ -0,0 +1,64 @@ +# frozen_string_literal: true + +RSpec.describe LlmModel do + describe ".seed_srv_backed_model" do + before do + SiteSetting.ai_vllm_endpoint_srv = "srv.llm.service." + SiteSetting.ai_vllm_api_key = "123" + end + + context "when the model doesn't exist yet" do + it "creates it" do + described_class.seed_srv_backed_model + + llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL) + + expect(llm_model).to be_present + expect(llm_model.name).to eq("mistralai/Mixtral-8x7B-Instruct-v0.1") + expect(llm_model.api_key).to eq(SiteSetting.ai_vllm_api_key) + end + end + + context "when the model already exists" do + before { described_class.seed_srv_backed_model } + + context "when the API key setting changes" do + it "updates it" do + new_key = "456" + SiteSetting.ai_vllm_api_key = new_key + + described_class.seed_srv_backed_model + + llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL) + + expect(llm_model.api_key).to eq(new_key) + end + end + + context "when the SRV is no longer defined" do + it "deletes the LlmModel" do + llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL) + expect(llm_model).to be_present + + SiteSetting.ai_vllm_endpoint_srv = "" # Triggers seed code + + expect { llm_model.reload }.to raise_exception(ActiveRecord::RecordNotFound) + end + + it "disabled the bot user" do + SiteSetting.ai_bot_enabled = true + llm_model = described_class.find_by(url: described_class::RESERVED_VLLM_SRV_URL) + llm_model.update!(enabled_chat_bot: true) + llm_model.toggle_companion_user + user = llm_model.user + + expect(user).to be_present + + SiteSetting.ai_vllm_endpoint_srv = "" # Triggers seed code + + expect { user.reload }.to raise_exception(ActiveRecord::RecordNotFound) + end + end + end + end +end