| 
									
										
										
										
											2023-05-11 10:03:03 -03:00
										 |  |  | # frozen_string_literal: true | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | module DiscourseAi | 
					
						
							|  |  |  |   module AiBot | 
					
						
							|  |  |  |     class OpenAiBot < Bot | 
					
						
							|  |  |  |       def self.can_reply_as?(bot_user) | 
					
						
							|  |  |  |         open_ai_bot_ids = [ | 
					
						
							|  |  |  |           DiscourseAi::AiBot::EntryPoint::GPT4_ID, | 
					
						
							|  |  |  |           DiscourseAi::AiBot::EntryPoint::GPT3_5_TURBO_ID, | 
					
						
							|  |  |  |         ] | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         open_ai_bot_ids.include?(bot_user.id) | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       def prompt_limit | 
					
						
							| 
									
										
										
										
											2023-08-22 08:36:41 +10:00
										 |  |  |         # note this is about 100 tokens over, OpenAI have a more optimal representation | 
					
						
							|  |  |  |         @function_size ||= tokenize(available_functions.to_json).length | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-28 15:32:22 +10:00
										 |  |  |         # provide a buffer of 120 tokens - our function counting is not | 
					
						
							|  |  |  |         # 100% accurate and getting numbers to align exactly is very hard | 
					
						
							|  |  |  |         buffer = @function_size + reply_params[:max_tokens] + 120
 | 
					
						
							| 
									
										
										
										
											2023-08-22 08:36:41 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-20 17:45:54 +10:00
										 |  |  |         if bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT4_ID | 
					
						
							| 
									
										
										
										
											2023-08-22 08:36:41 +10:00
										 |  |  |           8192 - buffer | 
					
						
							| 
									
										
										
										
											2023-05-20 17:45:54 +10:00
										 |  |  |         else | 
					
						
							| 
									
										
										
										
											2023-08-22 08:36:41 +10:00
										 |  |  |           16_384 - buffer | 
					
						
							| 
									
										
										
										
											2023-05-20 17:45:54 +10:00
										 |  |  |         end | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       def reply_params | 
					
						
							| 
									
										
										
										
											2023-06-23 10:02:04 +10:00
										 |  |  |         # technically we could allow GPT-3.5 16k more tokens | 
					
						
							|  |  |  |         # but lets just keep it here for now | 
					
						
							|  |  |  |         { temperature: 0.4, top_p: 0.9, max_tokens: 2500 } | 
					
						
							| 
									
										
										
										
											2023-05-11 10:03:03 -03:00
										 |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-22 08:36:41 +10:00
										 |  |  |       def extra_tokens_per_message | 
					
						
							|  |  |  |         # open ai defines about 4 tokens per message of overhead | 
					
						
							|  |  |  |         4
 | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-22 12:09:14 +10:00
										 |  |  |       def submit_prompt( | 
					
						
							|  |  |  |         prompt, | 
					
						
							|  |  |  |         prefer_low_cost: false, | 
					
						
							|  |  |  |         temperature: nil, | 
					
						
							|  |  |  |         top_p: nil, | 
					
						
							|  |  |  |         max_tokens: nil, | 
					
						
							|  |  |  |         &blk | 
					
						
							|  |  |  |       ) | 
					
						
							|  |  |  |         params = | 
					
						
							|  |  |  |           reply_params.merge( | 
					
						
							|  |  |  |             temperature: temperature, | 
					
						
							|  |  |  |             top_p: top_p, | 
					
						
							|  |  |  |             max_tokens: max_tokens, | 
					
						
							|  |  |  |           ) { |key, old_value, new_value| new_value.nil? ? old_value : new_value } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |         model = model_for(low_cost: prefer_low_cost) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-21 17:10:30 +10:00
										 |  |  |         params[:functions] = available_functions if available_functions.present? | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-22 12:09:14 +10:00
										 |  |  |         DiscourseAi::Inference::OpenAiCompletions.perform!(prompt, model, **params, &blk) | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-29 10:43:58 +10:00
										 |  |  |       def tokenizer | 
					
						
							|  |  |  |         DiscourseAi::Tokenizer::OpenAiTokenizer | 
					
						
							| 
									
										
										
										
											2023-05-22 12:09:14 +10:00
										 |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |       def model_for(low_cost: false) | 
					
						
							| 
									
										
										
										
											2023-08-17 11:00:11 +10:00
										 |  |  |         return "gpt-4" if bot_user.id == DiscourseAi::AiBot::EntryPoint::GPT4_ID && !low_cost | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |         "gpt-3.5-turbo-16k" | 
					
						
							| 
									
										
										
										
											2023-06-01 09:10:33 +10:00
										 |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-20 15:44:03 +10:00
										 |  |  |       def clean_username(username) | 
					
						
							|  |  |  |         if username.match?(/\0[a-zA-Z0-9_-]{1,64}\z/) | 
					
						
							|  |  |  |           username | 
					
						
							|  |  |  |         else | 
					
						
							|  |  |  |           # not the best in the world, but this is what we have to work with | 
					
						
							|  |  |  |           # if sites enable unicode usernames this can get messy | 
					
						
							|  |  |  |           username.gsub(/[^a-zA-Z0-9_-]/, "_")[0..63] | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-23 07:49:36 +10:00
										 |  |  |       def include_function_instructions_in_system_prompt? | 
					
						
							|  |  |  |         # open ai uses a bespoke system for function calls | 
					
						
							|  |  |  |         false | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-11 10:03:03 -03:00
										 |  |  |       private | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-23 07:49:36 +10:00
										 |  |  |       def populate_functions(partial:, reply:, functions:, done:) | 
					
						
							|  |  |  |         return if !partial | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |         fn = partial.dig(:choices, 0, :delta, :function_call) | 
					
						
							|  |  |  |         if fn | 
					
						
							|  |  |  |           functions.add_function(fn[:name]) if fn[:name].present? | 
					
						
							| 
									
										
										
										
											2023-09-14 16:46:56 +10:00
										 |  |  |           functions.add_argument_fragment(fn[:arguments]) if !fn[:arguments].nil? | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |         end | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       def build_message(poster_username, content, function: false, system: false) | 
					
						
							| 
									
										
										
										
											2023-05-20 17:45:54 +10:00
										 |  |  |         is_bot = poster_username == bot_user.username | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |         if function | 
					
						
							|  |  |  |           role = "function" | 
					
						
							|  |  |  |         elsif system | 
					
						
							| 
									
										
										
										
											2023-05-20 17:45:54 +10:00
										 |  |  |           role = "system" | 
					
						
							|  |  |  |         else | 
					
						
							|  |  |  |           role = is_bot ? "assistant" : "user" | 
					
						
							|  |  |  |         end | 
					
						
							| 
									
										
										
										
											2023-05-11 10:03:03 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |         result = { role: role, content: content } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if function | 
					
						
							|  |  |  |           result[:name] = poster_username | 
					
						
							| 
									
										
										
										
											2023-06-20 15:44:03 +10:00
										 |  |  |         elsif !system && poster_username != bot_user.username && poster_username.present? | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |           # Open AI restrict name to 64 chars and only A-Za-z._ (work around) | 
					
						
							| 
									
										
										
										
											2023-06-20 15:44:03 +10:00
										 |  |  |           result[:name] = clean_username(poster_username) | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |         end | 
					
						
							| 
									
										
										
										
											2023-05-11 10:03:03 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |         result | 
					
						
							| 
									
										
										
										
											2023-05-11 10:03:03 -03:00
										 |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-23 23:08:17 +10:00
										 |  |  |       def get_delta(partial, _context) | 
					
						
							|  |  |  |         partial.dig(:choices, 0, :delta, :content).to_s | 
					
						
							| 
									
										
										
										
											2023-05-11 10:03:03 -03:00
										 |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-16 14:38:21 -03:00
										 |  |  |       def get_updated_title(prompt) | 
					
						
							|  |  |  |         DiscourseAi::Inference::OpenAiCompletions.perform!( | 
					
						
							|  |  |  |           prompt, | 
					
						
							|  |  |  |           model_for, | 
					
						
							|  |  |  |           temperature: 0.7, | 
					
						
							|  |  |  |           top_p: 0.9, | 
					
						
							|  |  |  |           max_tokens: 40, | 
					
						
							|  |  |  |         ).dig(:choices, 0, :message, :content) | 
					
						
							|  |  |  |       end | 
					
						
							| 
									
										
										
										
											2023-05-11 10:03:03 -03:00
										 |  |  |     end | 
					
						
							|  |  |  |   end | 
					
						
							|  |  |  | end |