FEATURE: add support for final stable diffusion xl model (#122)

This commit is contained in:
Sam 2023-08-03 05:53:28 +10:00 committed by GitHub
parent 51fdf21143
commit 602bb843ea
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 82 additions and 6 deletions

View File

@ -103,10 +103,11 @@ plugins:
ai_stability_api_url: ai_stability_api_url:
default: "https://api.stability.ai" default: "https://api.stability.ai"
ai_stability_engine: ai_stability_engine:
default: "stable-diffusion-xl-beta-v2-2-2" default: "stable-diffusion-xl-1024-v1-0"
type: enum type: enum
choices: choices:
- "stable-diffusion-xl-beta-v2-2-2" - "stable-diffusion-xl-1024-v1-0"
- "stable-diffusion-768-v2-1"
- "stable-diffusion-v1-5" - "stable-diffusion-v1-5"
ai_hugging_face_api_url: ai_hugging_face_api_url:
default: "" default: ""

View File

@ -3,7 +3,7 @@
module ::DiscourseAi module ::DiscourseAi
module Inference module Inference
class StabilityGenerator class StabilityGenerator
def self.perform!(prompt) def self.perform!(prompt, width: nil, height: nil)
headers = { headers = {
"Referer" => Discourse.base_url, "Referer" => Discourse.base_url,
"Content-Type" => "application/json", "Content-Type" => "application/json",
@ -11,12 +11,32 @@ module ::DiscourseAi
"Authorization" => "Bearer #{SiteSetting.ai_stability_api_key}", "Authorization" => "Bearer #{SiteSetting.ai_stability_api_key}",
} }
sdxl_allowed_dimentions = [
[1024, 1024],
[1152, 896],
[1216, 832],
[1344, 768],
[1536, 640],
[640, 1536],
[768, 1344],
[832, 1216],
[896, 1152],
]
if (!width && !height)
if SiteSetting.ai_stability_engine.include? "xl"
width, height = sdxl_allowed_dimentions[0]
else
width, height = [512, 512]
end
end
payload = { payload = {
text_prompts: [{ text: prompt }], text_prompts: [{ text: prompt }],
cfg_scale: 7, cfg_scale: 7,
clip_guidance_preset: "FAST_BLUE", clip_guidance_preset: "FAST_BLUE",
height: 512, height: width,
width: 512, width: height,
samples: 4, samples: 4,
steps: 30, steps: 30,
} }
@ -27,7 +47,12 @@ module ::DiscourseAi
response = Faraday.post("#{base_url}/#{endpoint}", payload.to_json, headers) response = Faraday.post("#{base_url}/#{endpoint}", payload.to_json, headers)
raise Net::HTTPBadResponse if response.status != 200 if response.status != 200
Rails.logger.error(
"AI stability generator failed with status #{response.status}: #{response.body}}",
)
raise Net::HTTPBadResponse
end
JSON.parse(response.body, symbolize_names: true) JSON.parse(response.body, symbolize_names: true)
end end

View File

@ -0,0 +1,50 @@
# frozen_string_literal: true
require "rails_helper"
describe DiscourseAi::Inference::StabilityGenerator do
def gen(prompt)
DiscourseAi::Inference::StabilityGenerator.perform!(prompt)
end
it "sets dimentions to 512x512 for non XL model" do
SiteSetting.ai_stability_engine = "stable-diffusion-v1-5"
SiteSetting.ai_stability_api_url = "http://www.a.b.c"
SiteSetting.ai_stability_api_key = "123"
stub_request(:post, "http://www.a.b.c/v1/generation/stable-diffusion-v1-5/text-to-image")
.with do |request|
json = JSON.parse(request.body)
expect(json["text_prompts"][0]["text"]).to eq("a cow")
expect(json["width"]).to eq(512)
expect(json["height"]).to eq(512)
expect(request.headers["Authorization"]).to eq("Bearer 123")
expect(request.headers["Content-Type"]).to eq("application/json")
true
end
.to_return(status: 200, body: "{}", headers: {})
gen("a cow")
end
it "sets dimentions to 1024x1024 for XL model" do
SiteSetting.ai_stability_engine = "stable-diffusion-xl-1024-v1-0"
SiteSetting.ai_stability_api_url = "http://www.a.b.c"
SiteSetting.ai_stability_api_key = "123"
stub_request(
:post,
"http://www.a.b.c/v1/generation/stable-diffusion-xl-1024-v1-0/text-to-image",
)
.with do |request|
json = JSON.parse(request.body)
expect(json["text_prompts"][0]["text"]).to eq("a cow")
expect(json["width"]).to eq(1024)
expect(json["height"]).to eq(1024)
expect(request.headers["Authorization"]).to eq("Bearer 123")
expect(request.headers["Content-Type"]).to eq("application/json")
true
end
.to_return(status: 200, body: "{}", headers: {})
gen("a cow")
end
end