From 602bb843ea727f25dc33c208529979c949b45b06 Mon Sep 17 00:00:00 2001 From: Sam Date: Thu, 3 Aug 2023 05:53:28 +1000 Subject: [PATCH] FEATURE: add support for final stable diffusion xl model (#122) --- config/settings.yml | 5 +- lib/shared/inference/stability_generator.rb | 33 ++++++++++-- .../inference/stability_generator_spec.rb | 50 +++++++++++++++++++ 3 files changed, 82 insertions(+), 6 deletions(-) create mode 100644 spec/shared/inference/stability_generator_spec.rb diff --git a/config/settings.yml b/config/settings.yml index da989435..00315a47 100644 --- a/config/settings.yml +++ b/config/settings.yml @@ -103,10 +103,11 @@ plugins: ai_stability_api_url: default: "https://api.stability.ai" ai_stability_engine: - default: "stable-diffusion-xl-beta-v2-2-2" + default: "stable-diffusion-xl-1024-v1-0" type: enum choices: - - "stable-diffusion-xl-beta-v2-2-2" + - "stable-diffusion-xl-1024-v1-0" + - "stable-diffusion-768-v2-1" - "stable-diffusion-v1-5" ai_hugging_face_api_url: default: "" diff --git a/lib/shared/inference/stability_generator.rb b/lib/shared/inference/stability_generator.rb index ead46f7d..157f65d7 100644 --- a/lib/shared/inference/stability_generator.rb +++ b/lib/shared/inference/stability_generator.rb @@ -3,7 +3,7 @@ module ::DiscourseAi module Inference class StabilityGenerator - def self.perform!(prompt) + def self.perform!(prompt, width: nil, height: nil) headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json", @@ -11,12 +11,32 @@ module ::DiscourseAi "Authorization" => "Bearer #{SiteSetting.ai_stability_api_key}", } + sdxl_allowed_dimentions = [ + [1024, 1024], + [1152, 896], + [1216, 832], + [1344, 768], + [1536, 640], + [640, 1536], + [768, 1344], + [832, 1216], + [896, 1152], + ] + + if (!width && !height) + if SiteSetting.ai_stability_engine.include? "xl" + width, height = sdxl_allowed_dimentions[0] + else + width, height = [512, 512] + end + end + payload = { text_prompts: [{ text: prompt }], cfg_scale: 7, clip_guidance_preset: "FAST_BLUE", - height: 512, - width: 512, + height: width, + width: height, samples: 4, steps: 30, } @@ -27,7 +47,12 @@ module ::DiscourseAi response = Faraday.post("#{base_url}/#{endpoint}", payload.to_json, headers) - raise Net::HTTPBadResponse if response.status != 200 + if response.status != 200 + Rails.logger.error( + "AI stability generator failed with status #{response.status}: #{response.body}}", + ) + raise Net::HTTPBadResponse + end JSON.parse(response.body, symbolize_names: true) end diff --git a/spec/shared/inference/stability_generator_spec.rb b/spec/shared/inference/stability_generator_spec.rb new file mode 100644 index 00000000..d9b91b08 --- /dev/null +++ b/spec/shared/inference/stability_generator_spec.rb @@ -0,0 +1,50 @@ +# frozen_string_literal: true +require "rails_helper" + +describe DiscourseAi::Inference::StabilityGenerator do + def gen(prompt) + DiscourseAi::Inference::StabilityGenerator.perform!(prompt) + end + + it "sets dimentions to 512x512 for non XL model" do + SiteSetting.ai_stability_engine = "stable-diffusion-v1-5" + SiteSetting.ai_stability_api_url = "http://www.a.b.c" + SiteSetting.ai_stability_api_key = "123" + + stub_request(:post, "http://www.a.b.c/v1/generation/stable-diffusion-v1-5/text-to-image") + .with do |request| + json = JSON.parse(request.body) + expect(json["text_prompts"][0]["text"]).to eq("a cow") + expect(json["width"]).to eq(512) + expect(json["height"]).to eq(512) + expect(request.headers["Authorization"]).to eq("Bearer 123") + expect(request.headers["Content-Type"]).to eq("application/json") + true + end + .to_return(status: 200, body: "{}", headers: {}) + + gen("a cow") + end + + it "sets dimentions to 1024x1024 for XL model" do + SiteSetting.ai_stability_engine = "stable-diffusion-xl-1024-v1-0" + SiteSetting.ai_stability_api_url = "http://www.a.b.c" + SiteSetting.ai_stability_api_key = "123" + stub_request( + :post, + "http://www.a.b.c/v1/generation/stable-diffusion-xl-1024-v1-0/text-to-image", + ) + .with do |request| + json = JSON.parse(request.body) + expect(json["text_prompts"][0]["text"]).to eq("a cow") + expect(json["width"]).to eq(1024) + expect(json["height"]).to eq(1024) + expect(request.headers["Authorization"]).to eq("Bearer 123") + expect(request.headers["Content-Type"]).to eq("application/json") + true + end + .to_return(status: 200, body: "{}", headers: {}) + + gen("a cow") + end +end