FIX: Custom instructions where missing when generating custom prompt input (#348)

This commit is contained in:
Roman Rizzi 2023-12-11 19:26:56 -03:00 committed by GitHub
parent a89549919d
commit 2798e4c86d
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
5 changed files with 90 additions and 5 deletions

View File

@ -23,8 +23,11 @@ module DiscourseAi
prompt = CompletionPrompt.find_by(id: params[:mode]) prompt = CompletionPrompt.find_by(id: params[:mode])
raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled? raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled?
if prompt.prompt_type == "custom_prompt" && params[:custom_prompt].blank?
raise Discourse::InvalidParameters.new(:custom_prompt) if prompt.id == CompletionPrompt::CUSTOM_PROMPT
raise Discourse::InvalidParameters.new(:custom_prompt) if params[:custom_prompt].blank?
prompt.custom_instruction = params[:custom_prompt]
end end
hijack do hijack do

View File

@ -20,10 +20,23 @@ class CompletionPrompt < ActiveRecord::Base
where(enabled: true).find_by(name: name) where(enabled: true).find_by(name: name)
end end
attr_accessor :custom_instruction
def messages_with_input(input) def messages_with_input(input)
return unless input return unless input
messages_hash.merge(input: "<input>#{input}</input>") user_input =
if id == CUSTOM_PROMPT && custom_instruction.present?
"#{custom_instruction}:\n#{input}"
else
input
end
messages_hash.merge(input: <<~TEXT)
<input>
#{user_input}
</input>
TEXT
end end
private private

View File

@ -13,11 +13,13 @@ module DiscourseAi
def initialize(responses) def initialize(responses)
@responses = responses @responses = responses
@completions = 0 @completions = 0
@prompt = nil
end end
attr_reader :responses, :completions attr_reader :responses, :completions, :prompt
def perform_completion!(_prompt, _user, _model_params) def perform_completion!(prompt, _user, _model_params)
@prompt = prompt
response = responses[completions] response = responses[completions]
if response.nil? if response.nil?
raise CANNED_RESPONSE_ERROR, raise CANNED_RESPONSE_ERROR,

View File

@ -18,4 +18,44 @@ RSpec.describe CompletionPrompt do
end end
end end
end end
describe "messages_with_input" do
let(:user_input) { "A user wrote this." }
context "when the record has the custom_prompt type" do
let(:custom_prompt) { described_class.find(described_class::CUSTOM_PROMPT) }
it "wraps the user input with <input> XML tags and adds a custom instruction if given" do
expected = <<~TEXT
<input>
Translate to Turkish:
#{user_input}
</input>
TEXT
custom_prompt.custom_instruction = "Translate to Turkish"
prompt = custom_prompt.messages_with_input(user_input)
expect(prompt[:input]).to eq(expected)
end
end
context "when the records don't have the custom_prompt type" do
let(:title_prompt) { described_class.find(described_class::GENERATE_TITLES) }
it "wraps user input with <input> XML tags" do
expected = <<~TEXT
<input>
#{user_input}
</input>
TEXT
title_prompt.custom_instruction = "Translate to Turkish"
prompt = title_prompt.messages_with_input(user_input)
expect(prompt[:input]).to eq(expected)
end
end
end
end end

View File

@ -79,6 +79,33 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do
expect(response.parsed_body["diff"]).to eq(expected_diff) expect(response.parsed_body["diff"]).to eq(expected_diff)
end end
end end
it "uses custom instruction when using custom_prompt mode" do
translated_text = "Un usuario escribio esto"
expected_diff =
"<div class=\"inline-diff\"><p><ins>Un </ins><ins>usuario </ins><ins>escribio </ins><ins>esto</ins><del>A </del><del>user </del><del>wrote </del><del>this</del></p></div>"
expected_input = <<~TEXT
<input>
Translate to Spanish:
A user wrote this
</input>
TEXT
DiscourseAi::Completions::Llm.with_prepared_responses([translated_text]) do |spy|
post "/discourse-ai/ai-helper/suggest",
params: {
mode: CompletionPrompt::CUSTOM_PROMPT,
text: "A user wrote this",
custom_prompt: "Translate to Spanish",
}
expect(response.status).to eq(200)
expect(response.parsed_body["suggestions"].first).to eq(translated_text)
expect(response.parsed_body["diff"]).to eq(expected_diff)
expect(spy.prompt.last[:content]).to eq(expected_input)
end
end
end end
end end
end end