FEATURE: Make artist more creative (#266)
This allows for 2 big features: 1. Artist can ship up to 4 prompts for image generation 2. Artist can regenerate images cause it is aware of seed This allows for iteration on images maintaining visual style
This commit is contained in:
parent
818b20fb6f
commit
6add06af8f
|
@ -4,13 +4,14 @@ module DiscourseAi
|
|||
module AiBot
|
||||
module Commands
|
||||
class Parameter
|
||||
attr_reader :name, :description, :type, :enum, :required
|
||||
def initialize(name:, description:, type:, enum: nil, required: false)
|
||||
attr_reader :item_type, :name, :description, :type, :enum, :required
|
||||
def initialize(name:, description:, type:, enum: nil, required: false, item_type: nil)
|
||||
@name = name
|
||||
@description = description
|
||||
@type = type
|
||||
@enum = enum
|
||||
@required = required
|
||||
@item_type = item_type
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -14,12 +14,20 @@ module DiscourseAi::AiBot::Commands
|
|||
def parameters
|
||||
[
|
||||
Parameter.new(
|
||||
name: "prompt",
|
||||
name: "prompts",
|
||||
description:
|
||||
"The prompt used to generate or create or draw the image (40 words or less, be creative)",
|
||||
type: "string",
|
||||
"The prompts used to generate or create or draw the image (40 words or less, be creative) up to 4 prompts",
|
||||
type: "array",
|
||||
item_type: "string",
|
||||
required: true,
|
||||
),
|
||||
Parameter.new(
|
||||
name: "seeds",
|
||||
description:
|
||||
"The seed used to generate the image (optional) - can be used to retain image style on amended prompts",
|
||||
type: "array",
|
||||
item_type: "integer",
|
||||
),
|
||||
]
|
||||
end
|
||||
end
|
||||
|
@ -40,8 +48,12 @@ module DiscourseAi::AiBot::Commands
|
|||
@custom_raw
|
||||
end
|
||||
|
||||
def process(prompt:)
|
||||
@last_prompt = prompt
|
||||
def process(prompts:, seeds: nil)
|
||||
# max 4 prompts
|
||||
prompts = prompts[0..3]
|
||||
seeds = seeds[0..3] if seeds
|
||||
|
||||
@last_prompt = prompts[0]
|
||||
|
||||
show_progress(localized_description)
|
||||
|
||||
|
@ -53,41 +65,55 @@ module DiscourseAi::AiBot::Commands
|
|||
engine = SiteSetting.ai_stability_engine
|
||||
api_url = SiteSetting.ai_stability_api_url
|
||||
|
||||
# API is flaky, so try a few times
|
||||
3.times do
|
||||
threads = []
|
||||
prompts.each_with_index do |prompt, index|
|
||||
seed = seeds ? seeds[index] : nil
|
||||
threads << Thread.new(seed, prompt) do |inner_seed, inner_prompt|
|
||||
attempts = 0
|
||||
begin
|
||||
thread =
|
||||
Thread.new do
|
||||
begin
|
||||
results =
|
||||
DiscourseAi::Inference::StabilityGenerator.perform!(
|
||||
prompt,
|
||||
inner_prompt,
|
||||
engine: engine,
|
||||
api_key: api_key,
|
||||
api_url: api_url,
|
||||
image_count: 1,
|
||||
seed: inner_seed,
|
||||
)
|
||||
rescue => e
|
||||
attempts += 1
|
||||
retry if attempts < 3
|
||||
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
|
||||
nil
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
show_progress(".", progress_caret: true) while !thread.join(2)
|
||||
|
||||
break if results
|
||||
end
|
||||
while true
|
||||
show_progress(".", progress_caret: true)
|
||||
break if threads.all? { |t| t.join(2) }
|
||||
end
|
||||
|
||||
return { prompt: prompt, error: "Something went wrong, could not generate image" } if !results
|
||||
results = threads.map(&:value).compact
|
||||
|
||||
if !results.present?
|
||||
return { prompt: prompt, error: "Something went wrong, could not generate image" }
|
||||
end
|
||||
|
||||
uploads = []
|
||||
|
||||
results[:artifacts].each_with_index do |image, i|
|
||||
f = Tempfile.new("v1_txt2img_#{i}.png")
|
||||
f.binmode
|
||||
f.write(Base64.decode64(image[:base64]))
|
||||
f.rewind
|
||||
uploads << UploadCreator.new(f, "image.png").create_for(bot_user.id)
|
||||
f.unlink
|
||||
results.each_with_index do |result, index|
|
||||
result[:artifacts].each do |image|
|
||||
Tempfile.create("v1_txt2img_#{index}.png") do |file|
|
||||
file.binmode
|
||||
file.write(Base64.decode64(image[:base64]))
|
||||
file.rewind
|
||||
uploads << {
|
||||
prompt: prompts[index],
|
||||
upload: UploadCreator.new(file, "image.png").create_for(bot_user.id),
|
||||
seed: image[:seed],
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
@custom_raw = <<~RAW
|
||||
|
@ -95,13 +121,18 @@ module DiscourseAi::AiBot::Commands
|
|||
[grid]
|
||||
#{
|
||||
uploads
|
||||
.map { |upload| "![#{prompt.gsub(/\|\'\"/, "")}|512x512, 50%](#{upload.short_url})" }
|
||||
.map do |item|
|
||||
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
|
||||
end
|
||||
.join(" ")
|
||||
}
|
||||
[/grid]
|
||||
RAW
|
||||
|
||||
{ prompt: prompt, displayed_to_user: true }
|
||||
{
|
||||
prompts: uploads.map { |item| { prompt: item[:prompt], seed: item[:seed] } },
|
||||
displayed_to_user: true,
|
||||
}
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
|
@ -23,6 +23,10 @@ module DiscourseAi
|
|||
- Do not include any connector words such as "and" or "but" etc.
|
||||
- You are extremely creative, when given short non descriptive prompts from a user you add your own details
|
||||
|
||||
- When generating images, usually opt to generate 4 images unless the user specifies otherwise.
|
||||
- Be creative with your prompts, offer diverse options
|
||||
- You can use the seeds to regenerate the same image and amend the prompt keeping general style
|
||||
|
||||
{commands}
|
||||
|
||||
PROMPT
|
||||
|
|
|
@ -20,20 +20,26 @@ module ::DiscourseAi
|
|||
description: parameter.description,
|
||||
required: parameter.required,
|
||||
enum: parameter.enum,
|
||||
item_type: parameter.item_type,
|
||||
)
|
||||
else
|
||||
add_parameter_kwargs(**kwargs)
|
||||
end
|
||||
end
|
||||
|
||||
def add_parameter_kwargs(name:, type:, description:, enum: nil, required: false)
|
||||
@parameters << {
|
||||
name: name,
|
||||
type: type,
|
||||
description: description,
|
||||
enum: enum,
|
||||
required: required,
|
||||
}
|
||||
def add_parameter_kwargs(
|
||||
name:,
|
||||
type:,
|
||||
description:,
|
||||
enum: nil,
|
||||
required: false,
|
||||
item_type: nil
|
||||
)
|
||||
param = { name: name, type: type, description: description, enum: enum, required: required }
|
||||
param[:enum] = enum if enum
|
||||
param[:item_type] = item_type if item_type
|
||||
|
||||
@parameters << param
|
||||
end
|
||||
|
||||
def to_json(*args)
|
||||
|
@ -47,7 +53,7 @@ module ::DiscourseAi
|
|||
parameters.each do |parameter|
|
||||
definition = { type: parameter[:type], description: parameter[:description] }
|
||||
definition[:enum] = parameter[:enum] if parameter[:enum]
|
||||
|
||||
definition[:items] = { type: parameter[:item_type] } if parameter[:item_type]
|
||||
required_params << parameter[:name] if parameter[:required]
|
||||
properties[parameter[:name]] = definition
|
||||
end
|
||||
|
|
|
@ -3,7 +3,16 @@
|
|||
module ::DiscourseAi
|
||||
module Inference
|
||||
class StabilityGenerator
|
||||
def self.perform!(prompt, width: nil, height: nil, api_key: nil, engine: nil, api_url: nil)
|
||||
def self.perform!(
|
||||
prompt,
|
||||
width: nil,
|
||||
height: nil,
|
||||
api_key: nil,
|
||||
engine: nil,
|
||||
api_url: nil,
|
||||
image_count: 4,
|
||||
seed: nil
|
||||
)
|
||||
api_key ||= SiteSetting.ai_stability_api_key
|
||||
engine ||= SiteSetting.ai_stability_engine
|
||||
api_url ||= SiteSetting.ai_stability_api_url
|
||||
|
@ -40,10 +49,12 @@ module ::DiscourseAi
|
|||
clip_guidance_preset: "FAST_BLUE",
|
||||
height: width,
|
||||
width: height,
|
||||
samples: 4,
|
||||
samples: image_count,
|
||||
steps: 30,
|
||||
}
|
||||
|
||||
payload[:seed] = seed if seed
|
||||
|
||||
endpoint = "v1/generation/#{engine}/text-to-image"
|
||||
|
||||
response = Faraday.post("#{api_url}/#{endpoint}", payload.to_json, headers)
|
||||
|
|
|
@ -1,7 +1,5 @@
|
|||
#frozen_string_literal: true
|
||||
|
||||
require_relative "../../../../support/stable_difussion_stubs"
|
||||
|
||||
RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
||||
let(:bot_user) { User.find(DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID) }
|
||||
|
||||
|
@ -17,16 +15,36 @@ RSpec.describe DiscourseAi::AiBot::Commands::ImageCommand do
|
|||
image =
|
||||
"iVBORw0KGgoAAAANSUhEUgAAAAEAAAABCAYAAAAfFcSJAAAADUlEQVR42mP8z8BQDwAEhQGAhKmMIQAAAABJRU5ErkJggg=="
|
||||
|
||||
StableDiffusionStubs.new.stub_response("a pink cow", [image, image])
|
||||
artifacts = [{ base64: image, seed: 99 }]
|
||||
prompts = ["a pink cow", "a red cow"]
|
||||
|
||||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
"https://api.stability.dev/v1/generation/#{SiteSetting.ai_stability_engine}/text-to-image",
|
||||
)
|
||||
.with do |request|
|
||||
json = JSON.parse(request.body, symbolize_names: true)
|
||||
expect(prompts).to include(json[:text_prompts][0][:text])
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: { artifacts: artifacts }.to_json)
|
||||
|
||||
image = described_class.new(bot_user: bot_user, post: post, args: nil)
|
||||
|
||||
info = image.process(prompt: "a pink cow").to_json
|
||||
info = image.process(prompts: prompts).to_json
|
||||
|
||||
expect(JSON.parse(info)).to eq("prompt" => "a pink cow", "displayed_to_user" => true)
|
||||
expect(JSON.parse(info)).to eq(
|
||||
"prompts" => [
|
||||
{ "prompt" => "a pink cow", "seed" => 99 },
|
||||
{ "prompt" => "a red cow", "seed" => 99 },
|
||||
],
|
||||
"displayed_to_user" => true,
|
||||
)
|
||||
expect(image.custom_raw).to include("upload://")
|
||||
expect(image.custom_raw).to include("[grid]")
|
||||
expect(image.custom_raw).to include("a pink cow")
|
||||
expect(image.custom_raw).to include("a red cow")
|
||||
end
|
||||
end
|
||||
end
|
||||
|
|
Loading…
Reference in New Issue