FEATURE: basic progress for image generation (#133)

previously you would have to wait quite a while to see the prompt this implements
a very basic implementation of progress so you can see the API is working.

Also: 

- Fix google progress.
- Handle the incredibly rare, zero results from google.
- Simplify command so it is less error prone
- replace invoke and attache results with a invoke
- ensure invoke can only ever be run once
- pass in all the information a command needs in constructor
- use new pattern throughout
- test invocation in isolation
This commit is contained in:
Sam 2023-08-14 16:30:12 +10:00 committed by GitHub
parent b076e43d67
commit 20c1f2d788
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
14 changed files with 137 additions and 43 deletions

View File

@ -151,9 +151,14 @@ module DiscourseAi
name, args = function[:name], function[:arguments] name, args = function[:name], function[:arguments]
if command_klass = available_commands.detect { |cmd| cmd.invoked?(name) } if command_klass = available_commands.detect { |cmd| cmd.invoked?(name) }
command = command_klass.new(bot_user, args) command =
chain_intermediate, bot_reply_post = command_klass.new(
command.invoke_and_attach_result_to(bot_reply_post, post) bot_user: bot_user,
args: args,
post: bot_reply_post,
parent_post: post,
)
chain_intermediate, bot_reply_post = command.invoke!
chain ||= chain_intermediate chain ||= chain_intermediate
standalone ||= command.standalone? standalone ||= command.standalone?
end end

View File

@ -15,6 +15,9 @@ module DiscourseAi
end end
class Command class Command
CARET = "<!-- caret -->"
PROGRESS_CARET = "<!-- progress -->"
class << self class << self
def name def name
raise NotImplemented raise NotImplemented
@ -36,11 +39,25 @@ module DiscourseAi
end end
end end
attr_reader :bot_user, :args attr_reader :bot_user
def initialize(bot_user, args) def initialize(bot_user:, args:, post: nil, parent_post: nil)
@bot_user = bot_user @bot_user = bot_user
@args = args @args = args
@post = post
@parent_post = parent_post
@placeholder = +(<<~HTML).strip
<details>
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
<p>
#{CARET}
</p>
</details>
#{PROGRESS_CARET}
HTML
@invoked = false
end end
def bot def bot
@ -78,44 +95,59 @@ module DiscourseAi
true true
end end
def invoke_and_attach_result_to(post, parent_post) def show_progress(text, progress_caret: false)
placeholder = (<<~HTML).strip # during tests we may have none
<details> caret = progress_caret ? PROGRESS_CARET : CARET
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary> new_placeholder = @placeholder.sub(caret, text + caret)
</details> raw = @post.raw.sub(@placeholder, new_placeholder)
HTML @placeholder = new_placeholder
if !post @post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true)
post = end
def localized_description
I18n.t(
"discourse_ai.ai_bot.command_description.#{self.class.name}",
self.description_args,
)
end
def invoke!
raise StandardError.new("Command can only be invoked once!") if @invoked
@invoked = true
if !@post
@post =
PostCreator.create!( PostCreator.create!(
bot_user, bot_user,
raw: placeholder, raw: @placeholder,
topic_id: parent_post.topic_id, topic_id: @parent_post.topic_id,
skip_validations: true, skip_validations: true,
skip_rate_limiter: true, skip_rate_limiter: true,
) )
else else
post.revise( @post.revise(
bot_user, bot_user,
{ raw: post.raw + "\n\n" + placeholder + "\n\n" }, { raw: @post.raw + "\n\n" + @placeholder + "\n\n" },
skip_validations: true, skip_validations: true,
skip_revision: true, skip_revision: true,
) )
end end
post.post_custom_prompt ||= post.build_post_custom_prompt(custom_prompt: []) @post.post_custom_prompt ||= @post.build_post_custom_prompt(custom_prompt: [])
prompt = post.post_custom_prompt.custom_prompt || [] prompt = @post.post_custom_prompt.custom_prompt || []
parsed_args = JSON.parse(args).symbolize_keys parsed_args = JSON.parse(@args).symbolize_keys
prompt << [process(**parsed_args).to_json, self.class.name, "function"] prompt << [process(**parsed_args).to_json, self.class.name, "function"]
post.post_custom_prompt.update!(custom_prompt: prompt) @post.post_custom_prompt.update!(custom_prompt: prompt)
raw = +(<<~HTML) raw = +(<<~HTML)
<details> <details>
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary> <summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
<p> <p>
#{I18n.t("discourse_ai.ai_bot.command_description.#{self.class.name}", self.description_args)} #{localized_description}
</p> </p>
</details> </details>
@ -123,29 +155,29 @@ module DiscourseAi
raw << custom_raw if custom_raw.present? raw << custom_raw if custom_raw.present?
raw = post.raw.sub(placeholder, raw) raw = @post.raw.sub(@placeholder, raw)
post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true) @post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true)
if chain_next_response if chain_next_response
# somewhat annoying but whitespace was stripped in revise # somewhat annoying but whitespace was stripped in revise
# so we need to save again # so we need to save again
post.raw = raw @post.raw = raw
post.save!(validate: false) @post.save!(validate: false)
end end
[chain_next_response, post] [chain_next_response, @post]
end end
def format_results(rows, column_names = nil, args: nil) def format_results(rows, column_names = nil, args: nil)
rows = rows.map { |row| yield row } if block_given? rows = rows&.map { |row| yield row } if block_given?
if !column_names if !column_names
index = -1 index = -1
column_indexes = {} column_indexes = {}
rows = rows =
rows.map do |data| rows&.map do |data|
new_row = [] new_row = []
data.each do |key, value| data.each do |key, value|
found_index = column_indexes[key.to_s] ||= (index += 1) found_index = column_indexes[key.to_s] ||= (index += 1)

