FEATURE: Make artist more creative (#266)

This allows for 2 big features:

1. Artist can ship up to 4 prompts for image generation
2. Artist can regenerate images cause it is aware of seed

This allows for iteration on images maintaining visual style
This commit is contained in:
Sam 2023-10-27 14:48:12 +11:00 committed by GitHub
parent 818b20fb6f
commit 6add06af8f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
6 changed files with 125 additions and 54 deletions

View File

@ -4,13 +4,14 @@ module DiscourseAi
module AiBot
module Commands
class Parameter
attr_reader :name, :description, :type, :enum, :required
def initialize(name:, description:, type:, enum: nil, required: false)
attr_reader :item_type, :name, :description, :type, :enum, :required
def initialize(name:, description:, type:, enum: nil, required: false, item_type: nil)
@name = name
@description = description
@type = type
@enum = enum
@required = required
@item_type = item_type
end
end

View File

@ -14,12 +14,20 @@ module DiscourseAi::AiBot::Commands
def parameters
[
Parameter.new(
name: "prompt",
name: "prompts",
description:
"The prompt used to generate or create or draw the image (40 words or less, be creative)",
type: "string",
"The prompts used to generate or create or draw the image (40 words or less, be creative) up to 4 prompts",
type: "array",
item_type: "string",
required: true,
),
Parameter.new(
name: "seeds",
description:
"The seed used to generate the image (optional) - can be used to retain image style on amended prompts",
type: "array",
item_type: "integer",
),
]
end
end
@ -40,8 +48,12 @@ module DiscourseAi::AiBot::Commands
@custom_raw
end
def process(prompt:)
@last_prompt = prompt
def process(prompts:, seeds: nil)
# max 4 prompts
prompts = prompts[0..3]
seeds = seeds[0..3] if seeds
@last_prompt = prompts[0]
show_progress(localized_description)
@ -53,41 +65,55 @@ module DiscourseAi::AiBot::Commands
engine = SiteSetting.ai_stability_engine
api_url = SiteSetting.ai_stability_api_url
# API is flaky, so try a few times
3.times do
threads = []
prompts.each_with_index do |prompt, index|
seed = seeds ? seeds[index] : nil
threads << Thread.new(seed, prompt) do |inner_seed, inner_prompt|
attempts = 0
begin
thread =
Thread.new do
begin
results =
DiscourseAi::Inference::StabilityGenerator.perform!(
prompt,
inner_prompt,
engine: engine,
api_key: api_key,
api_url: api_url,
image_count: 1,
seed: inner_seed,
)
rescue => e
attempts += 1
retry if attempts < 3
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
nil
end
end
end
show_progress(".", progress_caret: true) while !thread.join(2)
break if results
end
while true
show_progress(".", progress_caret: true)
break if threads.all? { |t| t.join(2) }
end
return { prompt: prompt, error: "Something went wrong, could not generate image" } if !results
results = threads.map(&:value).compact
if !results.present?
return { prompt: prompt, error: "Something went wrong, could not generate image" }
end
uploads = []
results[:artifacts].each_with_index do |image, i|
f = Tempfile.new("v1_txt2img_#{i}.png")
f.binmode
f.write(Base64.decode64(image[:base64]))
f.rewind
uploads << UploadCreator.new(f, "image.png").create_for(bot_user.id)
f.unlink
results.each_with_index do |result, index|
result[:artifacts].each do |image|
Tempfile.create("v1_txt2img_#{index}.png") do |file|
file.binmode
file.write(Base64.decode64(image[:base64]))
file.rewind
uploads << {
prompt: prompts[index],
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
seed: image[:seed],
}
end
end
end
@custom_raw = <<~RAW
@ -95,13 +121,18 @@ module DiscourseAi::AiBot::Commands
[grid]
#{
uploads
.map { |upload| "![#{prompt.gsub(/\|\'\"/, "")}|512x512, 50%](#{upload.short_url})" }
.map do |item|
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
end
.join(" ")
}
[/grid]
RAW
{ prompt: prompt, displayed_to_user: true }
{
prompts: uploads.map { |item| { prompt: item[:prompt], seed: item[:seed] } },
displayed_to_user: true,
}
end
end
end

