FEATURE: basic progress for image generation (#133)
previously you would have to wait quite a while to see the prompt this implements a very basic implementation of progress so you can see the API is working. Also: - Fix google progress. - Handle the incredibly rare, zero results from google. - Simplify command so it is less error prone - replace invoke and attache results with a invoke - ensure invoke can only ever be run once - pass in all the information a command needs in constructor - use new pattern throughout - test invocation in isolation
This commit is contained in:
parent
b076e43d67
commit
20c1f2d788
|
@ -151,9 +151,14 @@ module DiscourseAi
|
||||||
name, args = function[:name], function[:arguments]
|
name, args = function[:name], function[:arguments]
|
||||||
|
|
||||||
if command_klass = available_commands.detect { |cmd| cmd.invoked?(name) }
|
if command_klass = available_commands.detect { |cmd| cmd.invoked?(name) }
|
||||||
command = command_klass.new(bot_user, args)
|
command =
|
||||||
chain_intermediate, bot_reply_post =
|
command_klass.new(
|
||||||
command.invoke_and_attach_result_to(bot_reply_post, post)
|
bot_user: bot_user,
|
||||||
|
args: args,
|
||||||
|
post: bot_reply_post,
|
||||||
|
parent_post: post,
|
||||||
|
)
|
||||||
|
chain_intermediate, bot_reply_post = command.invoke!
|
||||||
chain ||= chain_intermediate
|
chain ||= chain_intermediate
|
||||||
standalone ||= command.standalone?
|
standalone ||= command.standalone?
|
||||||
end
|
end
|
||||||
|
|
|
@ -15,6 +15,9 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
|
|
||||||
class Command
|
class Command
|
||||||
|
CARET = "<!-- caret -->"
|
||||||
|
PROGRESS_CARET = "<!-- progress -->"
|
||||||
|
|
||||||
class << self
|
class << self
|
||||||
def name
|
def name
|
||||||
raise NotImplemented
|
raise NotImplemented
|
||||||
|
@ -36,11 +39,25 @@ module DiscourseAi
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
attr_reader :bot_user, :args
|
attr_reader :bot_user
|
||||||
|
|
||||||
def initialize(bot_user, args)
|
def initialize(bot_user:, args:, post: nil, parent_post: nil)
|
||||||
@bot_user = bot_user
|
@bot_user = bot_user
|
||||||
@args = args
|
@args = args
|
||||||
|
@post = post
|
||||||
|
@parent_post = parent_post
|
||||||
|
|
||||||
|
@placeholder = +(<<~HTML).strip
|
||||||
|
<details>
|
||||||
|
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
|
||||||
|
<p>
|
||||||
|
#{CARET}
|
||||||
|
</p>
|
||||||
|
</details>
|
||||||
|
#{PROGRESS_CARET}
|
||||||
|
HTML
|
||||||
|
|
||||||
|
@invoked = false
|
||||||
end
|
end
|
||||||
|
|
||||||
def bot
|
def bot
|
||||||
|
@ -78,44 +95,59 @@ module DiscourseAi
|
||||||
true
|
true
|
||||||
end
|
end
|
||||||
|
|
||||||
def invoke_and_attach_result_to(post, parent_post)
|
def show_progress(text, progress_caret: false)
|
||||||
placeholder = (<<~HTML).strip
|
# during tests we may have none
|
||||||
<details>
|
caret = progress_caret ? PROGRESS_CARET : CARET
|
||||||
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
|
new_placeholder = @placeholder.sub(caret, text + caret)
|
||||||
</details>
|
raw = @post.raw.sub(@placeholder, new_placeholder)
|
||||||
HTML
|
@placeholder = new_placeholder
|
||||||
|
|
||||||
if !post
|
@post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true)
|
||||||
post =
|
end
|
||||||
|
|
||||||
|
def localized_description
|
||||||
|
I18n.t(
|
||||||
|
"discourse_ai.ai_bot.command_description.#{self.class.name}",
|
||||||
|
self.description_args,
|
||||||
|
)
|
||||||
|
end
|
||||||
|
|
||||||
|
def invoke!
|
||||||
|
raise StandardError.new("Command can only be invoked once!") if @invoked
|
||||||
|
|
||||||
|
@invoked = true
|
||||||
|
|
||||||
|
if !@post
|
||||||
|
@post =
|
||||||
PostCreator.create!(
|
PostCreator.create!(
|
||||||
bot_user,
|
bot_user,
|
||||||
raw: placeholder,
|
raw: @placeholder,
|
||||||
topic_id: parent_post.topic_id,
|
topic_id: @parent_post.topic_id,
|
||||||
skip_validations: true,
|
skip_validations: true,
|
||||||
skip_rate_limiter: true,
|
skip_rate_limiter: true,
|
||||||
)
|
)
|
||||||
else
|
else
|
||||||
post.revise(
|
@post.revise(
|
||||||
bot_user,
|
bot_user,
|
||||||
{ raw: post.raw + "\n\n" + placeholder + "\n\n" },
|
{ raw: @post.raw + "\n\n" + @placeholder + "\n\n" },
|
||||||
skip_validations: true,
|
skip_validations: true,
|
||||||
skip_revision: true,
|
skip_revision: true,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
post.post_custom_prompt ||= post.build_post_custom_prompt(custom_prompt: [])
|
@post.post_custom_prompt ||= @post.build_post_custom_prompt(custom_prompt: [])
|
||||||
prompt = post.post_custom_prompt.custom_prompt || []
|
prompt = @post.post_custom_prompt.custom_prompt || []
|
||||||
|
|
||||||
parsed_args = JSON.parse(args).symbolize_keys
|
parsed_args = JSON.parse(@args).symbolize_keys
|
||||||
|
|
||||||
prompt << [process(**parsed_args).to_json, self.class.name, "function"]
|
prompt << [process(**parsed_args).to_json, self.class.name, "function"]
|
||||||
post.post_custom_prompt.update!(custom_prompt: prompt)
|
@post.post_custom_prompt.update!(custom_prompt: prompt)
|
||||||
|
|
||||||
raw = +(<<~HTML)
|
raw = +(<<~HTML)
|
||||||
<details>
|
<details>
|
||||||
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
|
<summary>#{I18n.t("discourse_ai.ai_bot.command_summary.#{self.class.name}")}</summary>
|
||||||
<p>
|
<p>
|
||||||
#{I18n.t("discourse_ai.ai_bot.command_description.#{self.class.name}", self.description_args)}
|
#{localized_description}
|
||||||
</p>
|
</p>
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
@ -123,29 +155,29 @@ module DiscourseAi
|
||||||
|
|
||||||
raw << custom_raw if custom_raw.present?
|
raw << custom_raw if custom_raw.present?
|
||||||
|
|
||||||
raw = post.raw.sub(placeholder, raw)
|
raw = @post.raw.sub(@placeholder, raw)
|
||||||
|
|
||||||
post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true)
|
@post.revise(bot_user, { raw: raw }, skip_validations: true, skip_revision: true)
|
||||||
|
|
||||||
if chain_next_response
|
if chain_next_response
|
||||||
# somewhat annoying but whitespace was stripped in revise
|
# somewhat annoying but whitespace was stripped in revise
|
||||||
# so we need to save again
|
# so we need to save again
|
||||||
post.raw = raw
|
@post.raw = raw
|
||||||
post.save!(validate: false)
|
@post.save!(validate: false)
|
||||||
end
|
end
|
||||||
|
|
||||||
[chain_next_response, post]
|
[chain_next_response, @post]
|
||||||
end
|
end
|
||||||
|
|
||||||
def format_results(rows, column_names = nil, args: nil)
|
def format_results(rows, column_names = nil, args: nil)
|
||||||
rows = rows.map { |row| yield row } if block_given?
|
rows = rows&.map { |row| yield row } if block_given?
|
||||||
|
|
||||||
if !column_names
|
if !column_names
|
||||||
index = -1
|
index = -1
|
||||||
column_indexes = {}
|
column_indexes = {}
|
||||||
|
|
||||||
rows =
|
rows =
|
||||||
rows.map do |data|
|
rows&.map do |data|
|
||||||
new_row = []
|
new_row = []
|
||||||
data.each do |key, value|
|
data.each do |key, value|
|
||||||
found_index = column_indexes[key.to_s] ||= (index += 1)
|
found_index = column_indexes[key.to_s] ||= (index += 1)
|
||||||
|
|
|
@ -41,6 +41,9 @@ module DiscourseAi::AiBot::Commands
|
||||||
|
|
||||||
def process(query:)
|
def process(query:)
|
||||||
@last_query = query
|
@last_query = query
|
||||||
|
|
||||||
|
show_progress(localized_description)
|
||||||
|
|
||||||
api_key = SiteSetting.ai_google_custom_search_api_key
|
api_key = SiteSetting.ai_google_custom_search_api_key
|
||||||
cx = SiteSetting.ai_google_custom_search_cx
|
cx = SiteSetting.ai_google_custom_search_cx
|
||||||
query = CGI.escape(query)
|
query = CGI.escape(query)
|
||||||
|
|
|
@ -42,7 +42,28 @@ module DiscourseAi::AiBot::Commands
|
||||||
|
|
||||||
def process(prompt:)
|
def process(prompt:)
|
||||||
@last_prompt = prompt
|
@last_prompt = prompt
|
||||||
|
|
||||||
|
show_progress(localized_description)
|
||||||
|
|
||||||
|
results = nil
|
||||||
|
|
||||||
|
# API is flaky, so try a few times
|
||||||
|
3.times do
|
||||||
|
begin
|
||||||
|
thread =
|
||||||
|
Thread.new do
|
||||||
|
begin
|
||||||
results = DiscourseAi::Inference::StabilityGenerator.perform!(prompt)
|
results = DiscourseAi::Inference::StabilityGenerator.perform!(prompt)
|
||||||
|
rescue => e
|
||||||
|
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
|
show_progress(".", progress_caret: true) while !thread.join(2)
|
||||||
|
|
||||||
|
break if results
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
uploads = []
|
uploads = []
|
||||||
|
|
||||||
|
|
|
@ -64,7 +64,7 @@ RSpec.describe DiscourseAi::AiBot::Bot do
|
||||||
|
|
||||||
result =
|
result =
|
||||||
DiscourseAi::AiBot::Commands::SearchCommand
|
DiscourseAi::AiBot::Commands::SearchCommand
|
||||||
.new(nil, nil)
|
.new(bot_user: nil, args: nil)
|
||||||
.process(query: "test search")
|
.process(query: "test search")
|
||||||
.to_json
|
.to_json
|
||||||
|
|
||||||
|
|
|
@ -7,7 +7,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::CategoriesCommand do
|
||||||
it "can generate correct info" do
|
it "can generate correct info" do
|
||||||
Fabricate(:category, name: "america", posts_year: 999)
|
Fabricate(:category, name: "america", posts_year: 999)
|
||||||
|
|
||||||
info = DiscourseAi::AiBot::Commands::CategoriesCommand.new(nil, nil).process
|
info = DiscourseAi::AiBot::Commands::CategoriesCommand.new(bot_user: nil, args: nil).process
|
||||||
expect(info.to_s).to include("america")
|
expect(info.to_s).to include("america")
|
||||||
expect(info.to_s).to include("999")
|
expect(info.to_s).to include("999")
|
||||||
end
|
end
|
||||||
|
|
|
@ -4,7 +4,7 @@ require_relative "../../../../support/openai_completions_inference_stubs"
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Commands::Command do
|
RSpec.describe DiscourseAi::AiBot::Commands::Command do
|
||||||
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
let(:command) { DiscourseAi::AiBot::Commands::Command.new(bot_user, nil) }
|
let(:command) { DiscourseAi::AiBot::Commands::GoogleCommand.new(bot_user: bot_user, args: nil) }
|
||||||
|
|
||||||
describe "#format_results" do
|
describe "#format_results" do
|
||||||
it "can generate efficient tables of data" do
|
it "can generate efficient tables of data" do
|
||||||
|
|
|
@ -4,6 +4,26 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
|
||||||
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
fab!(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
|
|
||||||
describe "#process" do
|
describe "#process" do
|
||||||
|
it "will not explode if there are no results" do
|
||||||
|
post = Fabricate(:post)
|
||||||
|
|
||||||
|
SiteSetting.ai_google_custom_search_api_key = "abc"
|
||||||
|
SiteSetting.ai_google_custom_search_cx = "cx"
|
||||||
|
|
||||||
|
json_text = { searchInformation: { totalResults: "0" } }.to_json
|
||||||
|
|
||||||
|
stub_request(
|
||||||
|
:get,
|
||||||
|
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
|
||||||
|
).to_return(status: 200, body: json_text, headers: {})
|
||||||
|
|
||||||
|
google = described_class.new(bot_user: bot_user, post: post, args: {}.to_json)
|
||||||
|
info = google.process(query: "some search term").to_json
|
||||||
|
|
||||||
|
expect(google.description_args[:count]).to eq(0)
|
||||||
|
expect(info).to_not include("oops")
|
||||||
|
end
|
||||||
|
|
||||||
it "can generate correct info" do
|
it "can generate correct info" do
|
||||||
post = Fabricate(:post)
|
post = Fabricate(:post)
|
||||||
|
|
||||||
|
@ -31,7 +51,13 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
|
||||||
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
|
"https://www.googleapis.com/customsearch/v1?cx=cx&key=abc&num=10&q=some%20search%20term",
|
||||||
).to_return(status: 200, body: json_text, headers: {})
|
).to_return(status: 200, body: json_text, headers: {})
|
||||||
|
|
||||||
google = described_class.new(bot_user, post)
|
google =
|
||||||
|
described_class.new(
|
||||||
|
bot_user: bot_user,
|
||||||
|
post: post,
|
||||||
|
args: { query: "some search term" }.to_json,
|
||||||
|
)
|
||||||
|
|
||||||
info = google.process(query: "some search term").to_json
|
info = google.process(query: "some search term").to_json
|
||||||
|
|
||||||
expect(google.description_args[:count]).to eq(1)
|
expect(google.description_args[:count]).to eq(1)
|
||||||
|
@ -39,6 +65,12 @@ RSpec.describe DiscourseAi::AiBot::Commands::GoogleCommand do
|
||||||
expect(info).to include("snippet1")
|
expect(info).to include("snippet1")
|
||||||
expect(info).to include("some+search+term")
|
expect(info).to include("some+search+term")
|
||||||
expect(info).to_not include("oops")
|
expect(info).to_not include("oops")
|
||||||
|
|
||||||
|
google.invoke!
|
||||||
|
|
||||||
|
expect(post.reload.raw).to include("some search term")
|
||||||
|
|
||||||
|
expect { google.invoke! }.to raise_error(StandardError)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -24,7 +24,8 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
||||||
end
|
end
|
||||||
.to_return(status: 200, body: { artifacts: [{ base64: image }, { base64: image }] }.to_json)
|
.to_return(status: 200, body: { artifacts: [{ base64: image }, { base64: image }] }.to_json)
|
||||||
|
|
||||||
image = described_class.new(bot_user, post)
|
image = described_class.new(bot_user: bot_user, post: post, args: nil)
|
||||||
|
|
||||||
info = image.process(prompt: "a pink cow").to_json
|
info = image.process(prompt: "a pink cow").to_json
|
||||||
|
|
||||||
expect(JSON.parse(info)).to eq("prompt" => "a pink cow", "displayed_to_user" => true)
|
expect(JSON.parse(info)).to eq("prompt" => "a pink cow", "displayed_to_user" => true)
|
||||||
|
|
|
@ -8,7 +8,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::ReadCommand do
|
||||||
post1 = Fabricate(:post, raw: "hello there")
|
post1 = Fabricate(:post, raw: "hello there")
|
||||||
Fabricate(:post, raw: "mister sam", topic: post1.topic)
|
Fabricate(:post, raw: "mister sam", topic: post1.topic)
|
||||||
|
|
||||||
read = described_class.new(bot_user, post1)
|
read = described_class.new(bot_user: bot_user, args: nil, post: post1)
|
||||||
|
|
||||||
results = read.process(topic_id: post1.topic_id)
|
results = read.process(topic_id: post1.topic_id)
|
||||||
|
|
||||||
|
|
|
@ -11,7 +11,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
|
||||||
describe "#process" do
|
describe "#process" do
|
||||||
it "can handle no results" do
|
it "can handle no results" do
|
||||||
post1 = Fabricate(:post)
|
post1 = Fabricate(:post)
|
||||||
search = described_class.new(bot_user, post1)
|
search = described_class.new(bot_user: bot_user, post: post1, args: nil)
|
||||||
|
|
||||||
results = search.process(query: "order:fake ABDDCDCEDGDG")
|
results = search.process(query: "order:fake ABDDCDCEDGDG")
|
||||||
|
|
||||||
|
@ -24,7 +24,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
|
||||||
|
|
||||||
post1 = Fabricate(:post)
|
post1 = Fabricate(:post)
|
||||||
|
|
||||||
search = described_class.new(bot_user, post1)
|
search = described_class.new(bot_user: bot_user, post: post1, args: nil)
|
||||||
|
|
||||||
results = search.process(limit: 1, user: post1.user.username)
|
results = search.process(limit: 1, user: post1.user.username)
|
||||||
expect(results[:rows].to_s).to include("/subfolder" + post1.url)
|
expect(results[:rows].to_s).to include("/subfolder" + post1.url)
|
||||||
|
@ -36,7 +36,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SearchCommand do
|
||||||
_post3 = Fabricate(:post, user: post1.user)
|
_post3 = Fabricate(:post, user: post1.user)
|
||||||
|
|
||||||
# search has no built in support for limit: so handle it from the outside
|
# search has no built in support for limit: so handle it from the outside
|
||||||
search = described_class.new(bot_user, post1)
|
search = described_class.new(bot_user: bot_user, post: post1, args: nil)
|
||||||
|
|
||||||
results = search.process(limit: 2, user: post1.user.username)
|
results = search.process(limit: 2, user: post1.user.username)
|
||||||
|
|
||||||
|
|
|
@ -14,7 +14,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SummarizeCommand do
|
||||||
body: JSON.dump({ choices: [{ message: { content: "summary stuff" } }] }),
|
body: JSON.dump({ choices: [{ message: { content: "summary stuff" } }] }),
|
||||||
)
|
)
|
||||||
|
|
||||||
summarizer = described_class.new(bot_user, post)
|
summarizer = described_class.new(bot_user: bot_user, args: nil, post: post)
|
||||||
info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?")
|
info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?")
|
||||||
|
|
||||||
expect(info).to include("Topic summarized")
|
expect(info).to include("Topic summarized")
|
||||||
|
@ -30,7 +30,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::SummarizeCommand do
|
||||||
topic = Fabricate(:topic, category_id: category.id)
|
topic = Fabricate(:topic, category_id: category.id)
|
||||||
post = Fabricate(:post, topic: topic)
|
post = Fabricate(:post, topic: topic)
|
||||||
|
|
||||||
summarizer = described_class.new(bot_user, post)
|
summarizer = described_class.new(bot_user: bot_user, post: post, args: nil)
|
||||||
info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?")
|
info = summarizer.process(topic_id: post.topic_id, guidance: "why did it happen?")
|
||||||
|
|
||||||
expect(info).not_to include(post.raw)
|
expect(info).not_to include(post.raw)
|
||||||
|
|
|
@ -10,7 +10,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::TagsCommand do
|
||||||
Fabricate(:tag, name: "america", public_topic_count: 100)
|
Fabricate(:tag, name: "america", public_topic_count: 100)
|
||||||
Fabricate(:tag, name: "not_here", public_topic_count: 0)
|
Fabricate(:tag, name: "not_here", public_topic_count: 0)
|
||||||
|
|
||||||
info = DiscourseAi::AiBot::Commands::TagsCommand.new(nil, nil).process
|
info = DiscourseAi::AiBot::Commands::TagsCommand.new(bot_user: nil, args: nil).process
|
||||||
|
|
||||||
expect(info.to_s).to include("america")
|
expect(info.to_s).to include("america")
|
||||||
expect(info.to_s).not_to include("not_here")
|
expect(info.to_s).not_to include("not_here")
|
||||||
|
|
|
@ -8,7 +8,7 @@ RSpec.describe DiscourseAi::AiBot::Commands::TimeCommand do
|
||||||
freeze_time
|
freeze_time
|
||||||
|
|
||||||
args = { timezone: "America/Los_Angeles" }
|
args = { timezone: "America/Los_Angeles" }
|
||||||
info = DiscourseAi::AiBot::Commands::TimeCommand.new(nil, nil).process(**args)
|
info = DiscourseAi::AiBot::Commands::TimeCommand.new(bot_user: nil, args: nil).process(**args)
|
||||||
|
|
||||||
expect(info).to eq({ args: args, time: Time.now.in_time_zone("America/Los_Angeles").to_s })
|
expect(info).to eq({ args: args, time: Time.now.in_time_zone("America/Los_Angeles").to_s })
|
||||||
expect(info.to_s).not_to include("not_here")
|
expect(info.to_s).not_to include("not_here")
|
||||||
|
|
Loading…
Reference in New Issue