From 426e348c8a0b5695064129f21a7a422091e65f88 Mon Sep 17 00:00:00 2001 From: Sam Date: Wed, 25 Oct 2023 11:04:16 +1100 Subject: [PATCH] FIX: make stable diffusion multi site friendly (#265) Previous to this change image generation did not work on multisite There was a background thread generating the images and it was getting site settings from the default site in the cluster This also removes referer header which is not needed --- lib/modules/ai_bot/commands/image_command.rb | 16 +++++++++++++++- lib/shared/inference/stability_generator.rb | 15 ++++++++------- 2 files changed, 23 insertions(+), 8 deletions(-) diff --git a/lib/modules/ai_bot/commands/image_command.rb b/lib/modules/ai_bot/commands/image_command.rb index 044321a6..b5843b2a 100644 --- a/lib/modules/ai_bot/commands/image_command.rb +++ b/lib/modules/ai_bot/commands/image_command.rb @@ -47,13 +47,25 @@ module DiscourseAi::AiBot::Commands results = nil + # this ensures multisite safety since background threads + # generate the images + api_key = SiteSetting.ai_stability_api_key + engine = SiteSetting.ai_stability_engine + api_url = SiteSetting.ai_stability_api_url + # API is flaky, so try a few times 3.times do begin thread = Thread.new do begin - results = DiscourseAi::Inference::StabilityGenerator.perform!(prompt) + results = + DiscourseAi::Inference::StabilityGenerator.perform!( + prompt, + engine: engine, + api_key: api_key, + api_url: api_url, + ) rescue => e Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}") end @@ -65,6 +77,8 @@ module DiscourseAi::AiBot::Commands end end + return { prompt: prompt, error: "Something went wrong, could not generate image" } if !results + uploads = [] results[:artifacts].each_with_index do |image, i| diff --git a/lib/shared/inference/stability_generator.rb b/lib/shared/inference/stability_generator.rb index 157f65d7..04a8ca83 100644 --- a/lib/shared/inference/stability_generator.rb +++ b/lib/shared/inference/stability_generator.rb @@ -3,12 +3,15 @@ module ::DiscourseAi module Inference class StabilityGenerator - def self.perform!(prompt, width: nil, height: nil) + def self.perform!(prompt, width: nil, height: nil, api_key: nil, engine: nil, api_url: nil) + api_key ||= SiteSetting.ai_stability_api_key + engine ||= SiteSetting.ai_stability_engine + api_url ||= SiteSetting.ai_stability_api_url + headers = { - "Referer" => Discourse.base_url, "Content-Type" => "application/json", "Accept" => "application/json", - "Authorization" => "Bearer #{SiteSetting.ai_stability_api_key}", + "Authorization" => "Bearer #{api_key}", } sdxl_allowed_dimentions = [ @@ -24,7 +27,7 @@ module ::DiscourseAi ] if (!width && !height) - if SiteSetting.ai_stability_engine.include? "xl" + if engine.include? "xl" width, height = sdxl_allowed_dimentions[0] else width, height = [512, 512] @@ -41,11 +44,9 @@ module ::DiscourseAi steps: 30, } - base_url = SiteSetting.ai_stability_api_url - engine = SiteSetting.ai_stability_engine endpoint = "v1/generation/#{engine}/text-to-image" - response = Faraday.post("#{base_url}/#{endpoint}", payload.to_json, headers) + response = Faraday.post("#{api_url}/#{endpoint}", payload.to_json, headers) if response.status != 200 Rails.logger.error(