View File

@ -41,6 +41,9 @@ module DiscourseAi::AiBot::Commands
def process(query:) def process(query:)
@last_query = query @last_query = query
show_progress(localized_description)
api_key = SiteSetting.ai_google_custom_search_api_key api_key = SiteSetting.ai_google_custom_search_api_key
cx = SiteSetting.ai_google_custom_search_cx cx = SiteSetting.ai_google_custom_search_cx
query = CGI.escape(query) query = CGI.escape(query)

View File

@ -42,7 +42,28 @@ module DiscourseAi::AiBot::Commands
def process(prompt:) def process(prompt:)
@last_prompt = prompt @last_prompt = prompt
show_progress(localized_description)
results = nil
# API is flaky, so try a few times
3.times do
begin
thread =
Thread.new do
begin
results = DiscourseAi::Inference::StabilityGenerator.perform!(prompt) results = DiscourseAi::Inference::StabilityGenerator.perform!(prompt)
rescue => e
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
end
end
show_progress(".", progress_caret: true) while !thread.join(2)
break if results
end
end
uploads = [] uploads = []

View File

@ -64,7 +64,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
result = result =
DiscourseAi::AiBot::Commands::SearchCommand DiscourseAi::AiBot::Commands::SearchCommand
.new(nil, nil) .new(bot_user: nil, args: nil)
.process(query: "test search") .process(query: "test search")
.to_json .to_json

View File

@ -7,7 +7,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::CategoriesCommand do
it "can generate correct info" do it "can generate correct info" do
Fabricate(:category, name: "america", posts_year: 999) Fabricate(:category, name: "america", posts_year: 999)
info = DiscourseAi::AiBot::Commands::CategoriesCommand.new(nil, nil).process info = DiscourseAi::AiBot::Commands::CategoriesCommand.new(bot_user: nil, args: nil).process
expect(info.to_s).to include("america") expect(info.to_s).to include("america")
expect(info.to_s).to include("999") expect(info.to_s).to include("999")
end end

View File

@ -4,7 +4,7 @@ require_relative "../../../../support/openai_completions_inference_stubs"
RSpec.describe DiscourseAi::AiBot::Commands::Command do RSpec.describe DiscourseAi::AiBot::Commands::Command do
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
let(:command) { DiscourseAi::AiBot::Commands::Command.new(bot_user, nil) } let(:command) { DiscourseAi::AiBot::Commands::GoogleCommand.new(bot_user: bot_user, args: nil) }
describe "#format_results" do describe "#format_results" do
it "can generate efficient tables of data" do it "can generate efficient tables of data" do

View File

@ -4,6 +4,26 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
describe "#process" do describe "#process" do
it "will not explode if there are no results" do
post = Fabricate(:post)
SiteSetting.ai_google_custom_search_api_key = "abc"
SiteSetting.ai_google_custom_search_cx = "cx"
json_text = { searchInformation: { totalResults: "0" } }.to_json
stub_request(
:get,
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
).to_return(status: 200, body: json_text, headers: {})
google = described_class.new(bot_user: bot_user, post: post, args: {}.to_json)
info = google.process(query: "some search term").to_json
expect(google.description_args[:count]).to eq(0)
expect(info).to_not include("oops")
end
it "can generate correct info" do it "can generate correct info" do
post = Fabricate(:post) post = Fabricate(:post)
@ -31,7 +51,13 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term", "https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
).to_return(status: 200, body: json_text, headers: {}) ).to_return(status: 200, body: json_text, headers: {})
google = described_class.new(bot_user, post) google =
described_class.new(
bot_user: bot_user,
post: post,
args: { query: "some search term" }.to_json,
)
info = google.process(query: "some search term").to_json info = google.process(query: "some search term").to_json
expect(google.description_args[:count]).to eq(1) expect(google.description_args[:count]).to eq(1)
@ -39,6 +65,12 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
expect(info).to include("snippet1") expect(info).to include("snippet1")
expect(info).to include("some+search+term") expect(info).to include("some+search+term")
expect(info).to_not include("oops") expect(info).to_not include("oops")
google.invoke!
expect(post.reload.raw).to include("some search term")
expect { google.invoke! }.to raise_error(StandardError)
end end
end end
end end

