DEV: Make prompts available on `CurrentUserSerializer` (#472)

This commit is contained in:
Keegan George 2024-02-16 10:57:14 -08:00 committed by GitHub
parent 3a8d95f6b2
commit d66915ecc1
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
10 changed files with 90 additions and 152 deletions

View File

@ -8,17 +8,6 @@ module DiscourseAi
before_action :ensure_can_request_suggestions
before_action :rate_limiter_performed!, except: %i[prompts]
def prompts
name_filter = params[:name_filter]
render json:
ActiveModel::ArraySerializer.new(
DiscourseAi::AiHelper::Assistant.new.available_prompts(name_filter: name_filter),
root: false,
),
status: 200
end
def suggest
input = get_text_param!

View File

@ -19,7 +19,6 @@ export default class AiHelperContextMenu extends Component {
@service siteSettings;
@service modal;
@service capabilities;
@tracked helperOptions = [];
@tracked showContextMenu = false;
@tracked caretCoords;
@tracked virtualElement;
@ -56,15 +55,6 @@ export default class AiHelperContextMenu extends Component {
@tracked _contextMenu;
@tracked _activeAIRequest = null;
constructor() {
super(...arguments);
// Fetch prompts only if it hasn't been fetched yet
if (this.helperOptions.length === 0) {
this.loadPrompts();
}
}
willDestroy() {
super.willDestroy(...arguments);
document.removeEventListener("selectionchange", this.selectionChanged);
@ -81,8 +71,8 @@ export default class AiHelperContextMenu extends Component {
this._menuState = newState;
}
async loadPrompts() {
let prompts = await ajax("/discourse-ai/ai-helper/prompts");
get helperOptions() {
let prompts = this.currentUser?.ai_helper_prompts;
prompts = prompts
.filter((p) => p.location.includes("composer"))
@ -109,7 +99,7 @@ export default class AiHelperContextMenu extends Component {
memo[p.name] = p.prompt_type;
return memo;
}, {});
this.helperOptions = prompts;
return prompts;
}
@bind
@ -338,10 +328,6 @@ export default class AiHelperContextMenu extends Component {
@action
toggleAiHelperOptions() {
// Fetch prompts only if it hasn't been fetched yet
if (this.helperOptions.length === 0) {
this.loadPrompts();
}
this.menuState = this.CONTEXT_MENU_STATES.options;
}

View File

@ -1,6 +1,7 @@
import Component from "@glimmer/component";
import { tracked } from "@glimmer/tracking";
import { action } from "@ember/object";
import { inject as service } from "@ember/service";
import DButton from "discourse/components/d-button";
import { ajax } from "discourse/lib/ajax";
import { popupAjaxError } from "discourse/lib/ajax-error";
@ -11,31 +12,19 @@ export default class AiEditSuggestionButton extends Component {
return showPostAIHelper(outletArgs, helper);
}
@service currentUser;
@tracked loading = false;
@tracked suggestion = "";
@tracked _activeAIRequest = null;
constructor() {
super(...arguments);
if (!this.mode) {
this.loadMode();
}
}
get disabled() {
return this.loading || this.suggestion?.length > 0;
}
async loadMode() {
let mode = await ajax("/discourse-ai/ai-helper/prompts", {
method: "GET",
data: {
name_filter: "proofread",
},
});
this.mode = mode[0];
get mode() {
return this.currentUser?.ai_helper_prompts.find(
(prompt) => prompt.name === "proofread"
);
}
@action

View File

@ -31,7 +31,6 @@ export default class AIHelperOptionsMenu extends Component {
@service currentUser;
@service menu;
@tracked helperOptions = [];
@tracked menuState = this.MENU_STATES.triggers;
@tracked loading = false;
@tracked suggestion = "";
@ -51,14 +50,6 @@ export default class AIHelperOptionsMenu extends Component {
@tracked _activeAIRequest = null;
constructor() {
super(...arguments);
if (this.helperOptions.length === 0) {
this.loadPrompts();
}
}
@action
async showAIHelperOptions() {
this.showMainButtons = false;
@ -168,8 +159,8 @@ export default class AIHelperOptionsMenu extends Component {
}
}
async loadPrompts() {
let prompts = await ajax("/discourse-ai/ai-helper/prompts");
get helperOptions() {
let prompts = this.currentUser?.ai_helper_prompts;
prompts = prompts.filter((item) => item.location.includes("post"));
@ -191,7 +182,7 @@ export default class AIHelperOptionsMenu extends Component {
prompts = prompts.filter((p) => p.name !== "proofread");
}
this.helperOptions = prompts;
return prompts;
}
_showUserCustomPrompts() {

View File

@ -2,7 +2,6 @@
DiscourseAi::Engine.routes.draw do
scope module: :ai_helper, path: "/ai-helper", defaults: { format: :json } do
get "prompts" => "assistant#prompts"
post "suggest" => "assistant#suggest"
post "suggest_title" => "assistant#suggest_title"
post "suggest_category" => "assistant#suggest_category"

View File

@ -3,35 +3,37 @@
module DiscourseAi
module AiHelper
class Assistant
def available_prompts(name_filter: nil)
cp = CompletionPrompt
prompts = []
AI_HELPER_PROMPTS_CACHE_KEY = "ai_helper_prompts"
if name_filter
prompts = [cp.enabled_by_name(name_filter)]
else
prompts = cp.where(enabled: true)
# Hide illustrate_post if disabled
prompts =
prompts.where.not(
name: "illustrate_post",
) if SiteSetting.ai_helper_illustrate_post_model == "disabled"
end
def available_prompts
Discourse
.cache
.fetch(AI_HELPER_PROMPTS_CACHE_KEY) do
prompts = CompletionPrompt.where(enabled: true)
prompts.map do |prompt|
translation =
I18n.t("discourse_ai.ai_helper.prompts.#{prompt.name}", default: nil) ||
prompt.translated_name || prompt.name
# Hide illustrate_post if disabled
prompts =
prompts.where.not(
name: "illustrate_post",
) if SiteSetting.ai_helper_illustrate_post_model == "disabled"
{
id: prompt.id,
name: prompt.name,
translated_name: translation,
prompt_type: prompt.prompt_type,
icon: icon_map(prompt.name),
location: location_map(prompt.name),
}
end
prompts =
prompts.map do |prompt|
translation =
I18n.t("discourse_ai.ai_helper.prompts.#{prompt.name}", default: nil) ||
prompt.translated_name || prompt.name
{
id: prompt.id,
name: prompt.name,
translated_name: translation,
prompt_type: prompt.prompt_type,
icon: icon_map(prompt.name),
location: location_map(prompt.name),
}
end
prompts
end
end
def generate_prompt(completion_prompt, input, user, &block)

View File

@ -19,6 +19,20 @@ module DiscourseAi
thread_id: thread.id,
)
end
plugin.add_to_serializer(
:current_user,
:ai_helper_prompts,
include_condition: -> do
SiteSetting.composer_ai_helper_enabled && scope.authenticated? &&
scope.user.in_any_groups?(SiteSetting.ai_helper_allowed_groups_map)
end,
) do
ActiveModel::ArraySerializer.new(
DiscourseAi::AiHelper::Assistant.new.available_prompts,
root: false,
)
end
end
end
end

View File

@ -13,33 +13,30 @@ RSpec.describe DiscourseAi::AiHelper::Assistant do
STRING
describe("#available_prompts") do
context "when no name filter is provided" do
it "returns all available prompts" do
prompts = subject.available_prompts
expect(prompts.length).to eq(6)
expect(prompts.map { |p| p[:name] }).to contain_exactly(
"translate",
"generate_titles",
"proofread",
"markdown_table",
"custom_prompt",
"explain",
)
end
before do
SiteSetting.ai_helper_illustrate_post_model = "disabled"
Discourse.cache.delete(DiscourseAi::AiHelper::Assistant::AI_HELPER_PROMPTS_CACHE_KEY)
end
context "when name filter is provided" do
it "returns the prompt with the given name" do
prompts = subject.available_prompts(name_filter: "translate")
it "returns all available prompts" do
prompts = subject.available_prompts
expect(prompts.length).to eq(1)
expect(prompts.first[:name]).to eq("translate")
end
expect(prompts.length).to eq(6)
expect(prompts.map { |p| p[:name] }).to contain_exactly(
"translate",
"generate_titles",
"proofread",
"markdown_table",
"custom_prompt",
"explain",
)
end
context "when illustrate post model is enabled" do
before { SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl" }
before do
SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl"
Discourse.cache.delete(DiscourseAi::AiHelper::Assistant::AI_HELPER_PROMPTS_CACHE_KEY)
end
it "returns the illustrate_post prompt in the list of all prompts" do
prompts = subject.available_prompts

View File

@ -23,4 +23,22 @@ describe Plugin::Instance do
expect(accuracy.flags_agreed).to eq(1)
end
end
describe "current_user_serializer#ai_helper_prompts" do
fab!(:user)
before do
SiteSetting.ai_helper_model = "fake:fake"
SiteSetting.composer_ai_helper_enabled = true
SiteSetting.ai_helper_illustrate_post_model = "disabled"
Group.find_by(id: Group::AUTO_GROUPS[:admins]).add(user)
end
let(:serializer) { CurrentUserSerializer.new(user, scope: Guardian.new(user)) }
it "returns the available prompts" do
expect(serializer.ai_helper_prompts).to be_present
expect(serializer.ai_helper_prompts.object.count).to eq(6)
end
end
end

View File

@ -107,51 +107,4 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
end
end
end
describe "#prompts" do
context "when not logged in" do
it "returns a 403 response" do
get "/discourse-ai/ai-helper/prompts"
expect(response.status).to eq(403)
end
end
context "when logged in as a user without enough privileges" do
fab!(:user) { Fabricate(:newuser) }
before do
sign_in(user)
SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:staff]
end
it "returns a 403 response" do
get "/discourse-ai/ai-helper/prompts"
expect(response.status).to eq(403)
end
end
context "when logged in as an allowed user" do
fab!(:user) { Fabricate(:user) }
before do
sign_in(user)
user.group_ids = [Group::AUTO_GROUPS[:trust_level_1]]
SiteSetting.ai_helper_allowed_groups = Group::AUTO_GROUPS[:trust_level_1]
SiteSetting.ai_helper_illustrate_post_model = "stable_diffusion_xl"
end
it "returns a list of prompts when no name_filter is provided" do
get "/discourse-ai/ai-helper/prompts"
expect(response.status).to eq(200)
expect(response.parsed_body.length).to eq(7)
end
it "returns a list with with filtered prompts when name_filter is provided" do
get "/discourse-ai/ai-helper/prompts", params: { name_filter: "proofread" }
expect(response.status).to eq(200)
expect(response.parsed_body.length).to eq(1)
expect(response.parsed_body.first["name"]).to eq("proofread")
end
end
end
end