DEV: Make prompts available on `CurrentUserSerializer` (#472)

This commit is contained in:
Keegan George 2024-02-16 10:57:14 -08:00 committed by GitHub
parent 3a8d95f6b2
commit d66915ecc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 90 additions and 152 deletions

View File

@ -8,17 +8,6 @@ module DiscourseAi
before_action :ensure_can_request_suggestions before_action :ensure_can_request_suggestions
before_action :rate_limiter_performed!, except: %i[prompts] before_action :rate_limiter_performed!, except: %i[prompts]
def prompts
name_filter = params[:name_filter]
render json:
ActiveModel::ArraySerializer.new(
DiscourseAi::AiHelper::Assistant.new.available_prompts(name_filter: name_filter),
root: false,
),
status: 200
end
def suggest def suggest
input = get_text_param! input = get_text_param!

View File

@ -19,7 +19,6 @@ export default class AiHelperContextMenu extends Component {
@service siteSettings; @service siteSettings;
@service modal; @service modal;
@service capabilities; @service capabilities;
@tracked helperOptions = [];
@tracked showContextMenu = false; @tracked showContextMenu = false;
@tracked caretCoords; @tracked caretCoords;
@tracked virtualElement; @tracked virtualElement;
@ -56,15 +55,6 @@ export default class AiHelperContextMenu extends Component {
@tracked _contextMenu; @tracked _contextMenu;
@tracked _activeAIRequest = null; @tracked _activeAIRequest = null;
constructor() {
super(...arguments);
// Fetch prompts only if it hasn't been fetched yet
if (this.helperOptions.length === 0) {
this.loadPrompts();
}
}
willDestroy() { willDestroy() {
super.willDestroy(...arguments); super.willDestroy(...arguments);
document.removeEventListener("selectionchange", this.selectionChanged); document.removeEventListener("selectionchange", this.selectionChanged);
@ -81,8 +71,8 @@ export default class AiHelperContextMenu extends Component {
this._menuState = newState; this._menuState = newState;
} }
async loadPrompts() { get helperOptions() {
let prompts = await ajax("/discourse-ai/ai-helper/prompts"); let prompts = this.currentUser?.ai_helper_prompts;
prompts = prompts prompts = prompts
.filter((p) => p.location.includes("composer")) .filter((p) => p.location.includes("composer"))
@ -109,7 +99,7 @@ export default class AiHelperContextMenu extends Component {
memo[p.name] = p.prompt_type; memo[p.name] = p.prompt_type;
return memo; return memo;
}, {}); }, {});
this.helperOptions = prompts; return prompts;
} }
@bind @bind
@ -338,10 +328,6 @@ export default class AiHelperContextMenu extends Component {
@action @action
toggleAiHelperOptions() { toggleAiHelperOptions() {
// Fetch prompts only if it hasn't been fetched yet
if (this.helperOptions.length === 0) {
this.loadPrompts();
}
this.menuState = this.CONTEXT_MENU_STATES.options; this.menuState = this.CONTEXT_MENU_STATES.options;
} }

View File

@ -1,6 +1,7 @@
import Component from "@glimmer/component"; import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking"; import { tracked } from "@glimmer/tracking";
import { action } from "@ember/object"; import { action } from "@ember/object";
import { inject as service } from "@ember/service";
import DButton from "discourse/components/d-button"; import DButton from "discourse/components/d-button";
import { ajax } from "discourse/lib/ajax"; import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error"; import { popupAjaxError } from "discourse/lib/ajax-error";
@ -11,31 +12,19 @@ export default class AiEditSuggestionButton extends Component {
return showPostAIHelper(outletArgs, helper); return showPostAIHelper(outletArgs, helper);
} }
@service currentUser;
@tracked loading = false; @tracked loading = false;
@tracked suggestion = ""; @tracked suggestion = "";
@tracked _activeAIRequest = null; @tracked _activeAIRequest = null;
constructor() {
super(...arguments);
if (!this.mode) {
this.loadMode();
}
}
get disabled() { get disabled() {
return this.loading || this.suggestion?.length > 0; return this.loading || this.suggestion?.length > 0;
} }
async loadMode() { get mode() {
let mode = await ajax("/discourse-ai/ai-helper/prompts", { return this.currentUser?.ai_helper_prompts.find(
method: "GET", (prompt) => prompt.name === "proofread"
data: { );
name_filter: "proofread",
},
});
this.mode = mode[0];
} }
@action @action

View File

@ -31,7 +31,6 @@ export default class AIHelperOptionsMenu extends Component {
@service currentUser; @service currentUser;
@service menu; @service menu;
@tracked helperOptions = [];
@tracked menuState = this.MENU_STATES.triggers; @tracked menuState = this.MENU_STATES.triggers;
@tracked loading = false; @tracked loading = false;
@tracked suggestion = ""; @tracked suggestion = "";
@ -51,14 +50,6 @@ export default class AIHelperOptionsMenu extends Component {
@tracked _activeAIRequest = null; @tracked _activeAIRequest = null;
constructor() {
super(...arguments);
if (this.helperOptions.length === 0) {
this.loadPrompts();
}
}
@action @action
async showAIHelperOptions() { async showAIHelperOptions() {
this.showMainButtons = false; this.showMainButtons = false;
@ -168,8 +159,8 @@ export default class AIHelperOptionsMenu extends Component {
} }
} }
async loadPrompts() { get helperOptions() {
let prompts = await ajax("/discourse-ai/ai-helper/prompts"); let prompts = this.currentUser?.ai_helper_prompts;
prompts = prompts.filter((item) => item.location.includes("post")); prompts = prompts.filter((item) => item.location.includes("post"));
@ -191,7 +182,7 @@ export default class AIHelperOptionsMenu extends Component {
prompts = prompts.filter((p) => p.name !== "proofread"); prompts = prompts.filter((p) => p.name !== "proofread");
} }
this.helperOptions = prompts; return prompts;
} }
_showUserCustomPrompts() { _showUserCustomPrompts() {

View File

@ -2,7 +2,6 @@
DiscourseAi::Engine.routes.draw do DiscourseAi::Engine.routes.draw do
scope module: :ai_helper, path: "/ai-helper", defaults: { format: :json } do scope module: :ai_helper, path: "/ai-helper", defaults: { format: :json } do
get "prompts" => "assistant#prompts"
post "suggest" => "assistant#suggest" post "suggest" => "assistant#suggest"
post "suggest_title" => "assistant#suggest_title" post "suggest_title" => "assistant#suggest_title"
post "suggest_category" => "assistant#suggest_category" post "suggest_category" => "assistant#suggest_category"

View File

@ -3,35 +3,37 @@
module DiscourseAi module DiscourseAi
module AiHelper module AiHelper
class Assistant class Assistant
def available_prompts(name_filter: nil) AI_HELPER_PROMPTS_CACHE_KEY = "ai_helper_prompts"
cp = CompletionPrompt
prompts = []
if name_filter def available_prompts
prompts = [cp.enabled_by_name(name_filter)] Discourse
else .cache
prompts = cp.where(enabled: true) .fetch(AI_HELPER_PROMPTS_CACHE_KEY) do
# Hide illustrate_post if disabled prompts = CompletionPrompt.where(enabled: true)
prompts =
prompts.where.not(
name: "illustrate_post",
) if SiteSetting.ai_helper_illustrate_post_model == "disabled"
end
prompts.map do |prompt| # Hide illustrate_post if disabled
translation = prompts =
I18n.t("discourse_ai.ai_helper.prompts.#{prompt.name}", default: nil) || prompts.where.not(
prompt.translated_name || prompt.name name: "illustrate_post",
) if SiteSetting.ai_helper_illustrate_post_model == "disabled"
{ prompts =
id: prompt.id, prompts.map do |prompt|
name: prompt.name, translation =
translated_name: translation, I18n.t("discourse_ai.ai_helper.prompts.#{prompt.name}", default: nil) ||
prompt_type: prompt.prompt_type, prompt.translated_name || prompt.name
icon: icon_map(prompt.name),
location: location_map(prompt.name), {
} id: prompt.id,
end name: prompt.name,
translated_name: translation,
prompt_type: prompt.prompt_type,
icon: icon_map(prompt.name),
location: location_map(prompt.name),
}
end
prompts
end
end end
def generate_prompt(completion_prompt, input, user, &block) def generate_prompt(completion_prompt, input, user, &block)

View File

@ -19,6 +19,20 @@ module DiscourseAi
thread_id: thread.id, thread_id: thread.id,
) )
end end
plugin.add_to_serializer(
:current_user,
:ai_helper_prompts,
include_condition: -> do
SiteSetting.composer_ai_helper_enabled && scope.authenticated? &&
scope.user.in_any_groups?(SiteSetting.ai_helper_allowed_groups_map)
end,
) do
ActiveModel::ArraySerializer.new(
DiscourseAi::AiHelper::Assistant.new.available_prompts,
root: false,
)
end
end end
end end
end end

View File

@ -13,33 +13,30 @@ RSpec.describe DiscourseAi::AiHelper::Assistant do
STRING STRING
describe("#available_prompts") do describe("#available_prompts") do
context "when no name filter is provided" do before do
it "returns all available prompts" do SiteSetting.ai_helper_illustrate_post_model = "disabled"
prompts = subject.available_prompts Discourse.cache.delete(DiscourseAi::AiHelper::Assistant::AI_HELPER_PROMPTS_CACHE_KEY)
expect(prompts.length).to eq(6)
expect(prompts.map { |p| p[:name] }).to contain_exactly(
"translate",
"generate_titles",
"proofread",
"markdown_table",
"custom_prompt",
"explain",
)
end
end end
context "when name filter is provided" do it "returns all available prompts" do
it "returns the prompt with the given name" do prompts = subject.available_prompts
prompts = subject.available_prompts(name_filter: "translate")
expect(prompts.length).to eq(1) expect(prompts.length).to eq(6)
expect(prompts.first[:name]).to eq("translate") expect(prompts.map { |p| p[:name] }).to contain_exactly(
end "translate",
"generate_titles",
"proofread",
"markdown_table",
"custom_prompt",
"explain",
)
end end
context "when illustrate post model is enabled" do context "when illustrate post model is enabled" do
before { SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" } before do
SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl"
Discourse.cache.delete(DiscourseAi::AiHelper::Assistant::AI_HELPER_PROMPTS_CACHE_KEY)
end
it "returns the illustrate_post prompt in the list of all prompts" do it "returns the illustrate_post prompt in the list of all prompts" do
prompts = subject.available_prompts prompts = subject.available_prompts

View File

@ -23,4 +23,22 @@ describe Plugin::Instance do
expect(accuracy.flags_agreed).to eq(1) expect(accuracy.flags_agreed).to eq(1)
end end
end end
describe "current_user_serializer#ai_helper_prompts" do
fab!(:user)
before do
SiteSetting.ai_helper_model = "fake:fake"
SiteSetting.composer_ai_helper_enabled = true
SiteSetting.ai_helper_illustrate_post_model = "disabled"
Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user)
end
let(:serializer) { CurrentUserSerializer.new(user, scope: Guardian.new(user)) }
it "returns the available prompts" do
expect(serializer.ai_helper_prompts).to be_present
expect(serializer.ai_helper_prompts.object.count).to eq(6)
end
end
end end

View File

@ -107,51 +107,4 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
end end
end end
end end
describe "#prompts" do
context "when not logged in" do
it "returns a 403 response" do
get "/discourse-ai/ai-helper/prompts"
expect(response.status).to eq(403)
end
end
context "when logged in as a user without enough privileges" do
fab!(:user) { Fabricate(:newuser) }
before do
sign_in(user)
SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:staff]
end
it "returns a 403 response" do
get "/discourse-ai/ai-helper/prompts"
expect(response.status).to eq(403)
end
end
context "when logged in as an allowed user" do
fab!(:user) { Fabricate(:user) }
before do
sign_in(user)
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1]
SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl"
end
it "returns a list of prompts when no name_filter is provided" do
get "/discourse-ai/ai-helper/prompts"
expect(response.status).to eq(200)
expect(response.parsed_body.length).to eq(7)
end
it "returns a list with with filtered prompts when name_filter is provided" do
get "/discourse-ai/ai-helper/prompts", params: { name_filter: "proofread" }
expect(response.status).to eq(200)
expect(response.parsed_body.length).to eq(1)
expect(response.parsed_body.first["name"]).to eq("proofread")
end
end
end
end end