diff --git a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb index c32ea5d1..eea254ec 100644 --- a/app/controllers/discourse_ai/ai_helper/assistant_controller.rb +++ b/app/controllers/discourse_ai/ai_helper/assistant_controller.rb @@ -23,8 +23,11 @@ module DiscourseAi prompt = CompletionPrompt.find_by(id: params[:mode]) raise Discourse::InvalidParameters.new(:mode) if !prompt || !prompt.enabled? - if prompt.prompt_type == "custom_prompt" && params[:custom_prompt].blank? - raise Discourse::InvalidParameters.new(:custom_prompt) + + if prompt.id == CompletionPrompt::CUSTOM_PROMPT + raise Discourse::InvalidParameters.new(:custom_prompt) if params[:custom_prompt].blank? + + prompt.custom_instruction = params[:custom_prompt] end hijack do diff --git a/app/models/completion_prompt.rb b/app/models/completion_prompt.rb index 9167ffef..c4f4bf61 100644 --- a/app/models/completion_prompt.rb +++ b/app/models/completion_prompt.rb @@ -20,10 +20,23 @@ class CompletionPrompt < ActiveRecord::Base where(enabled: true).find_by(name: name) end + attr_accessor :custom_instruction + def messages_with_input(input) return unless input - messages_hash.merge(input: "#{input}") + user_input = + if id == CUSTOM_PROMPT && custom_instruction.present? + "#{custom_instruction}:\n#{input}" + else + input + end + + messages_hash.merge(input: <<~TEXT) + + #{user_input} + + TEXT end private diff --git a/lib/completions/endpoints/canned_response.rb b/lib/completions/endpoints/canned_response.rb index 2bdf7226..83300fcc 100644 --- a/lib/completions/endpoints/canned_response.rb +++ b/lib/completions/endpoints/canned_response.rb @@ -13,11 +13,13 @@ module DiscourseAi def initialize(responses) @responses = responses @completions = 0 + @prompt = nil end - attr_reader :responses, :completions + attr_reader :responses, :completions, :prompt - def perform_completion!(_prompt, _user, _model_params) + def perform_completion!(prompt, _user, _model_params) + @prompt = prompt response = responses[completions] if response.nil? raise CANNED_RESPONSE_ERROR, diff --git a/spec/models/completion_prompt_spec.rb b/spec/models/completion_prompt_spec.rb index 1eb356df..6f159f9e 100644 --- a/spec/models/completion_prompt_spec.rb +++ b/spec/models/completion_prompt_spec.rb @@ -18,4 +18,44 @@ RSpec.describe CompletionPrompt do end end end + + describe "messages_with_input" do + let(:user_input) { "A user wrote this." } + + context "when the record has the custom_prompt type" do + let(:custom_prompt) { described_class.find(described_class::CUSTOM_PROMPT) } + + it "wraps the user input with XML tags and adds a custom instruction if given" do + expected = <<~TEXT + + Translate to Turkish: + #{user_input} + + TEXT + + custom_prompt.custom_instruction = "Translate to Turkish" + + prompt = custom_prompt.messages_with_input(user_input) + + expect(prompt[:input]).to eq(expected) + end + end + + context "when the records don't have the custom_prompt type" do + let(:title_prompt) { described_class.find(described_class::GENERATE_TITLES) } + + it "wraps user input with XML tags" do + expected = <<~TEXT + + #{user_input} + + TEXT + title_prompt.custom_instruction = "Translate to Turkish" + + prompt = title_prompt.messages_with_input(user_input) + + expect(prompt[:input]).to eq(expected) + end + end + end end diff --git a/spec/requests/ai_helper/assistant_controller_spec.rb b/spec/requests/ai_helper/assistant_controller_spec.rb index 0f46a202..42ea139e 100644 --- a/spec/requests/ai_helper/assistant_controller_spec.rb +++ b/spec/requests/ai_helper/assistant_controller_spec.rb @@ -79,6 +79,33 @@ RSpec.describe DiscourseAi::AiHelper::AssistantController do expect(response.parsed_body["diff"]).to eq(expected_diff) end end + + it "uses custom instruction when using custom_prompt mode" do + translated_text = "Un usuario escribio esto" + expected_diff = + "

Un usuario escribio estoA user wrote this

" + + expected_input = <<~TEXT + + Translate to Spanish: + A user wrote this + + TEXT + + DiscourseAi::Completions::Llm.with_prepared_responses([translated_text]) do |spy| + post "/discourse-ai/ai-helper/suggest", + params: { + mode: CompletionPrompt::CUSTOM_PROMPT, + text: "A user wrote this", + custom_prompt: "Translate to Spanish", + } + + expect(response.status).to eq(200) + expect(response.parsed_body["suggestions"].first).to eq(translated_text) + expect(response.parsed_body["diff"]).to eq(expected_diff) + expect(spy.prompt.last[:content]).to eq(expected_input) + end + end end end end