FEATURE: Create custom prompts with composer AI helper (#214)

* DEV: Add icon support

* DEV: Add basic setup for custom prompt menu

* FEATURE: custom prompt backend

* fix custom prompt param check

* fix custom prompt replace

* WIP

* fix custom prompt usage

* fixes

* DEV: Update front-end

* DEV: No more custom prompt state

* DEV: Add specs

* FIX: Title/Category/Tag suggestions

Suggestion dropdowns broke because it `messages_with_user_input(user_input)` expects a hash now.

* DEV: Apply syntax tree

* DEV: Restrict custom prompts to configured groups

* oops

* fix tests

* lint

* I love tests

* lint is cool tho

---------

Co-authored-by: Rafael dos Santos Silva <xfalcox@gmail.com>
This commit is contained in:
Keegan George 2023-09-25 11:12:54 -07:00 committed by GitHub
parent 316ea9624e
commit 2e5a39360a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 442 additions and 125 deletions

View File

@ -21,10 +21,15 @@ module DiscourseAi
input = get_text_param!
prompt = CompletionPrompt.find_by(id: params[:mode])
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
if prompt.prompt_type == "custom_prompt" && params[:custom_prompt].blank?
raise Discourse::InvalidParameters.new(:custom_prompt)
end
hijack do
render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input),
render json:
DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, params),
status: 200
end
rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed,
@ -36,6 +41,7 @@ module DiscourseAi
def suggest_title
input = get_text_param!
input_hash = { text: input }
llm_prompt =
DiscourseAi::AiHelper::LlmPrompt
@ -46,7 +52,11 @@ module DiscourseAi
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
hijack do
render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input),
render json:
DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(
prompt,
input_hash,
),
status: 200
end
rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed,
@ -58,15 +68,21 @@ module DiscourseAi
def suggest_category
input = get_text_param!
input_hash = { text: input }
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).categories,
render json:
DiscourseAi::AiHelper::SemanticCategorizer.new(
input_hash,
current_user,
).categories,
status: 200
end
def suggest_tags
input = get_text_param!
input_hash = { text: input }
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).tags,
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input_hash, current_user).tags,
status: 200
end

View File

@ -10,13 +10,24 @@ class CompletionPrompt < ActiveRecord::Base
validate :each_message_length
def messages_with_user_input(user_input)
if user_input[:custom_prompt].present?
case ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider
when "huggingface"
self.messages.each { |msg| msg.sub!("{{custom_prompt}}", user_input[:custom_prompt]) }
else
self.messages.each do |msg|
msg["content"].sub!("{{custom_prompt}}", user_input[:custom_prompt])
end
end
end
case ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider
when "openai"
self.messages << { role: "user", content: user_input }
self.messages << { role: "user", content: user_input[:text] }
when "anthropic"
self.messages << { "role" => "Input", "content" => "<input>#{user_input}</input>" }
self.messages << { "role" => "Input", "content" => "<input>#{user_input[:text]}</input>" }
when "huggingface"
self.messages.first.sub("{{user_input}}", user_input)
self.messages.first.sub("{{user_input}}", user_input[:text])
end
end

View File

