FEATURE: Stable diffusion 3 support (#582)

- Adds support for sd3 and sd3 turbo models - this requires new endpoints
- Adds a hack to normalize arrays in the tool calls
- Removes some leftover code
- Adds support for aspect ratio as well so you can generate wide or tall images
This commit is contained in:
Sam 2024-04-19 18:08:16 +10:00 committed by GitHub
parent a223d18f1a
commit bd6f5caeac
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 152 additions and 28 deletions

View File

@ -122,6 +122,8 @@ discourse_ai:
default: "stable-diffusion-xl-1024-v1-0"
type: enum
choices:
- "sd3"
- "sd3-turbo"
- "stable-diffusion-xl-1024-v1-0"
- "stable-diffusion-768-v2-1"
- "stable-diffusion-v1-5"

View File

@ -203,10 +203,6 @@ module DiscourseAi
end
end
def tool_invocation?(partial)
Nokogiri::HTML5.fragment(partial).at("invoke").present?
end
def build_placeholder(summary, details, custom_raw: nil)
placeholder = +(<<~HTML)
<details>

View File

@ -188,7 +188,7 @@ module DiscourseAi
begin
JSON.parse(value)
rescue JSON::ParserError
nil
[value.to_s]
end
end

View File

@ -24,7 +24,13 @@ module DiscourseAi
"The seed used to generate the image (optional) - can be used to retain image style on amended prompts",
type: "array",
item_type: "integer",
required: true,
},
{
name: "aspect_ratio",
description: "The aspect ratio of the image (optional defaults to 1:1)",
type: "string",
required: false,
enum: %w[16:9 1:1 21:9 2:3 3:2 4:5 5:4 9:16 9:21],
},
],
}
@ -35,7 +41,11 @@ module DiscourseAi
end
def prompts
JSON.parse(parameters[:prompts].to_s)
parameters[:prompts]
end
def aspect_ratio
parameters[:aspect_ratio]
end
def seeds
@ -75,6 +85,7 @@ module DiscourseAi
api_url: api_url,
image_count: 1,
seed: inner_seed,
aspect_ratio: aspect_ratio,
)
rescue => e
attempts += 1
@ -116,7 +127,7 @@ module DiscourseAi
#{
uploads
.map do |item|
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
"![#{item[:prompt].gsub(/\|\'\"/, "")}|50%](#{item[:upload].short_url})"
end
.join(" ")
}

View File

@ -7,7 +7,6 @@ module DiscourseAi
return [] if input.blank?
model = SiteSetting.ai_helper_illustrate_post_model
attribution = "discourse_ai.ai_helper.painter.attribution.#{model}"
if model == "stable_diffusion_xl"
stable_diffusion_prompt = diffusion_prompt(input, user)

View File

