FEATURE: Make artist more creative (#266)
This allows for 2 big features: 1. Artist can ship up to 4 prompts for image generation 2. Artist can regenerate images cause it is aware of seed This allows for iteration on images maintaining visual style
This commit is contained in:
parent
818b20fb6f
commit
6add06af8f
|
@ -4,13 +4,14 @@ module DiscourseAi
|
||||||
module AiBot
|
module AiBot
|
||||||
module Commands
|
module Commands
|
||||||
class Parameter
|
class Parameter
|
||||||
attr_reader :name, :description, :type, :enum, :required
|
attr_reader :item_type, :name, :description, :type, :enum, :required
|
||||||
def initialize(name:, description:, type:, enum: nil, required: false)
|
def initialize(name:, description:, type:, enum: nil, required: false, item_type: nil)
|
||||||
@name = name
|
@name = name
|
||||||
@description = description
|
@description = description
|
||||||
@type = type
|
@type = type
|
||||||
@enum = enum
|
@enum = enum
|
||||||
@required = required
|
@required = required
|
||||||
|
@item_type = item_type
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
|
|
|
@ -14,12 +14,20 @@ module DiscourseAi::AiBot::Commands
|
||||||
def parameters
|
def parameters
|
||||||
[
|
[
|
||||||
Parameter.new(
|
Parameter.new(
|
||||||
name: "prompt",
|
name: "prompts",
|
||||||
description:
|
description:
|
||||||
"The prompt used to generate or create or draw the image (40 words or less, be creative)",
|
"The prompts used to generate or create or draw the image (40 words or less, be creative) up to 4 prompts",
|
||||||
type: "string",
|
type: "array",
|
||||||
|
item_type: "string",
|
||||||
required: true,
|
required: true,
|
||||||
),
|
),
|
||||||
|
Parameter.new(
|
||||||
|
name: "seeds",
|
||||||
|
description:
|
||||||
|
"The seed used to generate the image (optional) - can be used to retain image style on amended prompts",
|
||||||
|
type: "array",
|
||||||
|
item_type: "integer",
|
||||||
|
),
|
||||||
]
|
]
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
@ -40,8 +48,12 @@ module DiscourseAi::AiBot::Commands
|
||||||
@custom_raw
|
@custom_raw
|
||||||
end
|
end
|
||||||
|
|
||||||
def process(prompt:)
|
def process(prompts:, seeds: nil)
|
||||||
@last_prompt = prompt
|
# max 4 prompts
|
||||||
|
prompts = prompts[0..3]
|
||||||
|
seeds = seeds[0..3] if seeds
|
||||||
|
|
||||||
|
@last_prompt = prompts[0]
|
||||||
|
|
||||||
show_progress(localized_description)
|
show_progress(localized_description)
|
||||||
|
|
||||||
|
@ -53,41 +65,55 @@ module DiscourseAi::AiBot::Commands
|
||||||
engine = SiteSetting.ai_stability_engine
|
engine = SiteSetting.ai_stability_engine
|
||||||
api_url = SiteSetting.ai_stability_api_url
|
api_url = SiteSetting.ai_stability_api_url
|
||||||
|
|
||||||
# API is flaky, so try a few times
|
threads = []
|
||||||
3.times do
|
prompts.each_with_index do |prompt, index|
|
||||||
|
seed = seeds ? seeds[index] : nil
|
||||||
|
threads << Thread.new(seed, prompt) do |inner_seed, inner_prompt|
|
||||||
|
attempts = 0
|
||||||
begin
|
begin
|
||||||
thread =
|
|
||||||
Thread.new do
|
|
||||||
begin
|
|
||||||
results =
|
|
||||||
DiscourseAi::Inference::StabilityGenerator.perform!(
|
DiscourseAi::Inference::StabilityGenerator.perform!(
|
||||||
prompt,
|
inner_prompt,
|
||||||
engine: engine,
|
engine: engine,
|
||||||
api_key: api_key,
|
api_key: api_key,
|
||||||
api_url: api_url,
|
api_url: api_url,
|
||||||
|
image_count: 1,
|
||||||
|
seed: inner_seed,
|
||||||
)
|
)
|
||||||
rescue => e
|
rescue => e
|
||||||
|
attempts += 1
|
||||||
|
retry if attempts < 3
|
||||||
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
|
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
|
||||||
|
nil
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
show_progress(".", progress_caret: true) while !thread.join(2)
|
while true
|
||||||
|
show_progress(".", progress_caret: true)
|
||||||
break if results
|
break if threads.all? { |t| t.join(2) }
|
||||||
end
|
|
||||||
end
|
end
|
||||||
|
|
||||||
return { prompt: prompt, error: "Something went wrong, could not generate image" } if !results
|
results = threads.map(&:value).compact
|
||||||
|
|
||||||
|
if !results.present?
|
||||||
|
return { prompt: prompt, error: "Something went wrong, could not generate image" }
|
||||||
|
end
|
||||||
|
|
||||||
uploads = []
|
uploads = []
|
||||||
|
|
||||||
results[:artifacts].each_with_index do |image, i|
|
results.each_with_index do |result, index|
|
||||||
f = Tempfile.new("v1_txt2img_#{i}.png")
|
result[:artifacts].each do |image|
|
||||||
f.binmode
|
Tempfile.create("v1_txt2img_#{index}.png") do |file|
|
||||||
f.write(Base64.decode64(image[:base64]))
|
file.binmode
|
||||||
f.rewind
|
file.write(Base64.decode64(image[:base64]))
|
||||||
uploads << UploadCreator.new(f, "image.png").create_for(bot_user.id)
|
file.rewind
|
||||||
f.unlink
|
uploads << {
|
||||||
|
prompt: prompts[index],
|
||||||
|
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
|
||||||
|
seed: image[:seed],
|
||||||
|
}
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
@custom_raw = <<~RAW
|
@custom_raw = <<~RAW
|
||||||
|
@ -95,13 +121,18 @@ module DiscourseAi::AiBot::Commands
|
||||||
[grid]
|
[grid]
|
||||||
#{
|
#{
|
||||||
uploads
|
uploads
|
||||||
.map { |upload| "![#{prompt.gsub(/\|\'\"/, "")}|512x512, 50%](#{upload.short_url})" }
|
.map do |item|
|
||||||
|
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
|
||||||
|
end
|
||||||
.join(" ")
|
.join(" ")
|
||||||
}
|
}
|
||||||
[/grid]
|
[/grid]
|
||||||
RAW
|
RAW
|
||||||
|
|
||||||
{ prompt: prompt, displayed_to_user: true }
|
{
|
||||||
|
prompts: uploads.map { |item| { prompt: item[:prompt], seed: item[:seed] } },
|
||||||
|
displayed_to_user: true,
|
||||||
|
}
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
|
@ -23,6 +23,10 @@ module DiscourseAi
|
||||||
- Do not include any connector words such as "and" or "but" etc.
|
- Do not include any connector words such as "and" or "but" etc.
|
||||||
- You are extremely creative, when given short non descriptive prompts from a user you add your own details
|
- You are extremely creative, when given short non descriptive prompts from a user you add your own details
|
||||||
|
|
||||||
|
- When generating images, usually opt to generate 4 images unless the user specifies otherwise.
|
||||||
|
- Be creative with your prompts, offer diverse options
|
||||||
|
- You can use the seeds to regenerate the same image and amend the prompt keeping general style
|
||||||
|
|
||||||
{commands}
|
{commands}
|
||||||
|
|
||||||
PROMPT
|
PROMPT
|
||||||
|
|
|
@ -20,20 +20,26 @@ module ::DiscourseAi
|
||||||
description: parameter.description,
|
description: parameter.description,
|
||||||
required: parameter.required,
|
required: parameter.required,
|
||||||
enum: parameter.enum,
|
enum: parameter.enum,
|
||||||
|
item_type: parameter.item_type,
|
||||||
)
|
)
|
||||||
else
|
else
|
||||||
add_parameter_kwargs(**kwargs)
|
add_parameter_kwargs(**kwargs)
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
def add_parameter_kwargs(name:, type:, description:, enum: nil, required: false)
|
def add_parameter_kwargs(
|
||||||
@parameters << {
|
name:,
|
||||||
name: name,
|
type:,
|
||||||
type: type,
|
description:,
|
||||||
description: description,
|
enum: nil,
|
||||||
enum: enum,
|
required: false,
|
||||||
required: required,
|
item_type: nil
|
||||||
}
|
)
|
||||||
|
param = { name: name, type: type, description: description, enum: enum, required: required }
|
||||||
|
param[:enum] = enum if enum
|
||||||
|
param[:item_type] = item_type if item_type
|
||||||
|
|
||||||
|
@parameters << param
|
||||||
end
|
end
|
||||||
|
|
||||||
def to_json(*args)
|
def to_json(*args)
|
||||||
|
@ -47,7 +53,7 @@ module ::DiscourseAi
|
||||||
parameters.each do |parameter|
|
parameters.each do |parameter|
|
||||||
definition = { type: parameter[:type], description: parameter[:description] }
|
definition = { type: parameter[:type], description: parameter[:description] }
|
||||||
definition[:enum] = parameter[:enum] if parameter[:enum]
|
definition[:enum] = parameter[:enum] if parameter[:enum]
|
||||||
|
definition[:items] = { type: parameter[:item_type] } if parameter[:item_type]
|
||||||
required_params << parameter[:name] if parameter[:required]
|
required_params << parameter[:name] if parameter[:required]
|
||||||
properties[parameter[:name]] = definition
|
properties[parameter[:name]] = definition
|
||||||
end
|
end
|
||||||
|
|
|
@ -3,7 +3,16 @@
|
||||||
module ::DiscourseAi
|
module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class StabilityGenerator
|
class StabilityGenerator
|
||||||
def self.perform!(prompt, width: nil, height: nil, api_key: nil, engine: nil, api_url: nil)
|
def self.perform!(
|
||||||
|
prompt,
|
||||||
|
width: nil,
|
||||||
|
height: nil,
|
||||||
|
api_key: nil,
|
||||||
|
engine: nil,
|
||||||
|
api_url: nil,
|
||||||
|
image_count: 4,
|
||||||
|
seed: nil
|
||||||
|
)
|
||||||
api_key ||= SiteSetting.ai_stability_api_key
|
api_key ||= SiteSetting.ai_stability_api_key
|
||||||
engine ||= SiteSetting.ai_stability_engine
|
engine ||= SiteSetting.ai_stability_engine
|
||||||
api_url ||= SiteSetting.ai_stability_api_url
|
api_url ||= SiteSetting.ai_stability_api_url
|
||||||
|
@ -40,10 +49,12 @@ module ::DiscourseAi
|
||||||
clip_guidance_preset: "FAST_BLUE",
|
clip_guidance_preset: "FAST_BLUE",
|
||||||
height: width,
|
height: width,
|
||||||
width: height,
|
width: height,
|
||||||
samples: 4,
|
samples: image_count,
|
||||||
steps: 30,
|
steps: 30,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
payload[:seed] = seed if seed
|
||||||
|
|
||||||
endpoint = "v1/generation/#{engine}/text-to-image"
|
endpoint = "v1/generation/#{engine}/text-to-image"
|
||||||
|
|
||||||
response = Faraday.post("#{api_url}/#{endpoint}", payload.to_json, headers)
|
response = Faraday.post("#{api_url}/#{endpoint}", payload.to_json, headers)
|
||||||
|
|
|
@ -1,7 +1,5 @@
|
||||||
#frozen_string_literal: true
|
#frozen_string_literal: true
|
||||||
|
|
||||||
require_relative "../../../../support/stable_difussion_stubs"
|
|
||||||
|
|
||||||
RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
||||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||||
|
|
||||||
|
@ -17,16 +15,36 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
||||||
image =
|
image =
|
||||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||||
|
|
||||||
StableDiffusionStubs.new.stub_response("a pink cow", [image, image])
|
artifacts = [{ base64: image, seed: 99 }]
|
||||||
|
prompts = ["a pink cow", "a red cow"]
|
||||||
|
|
||||||
|
WebMock
|
||||||
|
.stub_request(
|
||||||
|
:post,
|
||||||
|
"https://api.stability.dev/v1/generation/#{SiteSetting.ai_stability_engine}/text-to-image",
|
||||||
|
)
|
||||||
|
.with do |request|
|
||||||
|
json = JSON.parse(request.body, symbolize_names: true)
|
||||||
|
expect(prompts).to include(json[:text_prompts][0][:text])
|
||||||
|
true
|
||||||
|
end
|
||||||
|
.to_return(status: 200, body: { artifacts: artifacts }.to_json)
|
||||||
|
|
||||||
image = described_class.new(bot_user: bot_user, post: post, args: nil)
|
image = described_class.new(bot_user: bot_user, post: post, args: nil)
|
||||||
|
|
||||||
info = image.process(prompt: "a pink cow").to_json
|
info = image.process(prompts: prompts).to_json
|
||||||
|
|
||||||
expect(JSON.parse(info)).to eq("prompt" => "a pink cow", "displayed_to_user" => true)
|
expect(JSON.parse(info)).to eq(
|
||||||
|
"prompts" => [
|
||||||
|
{ "prompt" => "a pink cow", "seed" => 99 },
|
||||||
|
{ "prompt" => "a red cow", "seed" => 99 },
|
||||||
|
],
|
||||||
|
"displayed_to_user" => true,
|
||||||
|
)
|
||||||
expect(image.custom_raw).to include("upload://")
|
expect(image.custom_raw).to include("upload://")
|
||||||
expect(image.custom_raw).to include("[grid]")
|
expect(image.custom_raw).to include("[grid]")
|
||||||
expect(image.custom_raw).to include("a pink cow")
|
expect(image.custom_raw).to include("a pink cow")
|
||||||
|
expect(image.custom_raw).to include("a red cow")
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
Loading…
Reference in New Issue