FEATURE: basic progress for image generation (#133)
previously you would have to wait quite a while to see the prompt this implements a very basic implementation of progress so you can see the API is working. Also: - Fix google progress. - Handle the incredibly rare, zero results from google. - Simplify command so it is less error prone - replace invoke and attache results with a invoke - ensure invoke can only ever be run once - pass in all the information a command needs in constructor - use new pattern throughout - test invocation in isolation
This commit is contained in:
parent
b076e43d67
commit
20c1f2d788
|
@ -151,9 +151,14 @@ module DiscourseAi
|
|||
name, args = function[:name], function[:arguments]
|
||||
|
||||
if command_klass = available_commands.detect { |cmd| cmd.invoked?(name) }
|
||||
command = command_klass.new(bot_user, args)
|
||||
chain_intermediate, bot_reply_post =
|
||||
command.invoke_and_attach_result_to(bot_reply_post, post)
|
||||
command =
|
||||
command_klass.new(
|
||||
bot_user: bot_user,
|
||||
args: args,
|
||||
post: bot_reply_post,
|
||||
parent_post: post,
|
||||
)
|
||||
chain_intermediate, bot_reply_post = command.invoke!
|
||||
chain ||= chain_intermediate
|
||||
standalone ||= command.standalone?
|
||||
end
|
||||
|
|
|
@ -15,6 +15,9 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
class Command
|
||||
CARET = "<!-- caret -->"
|
||||
PROGRESS_CARET = "<!-- progress -->"
|
||||
|
||||
class << self
|
||||
def name
|
||||
raise NotImplemented
|
||||
|
@ -36,11 +39,25 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
attr_reader :bot_user, :args
|
||||
attr_reader :bot_user
|
||||
|
||||
def initialize(bot_user, args)
|
||||
def initialize(bot_user:, args:, post: nil, parent_post: nil)
|
||||
@bot_user = bot_user
|
||||
@args = args
|
||||
@post = post
|
||||
@parent_post = parent_post
|
||||
|
||||
@placeholder = +(<<~HTML).strip
|
||||
<details>
|
||||
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
|
||||
<p>
|
||||
#{CARET}
|
||||
</p>
|
||||
</details>
|
||||
#{PROGRESS_CARET}
|
||||
HTML
|
||||
|
||||
@invoked = false
|
||||
end
|
||||
|
||||
def bot
|
||||
|
@ -78,44 +95,59 @@ module DiscourseAi
|
|||
true
|
||||
end
|
||||
|
||||
def invoke_and_attach_result_to(post, parent_post)
|
||||
placeholder = (<<~HTML).strip
|
||||
<details>
|
||||
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
|
||||
</details>
|
||||
HTML
|
||||
def show_progress(text, progress_caret: false)
|
||||
# during tests we may have none
|
||||
caret = progress_caret ? PROGRESS_CARET : CARET
|
||||
new_placeholder = @placeholder.sub(caret, text + caret)
|
||||
raw = @post.raw.sub(@placeholder, new_placeholder)
|
||||
@placeholder = new_placeholder
|
||||
|
||||
if !post
|
||||
post =
|
||||
@post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true)
|
||||
end
|
||||
|
||||
def localized_description
|
||||
I18n.t(
|
||||
"discourse_ai.ai_bot.command_description.#{self.class.name}",
|
||||
self.description_args,
|
||||
)
|
||||
end
|
||||
|
||||
def invoke!
|
||||
raise StandardError.new("Command can only be invoked once!") if @invoked
|
||||
|
||||
@invoked = true
|
||||
|
||||
if !@post
|
||||
@post =
|
||||
PostCreator.create!(
|
||||
bot_user,
|
||||
raw: placeholder,
|
||||
topic_id: parent_post.topic_id,
|
||||
raw: @placeholder,
|
||||
topic_id: @parent_post.topic_id,
|
||||
skip_validations: true,
|
||||
skip_rate_limiter: true,
|
||||
)
|
||||
else
|
||||
post.revise(
|
||||
@post.revise(
|
||||
bot_user,
|
||||
{ raw: post.raw + "\n\n" + placeholder + "\n\n" },
|
||||
{ raw: @post.raw + "\n\n" + @placeholder + "\n\n" },
|
||||
skip_validations: true,
|
||||
skip_revision: true,
|
||||
)
|
||||
end
|
||||
|
||||
post.post_custom_prompt ||= post.build_post_custom_prompt(custom_prompt: [])
|
||||
prompt = post.post_custom_prompt.custom_prompt || []
|
||||
@post.post_custom_prompt ||= @post.build_post_custom_prompt(custom_prompt: [])
|
||||
prompt = @post.post_custom_prompt.custom_prompt || []
|
||||
|
||||
parsed_args = JSON.parse(args).symbolize_keys
|
||||
parsed_args = JSON.parse(@args).symbolize_keys
|
||||
|
||||
prompt << [process(**parsed_args).to_json, self.class.name, "function"]
|
||||
post.post_custom_prompt.update!(custom_prompt: prompt)
|
||||
@post.post_custom_prompt.update!(custom_prompt: prompt)
|
||||
|
||||
raw = +(<<~HTML)
|
||||
<details>
|
||||
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
|
||||
<p>
|
||||
#{I18n.t("discourse_ai.ai_bot.command_description.#{self.class.name}", self.description_args)}
|
||||
#{localized_description}
|
||||
</p>
|
||||
</details>
|
||||
|
||||
|
@ -123,29 +155,29 @@ module DiscourseAi
|
|||
|
||||
raw << custom_raw if custom_raw.present?
|
||||
|
||||
raw = post.raw.sub(placeholder, raw)
|
||||
raw = @post.raw.sub(@placeholder, raw)
|
||||
|
||||
post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true)
|
||||
@post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true)
|
||||
|
||||
if chain_next_response
|
||||
# somewhat annoying but whitespace was stripped in revise
|
||||
# so we need to save again
|
||||
post.raw = raw
|
||||
post.save!(validate: false)
|
||||
@post.raw = raw
|
||||
@post.save!(validate: false)
|
||||
end
|
||||
|
||||
[chain_next_response, post]
|
||||
[chain_next_response, @post]
|
||||
end
|
||||
|
||||
def format_results(rows, column_names = nil, args: nil)
|
||||
rows = rows.map { |row| yield row } if block_given?
|
||||
rows = rows&.map { |row| yield row } if block_given?
|
||||
|
||||
if !column_names
|
||||
index = -1
|
||||
column_indexes = {}
|
||||
|
||||
rows =
|
||||
rows.map do |data|
|
||||
rows&.map do |data|
|
||||
new_row = []
|
||||
data.each do |key, value|
|
||||
found_index = column_indexes[key.to_s] ||= (index += 1)
|
||||
|
|
|
@ -41,6 +41,9 @@ module DiscourseAi::AiBot::Commands
|
|||
|
||||
def process(query:)
|
||||
@last_query = query
|
||||
|
||||
show_progress(localized_description)
|
||||
|
||||
api_key = SiteSetting.ai_google_custom_search_api_key
|
||||
cx = SiteSetting.ai_google_custom_search_cx
|
||||
query = CGI.escape(query)
|
||||
|
|
|
@ -42,7 +42,28 @@ module DiscourseAi::AiBot::Commands
|
|||
|
||||
def process(prompt:)
|
||||
@last_prompt = prompt
|
||||
|
||||
show_progress(localized_description)
|
||||
|
||||
results = nil
|
||||
|
||||
# API is flaky, so try a few times
|
||||
3.times do
|
||||
begin
|
||||
thread =
|
||||
Thread.new do
|
||||
begin
|
||||
results = DiscourseAi::Inference::StabilityGenerator.perform!(prompt)
|
||||
rescue => e
|
||||
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
|
||||
end
|
||||
end
|
||||
|
||||
show_progress(".", progress_caret: true) while !thread.join(2)
|
||||
|
||||
break if results
|
||||
end
|
||||
end
|
||||
|
||||
uploads = []
|
||||
|
||||
|
|
|
@ -64,7 +64,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
|
|||
|
||||
result =
|
||||
DiscourseAi::AiBot::Commands::SearchCommand
|
||||
.new(nil, nil)
|
||||
.new(bot_user: nil, args: nil)
|
||||
.process(query: "test search")
|
||||
.to_json
|
||||
|
||||
|
|
|
@ -7,7 +7,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::CategoriesCommand do
|
|||
it "can generate correct info" do
|
||||
Fabricate(:category, name: "america", posts_year: 999)
|
||||
|
||||
info = DiscourseAi::AiBot::Commands::CategoriesCommand.new(nil, nil).process
|
||||
info = DiscourseAi::AiBot::Commands::CategoriesCommand.new(bot_user: nil, args: nil).process
|
||||
expect(info.to_s).to include("america")
|
||||
expect(info.to_s).to include("999")
|
||||
end
|
||||
|
|
|
@ -4,7 +4,7 @@ require_relative "../../../../support/openai_completions_inference_stubs"
|
|||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::Command do
|
||||
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
let(:command) { DiscourseAi::AiBot::Commands::Command.new(bot_user, nil) }
|
||||
let(:command) { DiscourseAi::AiBot::Commands::GoogleCommand.new(bot_user: bot_user, args: nil) }
|
||||
|
||||
describe "#format_results" do
|
||||
it "can generate efficient tables of data" do
|
||||
|
|
|
@ -4,6 +4,26 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
|
|||
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
|
||||
describe "#process" do
|
||||
it "will not explode if there are no results" do
|
||||
post = Fabricate(:post)
|
||||
|
||||
SiteSetting.ai_google_custom_search_api_key = "abc"
|
||||
SiteSetting.ai_google_custom_search_cx = "cx"
|
||||
|
||||
json_text = { searchInformation: { totalResults: "0" } }.to_json
|
||||
|
||||
stub_request(
|
||||
:get,
|
||||
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
|
||||
).to_return(status: 200, body: json_text, headers: {})
|
||||
|
||||
google = described_class.new(bot_user: bot_user, post: post, args: {}.to_json)
|
||||
info = google.process(query: "some search term").to_json
|
||||
|
||||
expect(google.description_args[:count]).to eq(0)
|
||||
expect(info).to_not include("oops")
|
||||
end
|
||||
|
||||
it "can generate correct info" do
|
||||
post = Fabricate(:post)
|
||||
|
||||
|
@ -31,7 +51,13 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
|
|||
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
|
||||
).to_return(status: 200, body: json_text, headers: {})
|
||||
|
||||
google = described_class.new(bot_user, post)
|
||||
google =
|
||||
described_class.new(
|
||||
bot_user: bot_user,
|
||||
post: post,
|
||||
args: { query: "some search term" }.to_json,
|
||||
)
|
||||
|
||||
info = google.process(query: "some search term").to_json
|
||||
|
||||
expect(google.description_args[:count]).to eq(1)
|
||||
|
@ -39,6 +65,12 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
|
|||
expect(info).to include("snippet1")
|
||||
expect(info).to include("some+search+term")
|
||||
expect(info).to_not include("oops")
|
||||
|
||||
google.invoke!
|
||||
|
||||
expect(post.reload.raw).to include("some search term")
|
||||
|
||||
expect { google.invoke! }.to raise_error(StandardError)
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -24,7 +24,8 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
|||
end
|
||||
.to_return(status: 200, body: { artifacts: [{ base64: image }, { base64: image }] }.to_json)
|
||||
|
||||
image = described_class.new(bot_user, post)
|
||||
image = described_class.new(bot_user: bot_user, post: post, args: nil)
|
||||
|
||||
info = image.process(prompt: "a pink cow").to_json
|
||||
|
||||
expect(JSON.parse(info)).to eq("prompt" => "a pink cow", "displayed_to_user" => true)
|
||||
|
|
|
@ -8,7 +8,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::ReadCommand do
|
|||
post1 = Fabricate(:post, raw: "hello there")
|
||||
Fabricate(:post, raw: "mister sam", topic: post1.topic)
|
||||
|
||||
read = described_class.new(bot_user, post1)
|
||||
read = described_class.new(bot_user: bot_user, args: nil, post: post1)
|
||||
|
||||
results = read.process(topic_id: post1.topic_id)
|
||||
|
||||
|
|
|
@ -11,7 +11,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
|
|||
describe "#process" do
|
||||
it "can handle no results" do
|
||||
post1 = Fabricate(:post)
|
||||
search = described_class.new(bot_user, post1)
|
||||
search = described_class.new(bot_user: bot_user, post: post1, args: nil)
|
||||
|
||||
results = search.process(query: "order:fake ABDDCDCEDGDG")
|
||||
|
||||
|
@ -24,7 +24,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
|
|||
|
||||
post1 = Fabricate(:post)
|
||||
|
||||
search = described_class.new(bot_user, post1)
|
||||
search = described_class.new(bot_user: bot_user, post: post1, args: nil)
|
||||
|
||||
results = search.process(limit: 1, user: post1.user.username)
|
||||
expect(results[:rows].to_s).to include("/subfolder" + post1.url)
|
||||
|
@ -36,7 +36,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
|
|||
_post3 = Fabricate(:post, user: post1.user)
|
||||
|
||||
# search has no built in support for limit: so handle it from the outside
|
||||
search = described_class.new(bot_user, post1)
|
||||
search = described_class.new(bot_user: bot_user, post: post1, args: nil)
|
||||
|
||||
results = search.process(limit: 2, user: post1.user.username)
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SummarizeCommand do
|
|||
body: JSON.dump({ choices: [{ message: { content: "summary stuff" } }] }),
|
||||
)
|
||||
|
||||
summarizer = described_class.new(bot_user, post)
|
||||
summarizer = described_class.new(bot_user: bot_user, args: nil, post: post)
|
||||
info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?")
|
||||
|
||||
expect(info).to include("Topic summarized")
|
||||
|
@ -30,7 +30,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SummarizeCommand do
|
|||
topic = Fabricate(:topic, category_id: category.id)
|
||||
post = Fabricate(:post, topic: topic)
|
||||
|
||||
summarizer = described_class.new(bot_user, post)
|
||||
summarizer = described_class.new(bot_user: bot_user, post: post, args: nil)
|
||||
info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?")
|
||||
|
||||
expect(info).not_to include(post.raw)
|
||||
|
|
|
@ -10,7 +10,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::TagsCommand do
|
|||
Fabricate(:tag, name: "america", public_topic_count: 100)
|
||||
Fabricate(:tag, name: "not_here", public_topic_count: 0)
|
||||
|
||||
info = DiscourseAi::AiBot::Commands::TagsCommand.new(nil, nil).process
|
||||
info = DiscourseAi::AiBot::Commands::TagsCommand.new(bot_user: nil, args: nil).process
|
||||
|
||||
expect(info.to_s).to include("america")
|
||||
expect(info.to_s).not_to include("not_here")
|
||||
|
|
|
@ -8,7 +8,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::TimeCommand do
|
|||
freeze_time
|
||||
|
||||
args = { timezone: "America/Los_Angeles" }
|
||||
info = DiscourseAi::AiBot::Commands::TimeCommand.new(nil, nil).process(**args)
|
||||
info = DiscourseAi::AiBot::Commands::TimeCommand.new(bot_user: nil, args: nil).process(**args)
|
||||
|
||||
expect(info).to eq({ args: args, time: Time.now.in_time_zone("America/Los_Angeles").to_s })
|
||||
expect(info.to_s).not_to include("not_here")
|
||||
|
|
Loading…
Reference in New Issue