View File

@ -24,7 +24,8 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
end end
.to_return(status: 200, body: { artifacts: [{ base64: image }, { base64: image }] }.to_json) .to_return(status: 200, body: { artifacts: [{ base64: image }, { base64: image }] }.to_json)
image = described_class.new(bot_user, post) image = described_class.new(bot_user: bot_user, post: post, args: nil)
info = image.process(prompt: "a pink cow").to_json info = image.process(prompt: "a pink cow").to_json
expect(JSON.parse(info)).to eq("prompt" => "a pink cow", "displayed_to_user" => true) expect(JSON.parse(info)).to eq("prompt" => "a pink cow", "displayed_to_user" => true)

View File

@ -8,7 +8,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::ReadCommand do
post1 = Fabricate(:post, raw: "hello there") post1 = Fabricate(:post, raw: "hello there")
Fabricate(:post, raw: "mister sam", topic: post1.topic) Fabricate(:post, raw: "mister sam", topic: post1.topic)
read = described_class.new(bot_user, post1) read = described_class.new(bot_user: bot_user, args: nil, post: post1)
results = read.process(topic_id: post1.topic_id) results = read.process(topic_id: post1.topic_id)

View File

@ -11,7 +11,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
describe "#process" do describe "#process" do
it "can handle no results" do it "can handle no results" do
post1 = Fabricate(:post) post1 = Fabricate(:post)
search = described_class.new(bot_user, post1) search = described_class.new(bot_user: bot_user, post: post1, args: nil)
results = search.process(query: "order:fake ABDDCDCEDGDG") results = search.process(query: "order:fake ABDDCDCEDGDG")
@ -24,7 +24,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
post1 = Fabricate(:post) post1 = Fabricate(:post)
search = described_class.new(bot_user, post1) search = described_class.new(bot_user: bot_user, post: post1, args: nil)
results = search.process(limit: 1, user: post1.user.username) results = search.process(limit: 1, user: post1.user.username)
expect(results[:rows].to_s).to include("/subfolder" + post1.url) expect(results[:rows].to_s).to include("/subfolder" + post1.url)
@ -36,7 +36,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
_post3 = Fabricate(:post, user: post1.user) _post3 = Fabricate(:post, user: post1.user)
# search has no built in support for limit: so handle it from the outside # search has no built in support for limit: so handle it from the outside
search = described_class.new(bot_user, post1) search = described_class.new(bot_user: bot_user, post: post1, args: nil)
results = search.process(limit: 2, user: post1.user.username) results = search.process(limit: 2, user: post1.user.username)

View File

@ -14,7 +14,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SummarizeCommand do
body: JSON.dump({ choices: [{ message: { content: "summary stuff" } }] }), body: JSON.dump({ choices: [{ message: { content: "summary stuff" } }] }),
) )
summarizer = described_class.new(bot_user, post) summarizer = described_class.new(bot_user: bot_user, args: nil, post: post)
info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?") info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?")
expect(info).to include("Topic summarized") expect(info).to include("Topic summarized")
@ -30,7 +30,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SummarizeCommand do
topic = Fabricate(:topic, category_id: category.id) topic = Fabricate(:topic, category_id: category.id)
post = Fabricate(:post, topic: topic) post = Fabricate(:post, topic: topic)
summarizer = described_class.new(bot_user, post) summarizer = described_class.new(bot_user: bot_user, post: post, args: nil)
info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?") info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?")
expect(info).not_to include(post.raw) expect(info).not_to include(post.raw)

View File

@ -10,7 +10,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::TagsCommand do
Fabricate(:tag, name: "america", public_topic_count: 100) Fabricate(:tag, name: "america", public_topic_count: 100)
Fabricate(:tag, name: "not_here", public_topic_count: 0) Fabricate(:tag, name: "not_here", public_topic_count: 0)
info = DiscourseAi::AiBot::Commands::TagsCommand.new(nil, nil).process info = DiscourseAi::AiBot::Commands::TagsCommand.new(bot_user: nil, args: nil).process
expect(info.to_s).to include("america") expect(info.to_s).to include("america")
expect(info.to_s).not_to include("not_here") expect(info.to_s).not_to include("not_here")

View File

@ -8,7 +8,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::TimeCommand do
freeze_time freeze_time
args = { timezone: "America/Los_Angeles" } args = { timezone: "America/Los_Angeles" }
info = DiscourseAi::AiBot::Commands::TimeCommand.new(nil, nil).process(**args) info = DiscourseAi::AiBot::Commands::TimeCommand.new(bot_user: nil, args: nil).process(**args)
expect(info).to eq({ args: args, time: Time.now.in_time_zone("America/Los_Angeles").to_s }) expect(info).to eq({ args: args, time: Time.now.in_time_zone("America/Los_Angeles").to_s })
expect(info.to_s).not_to include("not_here") expect(info.to_s).not_to include("not_here")