FIX: properly truncate !command prompts (#227)
* FIX: properly truncate !command prompts ### What is going on here? Previous to this change where a command was issued by the LLM it could hallucinate a continuation eg: ``` This is what tags are !tags some nonsense here ``` This change introduces safeguards so `some nonsense here` does not creep in to the prompt history, poisoning the llm results This in effect grounds the llm a lot better and results in the llm forgetting less about results. The change only impacts Claude at the moment, but will also improve stuff for llama 2 in future. Also, this makes it significantly easier to test the bot framework without an llm cause we avoid a whole bunch of complex stubbing * blank is not a valid bot response, do not inject into prompt
This commit is contained in:
parent
f57c1bb0f6
commit
316ea9624e
|
@ -47,6 +47,24 @@ module DiscourseAi
|
|||
def to_a
|
||||
@functions
|
||||
end
|
||||
|
||||
def truncate(partial_reply)
|
||||
lines = []
|
||||
found_command = false
|
||||
partial_reply
|
||||
.split("\n")
|
||||
.each do |line|
|
||||
if line.match?(/^!/)
|
||||
found_command = true
|
||||
lines << line
|
||||
elsif found_command && line.match(/^\s*[^!]+/)
|
||||
break
|
||||
else
|
||||
lines << line
|
||||
end
|
||||
end
|
||||
lines.join("\n")
|
||||
end
|
||||
end
|
||||
|
||||
attr_reader :bot_user
|
||||
|
@ -142,7 +160,6 @@ module DiscourseAi
|
|||
bot_reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) if bot_reply_post
|
||||
end
|
||||
|
||||
next if reply.length < SiteSetting.min_personal_message_post_length
|
||||
# Minor hack to skip the delay during tests.
|
||||
next if (Time.now - start < 0.5) && !Rails.env.test?
|
||||
|
||||
|
@ -157,7 +174,7 @@ module DiscourseAi
|
|||
bot_user,
|
||||
topic_id: post.topic_id,
|
||||
raw: reply,
|
||||
skip_validations: false,
|
||||
skip_validations: true,
|
||||
)
|
||||
end
|
||||
|
||||
|
@ -181,7 +198,14 @@ module DiscourseAi
|
|||
bot_reply_post.post_custom_prompt ||= post.build_post_custom_prompt(custom_prompt: [])
|
||||
prompt = post.post_custom_prompt.custom_prompt || []
|
||||
|
||||
prompt << [partial_reply, bot_user.username]
|
||||
truncated_reply = partial_reply
|
||||
|
||||
if functions.found? && functions.cancel_completion?
|
||||
# we need to truncate the partial_reply
|
||||
truncated_reply = functions.truncate(partial_reply)
|
||||
end
|
||||
|
||||
prompt << [truncated_reply, bot_user.username] if truncated_reply.present?
|
||||
|
||||
post.post_custom_prompt.update!(custom_prompt: prompt)
|
||||
end
|
||||
|
@ -321,10 +345,10 @@ module DiscourseAi
|
|||
|
||||
def populate_functions(partial:, reply:, functions:, done:)
|
||||
if !done
|
||||
functions.found! if reply.match?(/^!/i)
|
||||
if functions.found?
|
||||
functions.cancel_completion! if reply.split("\n")[-1].match?(/^\s*[^!]+/)
|
||||
end
|
||||
functions.found! if reply.match?(/^!/i)
|
||||
else
|
||||
reply
|
||||
.scan(/^!.*$/i)
|
||||
|
|
|
@ -27,38 +27,40 @@ module ::DiscourseAi
|
|||
function = @functions.find { |f| f.name == name }
|
||||
next if function.blank?
|
||||
|
||||
arguments = arguments[0..-2] if arguments.end_with?(")")
|
||||
|
||||
temp_string = +""
|
||||
in_string = nil
|
||||
replace = SecureRandom.hex(10)
|
||||
arguments.each_char do |char|
|
||||
if %w[" '].include?(char) && !in_string
|
||||
in_string = char
|
||||
elsif char == in_string
|
||||
in_string = nil
|
||||
elsif char == "," && in_string
|
||||
char = replace
|
||||
end
|
||||
temp_string << char
|
||||
end
|
||||
|
||||
arguments = temp_string.split(",").map { |s| s.gsub(replace, ",").strip }
|
||||
|
||||
parsed_arguments = {}
|
||||
arguments.each do |argument|
|
||||
key, value = argument.split(":", 2)
|
||||
# remove stuff that is bypasses spec
|
||||
param = function.parameters.find { |p| p[:name] == key.strip }
|
||||
next if !param
|
||||
if arguments
|
||||
arguments = arguments[0..-2] if arguments.end_with?(")")
|
||||
|
||||
value = value.strip.gsub(/(\A"(.*)"\Z)|(\A'(.*)'\Z)/m, '\2\4') if value.present?
|
||||
|
||||
if param[:enum]
|
||||
next if !param[:enum].include?(value)
|
||||
temp_string = +""
|
||||
in_string = nil
|
||||
replace = SecureRandom.hex(10)
|
||||
arguments.each_char do |char|
|
||||
if %w[" '].include?(char) && !in_string
|
||||
in_string = char
|
||||
elsif char == in_string
|
||||
in_string = nil
|
||||
elsif char == "," && in_string
|
||||
char = replace
|
||||
end
|
||||
temp_string << char
|
||||
end
|
||||
|
||||
parsed_arguments[key.strip.to_sym] = value.strip
|
||||
arguments = temp_string.split(",").map { |s| s.gsub(replace, ",").strip }
|
||||
|
||||
arguments.each do |argument|
|
||||
key, value = argument.split(":", 2)
|
||||
# remove stuff that is bypasses spec
|
||||
param = function.parameters.find { |p| p[:name] == key.strip }
|
||||
next if !param
|
||||
|
||||
value = value.strip.gsub(/(\A"(.*)"\Z)|(\A'(.*)'\Z)/m, '\2\4') if value.present?
|
||||
|
||||
if param[:enum]
|
||||
next if !param[:enum].include?(value)
|
||||
end
|
||||
|
||||
parsed_arguments[key.strip.to_sym] = value.strip
|
||||
end
|
||||
end
|
||||
|
||||
# ensure parsed_arguments has all required arguments
|
||||
|
|
|
@ -2,7 +2,82 @@
|
|||
|
||||
require_relative "../../../support/openai_completions_inference_stubs"
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Bot do
|
||||
class FakeBot < DiscourseAi::AiBot::Bot
|
||||
class Tokenizer
|
||||
def tokenize(text)
|
||||
text.split(" ")
|
||||
end
|
||||
end
|
||||
|
||||
def tokenizer
|
||||
Tokenizer.new
|
||||
end
|
||||
|
||||
def prompt_limit
|
||||
10_000
|
||||
end
|
||||
|
||||
def build_message(poster_username, content, system: false, function: nil)
|
||||
role = poster_username == bot_user.username ? "Assistant" : "Human"
|
||||
|
||||
"#{role}: #{content}"
|
||||
end
|
||||
|
||||
def submit_prompt(prompt, prefer_low_cost: false)
|
||||
rows = @responses.shift
|
||||
rows.each { |data| yield data, lambda {} }
|
||||
end
|
||||
|
||||
def get_delta(partial, context)
|
||||
partial
|
||||
end
|
||||
|
||||
def add_response(response)
|
||||
@responses ||= []
|
||||
@responses << response
|
||||
end
|
||||
end
|
||||
|
||||
describe FakeBot do
|
||||
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) }
|
||||
fab!(:post) { Fabricate(:post, raw: "hello world") }
|
||||
|
||||
it "can handle command truncation for long messages" do
|
||||
bot = FakeBot.new(bot_user)
|
||||
|
||||
bot.add_response(["hello this is a big test I am testing 123\n", "!tags\nabc"])
|
||||
bot.add_response(["this is the reply"])
|
||||
|
||||
bot.reply_to(post)
|
||||
|
||||
reply = post.topic.posts.order(:post_number).last
|
||||
|
||||
expect(reply.raw).not_to include("abc")
|
||||
expect(reply.post_custom_prompt.custom_prompt.to_s).not_to include("abc")
|
||||
expect(reply.post_custom_prompt.custom_prompt.length).to eq(3)
|
||||
expect(reply.post_custom_prompt.custom_prompt[0][0]).to eq(
|
||||
"hello this is a big test I am testing 123\n!tags",
|
||||
)
|
||||
end
|
||||
|
||||
it "can handle command truncation for short bot messages" do
|
||||
bot = FakeBot.new(bot_user)
|
||||
|
||||
bot.add_response(["hello\n", "!tags\nabc"])
|
||||
bot.add_response(["this is the reply"])
|
||||
|
||||
bot.reply_to(post)
|
||||
|
||||
reply = post.topic.posts.order(:post_number).last
|
||||
|
||||
expect(reply.raw).not_to include("abc")
|
||||
expect(reply.post_custom_prompt.custom_prompt.to_s).not_to include("abc")
|
||||
expect(reply.post_custom_prompt.custom_prompt.length).to eq(3)
|
||||
expect(reply.post_custom_prompt.custom_prompt[0][0]).to eq("hello\n!tags")
|
||||
end
|
||||
end
|
||||
|
||||
describe DiscourseAi::AiBot::Bot do
|
||||
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) }
|
||||
fab!(:bot) { described_class.as(bot_user) }
|
||||
|
||||
|
|
Loading…
Reference in New Issue