FIX: properly truncate !command prompts (#227)

* FIX: properly truncate !command prompts

### What is going on here?

Previous to this change where a command was issued by the LLM it
could hallucinate a continuation eg:

```
This is what tags are

!tags

some nonsense here
```

This change introduces safeguards so `some nonsense here` does not
creep in to the prompt history, poisoning the llm results

This in effect grounds the llm a lot better and results in the llm
forgetting less about results.

The change only impacts Claude at the moment, but will also improve
stuff for llama 2 in future.

Also, this makes it significantly easier to test the bot framework
without an llm cause we avoid a whole bunch of complex stubbing

* blank is not a valid bot response, do not inject into prompt
This commit is contained in:
Sam 2023-09-15 07:02:37 +10:00 committed by GitHub
parent f57c1bb0f6
commit 316ea9624e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 134 additions and 33 deletions

View File

@ -47,6 +47,24 @@ module DiscourseAi
def to_a
@functions
end
def truncate(partial_reply)
lines = []
found_command = false
partial_reply
.split("\n")
.each do |line|
if line.match?(/^!/)
found_command = true
lines << line
elsif found_command && line.match(/^\s*[^!]+/)
break
else
lines << line
end
end
lines.join("\n")
end
end
attr_reader :bot_user
@ -142,7 +160,6 @@ module DiscourseAi
bot_reply_post.update!(raw: reply, cooked: PrettyText.cook(reply)) if bot_reply_post
end
next if reply.length < SiteSetting.min_personal_message_post_length
# Minor hack to skip the delay during tests.
next if (Time.now - start < 0.5) && !Rails.env.test?
@ -157,7 +174,7 @@ module DiscourseAi
bot_user,
topic_id: post.topic_id,
raw: reply,
skip_validations: false,
skip_validations: true,
)
end
@ -181,7 +198,14 @@ module DiscourseAi
bot_reply_post.post_custom_prompt ||= post.build_post_custom_prompt(custom_prompt: [])
prompt = post.post_custom_prompt.custom_prompt || []
prompt << [partial_reply, bot_user.username]
truncated_reply = partial_reply
if functions.found? && functions.cancel_completion?
# we need to truncate the partial_reply
truncated_reply = functions.truncate(partial_reply)
end
prompt << [truncated_reply, bot_user.username] if truncated_reply.present?
post.post_custom_prompt.update!(custom_prompt: prompt)
end
@ -321,10 +345,10 @@ module DiscourseAi
def populate_functions(partial:, reply:, functions:, done:)
if !done
functions.found! if reply.match?(/^!/i)
if functions.found?
functions.cancel_completion! if reply.split("\n")[-1].match?(/^\s*[^!]+/)
end
functions.found! if reply.match?(/^!/i)
else
reply
.scan(/^!.*$/i)

View File

@ -27,38 +27,40 @@ module ::DiscourseAi
function = @functions.find { |f| f.name == name }
next if function.blank?
arguments = arguments[0..-2] if arguments.end_with?(")")
temp_string = +""
in_string = nil
replace = SecureRandom.hex(10)
arguments.each_char do |char|
if %w[" '].include?(char) && !in_string
in_string = char
elsif char == in_string
in_string = nil
elsif char == "," && in_string
char = replace
end
temp_string << char
end
arguments = temp_string.split(",").map { |s| s.gsub(replace, ",").strip }
parsed_arguments = {}
arguments.each do |argument|
key, value = argument.split(":", 2)
# remove stuff that is bypasses spec
param = function.parameters.find { |p| p[:name] == key.strip }
next if !param
if arguments
arguments = arguments[0..-2] if arguments.end_with?(")")
value = value.strip.gsub(/(\A"(.*)"\Z)|(\A'(.*)'\Z)/m, '\2\4') if value.present?
if param[:enum]
next if !param[:enum].include?(value)
temp_string = +""
in_string = nil
replace = SecureRandom.hex(10)
arguments.each_char do |char|
if %w[" '].include?(char) && !in_string
in_string = char
elsif char == in_string
in_string = nil
elsif char == "," && in_string
char = replace
end
temp_string << char
end
parsed_arguments[key.strip.to_sym] = value.strip
arguments = temp_string.split(",").map { |s| s.gsub(replace, ",").strip }
arguments.each do |argument|
key, value = argument.split(":", 2)
# remove stuff that is bypasses spec
param = function.parameters.find { |p| p[:name] == key.strip }
next if !param
value = value.strip.gsub(/(\A"(.*)"\Z)|(\A'(.*)'\Z)/m, '\2\4') if value.present?
if param[:enum]
next if !param[:enum].include?(value)
end
parsed_arguments[key.strip.to_sym] = value.strip
end
end
# ensure parsed_arguments has all required arguments

View File

@ -2,7 +2,82 @@
require_relative "../../../support/openai_completions_inference_stubs"
RSpec.describe DiscourseAi::AiBot::Bot do
class FakeBot < DiscourseAi::AiBot::Bot
class Tokenizer
def tokenize(text)
text.split(" ")
end
end
def tokenizer
Tokenizer.new
end
def prompt_limit
10_000
end
def build_message(poster_username, content, system: false, function: nil)
role = poster_username == bot_user.username ? "Assistant" : "Human"
"#{role}: #{content}"
end
def submit_prompt(prompt, prefer_low_cost: false)
rows = @responses.shift
rows.each { |data| yield data, lambda {} }
end
def get_delta(partial, context)
partial
end
def add_response(response)
@responses ||= []
@responses << response
end
end
describe FakeBot do
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) }
fab!(:post) { Fabricate(:post, raw: "hello world") }
it "can handle command truncation for long messages" do
bot = FakeBot.new(bot_user)
bot.add_response(["hello this is a big test I am testing 123\n", "!tags\nabc"])
bot.add_response(["this is the reply"])
bot.reply_to(post)
reply = post.topic.posts.order(:post_number).last
expect(reply.raw).not_to include("abc")
expect(reply.post_custom_prompt.custom_prompt.to_s).not_to include("abc")
expect(reply.post_custom_prompt.custom_prompt.length).to eq(3)
expect(reply.post_custom_prompt.custom_prompt[0][0]).to eq(
"hello this is a big test I am testing 123\n!tags",
)
end
it "can handle command truncation for short bot messages" do
bot = FakeBot.new(bot_user)
bot.add_response(["hello\n", "!tags\nabc"])
bot.add_response(["this is the reply"])
bot.reply_to(post)
reply = post.topic.posts.order(:post_number).last
expect(reply.raw).not_to include("abc")
expect(reply.post_custom_prompt.custom_prompt.to_s).not_to include("abc")
expect(reply.post_custom_prompt.custom_prompt.length).to eq(3)
expect(reply.post_custom_prompt.custom_prompt[0][0]).to eq("hello\n!tags")
end
end
describe DiscourseAi::AiBot::Bot do
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT4_ID) }
fab!(:bot) { described_class.as(bot_user) }