mirror of
				https://github.com/discourse/discourse-ai.git
				synced 2025-10-25 03:28:40 +00:00 
			
		
		
		
	DallE 3 supports tall/square and wide images. This adds support to the 3 variants. (wide / tall / square)
		
			
				
	
	
		
			144 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Ruby
		
	
	
	
	
	
			
		
		
	
	
			144 lines
		
	
	
		
			3.8 KiB
		
	
	
	
		
			Ruby
		
	
	
	
	
	
| # frozen_string_literal: true
 | |
| 
 | |
| module DiscourseAi
 | |
|   module AiBot
 | |
|     module Tools
 | |
|       class DallE < Tool
 | |
|         def self.signature
 | |
|           {
 | |
|             name: name,
 | |
|             description: "Renders images from supplied descriptions",
 | |
|             parameters: [
 | |
|               {
 | |
|                 name: "prompts",
 | |
|                 description:
 | |
|                   "The prompts used to generate or create or draw the image (5000 chars or less, be creative) up to 4 prompts",
 | |
|                 type: "array",
 | |
|                 item_type: "string",
 | |
|                 required: true,
 | |
|               },
 | |
|               {
 | |
|                 name: "aspect_ratio",
 | |
|                 description: "The aspect ratio (optional, square by default)",
 | |
|                 type: "string",
 | |
|                 required: false,
 | |
|                 enum: %w[tall square wide],
 | |
|               },
 | |
|             ],
 | |
|           }
 | |
|         end
 | |
| 
 | |
|         def self.name
 | |
|           "dall_e"
 | |
|         end
 | |
| 
 | |
|         def prompts
 | |
|           parameters[:prompts]
 | |
|         end
 | |
| 
 | |
|         def aspect_ratio
 | |
|           parameters[:aspect_ratio]
 | |
|         end
 | |
| 
 | |
|         def chain_next_response?
 | |
|           false
 | |
|         end
 | |
| 
 | |
|         def invoke
 | |
|           # max 4 prompts
 | |
|           max_prompts = prompts.take(4)
 | |
|           progress = prompts.first
 | |
| 
 | |
|           yield(progress)
 | |
| 
 | |
|           results = nil
 | |
| 
 | |
|           # this ensures multisite safety since background threads
 | |
|           # generate the images
 | |
|           api_key = SiteSetting.ai_openai_api_key
 | |
|           api_url = SiteSetting.ai_openai_dall_e_3_url
 | |
| 
 | |
|           size = "1024x1024"
 | |
|           if aspect_ratio == "tall"
 | |
|             size = "1024x1792"
 | |
|           elsif aspect_ratio == "wide"
 | |
|             size = "1792x1024"
 | |
|           end
 | |
| 
 | |
|           threads = []
 | |
|           max_prompts.each_with_index do |prompt, index|
 | |
|             threads << Thread.new(prompt) do |inner_prompt|
 | |
|               attempts = 0
 | |
|               begin
 | |
|                 DiscourseAi::Inference::OpenAiImageGenerator.perform!(
 | |
|                   inner_prompt,
 | |
|                   size: size,
 | |
|                   api_key: api_key,
 | |
|                   api_url: api_url,
 | |
|                 )
 | |
|               rescue => e
 | |
|                 attempts += 1
 | |
|                 sleep 2
 | |
|                 retry if attempts < 3
 | |
|                 Discourse.warn_exception(
 | |
|                   e,
 | |
|                   message: "Failed to generate image for prompt #{prompt}",
 | |
|                 )
 | |
|                 nil
 | |
|               end
 | |
|             end
 | |
|           end
 | |
| 
 | |
|           break if threads.all? { |t| t.join(2) } while true
 | |
| 
 | |
|           results = threads.filter_map(&:value)
 | |
| 
 | |
|           if results.blank?
 | |
|             return { prompts: max_prompts, error: "Something went wrong, could not generate image" }
 | |
|           end
 | |
| 
 | |
|           uploads = []
 | |
| 
 | |
|           results.each_with_index do |result, index|
 | |
|             result[:data].each do |image|
 | |
|               Tempfile.create("v1_txt2img_#{index}.png") do |file|
 | |
|                 file.binmode
 | |
|                 file.write(Base64.decode64(image[:b64_json]))
 | |
|                 file.rewind
 | |
|                 uploads << {
 | |
|                   prompt: image[:revised_prompt],
 | |
|                   upload:
 | |
|                     UploadCreator.new(
 | |
|                       file,
 | |
|                       "image.png",
 | |
|                       for_private_message: context[:private_message],
 | |
|                     ).create_for(bot_user.id),
 | |
|                 }
 | |
|               end
 | |
|             end
 | |
|           end
 | |
| 
 | |
|           self.custom_raw = <<~RAW
 | |
| 
 | |
|             [grid]
 | |
|             #{
 | |
|             uploads
 | |
|               .map { |item| "![#{item[:prompt].gsub(/\|\'\"/, "")}](#{item[:upload].short_url})" }
 | |
|               .join(" ")
 | |
|           }
 | |
|             [/grid]
 | |
|           RAW
 | |
| 
 | |
|           { prompts: uploads.map { |item| item[:prompt] } }
 | |
|         end
 | |
| 
 | |
|         protected
 | |
| 
 | |
|         def description_args
 | |
|           { prompt: prompts.first }
 | |
|         end
 | |
|       end
 | |
|     end
 | |
|   end
 | |
| end
 |