2023-05-19 19:38:08 -04:00
|
|
|
# frozen_string_literal: true
|
|
|
|
|
|
|
|
module ::DiscourseAi
|
|
|
|
module Inference
|
|
|
|
class StabilityGenerator
|
2023-10-26 23:48:12 -04:00
|
|
|
def self.perform!(
|
|
|
|
prompt,
|
|
|
|
width: nil,
|
|
|
|
height: nil,
|
|
|
|
api_key: nil,
|
|
|
|
engine: nil,
|
|
|
|
api_url: nil,
|
|
|
|
image_count: 4,
|
|
|
|
seed: nil
|
|
|
|
)
|
2023-10-24 20:04:16 -04:00
|
|
|
api_key ||= SiteSetting.ai_stability_api_key
|
|
|
|
engine ||= SiteSetting.ai_stability_engine
|
|
|
|
api_url ||= SiteSetting.ai_stability_api_url
|
|
|
|
|
2023-05-19 19:38:08 -04:00
|
|
|
headers = {
|
|
|
|
"Content-Type" => "application/json",
|
|
|
|
"Accept" => "application/json",
|
2023-10-24 20:04:16 -04:00
|
|
|
"Authorization" => "Bearer #{api_key}",
|
2023-05-19 19:38:08 -04:00
|
|
|
}
|
|
|
|
|
2024-01-19 06:51:26 -05:00
|
|
|
sdxl_allowed_dimensions = [
|
2023-08-02 15:53:28 -04:00
|
|
|
[1024, 1024],
|
|
|
|
[1152, 896],
|
|
|
|
[1216, 832],
|
|
|
|
[1344, 768],
|
|
|
|
[1536, 640],
|
|
|
|
[640, 1536],
|
|
|
|
[768, 1344],
|
|
|
|
[832, 1216],
|
|
|
|
[896, 1152],
|
|
|
|
]
|
|
|
|
|
|
|
|
if (!width && !height)
|
2023-10-24 20:04:16 -04:00
|
|
|
if engine.include? "xl"
|
2024-01-19 06:51:26 -05:00
|
|
|
width, height = sdxl_allowed_dimensions[0]
|
2023-08-02 15:53:28 -04:00
|
|
|
else
|
|
|
|
width, height = [512, 512]
|
|
|
|
end
|
|
|
|
end
|
|
|
|
|
2023-05-19 19:38:08 -04:00
|
|
|
payload = {
|
|
|
|
text_prompts: [{ text: prompt }],
|
|
|
|
cfg_scale: 7,
|
|
|
|
clip_guidance_preset: "FAST_BLUE",
|
2023-08-02 15:53:28 -04:00
|
|
|
height: width,
|
|
|
|
width: height,
|
2023-10-26 23:48:12 -04:00
|
|
|
samples: image_count,
|
2023-05-19 19:38:08 -04:00
|
|
|
steps: 30,
|
|
|
|
}
|
|
|
|
|
2023-10-26 23:48:12 -04:00
|
|
|
payload[:seed] = seed if seed
|
|
|
|
|
2023-05-19 19:38:08 -04:00
|
|
|
endpoint = "v1/generation/#{engine}/text-to-image"
|
|
|
|
|
2024-02-21 15:14:50 -05:00
|
|
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
|
|
|
response = conn.post("#{api_url}/#{endpoint}", payload.to_json, headers)
|
2023-05-19 19:38:08 -04:00
|
|
|
|
2023-08-02 15:53:28 -04:00
|
|
|
if response.status != 200
|
|
|
|
Rails.logger.error(
|
|
|
|
"AI stability generator failed with status #{response.status}: #{response.body}}",
|
|
|
|
)
|
|
|
|
raise Net::HTTPBadResponse
|
|
|
|
end
|
2023-05-19 19:38:08 -04:00
|
|
|
|
|
|
|
JSON.parse(response.body, symbolize_names: true)
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|
|
|
|
end
|