161 lines
4.4 KiB
Ruby
161 lines
4.4 KiB
Ruby
# frozen_string_literal: true
|
|
|
|
module ::DiscourseAi
|
|
module Inference
|
|
class StabilityGenerator
|
|
TIMEOUT = 120
|
|
|
|
# there is a new api for sd3
|
|
def self.perform_sd3!(
|
|
prompt,
|
|
aspect_ratio: nil,
|
|
api_key: nil,
|
|
engine: nil,
|
|
api_url: nil,
|
|
output_format: "png",
|
|
seed: nil
|
|
)
|
|
api_key ||= SiteSetting.ai_stability_api_key
|
|
engine ||= SiteSetting.ai_stability_engine
|
|
api_url ||= SiteSetting.ai_stability_api_url
|
|
|
|
allowed_ratios = %w[16:9 1:1 21:9 2:3 3:2 4:5 5:4 9:16 9:21]
|
|
|
|
aspect_ratio = "1:1" if !aspect_ratio || !allowed_ratios.include?(aspect_ratio)
|
|
|
|
payload = {
|
|
prompt: prompt,
|
|
mode: "text-to-image",
|
|
model: engine,
|
|
output_format: output_format,
|
|
aspect_ratio: aspect_ratio,
|
|
}
|
|
|
|
payload[:seed] = seed if seed
|
|
|
|
endpoint = "v2beta/stable-image/generate/sd3"
|
|
|
|
form_data = payload.to_a.map { |k, v| [k.to_s, v.to_s] }
|
|
|
|
uri = URI("#{api_url}/#{endpoint}")
|
|
request = FinalDestination::HTTP::Post.new(uri)
|
|
|
|
request["authorization"] = "Bearer #{api_key}"
|
|
request["accept"] = "application/json"
|
|
request["User-Agent"] = DiscourseAi::AiBot::USER_AGENT
|
|
request.set_form form_data, "multipart/form-data"
|
|
|
|
response =
|
|
FinalDestination::HTTP.start(
|
|
uri.hostname,
|
|
uri.port,
|
|
use_ssl: uri.port != 80,
|
|
read_timeout: TIMEOUT,
|
|
open_timeout: TIMEOUT,
|
|
write_timeout: TIMEOUT,
|
|
) { |http| http.request(request) }
|
|
|
|
if response.code != "200"
|
|
Rails.logger.error(
|
|
"AI stability generator failed with status #{response.code}: #{response.body}}",
|
|
)
|
|
raise Net::HTTPBadResponse
|
|
end
|
|
|
|
parsed = JSON.parse(response.body, symbolize_names: true)
|
|
|
|
# remap to old format
|
|
{ artifacts: [{ base64: parsed[:image], seed: parsed[:seed] }] }
|
|
end
|
|
|
|
def self.perform!(
|
|
prompt,
|
|
aspect_ratio: nil,
|
|
api_key: nil,
|
|
engine: nil,
|
|
api_url: nil,
|
|
image_count: 4,
|
|
seed: nil
|
|
)
|
|
api_key ||= SiteSetting.ai_stability_api_key
|
|
engine ||= SiteSetting.ai_stability_engine
|
|
api_url ||= SiteSetting.ai_stability_api_url
|
|
|
|
image_count = 4 if image_count > 4
|
|
|
|
if engine.start_with? "sd3"
|
|
artifacts =
|
|
image_count.times.map do
|
|
perform_sd3!(
|
|
prompt,
|
|
api_key: api_key,
|
|
engine: engine,
|
|
api_url: api_url,
|
|
aspect_ratio: aspect_ratio,
|
|
seed: seed,
|
|
)[
|
|
:artifacts
|
|
][
|
|
0
|
|
]
|
|
end
|
|
|
|
return { artifacts: artifacts }
|
|
end
|
|
|
|
headers = {
|
|
"Content-Type" => "application/json",
|
|
"Accept" => "application/json",
|
|
"Authorization" => "Bearer #{api_key}",
|
|
}
|
|
|
|
ratio_to_dimension = {
|
|
"16:9" => [1536, 640],
|
|
"1:1" => [1024, 1024],
|
|
"21:9" => [1344, 768],
|
|
"2:3" => [896, 1152],
|
|
"3:2" => [1152, 896],
|
|
"4:5" => [832, 1216],
|
|
"5:4" => [1216, 832],
|
|
"9:16" => [640, 1536],
|
|
"9:21" => [768, 1344],
|
|
}
|
|
|
|
if engine.include? "xl"
|
|
width, height = ratio_to_dimension[aspect_ratio] if aspect_ratio
|
|
|
|
width, height = [1024, 1024] if !width || !height
|
|
else
|
|
width, height = [512, 512]
|
|
end
|
|
|
|
payload = {
|
|
text_prompts: [{ text: prompt }],
|
|
cfg_scale: 7,
|
|
clip_guidance_preset: "FAST_BLUE",
|
|
height: width,
|
|
width: height,
|
|
samples: image_count,
|
|
steps: 30,
|
|
}
|
|
|
|
payload[:seed] = seed if seed
|
|
|
|
endpoint = "v1/generation/#{engine}/text-to-image"
|
|
|
|
conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter }
|
|
response = conn.post("#{api_url}/#{endpoint}", payload.to_json, headers)
|
|
|
|
if response.status != 200
|
|
Rails.logger.error(
|
|
"AI stability generator failed with status #{response.status}: #{response.body}}",
|
|
)
|
|
raise Net::HTTPBadResponse
|
|
end
|
|
|
|
JSON.parse(response.body, symbolize_names: true)
|
|
end
|
|
end
|
|
end
|
|
end
|