FEATURE: Create custom prompts with composer AI helper (#214)

* DEV: Add icon support

* DEV: Add basic setup for custom prompt menu

* FEATURE: custom prompt backend

* fix custom prompt param check

* fix custom prompt replace

* WIP

* fix custom prompt usage

* fixes

* DEV: Update front-end

* DEV: No more custom prompt state

* DEV: Add specs

* FIX: Title/Category/Tag suggestions

Suggestion dropdowns broke because it `messages_with_user_input(user_input)` expects a hash now.

* DEV: Apply syntax tree

* DEV: Restrict custom prompts to configured groups

* oops

* fix tests

* lint

* I love tests

* lint is cool tho

---------

Co-authored-by: Rafael dos Santos Silva <xfalcox@gmail.com>
This commit is contained in:
Keegan George 2023-09-25 11:12:54 -07:00 committed by GitHub
parent 316ea9624e
commit 2e5a39360a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 442 additions and 125 deletions

View File

@ -21,10 +21,15 @@ module DiscourseAi
input = get_text_param! input = get_text_param!
prompt = CompletionPrompt.find_by(id: params[:mode]) prompt = CompletionPrompt.find_by(id: params[:mode])
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled? raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
if prompt.prompt_type == "custom_prompt" && params[:custom_prompt].blank?
raise Discourse::InvalidParameters.new(:custom_prompt)
end
hijack do hijack do
render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input), render json:
DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, params),
status: 200 status: 200
end end
rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed, rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed,
@ -36,6 +41,7 @@ module DiscourseAi
def suggest_title def suggest_title
input = get_text_param! input = get_text_param!
input_hash = { text: input }
llm_prompt = llm_prompt =
DiscourseAi::AiHelper::LlmPrompt DiscourseAi::AiHelper::LlmPrompt
@ -46,7 +52,11 @@ module DiscourseAi
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled? raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
hijack do hijack do
render json: DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(prompt, input), render json:
DiscourseAi::AiHelper::LlmPrompt.new.generate_and_send_prompt(
prompt,
input_hash,
),
status: 200 status: 200
end end
rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed, rescue ::DiscourseAi::Inference::OpenAiCompletions::CompletionFailed,
@ -58,15 +68,21 @@ module DiscourseAi
def suggest_category def suggest_category
input = get_text_param! input = get_text_param!
input_hash = { text: input }
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).categories, render json:
DiscourseAi::AiHelper::SemanticCategorizer.new(
input_hash,
current_user,
).categories,
status: 200 status: 200
end end
def suggest_tags def suggest_tags
input = get_text_param! input = get_text_param!
input_hash = { text: input }
render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input, current_user).tags, render json: DiscourseAi::AiHelper::SemanticCategorizer.new(input_hash, current_user).tags,
status: 200 status: 200
end end

View File

@ -10,13 +10,24 @@ class CompletionPrompt < ActiveRecord::Base
validate :each_message_length validate :each_message_length
def messages_with_user_input(user_input) def messages_with_user_input(user_input)
if user_input[:custom_prompt].present?
case ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider
when "huggingface"
self.messages.each { |msg| msg.sub!("{{custom_prompt}}", user_input[:custom_prompt]) }
else
self.messages.each do |msg|
msg["content"].sub!("{{custom_prompt}}", user_input[:custom_prompt])
end
end
end
case ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider case ::DiscourseAi::AiHelper::LlmPrompt.new.enabled_provider
when "openai" when "openai"
self.messages << { role: "user", content: user_input } self.messages << { role: "user", content: user_input[:text] }
when "anthropic" when "anthropic"
self.messages << { "role" => "Input", "content" => "<input>#{user_input}</input>" } self.messages << { "role" => "Input", "content" => "<input>#{user_input[:text]}</input>" }
when "huggingface" when "huggingface"
self.messages.first.sub("{{user_input}}", user_input) self.messages.first.sub("{{user_input}}", user_input[:text])
end end
end end

View File