@ -3,10 +3,74 @@
module ::DiscourseAi
module Inference
class StabilityGenerator
TIMEOUT = 120
# there is a new api for sd3
def self.perform_sd3!(
prompt,
aspect_ratio: nil,
api_key: nil,
engine: nil,
api_url: nil,
output_format: "png",
seed: nil
)
api_key ||= SiteSetting.ai_stability_api_key
engine ||= SiteSetting.ai_stability_engine
api_url ||= SiteSetting.ai_stability_api_url
allowed_ratios = %w[16:9 1:1 21:9 2:3 3:2 4:5 5:4 9:16 9:21]
aspect_ratio = "1:1" if !aspect_ratio || !allowed_ratios.include?(aspect_ratio)
payload = {
prompt: prompt,
mode: "text-to-image",
model: engine,
output_format: output_format,
aspect_ratio: aspect_ratio,
}
payload[:seed] = seed if seed
endpoint = "v2beta/stable-image/generate/sd3"
form_data = payload.to_a.map { |k, v| [k.to_s, v.to_s] }
uri = URI("#{api_url}/#{endpoint}")
request = FinalDestination::HTTP::Post.new(uri)
request["authorization"] = "Bearer #{api_key}"
request["accept"] = "application/json"
request["User-Agent"] = DiscourseAi::AiBot::USER_AGENT
request.set_form form_data, "multipart/form-data"
response =
FinalDestination::HTTP.start(
uri.hostname,
uri.port,
use_ssl: uri.port != 80,
read_timeout: TIMEOUT,
open_timeout: TIMEOUT,
write_timeout: TIMEOUT,
) { |http| http.request(request) }
if response.code != "200"
Rails.logger.error(
"AI stability generator failed with status #{response.code}: #{response.body}}",
)
raise Net::HTTPBadResponse
end
parsed = JSON.parse(response.body, symbolize_names: true)
# remap to old format
{ artifacts: [{ base64: parsed[:image], seed: parsed[:seed] }] }
end
def self.perform!(
prompt,
width: nil,
height: nil,
aspect_ratio: nil,
api_key: nil,
engine: nil,
api_url: nil,
@ -17,31 +81,53 @@ module ::DiscourseAi
engine ||= SiteSetting.ai_stability_engine
api_url ||= SiteSetting.ai_stability_api_url
image_count = 4 if image_count > 4
if engine.start_with? "sd3"
artifacts =
image_count.times.map do
perform_sd3!(
prompt,
api_key: api_key,
engine: engine,
api_url: api_url,
aspect_ratio: aspect_ratio,
seed: seed,
)[
:artifacts
][
0
]
end
return { artifacts: artifacts }
end
headers = {
"Content-Type" => "application/json",
"Accept" => "application/json",
"Authorization" => "Bearer #{api_key}",
}
sdxl_allowed_dimensions = [
[1024, 1024],
[1152, 896],
[1216, 832],
[1344, 768],
[1536, 640],
[640, 1536],
[768, 1344],
[832, 1216],
[896, 1152],
]
ratio_to_dimension = {
"16:9" => [1536, 640],
"1:1" => [1024, 1024],
"21:9" => [1344, 768],
"2:3" => [896, 1152],
"3:2" => [1152, 896],
"4:5" => [832, 1216],
"5:4" => [1216, 832],
"9:16" => [640, 1536],
"9:21" => [768, 1344],
}
if (!width && !height)
if engine.include? "xl"
width, height = sdxl_allowed_dimensions[0]
width, height = ratio_to_dimension[aspect_ratio] if aspect_ratio
width, height = [1024, 1024] if !width || !height
else
width, height = [512, 512]
end
end
payload = {
text_prompts: [{ text: prompt }],

View File

@ -5,6 +5,36 @@ describe DiscourseAi::Inference::StabilityGenerator do
DiscourseAi::Inference::StabilityGenerator.perform!(prompt)
end
let :sd3_response do
{ image: "BASE64", seed: 1 }.to_json
end
it "is able to generate sd3 images" do
SiteSetting.ai_stability_engine = "sd3"
SiteSetting.ai_stability_api_url = "http://www.a.b.c"
SiteSetting.ai_stability_api_key = "123"
# webmock does not support multipart form data :(
stub_request(:post, "http://www.a.b.c/v2beta/stable-image/generate/sd3").with(
headers: {
"Accept" => "application/json",
"Authorization" => "Bearer 123",
"Content-Type" => "multipart/form-data",
"Host" => "www.a.b.c",
"User-Agent" => DiscourseAi::AiBot::USER_AGENT,
},
).to_return(status: 200, body: sd3_response, headers: {})
json =
DiscourseAi::Inference::StabilityGenerator.perform!(
"a cow",
aspect_ratio: "16:9",
image_count: 2,
)
expect(json).to eq(artifacts: [{ base64: "BASE64", seed: 1 }, { base64: "BASE64", seed: 1 }])
end
it "sets dimensions to 512x512 for non XL model" do
SiteSetting.ai_stability_engine = "stable-diffusion-v1-5"
SiteSetting.ai_stability_api_url = "http://www.a.b.c"