FEATURE: add support for final stable diffusion xl model (#122)
This commit is contained in:
parent
51fdf21143
commit
602bb843ea
|
@ -103,10 +103,11 @@ plugins:
|
|||
ai_stability_api_url:
|
||||
default: "https://api.stability.ai"
|
||||
ai_stability_engine:
|
||||
default: "stable-diffusion-xl-beta-v2-2-2"
|
||||
default: "stable-diffusion-xl-1024-v1-0"
|
||||
type: enum
|
||||
choices:
|
||||
- "stable-diffusion-xl-beta-v2-2-2"
|
||||
- "stable-diffusion-xl-1024-v1-0"
|
||||
- "stable-diffusion-768-v2-1"
|
||||
- "stable-diffusion-v1-5"
|
||||
ai_hugging_face_api_url:
|
||||
default: ""
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
module ::DiscourseAi
|
||||
module Inference
|
||||
class StabilityGenerator
|
||||
def self.perform!(prompt)
|
||||
def self.perform!(prompt, width: nil, height: nil)
|
||||
headers = {
|
||||
"Referer" => Discourse.base_url,
|
||||
"Content-Type" => "application/json",
|
||||
|
@ -11,12 +11,32 @@ module ::DiscourseAi
|
|||
"Authorization" => "Bearer #{SiteSetting.ai_stability_api_key}",
|
||||
}
|
||||
|
||||
sdxl_allowed_dimentions = [
|
||||
[1024, 1024],
|
||||
[1152, 896],
|
||||
[1216, 832],
|
||||
[1344, 768],
|
||||
[1536, 640],
|
||||
[640, 1536],
|
||||
[768, 1344],
|
||||
[832, 1216],
|
||||
[896, 1152],
|
||||
]
|
||||
|
||||
if (!width && !height)
|
||||
if SiteSetting.ai_stability_engine.include? "xl"
|
||||
width, height = sdxl_allowed_dimentions[0]
|
||||
else
|
||||
width, height = [512, 512]
|
||||
end
|
||||
end
|
||||
|
||||
payload = {
|
||||
text_prompts: [{ text: prompt }],
|
||||
cfg_scale: 7,
|
||||
clip_guidance_preset: "FAST_BLUE",
|
||||
height: 512,
|
||||
width: 512,
|
||||
height: width,
|
||||
width: height,
|
||||
samples: 4,
|
||||
steps: 30,
|
||||
}
|
||||
|
@ -27,7 +47,12 @@ module ::DiscourseAi
|
|||
|
||||
response = Faraday.post("#{base_url}/#{endpoint}", payload.to_json, headers)
|
||||
|
||||
raise Net::HTTPBadResponse if response.status != 200
|
||||
if response.status != 200
|
||||
Rails.logger.error(
|
||||
"AI stability generator failed with status #{response.status}: #{response.body}}",
|
||||
)
|
||||
raise Net::HTTPBadResponse
|
||||
end
|
||||
|
||||
JSON.parse(response.body, symbolize_names: true)
|
||||
end
|
||||
|
|
|
@ -0,0 +1,50 @@
|
|||
# frozen_string_literal: true
|
||||
require "rails_helper"
|
||||
|
||||
describe DiscourseAi::Inference::StabilityGenerator do
|
||||
def gen(prompt)
|
||||
DiscourseAi::Inference::StabilityGenerator.perform!(prompt)
|
||||
end
|
||||
|
||||
it "sets dimentions to 512x512 for non XL model" do
|
||||
SiteSetting.ai_stability_engine = "stable-diffusion-v1-5"
|
||||
SiteSetting.ai_stability_api_url = "http://www.a.b.c"
|
||||
SiteSetting.ai_stability_api_key = "123"
|
||||
|
||||
stub_request(:post, "http://www.a.b.c/v1/generation/stable-diffusion-v1-5/text-to-image")
|
||||
.with do |request|
|
||||
json = JSON.parse(request.body)
|
||||
expect(json["text_prompts"][0]["text"]).to eq("a cow")
|
||||
expect(json["width"]).to eq(512)
|
||||
expect(json["height"]).to eq(512)
|
||||
expect(request.headers["Authorization"]).to eq("Bearer 123")
|
||||
expect(request.headers["Content-Type"]).to eq("application/json")
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: "{}", headers: {})
|
||||
|
||||
gen("a cow")
|
||||
end
|
||||
|
||||
it "sets dimentions to 1024x1024 for XL model" do
|
||||
SiteSetting.ai_stability_engine = "stable-diffusion-xl-1024-v1-0"
|
||||
SiteSetting.ai_stability_api_url = "http://www.a.b.c"
|
||||
SiteSetting.ai_stability_api_key = "123"
|
||||
stub_request(
|
||||
:post,
|
||||
"http://www.a.b.c/v1/generation/stable-diffusion-xl-1024-v1-0/text-to-image",
|
||||
)
|
||||
.with do |request|
|
||||
json = JSON.parse(request.body)
|
||||
expect(json["text_prompts"][0]["text"]).to eq("a cow")
|
||||
expect(json["width"]).to eq(1024)
|
||||
expect(json["height"]).to eq(1024)
|
||||
expect(request.headers["Authorization"]).to eq("Bearer 123")
|
||||
expect(request.headers["Content-Type"]).to eq("application/json")
|
||||
true
|
||||
end
|
||||
.to_return(status: 200, body: "{}", headers: {})
|
||||
|
||||
gen("a cow")
|
||||
end
|
||||
end
|
Loading…
Reference in New Issue