FEATURE: Stable diffusion 3 support (#582)
- Adds support for sd3 and sd3 turbo models - this requires new endpoints - Adds a hack to normalize arrays in the tool calls - Removes some leftover code - Adds support for aspect ratio as well so you can generate wide or tall images
This commit is contained in:
parent
a223d18f1a
commit
bd6f5caeac
|
@ -122,6 +122,8 @@ discourse_ai:
|
|||
default: "stable-diffusion-xl-1024-v1-0"
|
||||
type: enum
|
||||
choices:
|
||||
- "sd3"
|
||||
- "sd3-turbo"
|
||||
- "stable-diffusion-xl-1024-v1-0"
|
||||
- "stable-diffusion-768-v2-1"
|
||||
- "stable-diffusion-v1-5"
|
||||
|
|
|
@ -203,10 +203,6 @@ module DiscourseAi
|
|||
end
|
||||
end
|
||||
|
||||
def tool_invocation?(partial)
|
||||
Nokogiri::HTML5.fragment(partial).at("invoke").present?
|
||||
end
|
||||
|
||||
def build_placeholder(summary, details, custom_raw: nil)
|
||||
placeholder = +(<<~HTML)
|
||||
<details>
|
||||
|
|
|
@ -188,7 +188,7 @@ module DiscourseAi
|
|||
begin
|
||||
JSON.parse(value)
|
||||
rescue JSON::ParserError
|
||||
nil
|
||||
[value.to_s]
|
||||
end
|
||||
end
|
||||
|
||||
|
|
|
@ -24,7 +24,13 @@ module DiscourseAi
|
|||
"The seed used to generate the image (optional) - can be used to retain image style on amended prompts",
|
||||
type: "array",
|
||||
item_type: "integer",
|
||||
required: true,
|
||||
},
|
||||
{
|
||||
name: "aspect_ratio",
|
||||
description: "The aspect ratio of the image (optional defaults to 1:1)",
|
||||
type: "string",
|
||||
required: false,
|
||||
enum: %w[16:9 1:1 21:9 2:3 3:2 4:5 5:4 9:16 9:21],
|
||||
},
|
||||
],
|
||||
}
|
||||
|
@ -35,7 +41,11 @@ module DiscourseAi
|
|||
end
|
||||
|
||||
def prompts
|
||||
JSON.parse(parameters[:prompts].to_s)
|
||||
parameters[:prompts]
|
||||
end
|
||||
|
||||
def aspect_ratio
|
||||
parameters[:aspect_ratio]
|
||||
end
|
||||
|
||||
def seeds
|
||||
|
@ -75,6 +85,7 @@ module DiscourseAi
|
|||
api_url: api_url,
|
||||
image_count: 1,
|
||||
seed: inner_seed,
|
||||
aspect_ratio: aspect_ratio,
|
||||
)
|
||||
rescue => e
|
||||
attempts += 1
|
||||
|
@ -116,7 +127,7 @@ module DiscourseAi
|
|||
#{
|
||||
uploads
|
||||
.map do |item|
|
||||
"![#{item[:prompt].gsub(/\|\'\"/, "")}|512x512, 50%](#{item[:upload].short_url})"
|
||||
"![#{item[:prompt].gsub(/\|\'\"/, "")}|50%](#{item[:upload].short_url})"
|
||||
end
|
||||
.join(" ")
|
||||
}
|
||||
|
|
|
@ -7,7 +7,6 @@ module DiscourseAi
|
|||
return [] if input.blank?
|
||||
|
||||
model = SiteSetting.ai_helper_illustrate_post_model
|
||||
attribution = "discourse_ai.ai_helper.painter.attribution.#{model}"
|
||||
|
||||
if model == "stable_diffusion_xl"
|
||||
stable_diffusion_prompt = diffusion_prompt(input, user)
|
||||
|
|
|
@ -3,10 +3,74 @@
|
|||
module ::DiscourseAi
|
||||
module Inference
|
||||
class StabilityGenerator
|
||||
TIMEOUT = 120
|
||||
|
||||
# there is a new api for sd3
|
||||
def self.perform_sd3!(
|
||||
prompt,
|
||||
aspect_ratio: nil,
|
||||
api_key: nil,
|
||||
engine: nil,
|
||||
api_url: nil,
|
||||
output_format: "png",
|
||||
seed: nil
|
||||
)
|
||||
api_key ||= SiteSetting.ai_stability_api_key
|
||||
engine ||= SiteSetting.ai_stability_engine
|
||||
api_url ||= SiteSetting.ai_stability_api_url
|
||||
|
||||
allowed_ratios = %w[16:9 1:1 21:9 2:3 3:2 4:5 5:4 9:16 9:21]
|
||||
|
||||
aspect_ratio = "1:1" if !aspect_ratio || !allowed_ratios.include?(aspect_ratio)
|
||||
|
||||
payload = {
|
||||
prompt: prompt,
|
||||
mode: "text-to-image",
|
||||
model: engine,
|
||||
output_format: output_format,
|
||||
aspect_ratio: aspect_ratio,
|
||||
}
|
||||
|
||||
payload[:seed] = seed if seed
|
||||
|
||||
endpoint = "v2beta/stable-image/generate/sd3"
|
||||
|
||||
form_data = payload.to_a.map { |k, v| [k.to_s, v.to_s] }
|
||||
|
||||
uri = URI("#{api_url}/#{endpoint}")
|
||||
request = FinalDestination::HTTP::Post.new(uri)
|
||||
|
||||
request["authorization"] = "Bearer #{api_key}"
|
||||
request["accept"] = "application/json"
|
||||
request["User-Agent"] = DiscourseAi::AiBot::USER_AGENT
|
||||
request.set_form form_data, "multipart/form-data"
|
||||
|
||||
response =
|
||||
FinalDestination::HTTP.start(
|
||||
uri.hostname,
|
||||
uri.port,
|
||||
use_ssl: uri.port != 80,
|
||||
read_timeout: TIMEOUT,
|
||||
open_timeout: TIMEOUT,
|
||||
write_timeout: TIMEOUT,
|
||||
) { |http| http.request(request) }
|
||||
|
||||
if response.code != "200"
|
||||
Rails.logger.error(
|
||||
"AI stability generator failed with status #{response.code}: #{response.body}}",
|
||||
)
|
||||
raise Net::HTTPBadResponse
|
||||
end
|
||||
|
||||
parsed = JSON.parse(response.body, symbolize_names: true)
|
||||
|
||||
# remap to old format
|
||||
{ artifacts: [{ base64: parsed[:image], seed: parsed[:seed] }] }
|
||||
end
|
||||
|
||||
def self.perform!(
|
||||
prompt,
|
||||
width: nil,
|
||||
height: nil,
|
||||
aspect_ratio: nil,
|
||||
api_key: nil,
|
||||
engine: nil,
|
||||
api_url: nil,
|
||||
|
@ -17,30 +81,52 @@ module ::DiscourseAi
|
|||
engine ||= SiteSetting.ai_stability_engine
|
||||
api_url ||= SiteSetting.ai_stability_api_url
|
||||
|
||||
image_count = 4 if image_count > 4
|
||||
|
||||
if engine.start_with? "sd3"
|
||||
artifacts =
|
||||
image_count.times.map do
|
||||
perform_sd3!(
|
||||
prompt,
|
||||
api_key: api_key,
|
||||
engine: engine,
|
||||
api_url: api_url,
|
||||
aspect_ratio: aspect_ratio,
|
||||
seed: seed,
|
||||
)[
|
||||
:artifacts
|
||||
][
|
||||
0
|
||||
]
|
||||
end
|
||||
|
||||
return { artifacts: artifacts }
|
||||
end
|
||||
|
||||
headers = {
|
||||
"Content-Type" => "application/json",
|
||||
"Accept" => "application/json",
|
||||
"Authorization" => "Bearer #{api_key}",
|
||||
}
|
||||
|
||||
sdxl_allowed_dimensions = [
|
||||
[1024, 1024],
|
||||
[1152, 896],
|
||||
[1216, 832],
|
||||
[1344, 768],
|
||||
[1536, 640],
|
||||
[640, 1536],
|
||||
[768, 1344],
|
||||
[832, 1216],
|
||||
[896, 1152],
|
||||
]
|
||||
ratio_to_dimension = {
|
||||
"16:9" => [1536, 640],
|
||||
"1:1" => [1024, 1024],
|
||||
"21:9" => [1344, 768],
|
||||
"2:3" => [896, 1152],
|
||||
"3:2" => [1152, 896],
|
||||
"4:5" => [832, 1216],
|
||||
"5:4" => [1216, 832],
|
||||
"9:16" => [640, 1536],
|
||||
"9:21" => [768, 1344],
|
||||
}
|
||||
|
||||
if (!width && !height)
|
||||
if engine.include? "xl"
|
||||
width, height = sdxl_allowed_dimensions[0]
|
||||
else
|
||||
width, height = [512, 512]
|
||||
end
|
||||
if engine.include? "xl"
|
||||
width, height = ratio_to_dimension[aspect_ratio] if aspect_ratio
|
||||
|
||||
width, height = [1024, 1024] if !width || !height
|
||||
else
|
||||
width, height = [512, 512]
|
||||
end
|
||||
|
||||
payload = {
|
||||
|
|
|
@ -5,6 +5,36 @@ describe DiscourseAi::Inference::StabilityGenerator do
|
|||
DiscourseAi::Inference::StabilityGenerator.perform!(prompt)
|
||||
end
|
||||
|
||||
let :sd3_response do
|
||||
{ image: "BASE64", seed: 1 }.to_json
|
||||
end
|
||||
|
||||
it "is able to generate sd3 images" do
|
||||
SiteSetting.ai_stability_engine = "sd3"
|
||||
SiteSetting.ai_stability_api_url = "http://www.a.b.c"
|
||||
SiteSetting.ai_stability_api_key = "123"
|
||||
|
||||
# webmock does not support multipart form data :(
|
||||
stub_request(:post, "http://www.a.b.c/v2beta/stable-image/generate/sd3").with(
|
||||
headers: {
|
||||
"Accept" => "application/json",
|
||||
"Authorization" => "Bearer 123",
|
||||
"Content-Type" => "multipart/form-data",
|
||||
"Host" => "www.a.b.c",
|
||||
"User-Agent" => DiscourseAi::AiBot::USER_AGENT,
|
||||
},
|
||||
).to_return(status: 200, body: sd3_response, headers: {})
|
||||
|
||||
json =
|
||||
DiscourseAi::Inference::StabilityGenerator.perform!(
|
||||
"a cow",
|
||||
aspect_ratio: "16:9",
|
||||
image_count: 2,
|
||||
)
|
||||
|
||||
expect(json).to eq(artifacts: [{ base64: "BASE64", seed: 1 }, { base64: "BASE64", seed: 1 }])
|
||||
end
|
||||
|
||||
it "sets dimensions to 512x512 for non XL model" do
|
||||
SiteSetting.ai_stability_engine = "stable-diffusion-v1-5"
|
||||
SiteSetting.ai_stability_api_url = "http://www.a.b.c"
|
||||
|
|
Loading…
Reference in New Issue