@ -16,33 +16,54 @@
{{else if (eq this.menuState this.CONTEXT_MENU_STATES.options)}}
<ul class="ai-helper-context-menu__options">
{{#each this.helperOptions as |option|}}
<li data-name={{option.name}} data-value={{option.value}}>
<DButton
@class="btn-flat"
@translatedLabel={{option.name}}
@action={{this.updateSelected}}
@actionParam={{option.value}}
/>
</li>
{{#if (eq option.name "custom_prompt")}}
<div
class="ai-custom-prompt"
{{did-insert this.setupCustomPrompt}}
>
<Input
@value={{this.customPromptValue}}
placeholder="Enter a custom prompt..."
class="ai-custom-prompt__input"
@enter={{action (fn this.updateSelected option)}}
/>
{{#if this.customPromptValue.length}}
<DButton
@class="ai-custom-prompt__submit btn-primary"
@icon="discourse-sparkles"
@action={{this.updateSelected}}
@actionParam={{option}}
/>
{{/if}}
</div>
{{else}}
<li data-name={{option.translated_name}} data-value={{option.id}}>
<DButton
@icon={{option.icon}}
@class="btn-flat"
@translatedLabel={{option.translated_name}}
@action={{this.updateSelected}}
@actionParam={{option}}
/>
</li>
{{/if}}
{{/each}}
</ul>
{{else if (eq this.menuState this.CONTEXT_MENU_STATES.loading)}}
<ul class="ai-helper-context-menu__loading">
<li>
<div class="dot-falling"></div>
<span>
{{i18n "discourse_ai.ai_helper.context_menu.loading"}}
</span>
<DButton
@icon="times"
@title="discourse_ai.ai_helper.context_menu.cancel"
@action={{this.cancelAIAction}}
class="btn-flat cancel-request"
/>
</li>
</ul>
<div class="ai-helper-context-menu__loading">
<div class="dot-falling"></div>
<span>
{{i18n "discourse_ai.ai_helper.context_menu.loading"}}
</span>
<DButton
@icon="times"
@title="discourse_ai.ai_helper.context_menu.cancel"
@action={{this.cancelAIAction}}
class="btn-flat cancel-request"
/>
</div>
{{else if (eq this.menuState this.CONTEXT_MENU_STATES.review)}}
<ul class="ai-helper-context-menu__review">
@ -92,7 +113,6 @@
/>
</li>
</ul>
{{/if}}
</div>
{{/if}}

View File

@ -15,10 +15,10 @@ export default class AiHelperContextMenu extends Component {
return showAIHelper(outletArgs, helper);
}
@service currentUser;
@service siteSettings;
@tracked helperOptions = [];
@tracked showContextMenu = false;
@tracked menuState = this.CONTEXT_MENU_STATES.triggers;
@tracked caretCoords;
@tracked virtualElement;
@tracked selectedText = "";
@ -30,6 +30,8 @@ export default class AiHelperContextMenu extends Component {
@tracked showDiffModal = false;
@tracked diff;
@tracked popperPlacement = "top-start";
@tracked previousMenuState = null;
@tracked customPromptValue = "";
CONTEXT_MENU_STATES = {
triggers: "TRIGGERS",
@ -41,8 +43,10 @@ export default class AiHelperContextMenu extends Component {
prompts = [];
promptTypes = {};
@tracked _menuState = this.CONTEXT_MENU_STATES.triggers;
@tracked _popper;
@tracked _dEditorInput;
@tracked _customPromptInput;
@tracked _contextMenu;
@tracked _activeAIRequest = null;
@ -62,27 +66,42 @@ export default class AiHelperContextMenu extends Component {
this._popper?.destroy();
}
get menuState() {
return this._menuState;
}
set menuState(newState) {
this.previousMenuState = this._menuState;
this._menuState = newState;
}
async loadPrompts() {
let prompts = await ajax("/discourse-ai/ai-helper/prompts");
prompts
.filter((p) => p.name !== "generate_titles")
.map((p) => {
this.prompts[p.id] = p;
});
prompts = prompts.filter((p) => p.name !== "generate_titles");
// Find the custom_prompt object and move it to the beginning of the array
const customPromptIndex = prompts.findIndex(
(p) => p.name === "custom_prompt"
);
if (customPromptIndex !== -1) {
const customPrompt = prompts.splice(customPromptIndex, 1)[0];
prompts.unshift(customPrompt);
}
if (!this._showUserCustomPrompts()) {
prompts = prompts.filter((p) => p.name !== "custom_prompt");
}
prompts.forEach((p) => {
this.prompts[p.id] = p;
});
this.promptTypes = prompts.reduce((memo, p) => {
memo[p.name] = p.prompt_type;
return memo;
}, {});
this.helperOptions = prompts
.filter((p) => p.name !== "generate_titles")
.map((p) => {
return {
name: p.translated_name,
value: p.id,
};
});
this.helperOptions = prompts;
}
@bind
@ -153,6 +172,10 @@ export default class AiHelperContextMenu extends Component {
}
get canCloseContextMenu() {
if (document.activeElement === this._customPromptInput) {
return false;
}
if (this.loading && this._activeAIRequest !== null) {
return false;
}
@ -168,9 +191,9 @@ export default class AiHelperContextMenu extends Component {
if (!this.canCloseContextMenu) {
return;
}
this.showContextMenu = false;
this.menuState = this.CONTEXT_MENU_STATES.triggers;
this.customPromptValue = "";
}
_updateSuggestedByAI(data) {
@ -200,6 +223,15 @@ export default class AiHelperContextMenu extends Component {
return (this.loading = false);
}
_showUserCustomPrompts() {
const allowedGroups =
this.siteSettings?.ai_helper_custom_prompts_allowed_groups
.split("|")
.map((id) => parseInt(id, 10));
return this.currentUser?.groups.some((g) => allowedGroups.includes(g.id));
}
handleBoundaries() {
const textAreaWrapper = document
.querySelector(".d-editor-textarea-wrapper")
@ -276,6 +308,14 @@ export default class AiHelperContextMenu extends Component {
}
}
@action
setupCustomPrompt() {
this._customPromptInput = document.querySelector(
".ai-custom-prompt__input"
);
this._customPromptInput.focus();
}
@action
toggleAiHelperOptions() {
// Fetch prompts only if it hasn't been fetched yet
@ -303,7 +343,11 @@ export default class AiHelperContextMenu extends Component {
this._activeAIRequest = ajax("/discourse-ai/ai-helper/suggest", {
method: "POST",
data: { mode: option, text: this.selectedText },
data: {
mode: option.id,
text: this.selectedText,
custom_prompt: this.customPromptValue,
},
});
this._activeAIRequest
@ -340,4 +384,9 @@ export default class AiHelperContextMenu extends Component {
this.closeContextMenu();
}
}
@action
togglePreviousMenu() {
this.menuState = this.previousMenuState;
}
}

View File

@ -50,50 +50,51 @@
list-style: none;
}
.btn {
justify-content: left;
text-align: left;
background: none;
width: 100%;
border-radius: 0;
margin: 0;
li {
.btn-flat {
justify-content: left;
text-align: left;
background: none;
width: 100%;
border-radius: 0;
margin: 0;
padding-block: 0.6rem;
&:focus,
&:hover {
color: var(--primary);
background: var(--d-hover);
&:focus,
&:hover {
color: var(--primary);
background: var(--d-hover);
.d-icon {
color: var(--primary-medium);
.d-icon {
color: var(--primary-medium);
}
}
.d-button-label {
color: var(--primary-very-high);
}
}
}
.d-button-label {
color: var(--primary-very-high);
}
&__options {
padding: 0.25rem;
li:not(:last-child) {
border-bottom: 1px solid var(--primary-low);
}
}
&__loading {
display: flex;
padding: 0.5rem;
gap: 1rem;
justify-content: flex-start;
align-items: center;
.dot-falling {
margin-inline: 1rem;
margin-left: 1.5rem;
}
li {
display: flex;
padding: 0.5rem;
gap: 1rem;
justify-content: flex-start;
align-items: center;
}
.btn {
width: unset;
}
}
&__resets {
@ -107,6 +108,41 @@
align-items: center;
flex-flow: row wrap;
}
&__custom-prompt {
display: flex;
flex-flow: row wrap;
padding: 0.5rem;
&-header {
margin-bottom: 0.5rem;
.btn {
padding: 0;
}
}
.ai-custom-prompt-input {
min-height: 90px;
width: 100%;
}
}
.ai-custom-prompt {
display: flex;
gap: 0.25rem;
margin-bottom: 0.5rem;
&__input {
background: var(--primary-low);
border-color: var(--primary-low);
margin-bottom: 0;
&::placeholder {
color: var(--primary-medium);
}
}
}
}
.d-editor-input.loading {

View File

@ -19,6 +19,7 @@ en:
suggest: "Suggest with AI"
missing_content: "Please enter some content to generate suggestions."
context_menu:
back: "Back"
trigger: "AI"
undo: "Undo"
loading: "AI is generating"
@ -28,6 +29,10 @@ en:
confirm: "Confirm"
revert: "Revert"
changes: "Changes"
custom_prompt:
title: "Custom Prompt"
placeholder: "Enter a custom prompt..."
submit: "Send Prompt"
reviewables:
model_used: "Model used:"
accuracy: "Accuracy:"

View File

@ -21,7 +21,7 @@ en:
ai_sentiment_models: "Models to use for inference. Sentiment classifies post on the positive/neutral/negative space. Emotion classifies on the anger/disgust/fear/joy/neutral/sadness/surprise space."
ai_nsfw_detection_enabled: "Enable the NSFW module."
ai_nsfw_inference_service_api_endpoint: "URL where the API is running for the NSFW module"
ai_nsfw_inference_service_api_endpoint: "URL where the API is running for the NSFW module"
ai_nsfw_inference_service_api_key: "API key for the NSFW API"
ai_nsfw_flag_automatically: "Automatically flag NSFW posts that are above the configured thresholds."
ai_nsfw_flag_threshold_general: "General Threshold for an image to be considered NSFW."
@ -44,6 +44,7 @@ en:
ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
ai_helper_allowed_in_pm: "Enable the composer's AI helper in PMs."
ai_helper_model: "Model to use for the AI helper."
ai_helper_custom_prompts_allowed_groups: "Users on these groups will see the custom prompt option in the AI helper."
ai_embeddings_enabled: "Enable the embeddings module."
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
@ -76,9 +77,9 @@ en:
ai_google_custom_search_cx: "CX for Google Custom Search API"
reviewables:
reasons:
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
flagged_by_nsfw: The AI plugin flagged this after classifying at least one of the attached images as NSFW.
reasons:
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
flagged_by_nsfw: The AI plugin flagged this after classifying at least one of the attached images as NSFW.
errors:
prompt_message_length: The message %{idx} is over the 1000 character limit.
@ -93,6 +94,7 @@ en:
generate_titles: Suggest topic titles
proofread: Proofread text
markdown_table: Generate Markdown table
custom_prompt: "Custom Prompt"
ai_bot:
personas:

View File

@ -9,7 +9,7 @@ discourse_ai:
ai_toxicity_inference_service_api_endpoint:
default: "https://disorder-testing.demo-by-discourse.com"
ai_toxicity_inference_service_api_key:
default: ''
default: ""
secret: true
ai_toxicity_inference_service_api_model:
type: enum
@ -56,7 +56,7 @@ discourse_ai:
ai_sentiment_inference_service_api_endpoint:
default: "https://sentiment-testing.demo-by-discourse.com"
ai_sentiment_inference_service_api_key:
default: ''
default: ""
secret: true
ai_sentiment_models:
type: list
@ -64,8 +64,8 @@ discourse_ai:
default: "emotion"
allow_any: false
choices:
- sentiment
- emotion
- sentiment
- emotion
ai_nsfw_detection_enabled: false
ai_nsfw_inference_service_api_endpoint:
@ -85,8 +85,8 @@ discourse_ai:
default: "opennsfw2"
allow_any: false
choices:
- opennsfw2
- nsfw_detector
- opennsfw2
- nsfw_detector
ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions"
ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions"
@ -148,6 +148,13 @@ discourse_ai:
- gpt-4
- claude-2
- stable-beluga-2
ai_helper_custom_prompts_allowed_groups:
client: true
type: group_list
list_type: compact
default: "3" # 3: @staff
allow_any: false
refresh: true
ai_embeddings_enabled:
default: false
@ -162,9 +169,9 @@ discourse_ai:
default: "all-mpnet-base-v2"
allow_any: false
choices:
- all-mpnet-base-v2
- text-embedding-ada-002
- multilingual-e5-large
- all-mpnet-base-v2
- text-embedding-ada-002
- multilingual-e5-large
ai_embeddings_generate_for_pms: false
ai_embeddings_semantic_related_topics_enabled:
default: false
@ -183,11 +190,11 @@ discourse_ai:
allow_any: false
choices:
- Llama2-*-chat-hf
- claude-instant-1
- claude-2
- claude-instant-1
- claude-2
- gpt-3.5-turbo
- gpt-4
- StableBeluga2
- gpt-4
- StableBeluga2
- Upstage-Llama-2-*-instruct-v2
ai_summarization_discourse_service_api_endpoint: ""
@ -211,22 +218,22 @@ discourse_ai:
default: "gpt-3.5-turbo"
client: true
choices:
- gpt-3.5-turbo
- gpt-4
- claude-2
- gpt-3.5-turbo
- gpt-4
- claude-2
ai_bot_enabled_chat_commands:
type: list
default: "categories|google|image|search|tags|time|read"
client: true
choices:
- categories
- google
- image
- search
- summarize
- read
- tags
- time
- categories
- google
- image
- search
- summarize
- read
- tags
- time
ai_bot_enabled_personas:
type: list
default: "general|artist|sql_helper|settings_explorer|researcher"

View File

@ -123,3 +123,14 @@ CompletionPrompt.seed do |cp|
TEXT
]
end
CompletionPrompt.seed do |cp|
cp.id = -5
cp.provider = "openai"
cp.name = "custom_prompt"
cp.prompt_type = CompletionPrompt.prompt_types[:list]
cp.messages = [{ role: "system", content: <<~TEXT }]
You are a helpful assistant, I will provide you with a text below,
you will {{custom_prompt}} and you will reply with the result.
TEXT
end

View File

@ -54,3 +54,14 @@ CompletionPrompt.seed do |cp|
please reply with the corrected text between <ai></ai> tags.
TEXT
end
CompletionPrompt.seed do |cp|
cp.id = -105
cp.provider = "anthropic"
cp.name = "custom_prompt"
cp.prompt_type = CompletionPrompt.prompt_types[:diff]
cp.messages = [{ role: "Human", content: <<~TEXT }]
You are a helpful assistant, I will provide you with a text inside <input> tags,
you will {{custom_prompt}} and you will reply with the result between <ai></ai> tags.
TEXT
end

View File

@ -109,3 +109,20 @@ CompletionPrompt.seed do |cp|
### Assistant:
TEXT
end
CompletionPrompt.seed do |cp|
cp.id = -205
cp.provider = "huggingface"
cp.name = "custom_prompt"
cp.prompt_type = CompletionPrompt.prompt_types[:diff]
cp.messages = [<<~TEXT]
### System:
You are a helpful assistant, I will provide you with a text below,
you will {{custom_prompt}} and you will reply with the result.
### User:
{{user_input}}
### Assistant:
TEXT
end

View File

@ -12,7 +12,9 @@ module DiscourseAi
plugin.register_seedfu_fixtures(
Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_helper"),
)
plugin.register_svg_icon("discourse-sparkles")
additional_icons = %w[discourse-sparkles spell-check language]
additional_icons.each { |icon| plugin.register_svg_icon(icon) }
end
end
end

View File

@ -19,18 +19,19 @@ module DiscourseAi
name: prompt.name,
translated_name: translation,
prompt_type: prompt.prompt_type,
icon: icon_map(prompt.name),
}
end
end
def generate_and_send_prompt(prompt, text)
def generate_and_send_prompt(prompt, params)
case enabled_provider
when "openai"
openai_call(prompt, text)
openai_call(prompt, params)
when "anthropic"
anthropic_call(prompt, text)
anthropic_call(prompt, params)
when "huggingface"
huggingface_call(prompt, text)
huggingface_call(prompt, params)
end
end
@ -47,6 +48,27 @@ module DiscourseAi
private
def icon_map(name)
case name
when "translate"
"language"
when "generate_titles"
"heading"
when "proofread"
"spell-check"
when "markdown_table"
"table"
when "tone"
"microphone"
when "custom_prompt"
"comment"
when "rewrite"
"pen"
else
nil
end
end
def generate_diff(text, suggestion)
cooked_text = PrettyText.cook(text)
cooked_suggestion = PrettyText.cook(suggestion)
@ -71,10 +93,10 @@ module DiscourseAi
end
end
def openai_call(prompt, text)
def openai_call(prompt, params)
result = { type: prompt.prompt_type }
messages = prompt.messages_with_user_input(text)
messages = prompt.messages_with_user_input(params)
result[:suggestions] = DiscourseAi::Inference::OpenAiCompletions
.perform!(messages, SiteSetting.ai_helper_model)
@ -83,15 +105,15 @@ module DiscourseAi
.flat_map { |choice| parse_content(prompt, choice.dig(:message, :content).to_s) }
.compact_blank
result[:diff] = generate_diff(text, result[:suggestions].first) if prompt.diff?
result[:diff] = generate_diff(params[:text], result[:suggestions].first) if prompt.diff?
result
end
def anthropic_call(prompt, text)
def anthropic_call(prompt, params)
result = { type: prompt.prompt_type }
filled_message = prompt.messages_with_user_input(text)
filled_message = prompt.messages_with_user_input(params)
message =
filled_message.map { |msg| "#{msg["role"]}: #{msg["content"]}" }.join("\n\n") +
@ -101,15 +123,15 @@ module DiscourseAi
result[:suggestions] = parse_content(prompt, response.dig(:completion))
result[:diff] = generate_diff(text, result[:suggestions].first) if prompt.diff?
result[:diff] = generate_diff(params[:text], result[:suggestions].first) if prompt.diff?
result
end
def huggingface_call(prompt, text)
def huggingface_call(prompt, params)
result = { type: prompt.prompt_type }
message = prompt.messages_with_user_input(text)
message = prompt.messages_with_user_input(params)
response =
DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(
@ -119,7 +141,7 @@ module DiscourseAi
result[:suggestions] = parse_content(prompt, response.dig(:generated_text))
result[:diff] = generate_diff(text, result[:suggestions].first) if prompt.diff?
result[:diff] = generate_diff(params[:text], result[:suggestions].first) if prompt.diff?
result
end

View File

@ -36,7 +36,10 @@ module DiscourseAi
return "" if prompt_for_provider.nil?
llm_prompt.generate_and_send_prompt(prompt_for_provider, text).dig(:suggestions).first
llm_prompt
.generate_and_send_prompt(prompt_for_provider, { text: text })
.dig(:suggestions)
.first
end
def completion_prompts

View File

@ -3,7 +3,7 @@
require_relative "../../../support/openai_completions_inference_stubs"
RSpec.describe DiscourseAi::AiHelper::LlmPrompt do
let(:prompt) { CompletionPrompt.find_by(name: mode) }
let(:prompt) { CompletionPrompt.find_by(name: mode, provider: "openai") }
describe "#generate_and_send_prompt" do
context "when using the translate mode" do
@ -13,7 +13,10 @@ RSpec.describe DiscourseAi::AiHelper::LlmPrompt do
it "Sends the prompt to chatGPT and returns the response" do
response =
subject.generate_and_send_prompt(prompt, OpenAiCompletionsInferenceStubs.spanish_text)
subject.generate_and_send_prompt(
prompt,
{ text: OpenAiCompletionsInferenceStubs.spanish_text },
)
expect(response[:suggestions]).to contain_exactly(
OpenAiCompletionsInferenceStubs.translated_response.strip,
@ -30,7 +33,7 @@ RSpec.describe DiscourseAi::AiHelper::LlmPrompt do
response =
subject.generate_and_send_prompt(
prompt,
OpenAiCompletionsInferenceStubs.translated_response,
{ text: OpenAiCompletionsInferenceStubs.translated_response },
)
expect(response[:suggestions]).to contain_exactly(
@ -56,7 +59,7 @@ RSpec.describe DiscourseAi::AiHelper::LlmPrompt do
response =
subject.generate_and_send_prompt(
prompt,
OpenAiCompletionsInferenceStubs.translated_response,
{ text: OpenAiCompletionsInferenceStubs.translated_response },
)
expect(response[:suggestions]).to contain_exactly(*expected)

View File

@ -4,6 +4,7 @@ class OpenAiCompletionsInferenceStubs
TRANSLATE = "translate"
PROOFREAD = "proofread"
GENERATE_TITLES = "generate_titles"
CUSTOM_PROMPT = "custom_prompt"
class << self
def text_mode_to_id(mode)
@ -14,6 +15,8 @@ class OpenAiCompletionsInferenceStubs
-3
when GENERATE_TITLES
-2
when CUSTOM_PROMPT
-5
end
end
@ -30,6 +33,16 @@ class OpenAiCompletionsInferenceStubs
STRING
end
def custom_prompt_input
"Translate to French"
end
def custom_prompt_response
<<~STRING
Le destin favorise les répétitions, les variantes, les symétries ;
STRING
end
def translated_response
<<~STRING
"To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
@ -90,13 +103,24 @@ class OpenAiCompletionsInferenceStubs
proofread_response
when GENERATE_TITLES
generated_titles
when CUSTOM_PROMPT
custom_prompt_response
end
end
def stub_prompt(type)
text = type == TRANSLATE ? spanish_text : translated_response
user_input = type == TRANSLATE ? spanish_text : translated_response
id = text_mode_to_id(type)
prompt_messages = CompletionPrompt.find_by(name: type).messages_with_user_input(text)
if type == CUSTOM_PROMPT
user_input = { mode: id, text: translated_response, custom_prompt: "Translate to French" }
elsif type == TRANSLATE
user_input = { mode: id, text: spanish_text, custom_prompt: "" }
else
user_input = { mode: id, text: translated_response, custom_prompt: "" }
end
prompt_messages = CompletionPrompt.find_by(id: id).messages_with_user_input(user_input)
stub_response(prompt_messages, response_text_for(type))
end

View File

@ -54,6 +54,56 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
expect(ai_helper_context_menu).to have_no_context_menu
end
context "when using custom prompt" do
let(:mode) { OpenAiCompletionsInferenceStubs::CUSTOM_PROMPT }
before { OpenAiCompletionsInferenceStubs.stub_prompt(mode) }
it "shows custom prompt option" do
trigger_context_menu(OpenAiCompletionsInferenceStubs.translated_response)
ai_helper_context_menu.click_ai_button
expect(ai_helper_context_menu).to have_custom_prompt
end
it "shows the custom prompt button when input is filled" do
trigger_context_menu(OpenAiCompletionsInferenceStubs.translated_response)
ai_helper_context_menu.click_ai_button
expect(ai_helper_context_menu).to have_no_custom_prompt_button
ai_helper_context_menu.fill_custom_prompt(
OpenAiCompletionsInferenceStubs.custom_prompt_input,
)
expect(ai_helper_context_menu).to have_custom_prompt_button
end
it "replaces the composed message with AI generated content" do
trigger_context_menu(OpenAiCompletionsInferenceStubs.translated_response)
ai_helper_context_menu.click_ai_button
ai_helper_context_menu.fill_custom_prompt(
OpenAiCompletionsInferenceStubs.custom_prompt_input,
)
ai_helper_context_menu.click_custom_prompt_button
wait_for do
composer.composer_input.value ==
OpenAiCompletionsInferenceStubs.custom_prompt_response.strip
end
expect(composer.composer_input.value).to eq(
OpenAiCompletionsInferenceStubs.custom_prompt_response.strip,
)
end
end
context "when not a member of custom prompt group" do
let(:mode) { OpenAiCompletionsInferenceStubs::CUSTOM_PROMPT }
before { SiteSetting.ai_helper_custom_prompts_allowed_groups = non_member_group.id.to_s }
it "does not show custom prompt option" do
trigger_context_menu(OpenAiCompletionsInferenceStubs.translated_response)
ai_helper_context_menu.click_ai_button
expect(ai_helper_context_menu).to have_no_custom_prompt
end
end
context "when using translation mode" do
let(:mode) { OpenAiCompletionsInferenceStubs::TRANSLATE }
before { OpenAiCompletionsInferenceStubs.stub_prompt(mode) }

View File

@ -10,6 +10,9 @@ module PageObjects
LOADING_STATE_SELECTOR = "#{CONTEXT_MENU_SELECTOR}__loading"
RESETS_STATE_SELECTOR = "#{CONTEXT_MENU_SELECTOR}__resets"
REVIEW_STATE_SELECTOR = "#{CONTEXT_MENU_SELECTOR}__review"
CUSTOM_PROMPT_SELECTOR = "#{CONTEXT_MENU_SELECTOR} .ai-custom-prompt"
CUSTOM_PROMPT_INPUT_SELECTOR = "#{CUSTOM_PROMPT_SELECTOR}__input"
CUSTOM_PROMPT_BUTTON_SELECTOR = "#{CUSTOM_PROMPT_SELECTOR}__submit"
def click_ai_button
find("#{TRIGGER_STATE_SELECTOR} .btn").click
@ -43,6 +46,15 @@ module PageObjects
find("body").send_keys(:escape)
end
def click_custom_prompt_button
find(CUSTOM_PROMPT_BUTTON_SELECTOR).click
end
def fill_custom_prompt(content)
find(CUSTOM_PROMPT_INPUT_SELECTOR).fill_in(with: content)
self
end
def has_context_menu?
page.has_css?(CONTEXT_MENU_SELECTOR)
end
@ -70,6 +82,22 @@ module PageObjects
def not_showing_resets?
page.has_no_css?(RESETS_STATE_SELECTOR)
end
def has_custom_prompt?
page.has_css?(CUSTOM_PROMPT_SELECTOR)
end
def has_no_custom_prompt?
page.has_no_css?(CUSTOM_PROMPT_SELECTOR)
end
def has_custom_prompt_button?
page.has_css?(CUSTOM_PROMPT_BUTTON_SELECTOR)
end
def has_no_custom_prompt_button?
page.has_no_css?(CUSTOM_PROMPT_BUTTON_SELECTOR)
end
end
end
end