FIX: make stable diffusion multi site friendly (#265)

Previous to this change image generation did not work on multisite

There was a background thread generating the images and it was
getting site settings from the default site in the cluster

This also removes referer header which is not needed
This commit is contained in:
Sam 2023-10-25 11:04:16 +11:00 committed by GitHub
parent b02be91799
commit 426e348c8a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 23 additions and 8 deletions

View File

@ -47,13 +47,25 @@ module DiscourseAi::AiBot::Commands
results = nil results = nil
# this ensures multisite safety since background threads
# generate the images
api_key = SiteSetting.ai_stability_api_key
engine = SiteSetting.ai_stability_engine
api_url = SiteSetting.ai_stability_api_url
# API is flaky, so try a few times # API is flaky, so try a few times
3.times do 3.times do
begin begin
thread = thread =
Thread.new do Thread.new do
begin begin
results = DiscourseAi::Inference::StabilityGenerator.perform!(prompt) results =
DiscourseAi::Inference::StabilityGenerator.perform!(
prompt,
engine: engine,
api_key: api_key,
api_url: api_url,
)
rescue => e rescue => e
Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}") Rails.logger.warn("Failed to generate image for prompt #{prompt}: #{e}")
end end
@ -65,6 +77,8 @@ module DiscourseAi::AiBot::Commands
end end
end end
return { prompt: prompt, error: "Something went wrong, could not generate image" } if !results
uploads = [] uploads = []
results[:artifacts].each_with_index do |image, i| results[:artifacts].each_with_index do |image, i|

View File

@ -3,12 +3,15 @@
module ::DiscourseAi module ::DiscourseAi
module Inference module Inference
class StabilityGenerator class StabilityGenerator
def self.perform!(prompt, width: nil, height: nil) def self.perform!(prompt, width: nil, height: nil, api_key: nil, engine: nil, api_url: nil)
api_key ||= SiteSetting.ai_stability_api_key
engine ||= SiteSetting.ai_stability_engine
api_url ||= SiteSetting.ai_stability_api_url
headers = { headers = {
"Referer" => Discourse.base_url,
"Content-Type" => "application/json", "Content-Type" => "application/json",
"Accept" => "application/json", "Accept" => "application/json",
"Authorization" => "Bearer #{SiteSetting.ai_stability_api_key}", "Authorization" => "Bearer #{api_key}",
} }
sdxl_allowed_dimentions = [ sdxl_allowed_dimentions = [
@ -24,7 +27,7 @@ module ::DiscourseAi
] ]
if (!width && !height) if (!width && !height)
if SiteSetting.ai_stability_engine.include? "xl" if engine.include? "xl"
width, height = sdxl_allowed_dimentions[0] width, height = sdxl_allowed_dimentions[0]
else else
width, height = [512, 512] width, height = [512, 512]
@ -41,11 +44,9 @@ module ::DiscourseAi
steps: 30, steps: 30,
} }
base_url = SiteSetting.ai_stability_api_url
engine = SiteSetting.ai_stability_engine
endpoint = "v1/generation/#{engine}/text-to-image" endpoint = "v1/generation/#{engine}/text-to-image"
response = Faraday.post("#{base_url}/#{endpoint}", payload.to_json, headers) response = Faraday.post("#{api_url}/#{endpoint}", payload.to_json, headers)
if response.status != 200 if response.status != 200
Rails.logger.error( Rails.logger.error(