@ -16,33 +16,54 @@
{{else if (eq this.menuState this.CONTEXT_MENU_STATES.options)}} {{else if (eq this.menuState this.CONTEXT_MENU_STATES.options)}}
<ul class="ai-helper-context-menu__options"> <ul class="ai-helper-context-menu__options">
{{#each this.helperOptions as |option|}} {{#each this.helperOptions as |option|}}
<li data-name={{option.name}} data-value={{option.value}}> {{#if (eq option.name "custom_prompt")}}
<DButton <div
@class="btn-flat" class="ai-custom-prompt"
@translatedLabel={{option.name}} {{did-insert this.setupCustomPrompt}}
@action={{this.updateSelected}} >
@actionParam={{option.value}} <Input
/> @value={{this.customPromptValue}}
</li> placeholder="Enter a custom prompt..."
class="ai-custom-prompt__input"
@enter={{action (fn this.updateSelected option)}}
/>
{{#if this.customPromptValue.length}}
<DButton
@class="ai-custom-prompt__submit btn-primary"
@icon="discourse-sparkles"
@action={{this.updateSelected}}
@actionParam={{option}}
/>
{{/if}}
</div>
{{else}}
<li data-name={{option.translated_name}} data-value={{option.id}}>
<DButton
@icon={{option.icon}}
@class="btn-flat"
@translatedLabel={{option.translated_name}}
@action={{this.updateSelected}}
@actionParam={{option}}
/>
</li>
{{/if}}
{{/each}} {{/each}}
</ul> </ul>
{{else if (eq this.menuState this.CONTEXT_MENU_STATES.loading)}} {{else if (eq this.menuState this.CONTEXT_MENU_STATES.loading)}}
<ul class="ai-helper-context-menu__loading"> <div class="ai-helper-context-menu__loading">
<li> <div class="dot-falling"></div>
<div class="dot-falling"></div> <span>
<span> {{i18n "discourse_ai.ai_helper.context_menu.loading"}}
{{i18n "discourse_ai.ai_helper.context_menu.loading"}} </span>
</span> <DButton
<DButton @icon="times"
@icon="times" @title="discourse_ai.ai_helper.context_menu.cancel"
@title="discourse_ai.ai_helper.context_menu.cancel" @action={{this.cancelAIAction}}
@action={{this.cancelAIAction}} class="btn-flat cancel-request"
class="btn-flat cancel-request" />
/> </div>
</li>
</ul>
{{else if (eq this.menuState this.CONTEXT_MENU_STATES.review)}} {{else if (eq this.menuState this.CONTEXT_MENU_STATES.review)}}
<ul class="ai-helper-context-menu__review"> <ul class="ai-helper-context-menu__review">
@ -92,7 +113,6 @@
/> />
</li> </li>
</ul> </ul>
{{/if}} {{/if}}
</div> </div>
{{/if}} {{/if}}

View File

@ -15,10 +15,10 @@ export default class AiHelperContextMenu extends Component {
return showAIHelper(outletArgs, helper); return showAIHelper(outletArgs, helper);
} }
@service currentUser;
@service siteSettings; @service siteSettings;
@tracked helperOptions = []; @tracked helperOptions = [];
@tracked showContextMenu = false; @tracked showContextMenu = false;
@tracked menuState = this.CONTEXT_MENU_STATES.triggers;
@tracked caretCoords; @tracked caretCoords;
@tracked virtualElement; @tracked virtualElement;
@tracked selectedText = ""; @tracked selectedText = "";
@ -30,6 +30,8 @@ export default class AiHelperContextMenu extends Component {
@tracked showDiffModal = false; @tracked showDiffModal = false;
@tracked diff; @tracked diff;
@tracked popperPlacement = "top-start"; @tracked popperPlacement = "top-start";
@tracked previousMenuState = null;
@tracked customPromptValue = "";
CONTEXT_MENU_STATES = { CONTEXT_MENU_STATES = {
triggers: "TRIGGERS", triggers: "TRIGGERS",
@ -41,8 +43,10 @@ export default class AiHelperContextMenu extends Component {
prompts = []; prompts = [];
promptTypes = {}; promptTypes = {};
@tracked _menuState = this.CONTEXT_MENU_STATES.triggers;
@tracked _popper; @tracked _popper;
@tracked _dEditorInput; @tracked _dEditorInput;
@tracked _customPromptInput;
@tracked _contextMenu; @tracked _contextMenu;
@tracked _activeAIRequest = null; @tracked _activeAIRequest = null;
@ -62,27 +66,42 @@ export default class AiHelperContextMenu extends Component {
this._popper?.destroy(); this._popper?.destroy();
} }
get menuState() {
return this._menuState;
}
set menuState(newState) {
this.previousMenuState = this._menuState;
this._menuState = newState;
}
async loadPrompts() { async loadPrompts() {
let prompts = await ajax("/discourse-ai/ai-helper/prompts"); let prompts = await ajax("/discourse-ai/ai-helper/prompts");
prompts prompts = prompts.filter((p) => p.name !== "generate_titles");
.filter((p) => p.name !== "generate_titles")
.map((p) => { // Find the custom_prompt object and move it to the beginning of the array
this.prompts[p.id] = p; const customPromptIndex = prompts.findIndex(
}); (p) => p.name === "custom_prompt"
);
if (customPromptIndex !== -1) {
const customPrompt = prompts.splice(customPromptIndex, 1)[0];
prompts.unshift(customPrompt);
}
if (!this._showUserCustomPrompts()) {
prompts = prompts.filter((p) => p.name !== "custom_prompt");
}
prompts.forEach((p) => {
this.prompts[p.id] = p;
});
this.promptTypes = prompts.reduce((memo, p) => { this.promptTypes = prompts.reduce((memo, p) => {
memo[p.name] = p.prompt_type; memo[p.name] = p.prompt_type;
return memo; return memo;
}, {}); }, {});
this.helperOptions = prompts this.helperOptions = prompts;
.filter((p) => p.name !== "generate_titles")
.map((p) => {
return {
name: p.translated_name,
value: p.id,
};
});
} }
@bind @bind
@ -153,6 +172,10 @@ export default class AiHelperContextMenu extends Component {
} }
get canCloseContextMenu() { get canCloseContextMenu() {
if (document.activeElement === this._customPromptInput) {
return false;
}
if (this.loading && this._activeAIRequest !== null) { if (this.loading && this._activeAIRequest !== null) {
return false; return false;
} }
@ -168,9 +191,9 @@ export default class AiHelperContextMenu extends Component {
if (!this.canCloseContextMenu) { if (!this.canCloseContextMenu) {
return; return;
} }
this.showContextMenu = false; this.showContextMenu = false;
this.menuState = this.CONTEXT_MENU_STATES.triggers; this.menuState = this.CONTEXT_MENU_STATES.triggers;
this.customPromptValue = "";
} }
_updateSuggestedByAI(data) { _updateSuggestedByAI(data) {
@ -200,6 +223,15 @@ export default class AiHelperContextMenu extends Component {
return (this.loading = false); return (this.loading = false);
} }
_showUserCustomPrompts() {
const allowedGroups =
this.siteSettings?.ai_helper_custom_prompts_allowed_groups
.split("|")
.map((id) => parseInt(id, 10));
return this.currentUser?.groups.some((g) => allowedGroups.includes(g.id));
}
handleBoundaries() { handleBoundaries() {
const textAreaWrapper = document const textAreaWrapper = document
.querySelector(".d-editor-textarea-wrapper") .querySelector(".d-editor-textarea-wrapper")
@ -276,6 +308,14 @@ export default class AiHelperContextMenu extends Component {
} }
} }
@action
setupCustomPrompt() {
this._customPromptInput = document.querySelector(
".ai-custom-prompt__input"
);
this._customPromptInput.focus();
}
@action @action
toggleAiHelperOptions() { toggleAiHelperOptions() {
// Fetch prompts only if it hasn't been fetched yet // Fetch prompts only if it hasn't been fetched yet
@ -303,7 +343,11 @@ export default class AiHelperContextMenu extends Component {
this._activeAIRequest = ajax("/discourse-ai/ai-helper/suggest", { this._activeAIRequest = ajax("/discourse-ai/ai-helper/suggest", {
method: "POST", method: "POST",
data: { mode: option, text: this.selectedText }, data: {
mode: option.id,
text: this.selectedText,
custom_prompt: this.customPromptValue,
},
}); });
this._activeAIRequest this._activeAIRequest
@ -340,4 +384,9 @@ export default class AiHelperContextMenu extends Component {
this.closeContextMenu(); this.closeContextMenu();
} }
} }
@action
togglePreviousMenu() {
this.menuState = this.previousMenuState;
}
} }

View File

@ -50,50 +50,51 @@
list-style: none; list-style: none;
} }
.btn { li {
justify-content: left; .btn-flat {
text-align: left; justify-content: left;
background: none; text-align: left;
width: 100%; background: none;
border-radius: 0; width: 100%;
margin: 0; border-radius: 0;
margin: 0;
padding-block: 0.6rem;
&:focus, &:focus,
&:hover { &:hover {
color: var(--primary); color: var(--primary);
background: var(--d-hover); background: var(--d-hover);
.d-icon { .d-icon {
color: var(--primary-medium); color: var(--primary-medium);
}
}
.d-button-label {
color: var(--primary-very-high);
} }
} }
} }
.d-button-label {
color: var(--primary-very-high);
}
&__options { &__options {
padding: 0.25rem; padding: 0.25rem;
li:not(:last-child) {
border-bottom: 1px solid var(--primary-low);
}
} }
&__loading { &__loading {
display: flex;
padding: 0.5rem;
gap: 1rem;
justify-content: flex-start;
align-items: center;
.dot-falling { .dot-falling {
margin-inline: 1rem; margin-inline: 1rem;
margin-left: 1.5rem; margin-left: 1.5rem;
} }
li {
display: flex;
padding: 0.5rem;
gap: 1rem;
justify-content: flex-start;
align-items: center;
}
.btn {
width: unset;
}
} }
&__resets { &__resets {
@ -107,6 +108,41 @@
align-items: center; align-items: center;
flex-flow: row wrap; flex-flow: row wrap;
} }
&__custom-prompt {
display: flex;
flex-flow: row wrap;
padding: 0.5rem;
&-header {
margin-bottom: 0.5rem;
.btn {
padding: 0;
}
}
.ai-custom-prompt-input {
min-height: 90px;
width: 100%;
}
}
.ai-custom-prompt {
display: flex;
gap: 0.25rem;
margin-bottom: 0.5rem;
&__input {
background: var(--primary-low);
border-color: var(--primary-low);
margin-bottom: 0;
&::placeholder {
color: var(--primary-medium);
}
}
}
} }
.d-editor-input.loading { .d-editor-input.loading {

View File

@ -19,6 +19,7 @@ en:
suggest: "Suggest with AI" suggest: "Suggest with AI"
missing_content: "Please enter some content to generate suggestions." missing_content: "Please enter some content to generate suggestions."
context_menu: context_menu:
back: "Back"
trigger: "AI" trigger: "AI"
undo: "Undo" undo: "Undo"
loading: "AI is generating" loading: "AI is generating"
@ -28,6 +29,10 @@ en:
confirm: "Confirm" confirm: "Confirm"
revert: "Revert" revert: "Revert"
changes: "Changes" changes: "Changes"
custom_prompt:
title: "Custom Prompt"
placeholder: "Enter a custom prompt..."
submit: "Send Prompt"
reviewables: reviewables:
model_used: "Model used:" model_used: "Model used:"
accuracy: "Accuracy:" accuracy: "Accuracy:"

View File

@ -21,7 +21,7 @@ en:
ai_sentiment_models: "Models to use for inference. Sentiment classifies post on the positive/neutral/negative space. Emotion classifies on the anger/disgust/fear/joy/neutral/sadness/surprise space." ai_sentiment_models: "Models to use for inference. Sentiment classifies post on the positive/neutral/negative space. Emotion classifies on the anger/disgust/fear/joy/neutral/sadness/surprise space."
ai_nsfw_detection_enabled: "Enable the NSFW module." ai_nsfw_detection_enabled: "Enable the NSFW module."
ai_nsfw_inference_service_api_endpoint: "URL where the API is running for the NSFW module" ai_nsfw_inference_service_api_endpoint: "URL where the API is running for the NSFW module"
ai_nsfw_inference_service_api_key: "API key for the NSFW API" ai_nsfw_inference_service_api_key: "API key for the NSFW API"
ai_nsfw_flag_automatically: "Automatically flag NSFW posts that are above the configured thresholds." ai_nsfw_flag_automatically: "Automatically flag NSFW posts that are above the configured thresholds."
ai_nsfw_flag_threshold_general: "General Threshold for an image to be considered NSFW." ai_nsfw_flag_threshold_general: "General Threshold for an image to be considered NSFW."
@ -44,6 +44,7 @@ en:
ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer." ai_helper_allowed_groups: "Users on these groups will see the AI helper button in the composer."
ai_helper_allowed_in_pm: "Enable the composer's AI helper in PMs." ai_helper_allowed_in_pm: "Enable the composer's AI helper in PMs."
ai_helper_model: "Model to use for the AI helper." ai_helper_model: "Model to use for the AI helper."
ai_helper_custom_prompts_allowed_groups: "Users on these groups will see the custom prompt option in the AI helper."
ai_embeddings_enabled: "Enable the embeddings module." ai_embeddings_enabled: "Enable the embeddings module."
ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module" ai_embeddings_discourse_service_api_endpoint: "URL where the API is running for the embeddings module"
@ -76,9 +77,9 @@ en:
ai_google_custom_search_cx: "CX for Google Custom Search API" ai_google_custom_search_cx: "CX for Google Custom Search API"
reviewables: reviewables:
reasons: reasons:
flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic. flagged_by_toxicity: The AI plugin flagged this after classifying it as toxic.
flagged_by_nsfw: The AI plugin flagged this after classifying at least one of the attached images as NSFW. flagged_by_nsfw: The AI plugin flagged this after classifying at least one of the attached images as NSFW.
errors: errors:
prompt_message_length: The message %{idx} is over the 1000 character limit. prompt_message_length: The message %{idx} is over the 1000 character limit.
@ -93,6 +94,7 @@ en:
generate_titles: Suggest topic titles generate_titles: Suggest topic titles
proofread: Proofread text proofread: Proofread text
markdown_table: Generate Markdown table markdown_table: Generate Markdown table
custom_prompt: "Custom Prompt"
ai_bot: ai_bot:
personas: personas:

View File

@ -9,7 +9,7 @@ discourse_ai:
ai_toxicity_inference_service_api_endpoint: ai_toxicity_inference_service_api_endpoint:
default: "https://disorder-testing.demo-by-discourse.com" default: "https://disorder-testing.demo-by-discourse.com"
ai_toxicity_inference_service_api_key: ai_toxicity_inference_service_api_key:
default: '' default: ""
secret: true secret: true
ai_toxicity_inference_service_api_model: ai_toxicity_inference_service_api_model:
type: enum type: enum
@ -56,7 +56,7 @@ discourse_ai:
ai_sentiment_inference_service_api_endpoint: ai_sentiment_inference_service_api_endpoint:
default: "https://sentiment-testing.demo-by-discourse.com" default: "https://sentiment-testing.demo-by-discourse.com"
ai_sentiment_inference_service_api_key: ai_sentiment_inference_service_api_key:
default: '' default: ""
secret: true secret: true
ai_sentiment_models: ai_sentiment_models:
type: list type: list
@ -64,8 +64,8 @@ discourse_ai:
default: "emotion" default: "emotion"
allow_any: false allow_any: false
choices: choices:
- sentiment - sentiment
- emotion - emotion
ai_nsfw_detection_enabled: false ai_nsfw_detection_enabled: false
ai_nsfw_inference_service_api_endpoint: ai_nsfw_inference_service_api_endpoint:
@ -85,8 +85,8 @@ discourse_ai:
default: "opennsfw2" default: "opennsfw2"
allow_any: false allow_any: false
choices: choices:
- opennsfw2 - opennsfw2
- nsfw_detector - nsfw_detector
ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions" ai_openai_gpt35_url: "https://api.openai.com/v1/chat/completions"
ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions" ai_openai_gpt35_16k_url: "https://api.openai.com/v1/chat/completions"
@ -148,6 +148,13 @@ discourse_ai:
- gpt-4 - gpt-4
- claude-2 - claude-2
- stable-beluga-2 - stable-beluga-2
ai_helper_custom_prompts_allowed_groups:
client: true
type: group_list
list_type: compact
default: "3" # 3: @staff
allow_any: false
refresh: true
ai_embeddings_enabled: ai_embeddings_enabled:
default: false default: false
@ -162,9 +169,9 @@ discourse_ai:
default: "all-mpnet-base-v2" default: "all-mpnet-base-v2"
allow_any: false allow_any: false
choices: choices:
- all-mpnet-base-v2 - all-mpnet-base-v2
- text-embedding-ada-002 - text-embedding-ada-002
- multilingual-e5-large - multilingual-e5-large
ai_embeddings_generate_for_pms: false ai_embeddings_generate_for_pms: false
ai_embeddings_semantic_related_topics_enabled: ai_embeddings_semantic_related_topics_enabled:
default: false default: false
@ -183,11 +190,11 @@ discourse_ai:
allow_any: false allow_any: false
choices: choices:
- Llama2-*-chat-hf - Llama2-*-chat-hf
- claude-instant-1 - claude-instant-1
- claude-2 - claude-2
- gpt-3.5-turbo - gpt-3.5-turbo
- gpt-4 - gpt-4
- StableBeluga2 - StableBeluga2
- Upstage-Llama-2-*-instruct-v2 - Upstage-Llama-2-*-instruct-v2
ai_summarization_discourse_service_api_endpoint: "" ai_summarization_discourse_service_api_endpoint: ""
@ -211,22 +218,22 @@ discourse_ai:
default: "gpt-3.5-turbo" default: "gpt-3.5-turbo"
client: true client: true
choices: choices:
- gpt-3.5-turbo - gpt-3.5-turbo
- gpt-4 - gpt-4
- claude-2 - claude-2
ai_bot_enabled_chat_commands: ai_bot_enabled_chat_commands:
type: list type: list
default: "categories|google|image|search|tags|time|read" default: "categories|google|image|search|tags|time|read"
client: true client: true
choices: choices:
- categories - categories
- google - google
- image - image
- search - search
- summarize - summarize
- read - read
- tags - tags
- time - time
ai_bot_enabled_personas: ai_bot_enabled_personas:
type: list type: list
default: "general|artist|sql_helper|settings_explorer|researcher" default: "general|artist|sql_helper|settings_explorer|researcher"

View File

@ -123,3 +123,14 @@ CompletionPrompt.seed do |cp|
TEXT TEXT
] ]
end end
CompletionPrompt.seed do |cp|
cp.id = -5
cp.provider = "openai"
cp.name = "custom_prompt"
cp.prompt_type = CompletionPrompt.prompt_types[:list]
cp.messages = [{ role: "system", content: <<~TEXT }]
You are a helpful assistant, I will provide you with a text below,
you will {{custom_prompt}} and you will reply with the result.
TEXT
end

View File

@ -54,3 +54,14 @@ CompletionPrompt.seed do |cp|
please reply with the corrected text between <ai></ai> tags. please reply with the corrected text between <ai></ai> tags.
TEXT TEXT
end end
CompletionPrompt.seed do |cp|
cp.id = -105
cp.provider = "anthropic"
cp.name = "custom_prompt"
cp.prompt_type = CompletionPrompt.prompt_types[:diff]
cp.messages = [{ role: "Human", content: <<~TEXT }]
You are a helpful assistant, I will provide you with a text inside <input> tags,
you will {{custom_prompt}} and you will reply with the result between <ai></ai> tags.
TEXT
end

View File

@ -109,3 +109,20 @@ CompletionPrompt.seed do |cp|
### Assistant: ### Assistant:
TEXT TEXT
end end
CompletionPrompt.seed do |cp|
cp.id = -205
cp.provider = "huggingface"
cp.name = "custom_prompt"
cp.prompt_type = CompletionPrompt.prompt_types[:diff]
cp.messages = [<<~TEXT]
### System:
You are a helpful assistant, I will provide you with a text below,
you will {{custom_prompt}} and you will reply with the result.
### User:
{{user_input}}
### Assistant:
TEXT
end

View File

@ -12,7 +12,9 @@ module DiscourseAi
plugin.register_seedfu_fixtures( plugin.register_seedfu_fixtures(
Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_helper"), Rails.root.join("plugins", "discourse-ai", "db", "fixtures", "ai_helper"),
) )
plugin.register_svg_icon("discourse-sparkles")
additional_icons = %w[discourse-sparkles spell-check language]
additional_icons.each { |icon| plugin.register_svg_icon(icon) }
end end
end end
end end

View File

@ -19,18 +19,19 @@ module DiscourseAi
name: prompt.name, name: prompt.name,
translated_name: translation, translated_name: translation,
prompt_type: prompt.prompt_type, prompt_type: prompt.prompt_type,
icon: icon_map(prompt.name),
} }
end end
end end
def generate_and_send_prompt(prompt, text) def generate_and_send_prompt(prompt, params)
case enabled_provider case enabled_provider
when "openai" when "openai"
openai_call(prompt, text) openai_call(prompt, params)
when "anthropic" when "anthropic"
anthropic_call(prompt, text) anthropic_call(prompt, params)
when "huggingface" when "huggingface"
huggingface_call(prompt, text) huggingface_call(prompt, params)
end end
end end
@ -47,6 +48,27 @@ module DiscourseAi
private private
def icon_map(name)
case name
when "translate"
"language"
when "generate_titles"
"heading"
when "proofread"
"spell-check"
when "markdown_table"
"table"
when "tone"
"microphone"
when "custom_prompt"
"comment"
when "rewrite"
"pen"
else
nil
end
end
def generate_diff(text, suggestion) def generate_diff(text, suggestion)
cooked_text = PrettyText.cook(text) cooked_text = PrettyText.cook(text)
cooked_suggestion = PrettyText.cook(suggestion) cooked_suggestion = PrettyText.cook(suggestion)
@ -71,10 +93,10 @@ module DiscourseAi
end end
end end
def openai_call(prompt, text) def openai_call(prompt, params)
result = { type: prompt.prompt_type } result = { type: prompt.prompt_type }
messages = prompt.messages_with_user_input(text) messages = prompt.messages_with_user_input(params)
result[:suggestions] = DiscourseAi::Inference::OpenAiCompletions result[:suggestions] = DiscourseAi::Inference::OpenAiCompletions
.perform!(messages, SiteSetting.ai_helper_model) .perform!(messages, SiteSetting.ai_helper_model)
@ -83,15 +105,15 @@ module DiscourseAi
.flat_map { |choice| parse_content(prompt, choice.dig(:message, :content).to_s) } .flat_map { |choice| parse_content(prompt, choice.dig(:message, :content).to_s) }
.compact_blank .compact_blank
result[:diff] = generate_diff(text, result[:suggestions].first) if prompt.diff? result[:diff] = generate_diff(params[:text], result[:suggestions].first) if prompt.diff?
result result
end end
def anthropic_call(prompt, text) def anthropic_call(prompt, params)
result = { type: prompt.prompt_type } result = { type: prompt.prompt_type }
filled_message = prompt.messages_with_user_input(text) filled_message = prompt.messages_with_user_input(params)
message = message =
filled_message.map { |msg| "#{msg["role"]}: #{msg["content"]}" }.join("\n\n") + filled_message.map { |msg| "#{msg["role"]}: #{msg["content"]}" }.join("\n\n") +
@ -101,15 +123,15 @@ module DiscourseAi
result[:suggestions] = parse_content(prompt, response.dig(:completion)) result[:suggestions] = parse_content(prompt, response.dig(:completion))
result[:diff] = generate_diff(text, result[:suggestions].first) if prompt.diff? result[:diff] = generate_diff(params[:text], result[:suggestions].first) if prompt.diff?
result result
end end
def huggingface_call(prompt, text) def huggingface_call(prompt, params)
result = { type: prompt.prompt_type } result = { type: prompt.prompt_type }
message = prompt.messages_with_user_input(text) message = prompt.messages_with_user_input(params)
response = response =
DiscourseAi::Inference::HuggingFaceTextGeneration.perform!( DiscourseAi::Inference::HuggingFaceTextGeneration.perform!(
@ -119,7 +141,7 @@ module DiscourseAi
result[:suggestions] = parse_content(prompt, response.dig(:generated_text)) result[:suggestions] = parse_content(prompt, response.dig(:generated_text))
result[:diff] = generate_diff(text, result[:suggestions].first) if prompt.diff? result[:diff] = generate_diff(params[:text], result[:suggestions].first) if prompt.diff?
result result
end end

View File

@ -36,7 +36,10 @@ module DiscourseAi
return "" if prompt_for_provider.nil? return "" if prompt_for_provider.nil?
llm_prompt.generate_and_send_prompt(prompt_for_provider, text).dig(:suggestions).first llm_prompt
.generate_and_send_prompt(prompt_for_provider, { text: text })
.dig(:suggestions)
.first
end end
def completion_prompts def completion_prompts

View File

@ -3,7 +3,7 @@
require_relative "../../../support/openai_completions_inference_stubs" require_relative "../../../support/openai_completions_inference_stubs"
RSpec.describe DiscourseAi::AiHelper::LlmPrompt do RSpec.describe DiscourseAi::AiHelper::LlmPrompt do
let(:prompt) { CompletionPrompt.find_by(name: mode) } let(:prompt) { CompletionPrompt.find_by(name: mode, provider: "openai") }
describe "#generate_and_send_prompt" do describe "#generate_and_send_prompt" do
context "when using the translate mode" do context "when using the translate mode" do
@ -13,7 +13,10 @@ RSpec.describe DiscourseAi::AiHelper::LlmPrompt do
it "Sends the prompt to chatGPT and returns the response" do it "Sends the prompt to chatGPT and returns the response" do
response = response =
subject.generate_and_send_prompt(prompt, OpenAiCompletionsInferenceStubs.spanish_text) subject.generate_and_send_prompt(
prompt,
{ text: OpenAiCompletionsInferenceStubs.spanish_text },
)
expect(response[:suggestions]).to contain_exactly( expect(response[:suggestions]).to contain_exactly(
OpenAiCompletionsInferenceStubs.translated_response.strip, OpenAiCompletionsInferenceStubs.translated_response.strip,
@ -30,7 +33,7 @@ RSpec.describe DiscourseAi::AiHelper::LlmPrompt do
response = response =
subject.generate_and_send_prompt( subject.generate_and_send_prompt(
prompt, prompt,
OpenAiCompletionsInferenceStubs.translated_response, { text: OpenAiCompletionsInferenceStubs.translated_response },
) )
expect(response[:suggestions]).to contain_exactly( expect(response[:suggestions]).to contain_exactly(
@ -56,7 +59,7 @@ RSpec.describe DiscourseAi::AiHelper::LlmPrompt do
response = response =
subject.generate_and_send_prompt( subject.generate_and_send_prompt(
prompt, prompt,
OpenAiCompletionsInferenceStubs.translated_response, { text: OpenAiCompletionsInferenceStubs.translated_response },
) )
expect(response[:suggestions]).to contain_exactly(*expected) expect(response[:suggestions]).to contain_exactly(*expected)

View File

@ -4,6 +4,7 @@ class OpenAiCompletionsInferenceStubs
TRANSLATE = "translate" TRANSLATE = "translate"
PROOFREAD = "proofread" PROOFREAD = "proofread"
GENERATE_TITLES = "generate_titles" GENERATE_TITLES = "generate_titles"
CUSTOM_PROMPT = "custom_prompt"
class << self class << self
def text_mode_to_id(mode) def text_mode_to_id(mode)
@ -14,6 +15,8 @@ class OpenAiCompletionsInferenceStubs
-3 -3
when GENERATE_TITLES when GENERATE_TITLES
-2 -2
when CUSTOM_PROMPT
-5
end end
end end
@ -30,6 +33,16 @@ class OpenAiCompletionsInferenceStubs
STRING STRING
end end
def custom_prompt_input
"Translate to French"
end
def custom_prompt_response
<<~STRING
Le destin favorise les répétitions, les variantes, les symétries ;
STRING
end
def translated_response def translated_response
<<~STRING <<~STRING
"To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends, "To perfect his horror, Caesar, surrounded at the base of the statue by the impatient daggers of his friends,
@ -90,13 +103,24 @@ class OpenAiCompletionsInferenceStubs
proofread_response proofread_response
when GENERATE_TITLES when GENERATE_TITLES
generated_titles generated_titles
when CUSTOM_PROMPT
custom_prompt_response
end end
end end
def stub_prompt(type) def stub_prompt(type)
text = type == TRANSLATE ? spanish_text : translated_response user_input = type == TRANSLATE ? spanish_text : translated_response
id = text_mode_to_id(type)
prompt_messages = CompletionPrompt.find_by(name: type).messages_with_user_input(text) if type == CUSTOM_PROMPT
user_input = { mode: id, text: translated_response, custom_prompt: "Translate to French" }
elsif type == TRANSLATE
user_input = { mode: id, text: spanish_text, custom_prompt: "" }
else
user_input = { mode: id, text: translated_response, custom_prompt: "" }
end
prompt_messages = CompletionPrompt.find_by(id: id).messages_with_user_input(user_input)
stub_response(prompt_messages, response_text_for(type)) stub_response(prompt_messages, response_text_for(type))
end end

View File

@ -54,6 +54,56 @@ RSpec.describe "AI Composer helper", type: :system, js: true do
expect(ai_helper_context_menu).to have_no_context_menu expect(ai_helper_context_menu).to have_no_context_menu
end end
context "when using custom prompt" do
let(:mode) { OpenAiCompletionsInferenceStubs::CUSTOM_PROMPT }
before { OpenAiCompletionsInferenceStubs.stub_prompt(mode) }
it "shows custom prompt option" do
trigger_context_menu(OpenAiCompletionsInferenceStubs.translated_response)
ai_helper_context_menu.click_ai_button
expect(ai_helper_context_menu).to have_custom_prompt
end
it "shows the custom prompt button when input is filled" do
trigger_context_menu(OpenAiCompletionsInferenceStubs.translated_response)
ai_helper_context_menu.click_ai_button
expect(ai_helper_context_menu).to have_no_custom_prompt_button
ai_helper_context_menu.fill_custom_prompt(
OpenAiCompletionsInferenceStubs.custom_prompt_input,
)
expect(ai_helper_context_menu).to have_custom_prompt_button
end
it "replaces the composed message with AI generated content" do
trigger_context_menu(OpenAiCompletionsInferenceStubs.translated_response)
ai_helper_context_menu.click_ai_button
ai_helper_context_menu.fill_custom_prompt(
OpenAiCompletionsInferenceStubs.custom_prompt_input,
)
ai_helper_context_menu.click_custom_prompt_button
wait_for do
composer.composer_input.value ==
OpenAiCompletionsInferenceStubs.custom_prompt_response.strip
end
expect(composer.composer_input.value).to eq(
OpenAiCompletionsInferenceStubs.custom_prompt_response.strip,
)
end
end
context "when not a member of custom prompt group" do
let(:mode) { OpenAiCompletionsInferenceStubs::CUSTOM_PROMPT }
before { SiteSetting.ai_helper_custom_prompts_allowed_groups = non_member_group.id.to_s }
it "does not show custom prompt option" do
trigger_context_menu(OpenAiCompletionsInferenceStubs.translated_response)
ai_helper_context_menu.click_ai_button
expect(ai_helper_context_menu).to have_no_custom_prompt
end
end
context "when using translation mode" do context "when using translation mode" do
let(:mode) { OpenAiCompletionsInferenceStubs::TRANSLATE } let(:mode) { OpenAiCompletionsInferenceStubs::TRANSLATE }
before { OpenAiCompletionsInferenceStubs.stub_prompt(mode) } before { OpenAiCompletionsInferenceStubs.stub_prompt(mode) }

View File

@ -10,6 +10,9 @@ module PageObjects
LOADING_STATE_SELECTOR = "#{CONTEXT_MENU_SELECTOR}__loading" LOADING_STATE_SELECTOR = "#{CONTEXT_MENU_SELECTOR}__loading"
RESETS_STATE_SELECTOR = "#{CONTEXT_MENU_SELECTOR}__resets" RESETS_STATE_SELECTOR = "#{CONTEXT_MENU_SELECTOR}__resets"
REVIEW_STATE_SELECTOR = "#{CONTEXT_MENU_SELECTOR}__review" REVIEW_STATE_SELECTOR = "#{CONTEXT_MENU_SELECTOR}__review"
CUSTOM_PROMPT_SELECTOR = "#{CONTEXT_MENU_SELECTOR} .ai-custom-prompt"
CUSTOM_PROMPT_INPUT_SELECTOR = "#{CUSTOM_PROMPT_SELECTOR}__input"
CUSTOM_PROMPT_BUTTON_SELECTOR = "#{CUSTOM_PROMPT_SELECTOR}__submit"
def click_ai_button def click_ai_button
find("#{TRIGGER_STATE_SELECTOR} .btn").click find("#{TRIGGER_STATE_SELECTOR} .btn").click
@ -43,6 +46,15 @@ module PageObjects
find("body").send_keys(:escape) find("body").send_keys(:escape)
end end
def click_custom_prompt_button
find(CUSTOM_PROMPT_BUTTON_SELECTOR).click
end
def fill_custom_prompt(content)
find(CUSTOM_PROMPT_INPUT_SELECTOR).fill_in(with: content)
self
end
def has_context_menu? def has_context_menu?
page.has_css?(CONTEXT_MENU_SELECTOR) page.has_css?(CONTEXT_MENU_SELECTOR)
end end
@ -70,6 +82,22 @@ module PageObjects
def not_showing_resets? def not_showing_resets?
page.has_no_css?(RESETS_STATE_SELECTOR) page.has_no_css?(RESETS_STATE_SELECTOR)
end end
def has_custom_prompt?
page.has_css?(CUSTOM_PROMPT_SELECTOR)
end
def has_no_custom_prompt?
page.has_no_css?(CUSTOM_PROMPT_SELECTOR)
end
def has_custom_prompt_button?
page.has_css?(CUSTOM_PROMPT_BUTTON_SELECTOR)
end
def has_no_custom_prompt_button?
page.has_no_css?(CUSTOM_PROMPT_BUTTON_SELECTOR)
end
end end
end end
end end