View File

@ -23,6 +23,10 @@ module DiscourseAi
- Do not include any connector words such as "and" or "but" etc.
- You are extremely creative, when given short non descriptive prompts from a user you add your own details
- When generating images, usually opt to generate 4 images unless the user specifies otherwise.
- Be creative with your prompts, offer diverse options
- You can use the seeds to regenerate the same image and amend the prompt keeping general style
{commands}
PROMPT

View File

@ -20,20 +20,26 @@ module ::DiscourseAi
description: parameter.description,
required: parameter.required,
enum: parameter.enum,
item_type: parameter.item_type,
)
else
add_parameter_kwargs(**kwargs)
end
end
def add_parameter_kwargs(name:, type:, description:, enum: nil, required: false)
@parameters << {
name: name,
type: type,
description: description,
enum: enum,
required: required,
}
def add_parameter_kwargs(
name:,
type:,
description:,
enum: nil,
required: false,
item_type: nil
)
param = { name: name, type: type, description: description, enum: enum, required: required }
param[:enum] = enum if enum
param[:item_type] = item_type if item_type
@parameters << param
end
def to_json(*args)
@ -47,7 +53,7 @@ module ::DiscourseAi
parameters.each do |parameter|
definition = { type: parameter[:type], description: parameter[:description] }
definition[:enum] = parameter[:enum] if parameter[:enum]
definition[:items] = { type: parameter[:item_type] } if parameter[:item_type]
required_params << parameter[:name] if parameter[:required]
properties[parameter[:name]] = definition
end

View File

@ -3,7 +3,16 @@
module ::DiscourseAi
module Inference
class StabilityGenerator
def self.perform!(prompt, width: nil, height: nil, api_key: nil, engine: nil, api_url: nil)
def self.perform!(
prompt,
width: nil,
height: nil,
api_key: nil,
engine: nil,
api_url: nil,
image_count: 4,
seed: nil
)
api_key ||= SiteSetting.ai_stability_api_key
engine ||= SiteSetting.ai_stability_engine
api_url ||= SiteSetting.ai_stability_api_url
@ -40,10 +49,12 @@ module ::DiscourseAi
clip_guidance_preset: "FAST_BLUE",
height: width,
width: height,
samples: 4,
samples: image_count,
steps: 30,
}
payload[:seed] = seed if seed
endpoint = "v1/generation/#{engine}/text-to-image"
response = Faraday.post("#{api_url}/#{endpoint}", payload.to_json, headers)

View File

@ -1,7 +1,5 @@
#frozen_string_literal: true
require_relative "../../../../support/stable_difussion_stubs"
RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
@ -17,16 +15,36 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
image =
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
StableDiffusionStubs.new.stub_response("a pink cow", [image, image])
artifacts = [{ base64: image, seed: 99 }]
prompts = ["a pink cow", "a red cow"]
WebMock
.stub_request(
:post,
"https://api.stability.dev/v1/generation/#{SiteSetting.ai_stability_engine}/text-to-image",
)
.with do |request|
json = JSON.parse(request.body, symbolize_names: true)
expect(prompts).to include(json[:text_prompts][0][:text])
true
end
.to_return(status: 200, body: { artifacts: artifacts }.to_json)
image = described_class.new(bot_user: bot_user, post: post, args: nil)
info = image.process(prompt: "a pink cow").to_json
info = image.process(prompts: prompts).to_json
expect(JSON.parse(info)).to eq("prompt" => "a pink cow", "displayed_to_user" => true)
expect(JSON.parse(info)).to eq(
"prompts" => [
{ "prompt" => "a pink cow", "seed" => 99 },
{ "prompt" => "a red cow", "seed" => 99 },
],
"displayed_to_user" => true,
)
expect(image.custom_raw).to include("upload://")
expect(image.custom_raw).to include("[grid]")
expect(image.custom_raw).to include("a pink cow")
expect(image.custom_raw).to include("a red cow")
end
end
end