diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb index fdfa0b4e..f83f2359 100644 --- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb +++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb @@ -21,10 +21,15 @@ module DiscourseAi input = get_text_param! prompt = CompletionPrompt.find_by(id: params[:mode]) + raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled? + if prompt.prompt_type == "custom_prompt" && params[:custom_prompt].blank? + raise Discourse::InvalidParameters.new(:custom_prompt) + end hijack do - render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input), + render json: + DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, params), status: 200 end rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed, @@ -36,6 +41,7 @@ module DiscourseAi def suggest_title input = get_text_param! + input_hash = { text: input } llm_prompt = DiscourseAi::AiHelper::LlmPrompt @@ -46,7 +52,11 @@ module DiscourseAi raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled? hijack do - render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input), + render json: + DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt( + prompt, + input_hash, + ), status: 200 end rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed, @@ -58,15 +68,21 @@ module DiscourseAi def suggest_category input = get_text_param! + input_hash = { text: input } - render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).categories, + render json: + DiscourseAi::AiHelper::SemanticCategorizer.new( + input_hash, + current_user, + ).categories, status: 200 end def suggest_tags input = get_text_param! + input_hash = { text: input } - render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).tags, + render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input_hash, current_user).tags, status: 200 end diff --git a/app/models/completion_prompt.rb b/app/models/completion_prompt.rb index 054d94c1..b74627dc 100644 --- a/app/models/completion_prompt.rb +++ b/app/models/completion_prompt.rb @@ -10,13 +10,24 @@ class CompletionPrompt < ActiveRecord::Base validate :each_message_length def messages_with_user_input(user_input) + if user_input[:custom_prompt].present? + case ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider + when "huggingface" + self.messages.each { |msg| msg.sub!("{{custom_prompt}}", user_input[:custom_prompt]) } + else + self.messages.each do |msg| + msg["content"].sub!("{{custom_prompt}}", user_input[:custom_prompt]) + end + end + end + case ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider when "openai" - self.messages << { role: "user", content: user_input } + self.messages << { role: "user", content: user_input[:text] } when "anthropic" - self.messages << { "role" => "Input", "content" => "#{user_input}" } + self.messages << { "role" => "Input", "content" => "#{user_input[:text]}" } when "huggingface" - self.messages.first.sub("{{user_input}}", user_input) + self.messages.first.sub("{{user_input}}", user_input[:text]) end end diff --git a/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.hbs b/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.hbs index 27a45d13..efd6a818 100644 --- a/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.hbs +++ b/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.hbs @@ -16,33 +16,54 @@ {{else if (eq this.menuState this.CONTEXT_MENU_STATES.options)}} {{else if (eq this.menuState this.CONTEXT_MENU_STATES.loading)}} - +
+
+ + {{i18n "discourse_ai.ai_helper.context_menu.loading"}} + + +
{{else if (eq this.menuState this.CONTEXT_MENU_STATES.review)}} - {{/if}} {{/if}} diff --git a/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.js b/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.js index 365714ce..f8d9269d 100644 --- a/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.js +++ b/assets/javascripts/discourse/connectors/after-d-editor/ai-helper-context-menu.js @@ -15,10 +15,10 @@ export default class AiHelperContextMenu extends Component { return showAIHelper(outletArgs, helper); } + @service currentUser; @service siteSettings; @tracked helperOptions = []; @tracked showContextMenu = false; - @tracked menuState = this.CONTEXT_MENU_STATES.triggers; @tracked caretCoords; @tracked virtualElement; @tracked selectedText = ""; @@ -30,6 +30,8 @@ export default class AiHelperContextMenu extends Component { @tracked showDiffModal = false; @tracked diff; @tracked popperPlacement = "top-start"; + @tracked previousMenuState = null; + @tracked customPromptValue = ""; CONTEXT_MENU_STATES = { triggers: "TRIGGERS", @@ -41,8 +43,10 @@ export default class AiHelperContextMenu extends Component { prompts = []; promptTypes = {}; + @tracked _menuState = this.CONTEXT_MENU_STATES.triggers; @tracked _popper; @tracked _dEditorInput; + @tracked _customPromptInput; @tracked _contextMenu; @tracked _activeAIRequest = null; @@ -62,27 +66,42 @@ export default class AiHelperContextMenu extends Component { this._popper?.destroy(); } + get menuState() { + return this._menuState; + } + + set menuState(newState) { + this.previousMenuState = this._menuState; + this._menuState = newState; + } + async loadPrompts() { let prompts = await ajax("/discourse-ai/ai-helper/prompts"); - prompts - .filter((p) => p.name !== "generate_titles") - .map((p) => { - this.prompts[p.id] = p; - }); + prompts = prompts.filter((p) => p.name !== "generate_titles"); + + // Find the custom_prompt object and move it to the beginning of the array + const customPromptIndex = prompts.findIndex( + (p) => p.name === "custom_prompt" + ); + if (customPromptIndex !== -1) { + const customPrompt = prompts.splice(customPromptIndex, 1)[0]; + prompts.unshift(customPrompt); + } + + if (!this._showUserCustomPrompts()) { + prompts = prompts.filter((p) => p.name !== "custom_prompt"); + } + + prompts.forEach((p) => { + this.prompts[p.id] = p; + }); this.promptTypes = prompts.reduce((memo, p) => { memo[p.name] = p.prompt_type; return memo; }, {}); - this.helperOptions = prompts - .filter((p) => p.name !== "generate_titles") - .map((p) => { - return { - name: p.translated_name, - value: p.id, - }; - }); + this.helperOptions = prompts; } @bind @@ -153,6 +172,10 @@ export default class AiHelperContextMenu extends Component { } get canCloseContextMenu() { + if (document.activeElement === this._customPromptInput) { + return false; + } + if (this.loading && this._activeAIRequest !== null) { return false; } @@ -168,9 +191,9 @@ export default class AiHelperContextMenu extends Component { if (!this.canCloseContextMenu) { return; } - this.showContextMenu = false; this.menuState = this.CONTEXT_MENU_STATES.triggers; + this.customPromptValue = ""; } _updateSuggestedByAI(data) { @@ -200,6 +223,15 @@ export default class AiHelperContextMenu extends Component { return (this.loading = false); } + _showUserCustomPrompts() { + const allowedGroups = + this.siteSettings?.ai_helper_custom_prompts_allowed_groups + .split("|") + .map((id) => parseInt(id, 10)); + + return this.currentUser?.groups.some((g) => allowedGroups.includes(g.id)); + } + handleBoundaries() { const textAreaWrapper = document .querySelector(".d-editor-textarea-wrapper") @@ -276,6 +308,14 @@ export default class AiHelperContextMenu extends Component { } } + @action + setupCustomPrompt() { + this._customPromptInput = document.querySelector( + ".ai-custom-prompt__input" + ); + this._customPromptInput.focus(); + } + @action toggleAiHelperOptions() { // Fetch prompts only if it hasn't been fetched yet @@ -303,7 +343,11 @@ export default class AiHelperContextMenu extends Component { this._activeAIRequest = ajax("/discourse-ai/ai-helper/suggest", { method: "POST", - data: { mode: option, text: this.selectedText }, + data: { + mode: option.id, + text: this.selectedText, + custom_prompt: this.customPromptValue, + }, }); this._activeAIRequest @@ -340,4 +384,9 @@ export default class AiHelperContextMenu extends Component { this.closeContextMenu(); } } + + @action + togglePreviousMenu() { + this.menuState = this.previousMenuState; + } } diff --git a/assets/stylesheets/modules/ai-helper/common/ai-helper.scss b/assets/stylesheets/modules/ai-helper/common/ai-helper.scss index 1c8d1f91..b292571d 100644 --- a/assets/stylesheets/modules/ai-helper/common/ai-helper.scss +++ b/assets/stylesheets/modules/ai-helper/common/ai-helper.scss @@ -50,50 +50,51 @@ list-style: none; } - .btn { - justify-content: left; - text-align: left; - background: none; - width: 100%; - border-radius: 0; - margin: 0; + li { + .btn-flat { + justify-content: left; + text-align: left; + background: none; + width: 100%; + border-radius: 0; + margin: 0; + padding-block: 0.6rem; - &:focus, - &:hover { - color: var(--primary); - background: var(--d-hover); + &:focus, + &:hover { + color: var(--primary); + background: var(--d-hover); - .d-icon { - color: var(--primary-medium); + .d-icon { + color: var(--primary-medium); + } + } + + .d-button-label { + color: var(--primary-very-high); } } } - .d-button-label { - color: var(--primary-very-high); - } - &__options { padding: 0.25rem; + + li:not(:last-child) { + border-bottom: 1px solid var(--primary-low); + } } &__loading { + display: flex; + padding: 0.5rem; + gap: 1rem; + justify-content: flex-start; + align-items: center; + .dot-falling { margin-inline: 1rem; margin-left: 1.5rem; } - - li { - display: flex; - padding: 0.5rem; - gap: 1rem; - justify-content: flex-start; - align-items: center; - } - - .btn { - width: unset; - } } &__resets { @@ -107,6 +108,41 @@ align-items: center; flex-flow: row wrap; } + + &__custom-prompt { + display: flex; + flex-flow: row wrap; + padding: 0.5rem; + + &-header { + margin-bottom: 0.5rem; + + .btn { + padding: 0; + } + } + + .ai-custom-prompt-input { + min-height: 90px; + width: 100%; + } + } + + .ai-custom-prompt { + display: flex; + gap: 0.25rem; + margin-bottom: 0.5rem; + + &__input { + background: var(--primary-low); + border-color: var(--primary-low); + margin-bottom: 0; + + &::placeholder { + color: var(--primary-medium); + } + } + } } .d-editor-input.loading { diff --git a/config/locales/client.en.yml b/config/locales/client.en.yml index b5a02936..6e8cb7b9 100644 --- a/config/locales/client.en.yml +++ b/config/locales/client.en.yml @@ -19,6 +19,7 @@ en: suggest: "Suggest with AI" missing_content: "Please enter some content to generate suggestions." context_menu: + back: "Back" trigger: "AI" undo: "Undo" loading: "AI is generating" @@ -28,6 +29,10 @@ en: confirm: "Confirm" revert: "Revert" changes: "Changes" + custom_prompt: + title: "Custom Prompt" + placeholder: "Enter a custom prompt..." + submit: "Send Prompt" reviewables: model_used: "Model used:" accuracy: "Accuracy:" diff --git a/config/locales/server.en.yml b/config/locales/server.en.yml index 93442e74..9a001229 100644 --- a/config/locales/server.en.yml +++ b/config/locales/server.en.yml @@ -21,7 +21,7 @@ en: ai_sentiment_models: "Models to use for inference. Sentiment classifies post on the positive/neutral/negative space. Emotion classifies on the anger/disgust/fear/joy/neutral/sadness/surprise space." ai_nsfw_detection_enabled: "Enable the NSFW module." - ai_nsfw_inference_service_api_endpoint: "URL where the API is running for the NSFW module" + ai_nsfw_inference_service_api_endpoint: "URL where the API is running for the NSFW module" ai_nsfw_inference_service_api_key: "API key for the NSFW API" ai_nsfw_flag_automatically: "Automatically flag NSFW posts that are above the configured thresholds." ai_nsfw_flag_threshold_general: "General Threshold for an image to be considered NSFW." @@ -44,6 +44,7 @@ en: ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer." ai_helper_allowed_in_pm: "Enable the composer's AI helper in PMs." ai_helper_model: "Model to use for the AI helper." + ai_helper_custom_prompts_allowed_groups: "Users on these groups will see the custom prompt option in the AI helper." ai_embeddings_enabled: "Enable the embeddings module." ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module" @@ -76,9 +77,9 @@ en: ai_google_custom_search_cx: "CX for Google Custom Search API" reviewables: - reasons: - flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic. - flagged_by_nsfw: The AI plugin flagged this after classifying at least one of the attached images as NSFW. + reasons: + flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic. + flagged_by_nsfw: The AI plugin flagged this after classifying at least one of the attached images as NSFW. errors: prompt_message_length: The message %{idx} is over the 1000 character limit. @@ -93,6 +94,7 @@ en: generate_titles: Suggest topic titles proofread: Proofread text markdown_table: Generate Markdown table + custom_prompt: "Custom Prompt" ai_bot: personas: diff --git a/config/settings.yml b/config/settings.yml index 6663a133..1f6875b9 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -9,7 +9,7 @@ discourse_ai: ai_toxicity_inference_service_api_endpoint: default: "https://disorder-testing.demo-by-discourse.com" ai_toxicity_inference_service_api_key: - default: '' + default: "" secret: true ai_toxicity_inference_service_api_model: type: enum @@ -56,7 +56,7 @@ discourse_ai: ai_sentiment_inference_service_api_endpoint: default: "https://sentiment-testing.demo-by-discourse.com" ai_sentiment_inference_service_api_key: - default: '' + default: "" secret: true ai_sentiment_models: type: list @@ -64,8 +64,8 @@ discourse_ai: default: "emotion" allow_any: false choices: - - sentiment - - emotion + - sentiment + - emotion ai_nsfw_detection_enabled: false ai_nsfw_inference_service_api_endpoint: @@ -85,8 +85,8 @@ discourse_ai: default: "opennsfw2" allow_any: false choices: - - opennsfw2 - - nsfw_detector + - opennsfw2 + - nsfw_detector ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions" ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions" @@ -148,6 +148,13 @@ discourse_ai: - gpt-4 - claude-2 - stable-beluga-2 + ai_helper_custom_prompts_allowed_groups: + client: true + type: group_list + list_type: compact + default: "3" # 3: @staff + allow_any: false + refresh: true ai_embeddings_enabled: default: false @@ -162,9 +169,9 @@ discourse_ai: default: "all-mpnet-base-v2" allow_any: false choices: - - all-mpnet-base-v2 - - text-embedding-ada-002 - - multilingual-e5-large + - all-mpnet-base-v2 + - text-embedding-ada-002 + - multilingual-e5-large ai_embeddings_generate_for_pms: false ai_embeddings_semantic_related_topics_enabled: default: false @@ -183,11 +190,11 @@ discourse_ai: allow_any: false choices: - Llama2-*-chat-hf - - claude-instant-1 - - claude-2 + - claude-instant-1 + - claude-2 - gpt-3.5-turbo - - gpt-4 - - StableBeluga2 + - gpt-4 + - StableBeluga2 - Upstage-Llama-2-*-instruct-v2 ai_summarization_discourse_service_api_endpoint: "" @@ -211,22 +218,22 @@ discourse_ai: default: "gpt-3.5-turbo" client: true choices: - - gpt-3.5-turbo - - gpt-4 - - claude-2 + - gpt-3.5-turbo + - gpt-4 + - claude-2 ai_bot_enabled_chat_commands: type: list default: "categories|google|image|search|tags|time|read" client: true choices: - - categories - - google - - image - - search - - summarize - - read - - tags - - time + - categories + - google + - image + - search + - summarize + - read + - tags + - time ai_bot_enabled_personas: type: list default: "general|artist|sql_helper|settings_explorer|researcher" diff --git a/db/fixtures/ai_helper/600_openai_completion_prompts.rb b/db/fixtures/ai_helper/600_openai_completion_prompts.rb index ec38dc8a..7c8849ff 100644 --- a/db/fixtures/ai_helper/600_openai_completion_prompts.rb +++ b/db/fixtures/ai_helper/600_openai_completion_prompts.rb @@ -123,3 +123,14 @@ CompletionPrompt.seed do |cp| TEXT ] end + +CompletionPrompt.seed do |cp| + cp.id = -5 + cp.provider = "openai" + cp.name = "custom_prompt" + cp.prompt_type = CompletionPrompt.prompt_types[:list] + cp.messages = [{ role: "system", content: <<~TEXT }] + You are a helpful assistant, I will provide you with a text below, + you will {{custom_prompt}} and you will reply with the result. + TEXT +end diff --git a/db/fixtures/ai_helper/601_anthropic_completion_prompts.rb b/db/fixtures/ai_helper/601_anthropic_completion_prompts.rb index c33157c8..db6e1011 100644 --- a/db/fixtures/ai_helper/601_anthropic_completion_prompts.rb +++ b/db/fixtures/ai_helper/601_anthropic_completion_prompts.rb @@ -54,3 +54,14 @@ CompletionPrompt.seed do |cp| please reply with the corrected text between tags. TEXT end + +CompletionPrompt.seed do |cp| + cp.id = -105 + cp.provider = "anthropic" + cp.name = "custom_prompt" + cp.prompt_type = CompletionPrompt.prompt_types[:diff] + cp.messages = [{ role: "Human", content: <<~TEXT }] + You are a helpful assistant, I will provide you with a text inside tags, + you will {{custom_prompt}} and you will reply with the result between tags. + TEXT +end diff --git a/db/fixtures/ai_helper/602_stablebeluga2_completion_prompts.rb b/db/fixtures/ai_helper/602_stablebeluga2_completion_prompts.rb index f72767e0..541836de 100644 --- a/db/fixtures/ai_helper/602_stablebeluga2_completion_prompts.rb +++ b/db/fixtures/ai_helper/602_stablebeluga2_completion_prompts.rb @@ -109,3 +109,20 @@ CompletionPrompt.seed do |cp| ### Assistant: TEXT end + +CompletionPrompt.seed do |cp| + cp.id = -205 + cp.provider = "huggingface" + cp.name = "custom_prompt" + cp.prompt_type = CompletionPrompt.prompt_types[:diff] + cp.messages = [<<~TEXT] + ### System: + You are a helpful assistant, I will provide you with a text below, + you will {{custom_prompt}} and you will reply with the result. + + ### User: + {{user_input}} + + ### Assistant: + TEXT +end diff --git a/lib/modules/ai_helper/entry_point.rb b/lib/modules/ai_helper/entry_point.rb index dcd012f6..901bd925 100644 --- a/lib/modules/ai_helper/entry_point.rb +++ b/lib/modules/ai_helper/entry_point.rb @@ -12,7 +12,9 @@ module DiscourseAi plugin.register_seedfu_fixtures( Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_helper"), ) - plugin.register_svg_icon("discourse-sparkles") + + additional_icons = %w[discourse-sparkles spell-check language] + additional_icons.each { |icon| plugin.register_svg_icon(icon) } end end end diff --git a/lib/modules/ai_helper/llm_prompt.rb b/lib/modules/ai_helper/llm_prompt.rb index 145e9007..07185e7a 100644 --- a/lib/modules/ai_helper/llm_prompt.rb +++ b/lib/modules/ai_helper/llm_prompt.rb @@ -19,18 +19,19 @@ module DiscourseAi name: prompt.name, translated_name: translation, prompt_type: prompt.prompt_type, + icon: icon_map(prompt.name), } end end - def generate_and_send_prompt(prompt, text) + def generate_and_send_prompt(prompt, params) case enabled_provider when "openai" - openai_call(prompt, text) + openai_call(prompt, params) when "anthropic" - anthropic_call(prompt, text) + anthropic_call(prompt, params) when "huggingface" - huggingface_call(prompt, text) + huggingface_call(prompt, params) end end @@ -47,6 +48,27 @@ module DiscourseAi private + def icon_map(name) + case name + when "translate" + "language" + when "generate_titles" + "heading" + when "proofread" + "spell-check" + when "markdown_table" + "table" + when "tone" + "microphone" + when "custom_prompt" + "comment" + when "rewrite" + "pen" + else + nil + end + end + def generate_diff(text, suggestion) cooked_text = PrettyText.cook(text) cooked_suggestion = PrettyText.cook(suggestion) @@ -71,10 +93,10 @@ module DiscourseAi end end - def openai_call(prompt, text) + def openai_call(prompt, params) result = { type: prompt.prompt_type } - messages = prompt.messages_with_user_input(text) + messages = prompt.messages_with_user_input(params) result[:suggestions] = DiscourseAi::Inference::OpenAiCompletions .perform!(messages, SiteSetting.ai_helper_model) @@ -83,15 +105,15 @@ module DiscourseAi .flat_map { |choice| parse_content(prompt, choice.dig(:message, :content).to_s) } .compact_blank - result[:diff] = generate_diff(text, result[:suggestions].first) if prompt.diff? + result[:diff] = generate_diff(params[:text], result[:suggestions].first) if prompt.diff? result end - def anthropic_call(prompt, text) + def anthropic_call(prompt, params) result = { type: prompt.prompt_type } - filled_message = prompt.messages_with_user_input(text) + filled_message = prompt.messages_with_user_input(params) message = filled_message.map { |msg| "#{msg["role"]}: #{msg["content"]}" }.join("\n\n") + @@ -101,15 +123,15 @@ module DiscourseAi result[:suggestions] = parse_content(prompt, response.dig(:completion)) - result[:diff] = generate_diff(text, result[:suggestions].first) if prompt.diff? + result[:diff] = generate_diff(params[:text], result[:suggestions].first) if prompt.diff? result end - def huggingface_call(prompt, text) + def huggingface_call(prompt, params) result = { type: prompt.prompt_type } - message = prompt.messages_with_user_input(text) + message = prompt.messages_with_user_input(params) response = DiscourseAi::Inference::HuggingFaceTextGeneration.perform!( @@ -119,7 +141,7 @@ module DiscourseAi result[:suggestions] = parse_content(prompt, response.dig(:generated_text)) - result[:diff] = generate_diff(text, result[:suggestions].first) if prompt.diff? + result[:diff] = generate_diff(params[:text], result[:suggestions].first) if prompt.diff? result end diff --git a/lib/modules/ai_helper/painter.rb b/lib/modules/ai_helper/painter.rb index f62362c2..b9e5de48 100644 --- a/lib/modules/ai_helper/painter.rb +++ b/lib/modules/ai_helper/painter.rb @@ -36,7 +36,10 @@ module DiscourseAi return "" if prompt_for_provider.nil? - llm_prompt.generate_and_send_prompt(prompt_for_provider, text).dig(:suggestions).first + llm_prompt + .generate_and_send_prompt(prompt_for_provider, { text: text }) + .dig(:suggestions) + .first end def completion_prompts diff --git a/spec/lib/modules/ai_helper/llm_prompt_spec.rb b/spec/lib/modules/ai_helper/llm_prompt_spec.rb index ea00dcdd..a03857dd 100644 --- a/spec/lib/modules/ai_helper/llm_prompt_spec.rb +++ b/spec/lib/modules/ai_helper/llm_prompt_spec.rb @@ -3,7 +3,7 @@ require_relative "../../../support/openai_completions_inference_stubs" RSpec.describe DiscourseAi::AiHelper::LlmPrompt do - let(:prompt) { CompletionPrompt.find_by(name: mode) } + let(:prompt) { CompletionPrompt.find_by(name: mode, provider: "openai") } describe "#generate_and_send_prompt" do context "when using the translate mode" do @@ -13,7 +13,10 @@ RSpec.describe DiscourseAi::AiHelper::LlmPrompt do it "Sends the prompt to chatGPT and returns the response" do response = - subject.generate_and_send_prompt(prompt, OpenAiCompletionsInferenceStubs.spanish_text) + subject.generate_and_send_prompt( + prompt, + { text: OpenAiCompletionsInferenceStubs.spanish_text }, + ) expect(response[:suggestions]).to contain_exactly( OpenAiCompletionsInferenceStubs.translated_response.strip, @@ -30,7 +33,7 @@ RSpec.describe DiscourseAi::AiHelper::LlmPrompt do response = subject.generate_and_send_prompt( prompt, - OpenAiCompletionsInferenceStubs.translated_response, + { text: OpenAiCompletionsInferenceStubs.translated_response }, ) expect(response[:suggestions]).to contain_exactly( @@ -56,7 +59,7 @@ RSpec.describe DiscourseAi::AiHelper::LlmPrompt do response = subject.generate_and_send_prompt( prompt, - OpenAiCompletionsInferenceStubs.translated_response, + { text: OpenAiCompletionsInferenceStubs.translated_response }, ) expect(response[:suggestions]).to contain_exactly(*expected) diff --git a/spec/support/openai_completions_inference_stubs.rb b/spec/support/openai_completions_inference_stubs.rb index afa1cd0a..fc5619ac 100644 --- a/spec/support/openai_completions_inference_stubs.rb +++ b/spec/support/openai_completions_inference_stubs.rb @@ -4,6 +4,7 @@ class OpenAiCompletionsInferenceStubs TRANSLATE = "translate" PROOFREAD = "proofread" GENERATE_TITLES = "generate_titles" + CUSTOM_PROMPT = "custom_prompt" class << self def text_mode_to_id(mode) @@ -14,6 +15,8 @@ class OpenAiCompletionsInferenceStubs -3 when GENERATE_TITLES -2 + when CUSTOM_PROMPT + -5 end end @@ -30,6 +33,16 @@ class OpenAiCompletionsInferenceStubs STRING end + def custom_prompt_input + "Translate to French" + end + + def custom_prompt_response + <<~STRING + Le destin favorise les répétitions, les variantes, les symétries ; + STRING + end + def translated_response <<~STRING "To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends, @@ -90,13 +103,24 @@ class OpenAiCompletionsInferenceStubs proofread_response when GENERATE_TITLES generated_titles + when CUSTOM_PROMPT + custom_prompt_response end end def stub_prompt(type) - text = type == TRANSLATE ? spanish_text : translated_response + user_input = type == TRANSLATE ? spanish_text : translated_response + id = text_mode_to_id(type) - prompt_messages = CompletionPrompt.find_by(name: type).messages_with_user_input(text) + if type == CUSTOM_PROMPT + user_input = { mode: id, text: translated_response, custom_prompt: "Translate to French" } + elsif type == TRANSLATE + user_input = { mode: id, text: spanish_text, custom_prompt: "" } + else + user_input = { mode: id, text: translated_response, custom_prompt: "" } + end + + prompt_messages = CompletionPrompt.find_by(id: id).messages_with_user_input(user_input) stub_response(prompt_messages, response_text_for(type)) end diff --git a/spec/system/ai_helper/ai_composer_helper_spec.rb b/spec/system/ai_helper/ai_composer_helper_spec.rb index 14b08710..48c8ae87 100644 --- a/spec/system/ai_helper/ai_composer_helper_spec.rb +++ b/spec/system/ai_helper/ai_composer_helper_spec.rb @@ -54,6 +54,56 @@ RSpec.describe "AI Composer helper", type: :system, js: true do expect(ai_helper_context_menu).to have_no_context_menu end + context "when using custom prompt" do + let(:mode) { OpenAiCompletionsInferenceStubs::CUSTOM_PROMPT } + before { OpenAiCompletionsInferenceStubs.stub_prompt(mode) } + + it "shows custom prompt option" do + trigger_context_menu(OpenAiCompletionsInferenceStubs.translated_response) + ai_helper_context_menu.click_ai_button + expect(ai_helper_context_menu).to have_custom_prompt + end + + it "shows the custom prompt button when input is filled" do + trigger_context_menu(OpenAiCompletionsInferenceStubs.translated_response) + ai_helper_context_menu.click_ai_button + expect(ai_helper_context_menu).to have_no_custom_prompt_button + ai_helper_context_menu.fill_custom_prompt( + OpenAiCompletionsInferenceStubs.custom_prompt_input, + ) + expect(ai_helper_context_menu).to have_custom_prompt_button + end + + it "replaces the composed message with AI generated content" do + trigger_context_menu(OpenAiCompletionsInferenceStubs.translated_response) + ai_helper_context_menu.click_ai_button + ai_helper_context_menu.fill_custom_prompt( + OpenAiCompletionsInferenceStubs.custom_prompt_input, + ) + ai_helper_context_menu.click_custom_prompt_button + + wait_for do + composer.composer_input.value == + OpenAiCompletionsInferenceStubs.custom_prompt_response.strip + end + + expect(composer.composer_input.value).to eq( + OpenAiCompletionsInferenceStubs.custom_prompt_response.strip, + ) + end + end + + context "when not a member of custom prompt group" do + let(:mode) { OpenAiCompletionsInferenceStubs::CUSTOM_PROMPT } + before { SiteSetting.ai_helper_custom_prompts_allowed_groups = non_member_group.id.to_s } + + it "does not show custom prompt option" do + trigger_context_menu(OpenAiCompletionsInferenceStubs.translated_response) + ai_helper_context_menu.click_ai_button + expect(ai_helper_context_menu).to have_no_custom_prompt + end + end + context "when using translation mode" do let(:mode) { OpenAiCompletionsInferenceStubs::TRANSLATE } before { OpenAiCompletionsInferenceStubs.stub_prompt(mode) } diff --git a/spec/system/page_objects/components/ai_helper_context_menu.rb b/spec/system/page_objects/components/ai_helper_context_menu.rb index 52623868..47f6355a 100644 --- a/spec/system/page_objects/components/ai_helper_context_menu.rb +++ b/spec/system/page_objects/components/ai_helper_context_menu.rb @@ -10,6 +10,9 @@ module PageObjects LOADING_STATE_SELECTOR = "#{CONTEXT_MENU_SELECTOR}__loading" RESETS_STATE_SELECTOR = "#{CONTEXT_MENU_SELECTOR}__resets" REVIEW_STATE_SELECTOR = "#{CONTEXT_MENU_SELECTOR}__review" + CUSTOM_PROMPT_SELECTOR = "#{CONTEXT_MENU_SELECTOR} .ai-custom-prompt" + CUSTOM_PROMPT_INPUT_SELECTOR = "#{CUSTOM_PROMPT_SELECTOR}__input" + CUSTOM_PROMPT_BUTTON_SELECTOR = "#{CUSTOM_PROMPT_SELECTOR}__submit" def click_ai_button find("#{TRIGGER_STATE_SELECTOR} .btn").click @@ -43,6 +46,15 @@ module PageObjects find("body").send_keys(:escape) end + def click_custom_prompt_button + find(CUSTOM_PROMPT_BUTTON_SELECTOR).click + end + + def fill_custom_prompt(content) + find(CUSTOM_PROMPT_INPUT_SELECTOR).fill_in(with: content) + self + end + def has_context_menu? page.has_css?(CONTEXT_MENU_SELECTOR) end @@ -70,6 +82,22 @@ module PageObjects def not_showing_resets? page.has_no_css?(RESETS_STATE_SELECTOR) end + + def has_custom_prompt? + page.has_css?(CUSTOM_PROMPT_SELECTOR) + end + + def has_no_custom_prompt? + page.has_no_css?(CUSTOM_PROMPT_SELECTOR) + end + + def has_custom_prompt_button? + page.has_css?(CUSTOM_PROMPT_BUTTON_SELECTOR) + end + + def has_no_custom_prompt_button? + page.has_no_css?(CUSTOM_PROMPT_BUTTON_SELECTOR) + end end end end