| 
									
										
										
										
											2023-03-07 16:14:39 -03:00
										 |  |  | # frozen_string_literal: true | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-14 16:03:50 -03:00
										 |  |  | module ::DiscourseAi | 
					
						
							| 
									
										
										
										
											2023-03-07 16:14:39 -03:00
										 |  |  |   module Inference | 
					
						
							| 
									
										
										
										
											2023-03-15 17:02:20 -03:00
										 |  |  |     class OpenAiCompletions | 
					
						
							| 
									
										
										
										
											2023-04-21 16:54:25 +10:00
										 |  |  |       TIMEOUT = 60
 | 
					
						
							| 
									
										
										
										
											2023-10-05 09:00:45 +11:00
										 |  |  |       DEFAULT_RETRIES = 3
 | 
					
						
							|  |  |  |       DEFAULT_RETRY_TIMEOUT_SECONDS = 3
 | 
					
						
							|  |  |  |       RETRY_TIMEOUT_BACKOFF_MULTIPLIER = 3
 | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-03-22 16:00:28 -03:00
										 |  |  |       CompletionFailed = Class.new(StandardError) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-21 16:54:25 +10:00
										 |  |  |       def self.perform!( | 
					
						
							|  |  |  |         messages, | 
					
						
							| 
									
										
										
										
											2023-05-11 10:03:03 -03:00
										 |  |  |         model, | 
					
						
							| 
									
										
										
										
											2023-04-21 16:54:25 +10:00
										 |  |  |         temperature: nil, | 
					
						
							|  |  |  |         top_p: nil, | 
					
						
							|  |  |  |         max_tokens: nil, | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |         functions: nil, | 
					
						
							| 
									
										
										
										
											2023-10-05 09:00:45 +11:00
										 |  |  |         user_id: nil, | 
					
						
							|  |  |  |         retries: DEFAULT_RETRIES, | 
					
						
							|  |  |  |         retry_timeout: DEFAULT_RETRY_TIMEOUT_SECONDS, | 
					
						
							| 
									
										
										
										
											2023-11-01 08:41:31 +11:00
										 |  |  |         post: nil, | 
					
						
							| 
									
										
										
										
											2023-10-05 09:00:45 +11:00
										 |  |  |         &blk | 
					
						
							| 
									
										
										
										
											2023-04-21 16:54:25 +10:00
										 |  |  |       ) | 
					
						
							| 
									
										
										
										
											2023-09-01 11:48:51 +10:00
										 |  |  |         log = nil | 
					
						
							|  |  |  |         response_data = +"" | 
					
						
							|  |  |  |         response_raw = +"" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-21 10:39:51 +10:00
										 |  |  |         url = | 
					
						
							|  |  |  |           if model.include?("gpt-4") | 
					
						
							| 
									
										
										
										
											2023-08-17 11:00:11 +10:00
										 |  |  |             if model.include?("32k") | 
					
						
							|  |  |  |               URI(SiteSetting.ai_openai_gpt4_32k_url) | 
					
						
							|  |  |  |             else | 
					
						
							|  |  |  |               URI(SiteSetting.ai_openai_gpt4_url) | 
					
						
							|  |  |  |             end | 
					
						
							| 
									
										
										
										
											2023-06-21 10:39:51 +10:00
										 |  |  |           else | 
					
						
							| 
									
										
										
										
											2023-08-17 11:00:11 +10:00
										 |  |  |             if model.include?("16k") | 
					
						
							|  |  |  |               URI(SiteSetting.ai_openai_gpt35_16k_url) | 
					
						
							|  |  |  |             else | 
					
						
							|  |  |  |               URI(SiteSetting.ai_openai_gpt35_url) | 
					
						
							|  |  |  |             end | 
					
						
							| 
									
										
										
										
											2023-06-21 10:39:51 +10:00
										 |  |  |           end | 
					
						
							|  |  |  |         headers = { "Content-Type" => "application/json" } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         if url.host.include? ("azure") | 
					
						
							|  |  |  |           headers["api-key"] = SiteSetting.ai_openai_api_key | 
					
						
							|  |  |  |         else | 
					
						
							|  |  |  |           headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}" | 
					
						
							|  |  |  |         end | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-06 10:23:18 +11:00
										 |  |  |         if SiteSetting.ai_openai_organization.present? | 
					
						
							|  |  |  |           headers["OpenAI-Organization"] = SiteSetting.ai_openai_organization | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-21 16:54:25 +10:00
										 |  |  |         payload = { model: model, messages: messages } | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         payload[:temperature] = temperature if temperature | 
					
						
							|  |  |  |         payload[:top_p] = top_p if top_p | 
					
						
							|  |  |  |         payload[:max_tokens] = max_tokens if max_tokens | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |         payload[:functions] = functions if functions | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |         payload[:stream] = true if block_given? | 
					
						
							| 
									
										
										
										
											2023-03-07 16:14:39 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-04-21 16:54:25 +10:00
										 |  |  |         Net::HTTP.start( | 
					
						
							|  |  |  |           url.host, | 
					
						
							|  |  |  |           url.port, | 
					
						
							|  |  |  |           use_ssl: true, | 
					
						
							|  |  |  |           read_timeout: TIMEOUT, | 
					
						
							|  |  |  |           open_timeout: TIMEOUT, | 
					
						
							|  |  |  |           write_timeout: TIMEOUT, | 
					
						
							|  |  |  |         ) do |http| | 
					
						
							|  |  |  |           request = Net::HTTP::Post.new(url, headers) | 
					
						
							| 
									
										
										
										
											2023-04-26 11:44:29 +10:00
										 |  |  |           request_body = payload.to_json | 
					
						
							|  |  |  |           request.body = request_body | 
					
						
							| 
									
										
										
										
											2023-03-20 16:43:51 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |           http.request(request) do |response| | 
					
						
							| 
									
										
										
										
											2023-10-05 09:00:45 +11:00
										 |  |  |             if retries > 0 && response.code.to_i == 429
 | 
					
						
							|  |  |  |               sleep(retry_timeout) | 
					
						
							|  |  |  |               retries -= 1
 | 
					
						
							|  |  |  |               retry_timeout *= RETRY_TIMEOUT_BACKOFF_MULTIPLIER | 
					
						
							|  |  |  |               return( | 
					
						
							|  |  |  |                 perform!( | 
					
						
							|  |  |  |                   messages, | 
					
						
							|  |  |  |                   model, | 
					
						
							|  |  |  |                   temperature: temperature, | 
					
						
							|  |  |  |                   top_p: top_p, | 
					
						
							|  |  |  |                   max_tokens: max_tokens, | 
					
						
							|  |  |  |                   functions: functions, | 
					
						
							|  |  |  |                   user_id: user_id, | 
					
						
							|  |  |  |                   retries: retries, | 
					
						
							|  |  |  |                   retry_timeout: retry_timeout, | 
					
						
							|  |  |  |                   &blk | 
					
						
							|  |  |  |                 ) | 
					
						
							|  |  |  |               ) | 
					
						
							|  |  |  |             elsif response.code.to_i != 200
 | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |               Rails.logger.error( | 
					
						
							|  |  |  |                 "OpenAiCompletions: status: #{response.code.to_i} - body: #{response.body}", | 
					
						
							|  |  |  |               ) | 
					
						
							| 
									
										
										
										
											2023-10-05 09:00:45 +11:00
										 |  |  |               raise CompletionFailed, "status: #{response.code.to_i} - body: #{response.body}" | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |             end | 
					
						
							| 
									
										
										
										
											2023-04-26 11:44:29 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |             log = | 
					
						
							|  |  |  |               AiApiAuditLog.create!( | 
					
						
							|  |  |  |                 provider_id: AiApiAuditLog::Provider::OpenAI, | 
					
						
							|  |  |  |                 raw_request_payload: request_body, | 
					
						
							|  |  |  |                 user_id: user_id, | 
					
						
							| 
									
										
										
										
											2023-11-01 08:41:31 +11:00
										 |  |  |                 post_id: post&.id, | 
					
						
							|  |  |  |                 topic_id: post&.topic_id, | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |               ) | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-05 09:00:45 +11:00
										 |  |  |             if !blk | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |               response_body = response.read_body | 
					
						
							|  |  |  |               parsed_response = JSON.parse(response_body, symbolize_names: true) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |               log.update!( | 
					
						
							|  |  |  |                 raw_response_payload: response_body, | 
					
						
							|  |  |  |                 request_tokens: parsed_response.dig(:usage, :prompt_tokens), | 
					
						
							|  |  |  |                 response_tokens: parsed_response.dig(:usage, :completion_tokens), | 
					
						
							|  |  |  |               ) | 
					
						
							|  |  |  |               return parsed_response | 
					
						
							|  |  |  |             end | 
					
						
							| 
									
										
										
										
											2023-04-21 16:54:25 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |             begin | 
					
						
							|  |  |  |               cancelled = false | 
					
						
							|  |  |  |               cancel = lambda { cancelled = true } | 
					
						
							| 
									
										
										
										
											2023-04-21 16:54:25 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |               leftover = "" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |               response.read_body do |chunk| | 
					
						
							|  |  |  |                 if cancelled | 
					
						
							|  |  |  |                   http.finish | 
					
						
							|  |  |  |                   return | 
					
						
							| 
									
										
										
										
											2023-04-26 11:44:29 +10:00
										 |  |  |                 end | 
					
						
							| 
									
										
										
										
											2023-04-21 16:54:25 +10:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |                 response_raw << chunk | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |                 (leftover + chunk) | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |                   .split("\n") | 
					
						
							|  |  |  |                   .each do |line| | 
					
						
							|  |  |  |                     data = line.split("data: ", 2)[1] | 
					
						
							|  |  |  |                     next if !data || data == "[DONE]" | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |                     next if cancelled | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     partial = nil | 
					
						
							|  |  |  |                     begin | 
					
						
							|  |  |  |                       partial = JSON.parse(data, symbolize_names: true) | 
					
						
							|  |  |  |                       leftover = "" | 
					
						
							|  |  |  |                     rescue JSON::ParserError | 
					
						
							|  |  |  |                       leftover = line | 
					
						
							|  |  |  |                     end | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |                     if partial | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |                       response_data << partial.dig(:choices, 0, :delta, :content).to_s | 
					
						
							| 
									
										
										
										
											2023-06-20 08:45:31 +10:00
										 |  |  |                       response_data << partial.dig(:choices, 0, :delta, :function_call).to_s | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-10-05 09:00:45 +11:00
										 |  |  |                       blk.call(partial, cancel) | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |                     end | 
					
						
							|  |  |  |                   end | 
					
						
							|  |  |  |               rescue IOError | 
					
						
							|  |  |  |                 raise if !cancelled | 
					
						
							| 
									
										
										
										
											2023-04-21 16:54:25 +10:00
										 |  |  |               end | 
					
						
							|  |  |  |             end | 
					
						
							| 
									
										
										
										
											2023-08-11 15:08:54 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             return response_data | 
					
						
							| 
									
										
										
										
											2023-05-05 15:28:31 -03:00
										 |  |  |           end | 
					
						
							| 
									
										
										
										
											2023-04-21 16:54:25 +10:00
										 |  |  |         end | 
					
						
							| 
									
										
										
										
											2023-09-01 11:48:51 +10:00
										 |  |  |       ensure | 
					
						
							|  |  |  |         if log && block_given? | 
					
						
							|  |  |  |           request_tokens = DiscourseAi::Tokenizer::OpenAiTokenizer.size(extract_prompt(messages)) | 
					
						
							|  |  |  |           response_tokens = DiscourseAi::Tokenizer::OpenAiTokenizer.size(response_data) | 
					
						
							|  |  |  |           log.update!( | 
					
						
							|  |  |  |             raw_response_payload: response_raw, | 
					
						
							|  |  |  |             request_tokens: request_tokens, | 
					
						
							|  |  |  |             response_tokens: response_tokens, | 
					
						
							|  |  |  |           ) | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  |         if log && Rails.env.development? | 
					
						
							|  |  |  |           puts "OpenAiCompletions: request_tokens #{log.request_tokens} response_tokens #{log.response_tokens}" | 
					
						
							|  |  |  |         end | 
					
						
							| 
									
										
										
										
											2023-04-26 11:44:29 +10:00
										 |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       def self.extract_prompt(messages) | 
					
						
							|  |  |  |         messages.map { |message| message[:content] || message["content"] || "" }.join("\n") | 
					
						
							| 
									
										
										
										
											2023-03-07 16:14:39 -03:00
										 |  |  |       end | 
					
						
							|  |  |  |     end | 
					
						
							|  |  |  |   end | 
					
						
							|  |  |  | end |