FEATURE: add support for final stable diffusion xl model (#122)
This commit is contained in:
parent
51fdf21143
commit
602bb843ea
|
@ -103,10 +103,11 @@ plugins:
|
||||||
ai_stability_api_url:
|
ai_stability_api_url:
|
||||||
default: "https://api.stability.ai"
|
default: "https://api.stability.ai"
|
||||||
ai_stability_engine:
|
ai_stability_engine:
|
||||||
default: "stable-diffusion-xl-beta-v2-2-2"
|
default: "stable-diffusion-xl-1024-v1-0"
|
||||||
type: enum
|
type: enum
|
||||||
choices:
|
choices:
|
||||||
- "stable-diffusion-xl-beta-v2-2-2"
|
- "stable-diffusion-xl-1024-v1-0"
|
||||||
|
- "stable-diffusion-768-v2-1"
|
||||||
- "stable-diffusion-v1-5"
|
- "stable-diffusion-v1-5"
|
||||||
ai_hugging_face_api_url:
|
ai_hugging_face_api_url:
|
||||||
default: ""
|
default: ""
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
module ::DiscourseAi
|
module ::DiscourseAi
|
||||||
module Inference
|
module Inference
|
||||||
class StabilityGenerator
|
class StabilityGenerator
|
||||||
def self.perform!(prompt)
|
def self.perform!(prompt, width: nil, height: nil)
|
||||||
headers = {
|
headers = {
|
||||||
"Referer" => Discourse.base_url,
|
"Referer" => Discourse.base_url,
|
||||||
"Content-Type" => "application/json",
|
"Content-Type" => "application/json",
|
||||||
|
@ -11,12 +11,32 @@ module ::DiscourseAi
|
||||||
"Authorization" => "Bearer #{SiteSetting.ai_stability_api_key}",
|
"Authorization" => "Bearer #{SiteSetting.ai_stability_api_key}",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
sdxl_allowed_dimentions = [
|
||||||
|
[1024, 1024],
|
||||||
|
[1152, 896],
|
||||||
|
[1216, 832],
|
||||||
|
[1344, 768],
|
||||||
|
[1536, 640],
|
||||||
|
[640, 1536],
|
||||||
|
[768, 1344],
|
||||||
|
[832, 1216],
|
||||||
|
[896, 1152],
|
||||||
|
]
|
||||||
|
|
||||||
|
if (!width && !height)
|
||||||
|
if SiteSetting.ai_stability_engine.include? "xl"
|
||||||
|
width, height = sdxl_allowed_dimentions[0]
|
||||||
|
else
|
||||||
|
width, height = [512, 512]
|
||||||
|
end
|
||||||
|
end
|
||||||
|
|
||||||
payload = {
|
payload = {
|
||||||
text_prompts: [{ text: prompt }],
|
text_prompts: [{ text: prompt }],
|
||||||
cfg_scale: 7,
|
cfg_scale: 7,
|
||||||
clip_guidance_preset: "FAST_BLUE",
|
clip_guidance_preset: "FAST_BLUE",
|
||||||
height: 512,
|
height: width,
|
||||||
width: 512,
|
width: height,
|
||||||
samples: 4,
|
samples: 4,
|
||||||
steps: 30,
|
steps: 30,
|
||||||
}
|
}
|
||||||
|
@ -27,7 +47,12 @@ module ::DiscourseAi
|
||||||
|
|
||||||
response = Faraday.post("#{base_url}/#{endpoint}", payload.to_json, headers)
|
response = Faraday.post("#{base_url}/#{endpoint}", payload.to_json, headers)
|
||||||
|
|
||||||
raise Net::HTTPBadResponse if response.status != 200
|
if response.status != 200
|
||||||
|
Rails.logger.error(
|
||||||
|
"AI stability generator failed with status #{response.status}: #{response.body}}",
|
||||||
|
)
|
||||||
|
raise Net::HTTPBadResponse
|
||||||
|
end
|
||||||
|
|
||||||
JSON.parse(response.body, symbolize_names: true)
|
JSON.parse(response.body, symbolize_names: true)
|
||||||
end
|
end
|
||||||
|
|
|
@ -0,0 +1,50 @@
|
||||||
|
# frozen_string_literal: true
|
||||||
|
require "rails_helper"
|
||||||
|
|
||||||
|
describe DiscourseAi::Inference::StabilityGenerator do
|
||||||
|
def gen(prompt)
|
||||||
|
DiscourseAi::Inference::StabilityGenerator.perform!(prompt)
|
||||||
|
end
|
||||||
|
|
||||||
|
it "sets dimentions to 512x512 for non XL model" do
|
||||||
|
SiteSetting.ai_stability_engine = "stable-diffusion-v1-5"
|
||||||
|
SiteSetting.ai_stability_api_url = "http://www.a.b.c"
|
||||||
|
SiteSetting.ai_stability_api_key = "123"
|
||||||
|
|
||||||
|
stub_request(:post, "http://www.a.b.c/v1/generation/stable-diffusion-v1-5/text-to-image")
|
||||||
|
.with do |request|
|
||||||
|
json = JSON.parse(request.body)
|
||||||
|
expect(json["text_prompts"][0]["text"]).to eq("a cow")
|
||||||
|
expect(json["width"]).to eq(512)
|
||||||
|
expect(json["height"]).to eq(512)
|
||||||
|
expect(request.headers["Authorization"]).to eq("Bearer 123")
|
||||||
|
expect(request.headers["Content-Type"]).to eq("application/json")
|
||||||
|
true
|
||||||
|
end
|
||||||
|
.to_return(status: 200, body: "{}", headers: {})
|
||||||
|
|
||||||
|
gen("a cow")
|
||||||
|
end
|
||||||
|
|
||||||
|
it "sets dimentions to 1024x1024 for XL model" do
|
||||||
|
SiteSetting.ai_stability_engine = "stable-diffusion-xl-1024-v1-0"
|
||||||
|
SiteSetting.ai_stability_api_url = "http://www.a.b.c"
|
||||||
|
SiteSetting.ai_stability_api_key = "123"
|
||||||
|
stub_request(
|
||||||
|
:post,
|
||||||
|
"http://www.a.b.c/v1/generation/stable-diffusion-xl-1024-v1-0/text-to-image",
|
||||||
|
)
|
||||||
|
.with do |request|
|
||||||
|
json = JSON.parse(request.body)
|
||||||
|
expect(json["text_prompts"][0]["text"]).to eq("a cow")
|
||||||
|
expect(json["width"]).to eq(1024)
|
||||||
|
expect(json["height"]).to eq(1024)
|
||||||
|
expect(request.headers["Authorization"]).to eq("Bearer 123")
|
||||||
|
expect(request.headers["Content-Type"]).to eq("application/json")
|
||||||
|
true
|
||||||
|
end
|
||||||
|
.to_return(status: 200, body: "{}", headers: {})
|
||||||
|
|
||||||
|
gen("a cow")
|
||||||
|
end
|
||||||
|
end
|
Loading…
Reference in New Issue