diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb index fdfa0b4e..f83f2359 100644 --- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb +++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb @@ -21,10 +21,15 @@ module DiscourseAi input = get_text_param! prompt = CompletionPrompt.find_by(id: params[:mode]) + raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled? + if prompt.prompt_type == "custom_prompt" && params[:custom_prompt].blank? + raise Discourse::InvalidParameters.new(:custom_prompt) + end hijack do - render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input), + render json: + DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, params), status: 200 end rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed, @@ -36,6 +41,7 @@ module DiscourseAi def suggest_title input = get_text_param! + input_hash = { text: input } llm_prompt = DiscourseAi::AiHelper::LlmPrompt @@ -46,7 +52,11 @@ module DiscourseAi raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled? hijack do - render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input), + render json: + DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt( + prompt, + input_hash, + ), status: 200 end rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed, @@ -58,15 +68,21 @@ module DiscourseAi def suggest_category input = get_text_param! + input_hash = { text: input } - render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).categories, + render json: + DiscourseAi::AiHelper::SemanticCategorizer.new( + input_hash, + current_user, + ).categories, status: 200 end def suggest_tags input = get_text_param! + input_hash = { text: input } - render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).tags, + render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input_hash, current_user).tags, status: 200 end diff --git a/app/models/completion_prompt.rb b/app/models/completion_prompt.rb index 054d94c1..b74627dc 100644 --- a/app/models/completion_prompt.rb +++ b/app/models/completion_prompt.rb @@ -10,13 +10,24 @@ class CompletionPrompt < ActiveRecord::Base validate :each_message_length def messages_with_user_input(user_input) + if user_input[:custom_prompt].present? + case ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider + when "huggingface" + self.messages.each { |msg| msg.sub!("{{custom_prompt}}", user_input[:custom_prompt]) } + else + self.messages.each do |msg| + msg["content"].sub!("{{custom_prompt}}", user_input[:custom_prompt]) + end + end + end + case ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider when "openai" - self.messages << { role: "user", content: user_input } + self.messages << { role: "user", content: user_input[:text] } when "anthropic" - self.messages << { "role" => "Input", "content" => "#{user_input}" } + self.messages << { "role" => "Input", "content" => "#{user_input[:text]}" } when "huggingface" - self.messages.first.sub("{{user_input}}", user_input) + self.messages.first.sub("{{user_input}}", user_input[:text]) end end diff --git a/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.hbs b/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.hbs index 27a45d13..efd6a818 100644 --- a/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.hbs +++ b/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.hbs @@ -16,33 +16,54 @@ {{else if (eq this.menuState this.CONTEXT_MENU_STATES.options)}}
{{else if (eq this.menuState this.CONTEXT_MENU_STATES.loading)}} - + {{else if (eq this.menuState this.CONTEXT_MENU_STATES.review)}} - {{/if}} {{/if}} diff --git a/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.js b/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.js index 365714ce..f8d9269d 100644 --- a/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.js +++ b/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.js @@ -15,10 +15,10 @@ export default class AiHelperContextMenu extends Component { return showAIHelper(outletArgs, helper); } + @service currentUser; @service siteSettings; @tracked helperOptions = []; @tracked showContextMenu = false; - @tracked menuState = this.CONTEXT_MENU_STATES.triggers; @tracked caretCoords; @tracked virtualElement; @tracked selectedText = ""; @@ -30,6 +30,8 @@ export default class AiHelperContextMenu extends Component { @tracked showDiffModal = false; @tracked diff; @tracked popperPlacement = "top-start"; + @tracked previousMenuState = null; + @tracked customPromptValue = ""; CONTEXT_MENU_STATES = { triggers: "TRIGGERS", @@ -41,8 +43,10 @@ export default class AiHelperContextMenu extends Component { prompts = []; promptTypes = {}; + @tracked _menuState = this.CONTEXT_MENU_STATES.triggers; @tracked _popper; @tracked _dEditorInput; + @tracked _customPromptInput; @tracked _contextMenu; @tracked _activeAIRequest = null; @@ -62,27 +66,42 @@ export default class AiHelperContextMenu extends Component { this._popper?.destroy(); } + get menuState() { + return this._menuState; + } + + set menuState(newState) { + this.previousMenuState = this._menuState; + this._menuState = newState; + } + async loadPrompts() { let prompts = await ajax("/discourse-ai/ai-helper/prompts"); - prompts - .filter((p) => p.name !== "generate_titles") - .map((p) => { - this.prompts[p.id] = p; - }); + prompts = prompts.filter((p) => p.name !== "generate_titles"); + + // Find the custom_prompt object and move it to the beginning of the array + const customPromptIndex = prompts.findIndex( + (p) => p.name === "custom_prompt" + ); + if (customPromptIndex !== -1) { + const customPrompt = prompts.splice(customPromptIndex, 1)[0]; + prompts.unshift(customPrompt); + } + + if (!this._showUserCustomPrompts()) { + prompts = prompts.filter((p) => p.name !== "custom_prompt"); + } + + prompts.forEach((p) => { + this.prompts[p.id] = p; + }); this.promptTypes = prompts.reduce((memo, p) => { memo[p.name] = p.prompt_type; return memo; }, {}); - this.helperOptions = prompts - .filter((p) => p.name !== "generate_titles") - .map((p) => { - return { - name: p.translated_name, - value: p.id, - }; - }); + this.helperOptions = prompts; } @bind @@ -153,6 +172,10 @@ export default class AiHelperContextMenu extends Component { } get canCloseContextMenu() { + if (document.activeElement === this._customPromptInput) { + return false; + } + if (this.loading && this._activeAIRequest !== null) { return false; } @@ -168,9 +191,9 @@ export default class AiHelperContextMenu extends Component { if (!this.canCloseContextMenu) { return; } - this.showContextMenu = false; this.menuState = this.CONTEXT_MENU_STATES.triggers; + this.customPromptValue = ""; } _updateSuggestedByAI(data) { @@ -200,6 +223,15 @@ export default class AiHelperContextMenu extends Component { return (this.loading = false); } + _showUserCustomPrompts() { + const allowedGroups = + this.siteSettings?.ai_helper_custom_prompts_allowed_groups + .split("|") + .map((id) => parseInt(id, 10)); + + return this.currentUser?.groups.some((g) => allowedGroups.includes(g.id)); + } + handleBoundaries() { const textAreaWrapper = document .querySelector(".d-editor-textarea-wrapper") @@ -276,6 +308,14 @@ export default class AiHelperContextMenu extends Component { } } + @action + setupCustomPrompt() { + this._customPromptInput = document.querySelector( + ".ai-custom-prompt__input" + ); + this._customPromptInput.focus(); + } + @action toggleAiHelperOptions() { // Fetch prompts only if it hasn't been fetched yet @@ -303,7 +343,11 @@ export default class AiHelperContextMenu extends Component { this._activeAIRequest = ajax("/discourse-ai/ai-helper/suggest", { method: "POST", - data: { mode: option, text: this.selectedText }, + data: { + mode: option.id, + text: this.selectedText, + custom_prompt: this.customPromptValue, + }, }); this._activeAIRequest @@ -340,4 +384,9 @@ export default class AiHelperContextMenu extends Component { this.closeContextMenu(); } } + + @action + togglePreviousMenu() { + this.menuState = this.previousMenuState; + } } diff --git a/assets/stylesheets/modules/ai-helper/common/ai-helper.scss b/assets/stylesheets/modules/ai-helper/common/ai-helper.scss index 1c8d1f91..b292571d 100644 --- a/assets/stylesheets/modules/ai-helper/common/ai-helper.scss +++ b/assets/stylesheets/modules/ai-helper/common/ai-helper.scss @@ -50,50 +50,51 @@ list-style: none; } - .btn { - justify-content: left; - text-align: left; - background: none; - width: 100%; - border-radius: 0; - margin: 0; + li { + .btn-flat { + justify-content: left; + text-align: left; + background: none; + width: 100%; + border-radius: 0; + margin: 0; + padding-block: 0.6rem; - &:focus, - &:hover { - color: var(--primary); - background: var(--d-hover); + &:focus, + &:hover { + color: var(--primary); + background: var(--d-hover); - .d-icon { - color: var(--primary-medium); + .d-icon { + color: var(--primary-medium); + } + } + + .d-button-label { + color: var(--primary-very-high); } } } - .d-button-label { - color: var(--primary-very-high); - } - &__options { padding: 0.25rem; + + li:not(:last-child) { + border-bottom: 1px solid var(--primary-low); + } } &__loading { + display: flex; + padding: 0.5rem; + gap: 1rem; + justify-content: flex-start; + align-items: center; + .dot-falling { margin-inline: 1rem; margin-left: 1.5rem; } - - li { - display: flex; - padding: 0.5rem; - gap: 1rem; - justify-content: flex-start; - align-items: center; - } - - .btn { - width: unset; - } } &__resets { @@ -107,6 +108,41 @@ align-items: center; flex-flow: row wrap; } + + &__custom-prompt { + display: flex; + flex-flow: row wrap; + padding: 0.5rem; + + &-header { + margin-bottom: 0.5rem; + + .btn { + padding: 0; + } + } + + .ai-custom-prompt-input { + min-height: 90px; + width: 100%; + } + } + + .ai-custom-prompt { + display: flex; + gap: 0.25rem; + margin-bottom: 0.5rem; + + &__input { + background: var(--primary-low); + border-color: var(--primary-low); + margin-bottom: 0; + + &::placeholder { + color: var(--primary-medium); + } + } + } } .d-editor-input.loading { diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index b5a02936..6e8cb7b9 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -19,6 +19,7 @@ en: suggest: "Suggest with AI" missing_content: "Please enter some content to generate suggestions." context_menu: + back: "Back" trigger: "AI" undo: "Undo" loading: "AI is generating" @@ -28,6 +29,10 @@ en: confirm: "Confirm" revert: "Revert" changes: "Changes" + custom_prompt: + title: "Custom Prompt" + placeholder: "Enter a custom prompt..." + submit: "Send Prompt" reviewables: model_used: "Model used:" accuracy: "Accuracy:" diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 93442e74..9a001229 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -21,7 +21,7 @@ en: ai_sentiment_models: "Models to use for inference. Sentiment classifies post on the positive/neutral/negative space. Emotion classifies on the anger/disgust/fear/joy/neutral/sadness/surprise space." ai_nsfw_detection_enabled: "Enable the NSFW module." - ai_nsfw_inference_service_api_endpoint: "URL where the API is running for the NSFW module" + ai_nsfw_inference_service_api_endpoint: "URL where the API is running for the NSFW module" ai_nsfw_inference_service_api_key: "API key for the NSFW API" ai_nsfw_flag_automatically: "Automatically flag NSFW posts that are above the configured thresholds." ai_nsfw_flag_threshold_general: "General Threshold for an image to be considered NSFW." @@ -44,6 +44,7 @@ en: ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer." ai_helper_allowed_in_pm: "Enable the composer's AI helper in PMs." ai_helper_model: "Model to use for the AI helper." + ai_helper_custom_prompts_allowed_groups: "Users on these groups will see the custom prompt option in the AI helper." ai_embeddings_enabled: "Enable the embeddings module." ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module" @@ -76,9 +77,9 @@ en: ai_google_custom_search_cx: "CX for Google Custom Search API" reviewables: - reasons: - flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic. - flagged_by_nsfw: The AI plugin flagged this after classifying at least one of the attached images as NSFW. + reasons: + flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic. + flagged_by_nsfw: The AI plugin flagged this after classifying at least one of the attached images as NSFW. errors: prompt_message_length: The message %{idx} is over the 1000 character limit. @@ -93,6 +94,7 @@ en: generate_titles: Suggest topic titles proofread: Proofread text markdown_table: Generate Markdown table + custom_prompt: "Custom Prompt" ai_bot: personas: diff --git a/config/settings.yml b/config/settings.yml index 6663a133..1f6875b9 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -9,7 +9,7 @@ discourse_ai: ai_toxicity_inference_service_api_endpoint: default: "https://disorder-testing.demo-by-discourse.com" ai_toxicity_inference_service_api_key: - default: '' + default: "" secret: true ai_toxicity_inference_service_api_model: type: enum @@ -56,7 +56,7 @@ discourse_ai: ai_sentiment_inference_service_api_endpoint: default: "https://sentiment-testing.demo-by-discourse.com" ai_sentiment_inference_service_api_key: - default: '' + default: "" secret: true ai_sentiment_models: type: list @@ -64,8 +64,8 @@ discourse_ai: default: "emotion" allow_any: false choices: - - sentiment - - emotion + - sentiment + - emotion ai_nsfw_detection_enabled: false ai_nsfw_inference_service_api_endpoint: @@ -85,8 +85,8 @@ discourse_ai: default: "opennsfw2" allow_any: false choices: - - opennsfw2 - - nsfw_detector + - opennsfw2 + - nsfw_detector ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions" ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions" @@ -148,6 +148,13 @@ discourse_ai: - gpt-4 - claude-2 - stable-beluga-2 + ai_helper_custom_prompts_allowed_groups: + client: true + type: group_list + list_type: compact + default: "3" # 3: @staff + allow_any: false + refresh: true ai_embeddings_enabled: default: false @@ -162,9 +169,9 @@ discourse_ai: default: "all-mpnet-base-v2" allow_any: false choices: - - all-mpnet-base-v2 - - text-embedding-ada-002 - - multilingual-e5-large + - all-mpnet-base-v2 + - text-embedding-ada-002 + - multilingual-e5-large ai_embeddings_generate_for_pms: false ai_embeddings_semantic_related_topics_enabled: default: false @@ -183,11 +190,11 @@ discourse_ai: allow_any: false choices: - Llama2-*-chat-hf - - claude-instant-1 - - claude-2 + - claude-instant-1 + - claude-2 - gpt-3.5-turbo - - gpt-4 - - StableBeluga2 + - gpt-4 + - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 ai_summarization_discourse_service_api_endpoint: "" @@ -211,22 +218,22 @@ discourse_ai: default: "gpt-3.5-turbo" client: true choices: - - gpt-3.5-turbo - - gpt-4 - - claude-2 + - gpt-3.5-turbo + - gpt-4 + - claude-2 ai_bot_enabled_chat_commands: type: list default: "categories|google|image|search|tags|time|read" client: true choices: - - categories - - google - - image - - search - - summarize - - read - - tags - - time + - categories + - google + - image + - search + - summarize + - read + - tags + - time ai_bot_enabled_personas: type: list default: "general|artist|sql_helper|settings_explorer|researcher" diff --git a/db/fixtures/ai_helper/600_openai_completion_prompts.rb b/db/fixtures/ai_helper/600_openai_completion_prompts.rb index ec38dc8a..7c8849ff 100644 --- a/db/fixtures/ai_helper/600_openai_completion_prompts.rb +++ b/db/fixtures/ai_helper/600_openai_completion_prompts.rb @@ -123,3 +123,14 @@ CompletionPrompt.seed do |cp| TEXT ] end + +CompletionPrompt.seed do |cp| + cp.id = -5 + cp.provider = "openai" + cp.name = "custom_prompt" + cp.prompt_type = CompletionPrompt.prompt_types[:list] + cp.messages = [{ role: "system", content: <<~TEXT }] + You are a helpful assistant, I will provide you with a text below, + you will {{custom_prompt}} and you will reply with the result. + TEXT +end diff --git a/db/fixtures/ai_helper/601_anthropic_completion_prompts.rb b/db/fixtures/ai_helper/601_anthropic_completion_prompts.rb index c33157c8..db6e1011 100644 --- a/db/fixtures/ai_helper/601_anthropic_completion_prompts.rb +++ b/db/fixtures/ai_helper/601_anthropic_completion_prompts.rb @@ -54,3 +54,14 @@ CompletionPrompt.seed do |cp| please reply with the corrected text between