diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb index c3e08f4c..76a3b3b2 100644 --- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb +++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb @@ -8,17 +8,6 @@ module DiscourseAi before_action :ensure_can_request_suggestions before_action :rate_limiter_performed!, except: %i[prompts] - def prompts - name_filter = params[:name_filter] - - render json: - ActiveModel::ArraySerializer.new( - DiscourseAi::AiHelper::Assistant.new.available_prompts(name_filter: name_filter), - root: false, - ), - status: 200 - end - def suggest input = get_text_param! diff --git a/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.js b/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.js index 2b4a3fd9..07a32f03 100644 --- a/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.js +++ b/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.js @@ -19,7 +19,6 @@ export default class AiHelperContextMenu extends Component { @service siteSettings; @service modal; @service capabilities; - @tracked helperOptions = []; @tracked showContextMenu = false; @tracked caretCoords; @tracked virtualElement; @@ -56,15 +55,6 @@ export default class AiHelperContextMenu extends Component { @tracked _contextMenu; @tracked _activeAIRequest = null; - constructor() { - super(...arguments); - - // Fetch prompts only if it hasn't been fetched yet - if (this.helperOptions.length === 0) { - this.loadPrompts(); - } - } - willDestroy() { super.willDestroy(...arguments); document.removeEventListener("selectionchange", this.selectionChanged); @@ -81,8 +71,8 @@ export default class AiHelperContextMenu extends Component { this._menuState = newState; } - async loadPrompts() { - let prompts = await ajax("/discourse-ai/ai-helper/prompts"); + get helperOptions() { + let prompts = this.currentUser?.ai_helper_prompts; prompts = prompts .filter((p) => p.location.includes("composer")) @@ -109,7 +99,7 @@ export default class AiHelperContextMenu extends Component { memo[p.name] = p.prompt_type; return memo; }, {}); - this.helperOptions = prompts; + return prompts; } @bind @@ -338,10 +328,6 @@ export default class AiHelperContextMenu extends Component { @action toggleAiHelperOptions() { - // Fetch prompts only if it hasn't been fetched yet - if (this.helperOptions.length === 0) { - this.loadPrompts(); - } this.menuState = this.CONTEXT_MENU_STATES.options; } diff --git a/assets/javascripts/discourse/connectors/fast-edit-footer-after/ai-edit-suggestion-button.gjs b/assets/javascripts/discourse/connectors/fast-edit-footer-after/ai-edit-suggestion-button.gjs index 1531d310..c09bd43f 100644 --- a/assets/javascripts/discourse/connectors/fast-edit-footer-after/ai-edit-suggestion-button.gjs +++ b/assets/javascripts/discourse/connectors/fast-edit-footer-after/ai-edit-suggestion-button.gjs @@ -1,6 +1,7 @@ import Component from "@glimmer/component"; import { tracked } from "@glimmer/tracking"; import { action } from "@ember/object"; +import { inject as service } from "@ember/service"; import DButton from "discourse/components/d-button"; import { ajax } from "discourse/lib/ajax"; import { popupAjaxError } from "discourse/lib/ajax-error"; @@ -11,31 +12,19 @@ export default class AiEditSuggestionButton extends Component { return showPostAIHelper(outletArgs, helper); } + @service currentUser; @tracked loading = false; @tracked suggestion = ""; @tracked _activeAIRequest = null; - constructor() { - super(...arguments); - - if (!this.mode) { - this.loadMode(); - } - } - get disabled() { return this.loading || this.suggestion?.length > 0; } - async loadMode() { - let mode = await ajax("/discourse-ai/ai-helper/prompts", { - method: "GET", - data: { - name_filter: "proofread", - }, - }); - - this.mode = mode[0]; + get mode() { + return this.currentUser?.ai_helper_prompts.find( + (prompt) => prompt.name === "proofread" + ); } @action diff --git a/assets/javascripts/discourse/connectors/post-text-buttons/ai-helper-options-menu.gjs b/assets/javascripts/discourse/connectors/post-text-buttons/ai-helper-options-menu.gjs index 4e35bb2f..d00fa780 100644 --- a/assets/javascripts/discourse/connectors/post-text-buttons/ai-helper-options-menu.gjs +++ b/assets/javascripts/discourse/connectors/post-text-buttons/ai-helper-options-menu.gjs @@ -31,7 +31,6 @@ export default class AIHelperOptionsMenu extends Component { @service currentUser; @service menu; - @tracked helperOptions = []; @tracked menuState = this.MENU_STATES.triggers; @tracked loading = false; @tracked suggestion = ""; @@ -51,14 +50,6 @@ export default class AIHelperOptionsMenu extends Component { @tracked _activeAIRequest = null; - constructor() { - super(...arguments); - - if (this.helperOptions.length === 0) { - this.loadPrompts(); - } - } - @action async showAIHelperOptions() { this.showMainButtons = false; @@ -168,8 +159,8 @@ export default class AIHelperOptionsMenu extends Component { } } - async loadPrompts() { - let prompts = await ajax("/discourse-ai/ai-helper/prompts"); + get helperOptions() { + let prompts = this.currentUser?.ai_helper_prompts; prompts = prompts.filter((item) => item.location.includes("post")); @@ -191,7 +182,7 @@ export default class AIHelperOptionsMenu extends Component { prompts = prompts.filter((p) => p.name !== "proofread"); } - this.helperOptions = prompts; + return prompts; } _showUserCustomPrompts() { diff --git a/config/routes.rb b/config/routes.rb index 54d6820f..d91bc4d3 100644 --- a/config/routes.rb +++ b/config/routes.rb @@ -2,7 +2,6 @@ DiscourseAi::Engine.routes.draw do scope module: :ai_helper, path: "/ai-helper", defaults: { format: :json } do - get "prompts" => "assistant#prompts" post "suggest" => "assistant#suggest" post "suggest_title" => "assistant#suggest_title" post "suggest_category" => "assistant#suggest_category" diff --git a/lib/ai_helper/assistant.rb b/lib/ai_helper/assistant.rb index 0c4078b5..4fa8efe4 100644 --- a/lib/ai_helper/assistant.rb +++ b/lib/ai_helper/assistant.rb @@ -3,35 +3,37 @@ module DiscourseAi module AiHelper class Assistant - def available_prompts(name_filter: nil) - cp = CompletionPrompt - prompts = [] + AI_HELPER_PROMPTS_CACHE_KEY = "ai_helper_prompts" - if name_filter - prompts = [cp.enabled_by_name(name_filter)] - else - prompts = cp.where(enabled: true) - # Hide illustrate_post if disabled - prompts = - prompts.where.not( - name: "illustrate_post", - ) if SiteSetting.ai_helper_illustrate_post_model == "disabled" - end + def available_prompts + Discourse + .cache + .fetch(AI_HELPER_PROMPTS_CACHE_KEY) do + prompts = CompletionPrompt.where(enabled: true) - prompts.map do |prompt| - translation = - I18n.t("discourse_ai.ai_helper.prompts.#{prompt.name}", default: nil) || - prompt.translated_name || prompt.name + # Hide illustrate_post if disabled + prompts = + prompts.where.not( + name: "illustrate_post", + ) if SiteSetting.ai_helper_illustrate_post_model == "disabled" - { - id: prompt.id, - name: prompt.name, - translated_name: translation, - prompt_type: prompt.prompt_type, - icon: icon_map(prompt.name), - location: location_map(prompt.name), - } - end + prompts = + prompts.map do |prompt| + translation = + I18n.t("discourse_ai.ai_helper.prompts.#{prompt.name}", default: nil) || + prompt.translated_name || prompt.name + + { + id: prompt.id, + name: prompt.name, + translated_name: translation, + prompt_type: prompt.prompt_type, + icon: icon_map(prompt.name), + location: location_map(prompt.name), + } + end + prompts + end end def generate_prompt(completion_prompt, input, user, &block) diff --git a/lib/ai_helper/entry_point.rb b/lib/ai_helper/entry_point.rb index 87ebca24..840d2b55 100644 --- a/lib/ai_helper/entry_point.rb +++ b/lib/ai_helper/entry_point.rb @@ -19,6 +19,20 @@ module DiscourseAi thread_id: thread.id, ) end + + plugin.add_to_serializer( + :current_user, + :ai_helper_prompts, + include_condition: -> do + SiteSetting.composer_ai_helper_enabled && scope.authenticated? && + scope.user.in_any_groups?(SiteSetting.ai_helper_allowed_groups_map) + end, + ) do + ActiveModel::ArraySerializer.new( + DiscourseAi::AiHelper::Assistant.new.available_prompts, + root: false, + ) + end end end end diff --git a/spec/lib/modules/ai_helper/assistant_spec.rb b/spec/lib/modules/ai_helper/assistant_spec.rb index 875ae6a9..c7e59c96 100644 --- a/spec/lib/modules/ai_helper/assistant_spec.rb +++ b/spec/lib/modules/ai_helper/assistant_spec.rb @@ -13,33 +13,30 @@ RSpec.describe DiscourseAi::AiHelper::Assistant do STRING describe("#available_prompts") do - context "when no name filter is provided" do - it "returns all available prompts" do - prompts = subject.available_prompts - - expect(prompts.length).to eq(6) - expect(prompts.map { |p| p[:name] }).to contain_exactly( - "translate", - "generate_titles", - "proofread", - "markdown_table", - "custom_prompt", - "explain", - ) - end + before do + SiteSetting.ai_helper_illustrate_post_model = "disabled" + Discourse.cache.delete(DiscourseAi::AiHelper::Assistant::AI_HELPER_PROMPTS_CACHE_KEY) end - context "when name filter is provided" do - it "returns the prompt with the given name" do - prompts = subject.available_prompts(name_filter: "translate") + it "returns all available prompts" do + prompts = subject.available_prompts - expect(prompts.length).to eq(1) - expect(prompts.first[:name]).to eq("translate") - end + expect(prompts.length).to eq(6) + expect(prompts.map { |p| p[:name] }).to contain_exactly( + "translate", + "generate_titles", + "proofread", + "markdown_table", + "custom_prompt", + "explain", + ) end context "when illustrate post model is enabled" do - before { SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" } + before do + SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" + Discourse.cache.delete(DiscourseAi::AiHelper::Assistant::AI_HELPER_PROMPTS_CACHE_KEY) + end it "returns the illustrate_post prompt in the list of all prompts" do prompts = subject.available_prompts diff --git a/spec/plugin_spec.rb b/spec/plugin_spec.rb index e50230a8..4cb0240d 100644 --- a/spec/plugin_spec.rb +++ b/spec/plugin_spec.rb @@ -23,4 +23,22 @@ describe Plugin::Instance do expect(accuracy.flags_agreed).to eq(1) end end + + describe "current_user_serializer#ai_helper_prompts" do + fab!(:user) + + before do + SiteSetting.ai_helper_model = "fake:fake" + SiteSetting.composer_ai_helper_enabled = true + SiteSetting.ai_helper_illustrate_post_model = "disabled" + Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user) + end + + let(:serializer) { CurrentUserSerializer.new(user, scope: Guardian.new(user)) } + + it "returns the available prompts" do + expect(serializer.ai_helper_prompts).to be_present + expect(serializer.ai_helper_prompts.object.count).to eq(6) + end + end end diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb index 5902ba1b..8b262856 100644 --- a/spec/requests/ai_helper/assistant_controller_spec.rb +++ b/spec/requests/ai_helper/assistant_controller_spec.rb @@ -107,51 +107,4 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do end end end - - describe "#prompts" do - context "when not logged in" do - it "returns a 403 response" do - get "/discourse-ai/ai-helper/prompts" - expect(response.status).to eq(403) - end - end - - context "when logged in as a user without enough privileges" do - fab!(:user) { Fabricate(:newuser) } - - before do - sign_in(user) - SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:staff] - end - - it "returns a 403 response" do - get "/discourse-ai/ai-helper/prompts" - expect(response.status).to eq(403) - end - end - - context "when logged in as an allowed user" do - fab!(:user) { Fabricate(:user) } - - before do - sign_in(user) - user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]] - SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1] - SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" - end - - it "returns a list of prompts when no name_filter is provided" do - get "/discourse-ai/ai-helper/prompts" - expect(response.status).to eq(200) - expect(response.parsed_body.length).to eq(7) - end - - it "returns a list with with filtered prompts when name_filter is provided" do - get "/discourse-ai/ai-helper/prompts", params: { name_filter: "proofread" } - expect(response.status).to eq(200) - expect(response.parsed_body.length).to eq(1) - expect(response.parsed_body.first["name"]).to eq("proofread") - end - end - end end