FEATURE: Make artist more creative (#266)

This allows for 2 big features:

1. Artist can ship up to 4 prompts for image generation
2. Artist can regenerate images cause it is aware of seed

This allows for iteration on images maintaining visual style
This commit is contained in:
Sam 2023-10-27 14:48:12 +11:00 committed by GitHub
parent 818b20fb6f
commit 6add06af8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 125 additions and 54 deletions

View File

@ -4,13 +4,14 @@ module DiscourseAi
module AiBot module AiBot
module Commands module Commands
class Parameter class Parameter
attr_reader :name, :description, :type, :enum, :required attr_reader :item_type, :name, :description, :type, :enum, :required
def initialize(name:, description:, type:, enum: nil, required: false) def initialize(name:, description:, type:, enum: nil, required: false, item_type: nil)
@name = name @name = name
@description = description @description = description
@type = type @type = type
@enum = enum @enum = enum
@required = required @required = required
@item_type = item_type
end end
end end

View File

@ -14,12 +14,20 @@ module DiscourseAi::AiBot::Commands
def parameters def parameters
[ [
Parameter.new( Parameter.new(
name: "prompt", name: "prompts",
description: description:
"The prompt used to generate or create or draw the image (40 words or less, be creative)", "The prompts used to generate or create or draw the image (40 words or less, be creative) up to 4 prompts",
type: "string", type: "array",
item_type: "string",
required: true, required: true,
), ),
Parameter.new(
name: "seeds",
description:
"The seed used to generate the image (optional) - can be used to retain image style on amended prompts",
type: "array",
item_type: "integer",
),
] ]
end end
end end
@ -40,8 +48,12 @@ module DiscourseAi::AiBot::Commands
@custom_raw @custom_raw
end end
def process(prompt:) def process(prompts:, seeds: nil)
@last_prompt = prompt # max 4 prompts
prompts = prompts[0..3]
seeds = seeds[0..3] if seeds
@last_prompt = prompts[0]
show_progress(localized_description) show_progress(localized_description)
@ -53,41 +65,55 @@ module DiscourseAi::AiBot::Commands
engine = SiteSetting.ai_stability_engine engine = SiteSetting.ai_stability_engine
api_url = SiteSetting.ai_stability_api_url api_url = SiteSetting.ai_stability_api_url
# API is flaky, so try a few times threads = []
3.times do prompts.each_with_index do |prompt, index|
seed = seeds ? seeds[index] : nil
threads << Thread.new(seed, prompt) do |inner_seed, inner_prompt|
attempts = 0
begin begin
thread =
Thread.new do
begin
results =
DiscourseAi::Inference::StabilityGenerator.perform!( DiscourseAi::Inference::StabilityGenerator.perform!(
prompt, inner_prompt,
engine: engine, engine: engine,
api_key: api_key, api_key: api_key,
api_url: api_url, api_url: api_url,
image_count: 1,
seed: inner_seed,
) )
rescue => e rescue => e
attempts += 1
retry if attempts < 3
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}") Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
nil
end
end end
end end
show_progress(".", progress_caret: true) while !thread.join(2) while true
show_progress(".", progress_caret: true)
break if results break if threads.all? { |t| t.join(2) }
end
end end
return { prompt: prompt, error: "Something went wrong, could not generate image" } if !results results = threads.map(&:value).compact
if !results.present?
return { prompt: prompt, error: "Something went wrong, could not generate image" }
end
uploads = [] uploads = []
results[:artifacts].each_with_index do |image, i| results.each_with_index do |result, index|
f = Tempfile.new("v1_txt2img_#{i}.png") result[:artifacts].each do |image|
f.binmode Tempfile.create("v1_txt2img_#{index}.png") do |file|
f.write(Base64.decode64(image[:base64])) file.binmode
f.rewind file.write(Base64.decode64(image[:base64]))
uploads << UploadCreator.new(f, "image.png").create_for(bot_user.id) file.rewind
f.unlink uploads << {
prompt: prompts[index],
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
seed: image[:seed],
}
end
end
end end
@custom_raw = <<~RAW @custom_raw = <<~RAW
@ -95,13 +121,18 @@ module DiscourseAi::AiBot::Commands
[grid] [grid]
#{ #{
uploads uploads
.map { |upload| "![#{prompt.gsub(/\|\'\"/, "")}|512x512, 50%](#{upload.short_url})" } .map do |item|
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
end
.join(" ") .join(" ")
} }
[/grid] [/grid]
RAW RAW
{ prompt: prompt, displayed_to_user: true } {
prompts: uploads.map { |item| { prompt: item[:prompt], seed: item[:seed] } },
displayed_to_user: true,
}
end end
end end
end end

View File

