| 
									
										
										
										
											2024-01-04 10:44:07 -03:00
										 |  |  | # frozen_string_literal: true | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | module DiscourseAi | 
					
						
							|  |  |  |   module AiBot | 
					
						
							|  |  |  |     module Tools | 
					
						
							|  |  |  |       class DallE < Tool | 
					
						
							|  |  |  |         def self.signature | 
					
						
							|  |  |  |           { | 
					
						
							|  |  |  |             name: name, | 
					
						
							|  |  |  |             description: "Renders images from supplied descriptions", | 
					
						
							|  |  |  |             parameters: [ | 
					
						
							|  |  |  |               { | 
					
						
							|  |  |  |                 name: "prompts", | 
					
						
							|  |  |  |                 description: | 
					
						
							|  |  |  |                   "The prompts used to generate or create or draw the image (5000 chars or less, be creative) up to 4 prompts", | 
					
						
							|  |  |  |                 type: "array", | 
					
						
							|  |  |  |                 item_type: "string", | 
					
						
							|  |  |  |                 required: true, | 
					
						
							|  |  |  |               }, | 
					
						
							| 
									
										
										
										
											2024-05-28 16:21:40 +10:00
										 |  |  |               { | 
					
						
							|  |  |  |                 name: "aspect_ratio", | 
					
						
							|  |  |  |                 description: "The aspect ratio (optional, square by default)", | 
					
						
							|  |  |  |                 type: "string", | 
					
						
							|  |  |  |                 required: false, | 
					
						
							|  |  |  |                 enum: %w[tall square wide], | 
					
						
							|  |  |  |               }, | 
					
						
							| 
									
										
										
										
											2024-01-04 10:44:07 -03:00
										 |  |  |             ], | 
					
						
							|  |  |  |           } | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def self.name | 
					
						
							|  |  |  |           "dall_e" | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def prompts | 
					
						
							|  |  |  |           parameters[:prompts] | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-28 16:21:40 +10:00
										 |  |  |         def aspect_ratio | 
					
						
							|  |  |  |           parameters[:aspect_ratio] | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 10:44:07 -03:00
										 |  |  |         def chain_next_response? | 
					
						
							|  |  |  |           false | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-07 21:55:46 +10:00
										 |  |  |         def invoke | 
					
						
							| 
									
										
										
										
											2024-01-04 10:44:07 -03:00
										 |  |  |           # max 4 prompts | 
					
						
							|  |  |  |           max_prompts = prompts.take(4) | 
					
						
							| 
									
										
										
										
											2024-01-09 23:20:28 +11:00
										 |  |  |           progress = prompts.first | 
					
						
							| 
									
										
										
										
											2024-01-04 10:44:07 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |           yield(progress) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           results = nil | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           # this ensures multisite safety since background threads | 
					
						
							|  |  |  |           # generate the images | 
					
						
							|  |  |  |           api_key = SiteSetting.ai_openai_api_key | 
					
						
							|  |  |  |           api_url = SiteSetting.ai_openai_dall_e_3_url | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-05-28 16:21:40 +10:00
										 |  |  |           size = "1024x1024" | 
					
						
							|  |  |  |           if aspect_ratio == "tall" | 
					
						
							|  |  |  |             size = "1024x1792" | 
					
						
							|  |  |  |           elsif aspect_ratio == "wide" | 
					
						
							|  |  |  |             size = "1792x1024" | 
					
						
							|  |  |  |           end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-04 10:44:07 -03:00
										 |  |  |           threads = [] | 
					
						
							|  |  |  |           max_prompts.each_with_index do |prompt, index| | 
					
						
							|  |  |  |             threads << Thread.new(prompt) do |inner_prompt| | 
					
						
							|  |  |  |               attempts = 0
 | 
					
						
							|  |  |  |               begin | 
					
						
							|  |  |  |                 DiscourseAi::Inference::OpenAiImageGenerator.perform!( | 
					
						
							|  |  |  |                   inner_prompt, | 
					
						
							| 
									
										
										
										
											2024-05-28 16:21:40 +10:00
										 |  |  |                   size: size, | 
					
						
							| 
									
										
										
										
											2024-01-04 10:44:07 -03:00
										 |  |  |                   api_key: api_key, | 
					
						
							|  |  |  |                   api_url: api_url, | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |               rescue => e | 
					
						
							|  |  |  |                 attempts += 1
 | 
					
						
							|  |  |  |                 sleep 2
 | 
					
						
							|  |  |  |                 retry if attempts < 3
 | 
					
						
							|  |  |  |                 Discourse.warn_exception( | 
					
						
							|  |  |  |                   e, | 
					
						
							|  |  |  |                   message: "Failed to generate image for prompt #{prompt}", | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |                 nil | 
					
						
							|  |  |  |               end | 
					
						
							|  |  |  |             end | 
					
						
							|  |  |  |           end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2024-01-09 23:20:28 +11:00
										 |  |  |           break if threads.all? { |t| t.join(2) } while true | 
					
						
							| 
									
										
										
										
											2024-01-04 10:44:07 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |           results = threads.filter_map(&:value) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           if results.blank? | 
					
						
							|  |  |  |             return { prompts: max_prompts, error: "Something went wrong, could not generate image" } | 
					
						
							|  |  |  |           end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           uploads = [] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           results.each_with_index do |result, index| | 
					
						
							|  |  |  |             result[:data].each do |image| | 
					
						
							|  |  |  |               Tempfile.create("v1_txt2img_#{index}.png") do |file| | 
					
						
							|  |  |  |                 file.binmode | 
					
						
							|  |  |  |                 file.write(Base64.decode64(image[:b64_json])) | 
					
						
							|  |  |  |                 file.rewind | 
					
						
							|  |  |  |                 uploads << { | 
					
						
							|  |  |  |                   prompt: image[:revised_prompt], | 
					
						
							| 
									
										
										
										
											2024-05-07 21:55:46 +10:00
										 |  |  |                   upload: | 
					
						
							|  |  |  |                     UploadCreator.new( | 
					
						
							|  |  |  |                       file, | 
					
						
							|  |  |  |                       "image.png", | 
					
						
							|  |  |  |                       for_private_message: context[:private_message], | 
					
						
							|  |  |  |                     ).create_for(bot_user.id), | 
					
						
							| 
									
										
										
										
											2024-01-04 10:44:07 -03:00
										 |  |  |                 } | 
					
						
							|  |  |  |               end | 
					
						
							|  |  |  |             end | 
					
						
							|  |  |  |           end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           self.custom_raw = <<~RAW | 
					
						
							| 
									
										
										
										
											2024-01-05 14:39:32 +11:00
										 |  |  | 
 | 
					
						
							|  |  |  |             [grid] | 
					
						
							|  |  |  |             #{ | 
					
						
							| 
									
										
										
										
											2024-01-04 10:44:07 -03:00
										 |  |  |             uploads | 
					
						
							| 
									
										
										
										
											2024-05-07 21:55:46 +10:00
										 |  |  |               .map { |item| "![#{item[:prompt].gsub(/\|\'\"/, "")}](#{item[:upload].short_url})" } | 
					
						
							| 
									
										
										
										
											2024-01-04 10:44:07 -03:00
										 |  |  |               .join(" ") | 
					
						
							|  |  |  |           } | 
					
						
							| 
									
										
										
										
											2024-01-05 14:39:32 +11:00
										 |  |  |             [/grid]
 | 
					
						
							|  |  |  |           RAW | 
					
						
							| 
									
										
										
										
											2024-01-04 10:44:07 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |           { prompts: uploads.map { |item| item[:prompt] } } | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         protected | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def description_args | 
					
						
							|  |  |  |           { prompt: prompts.first } | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  |     end | 
					
						
							|  |  |  |   end | 
					
						
							|  |  |  | end |