@ -23,6 +23,10 @@ module DiscourseAi
- Do not include any connector words such as "and" or "but" etc. - Do not include any connector words such as "and" or "but" etc.
- You are extremely creative, when given short non descriptive prompts from a user you add your own details - You are extremely creative, when given short non descriptive prompts from a user you add your own details
- When generating images, usually opt to generate 4 images unless the user specifies otherwise.
- Be creative with your prompts, offer diverse options
- You can use the seeds to regenerate the same image and amend the prompt keeping general style
{commands} {commands}
PROMPT PROMPT

View File

@ -20,20 +20,26 @@ module ::DiscourseAi
description: parameter.description, description: parameter.description,
required: parameter.required, required: parameter.required,
enum: parameter.enum, enum: parameter.enum,
item_type: parameter.item_type,
) )
else else
add_parameter_kwargs(**kwargs) add_parameter_kwargs(**kwargs)
end end
end end
def add_parameter_kwargs(name:, type:, description:, enum: nil, required: false) def add_parameter_kwargs(
@parameters << { name:,
name: name, type:,
type: type, description:,
description: description, enum: nil,
enum: enum, required: false,
required: required, item_type: nil
} )
param = { name: name, type: type, description: description, enum: enum, required: required }
param[:enum] = enum if enum
param[:item_type] = item_type if item_type
@parameters << param
end end
def to_json(*args) def to_json(*args)
@ -47,7 +53,7 @@ module ::DiscourseAi
parameters.each do |parameter| parameters.each do |parameter|
definition = { type: parameter[:type], description: parameter[:description] } definition = { type: parameter[:type], description: parameter[:description] }
definition[:enum] = parameter[:enum] if parameter[:enum] definition[:enum] = parameter[:enum] if parameter[:enum]
definition[:items] = { type: parameter[:item_type] } if parameter[:item_type]
required_params << parameter[:name] if parameter[:required] required_params << parameter[:name] if parameter[:required]
properties[parameter[:name]] = definition properties[parameter[:name]] = definition
end end

View File

@ -3,7 +3,16 @@
module ::DiscourseAi module ::DiscourseAi
module Inference module Inference
class StabilityGenerator class StabilityGenerator
def self.perform!(prompt, width: nil, height: nil, api_key: nil, engine: nil, api_url: nil) def self.perform!(
prompt,
width: nil,
height: nil,
api_key: nil,
engine: nil,
api_url: nil,
image_count: 4,
seed: nil
)
api_key ||= SiteSetting.ai_stability_api_key api_key ||= SiteSetting.ai_stability_api_key
engine ||= SiteSetting.ai_stability_engine engine ||= SiteSetting.ai_stability_engine
api_url ||= SiteSetting.ai_stability_api_url api_url ||= SiteSetting.ai_stability_api_url
@ -40,10 +49,12 @@ module ::DiscourseAi
clip_guidance_preset: "FAST_BLUE", clip_guidance_preset: "FAST_BLUE",
height: width, height: width,
width: height, width: height,
samples: 4, samples: image_count,
steps: 30, steps: 30,
} }
payload[:seed] = seed if seed
endpoint = "v1/generation/#{engine}/text-to-image" endpoint = "v1/generation/#{engine}/text-to-image"
response = Faraday.post("#{api_url}/#{endpoint}", payload.to_json, headers) response = Faraday.post("#{api_url}/#{endpoint}", payload.to_json, headers)

View File

@ -1,7 +1,5 @@
#frozen_string_literal: true #frozen_string_literal: true
require_relative "../../../../support/stable_difussion_stubs"
RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) } let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
@ -17,16 +15,36 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
image = image =
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg==" "iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
StableDiffusionStubs.new.stub_response("a pink cow", [image, image]) artifacts = [{ base64: image, seed: 99 }]
prompts = ["a pink cow", "a red cow"]
WebMock
.stub_request(
:post,
"https://api.stability.dev/v1/generation/#{SiteSetting.ai_stability_engine}/text-to-image",
)
.with do |request|
json = JSON.parse(request.body, symbolize_names: true)
expect(prompts).to include(json[:text_prompts][0][:text])
true
end
.to_return(status: 200, body: { artifacts: artifacts }.to_json)
image = described_class.new(bot_user: bot_user, post: post, args: nil) image = described_class.new(bot_user: bot_user, post: post, args: nil)
info = image.process(prompt: "a pink cow").to_json info = image.process(prompts: prompts).to_json
expect(JSON.parse(info)).to eq("prompt" => "a pink cow", "displayed_to_user" => true) expect(JSON.parse(info)).to eq(
"prompts" => [
{ "prompt" => "a pink cow", "seed" => 99 },
{ "prompt" => "a red cow", "seed" => 99 },
],
"displayed_to_user" => true,
)
expect(image.custom_raw).to include("upload://") expect(image.custom_raw).to include("upload://")
expect(image.custom_raw).to include("[grid]") expect(image.custom_raw).to include("[grid]")
expect(image.custom_raw).to include("a pink cow") expect(image.custom_raw).to include("a pink cow")
expect(image.custom_raw).to include("a red cow")
end end
end end
end end