| 
									
										
										
										
											2023-07-27 13:55:32 -03:00
										 |  |  | # frozen_string_literal: true | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | module ::DiscourseAi | 
					
						
							|  |  |  |   module Inference | 
					
						
							|  |  |  |     class HuggingFaceTextGeneration | 
					
						
							|  |  |  |       CompletionFailed = Class.new(StandardError) | 
					
						
							| 
									
										
										
										
											2023-09-05 11:08:23 -03:00
										 |  |  |       TIMEOUT = 120
 | 
					
						
							| 
									
										
										
										
											2023-07-27 13:55:32 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |       def self.perform!( | 
					
						
							|  |  |  |         prompt, | 
					
						
							|  |  |  |         model, | 
					
						
							|  |  |  |         temperature: 0.7, | 
					
						
							|  |  |  |         top_p: nil, | 
					
						
							|  |  |  |         top_k: nil, | 
					
						
							|  |  |  |         typical_p: nil, | 
					
						
							|  |  |  |         max_tokens: 2000, | 
					
						
							|  |  |  |         repetition_penalty: 1.1, | 
					
						
							| 
									
										
										
										
											2023-08-02 17:00:00 -03:00
										 |  |  |         user_id: nil, | 
					
						
							|  |  |  |         tokenizer: DiscourseAi::Tokenizer::Llama2Tokenizer, | 
					
						
							| 
									
										
										
										
											2023-08-03 15:29:30 -03:00
										 |  |  |         token_limit: nil | 
					
						
							| 
									
										
										
										
											2023-07-27 13:55:32 -03:00
										 |  |  |       ) | 
					
						
							|  |  |  |         raise CompletionFailed if model.blank? | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         url = URI(SiteSetting.ai_hugging_face_api_url) | 
					
						
							|  |  |  |         if block_given? | 
					
						
							|  |  |  |           url.path = "/generate_stream" | 
					
						
							|  |  |  |         else | 
					
						
							|  |  |  |           url.path = "/generate" | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  |         headers = { "Content-Type" => "application/json" } | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-02 17:00:00 -03:00
										 |  |  |         if SiteSetting.ai_hugging_face_api_key.present? | 
					
						
							|  |  |  |           headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_api_key}" | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-03 15:29:30 -03:00
										 |  |  |         token_limit = token_limit || SiteSetting.ai_hugging_face_token_limit | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-27 13:55:32 -03:00
										 |  |  |         parameters = {} | 
					
						
							|  |  |  |         payload = { inputs: prompt, parameters: parameters } | 
					
						
							| 
									
										
										
										
											2023-08-02 17:00:00 -03:00
										 |  |  |         prompt_size = tokenizer.size(prompt) | 
					
						
							| 
									
										
										
										
											2023-07-27 13:55:32 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |         parameters[:top_p] = top_p if top_p | 
					
						
							|  |  |  |         parameters[:top_k] = top_k if top_k | 
					
						
							|  |  |  |         parameters[:typical_p] = typical_p if typical_p | 
					
						
							| 
									
										
										
										
											2023-08-02 17:00:00 -03:00
										 |  |  |         parameters[:max_new_tokens] = token_limit - prompt_size | 
					
						
							| 
									
										
										
										
											2023-07-27 13:55:32 -03:00
										 |  |  |         parameters[:temperature] = temperature if temperature | 
					
						
							|  |  |  |         parameters[:repetition_penalty] = repetition_penalty if repetition_penalty | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         Net::HTTP.start( | 
					
						
							|  |  |  |           url.host, | 
					
						
							|  |  |  |           url.port, | 
					
						
							|  |  |  |           use_ssl: url.scheme == "https", | 
					
						
							|  |  |  |           read_timeout: TIMEOUT, | 
					
						
							|  |  |  |           open_timeout: TIMEOUT, | 
					
						
							|  |  |  |           write_timeout: TIMEOUT, | 
					
						
							|  |  |  |         ) do |http| | 
					
						
							|  |  |  |           request = Net::HTTP::Post.new(url, headers) | 
					
						
							|  |  |  |           request_body = payload.to_json | 
					
						
							|  |  |  |           request.body = request_body | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           http.request(request) do |response| | 
					
						
							|  |  |  |             if response.code.to_i != 200
 | 
					
						
							|  |  |  |               Rails.logger.error( | 
					
						
							|  |  |  |                 "HuggingFaceTextGeneration: status: #{response.code.to_i} - body: #{response.body}", | 
					
						
							|  |  |  |               ) | 
					
						
							|  |  |  |               raise CompletionFailed | 
					
						
							|  |  |  |             end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             log = | 
					
						
							|  |  |  |               AiApiAuditLog.create!( | 
					
						
							|  |  |  |                 provider_id: AiApiAuditLog::Provider::HuggingFaceTextGeneration, | 
					
						
							|  |  |  |                 raw_request_payload: request_body, | 
					
						
							|  |  |  |                 user_id: user_id, | 
					
						
							|  |  |  |               ) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |             if !block_given? | 
					
						
							|  |  |  |               response_body = response.read_body | 
					
						
							|  |  |  |               parsed_response = JSON.parse(response_body, symbolize_names: true) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |               log.update!( | 
					
						
							|  |  |  |                 raw_response_payload: response_body, | 
					
						
							| 
									
										
										
										
											2023-08-02 17:00:00 -03:00
										 |  |  |                 request_tokens: tokenizer.size(prompt), | 
					
						
							|  |  |  |                 response_tokens: tokenizer.size(parsed_response[:generated_text]), | 
					
						
							| 
									
										
										
										
											2023-07-27 13:55:32 -03:00
										 |  |  |               ) | 
					
						
							|  |  |  |               return parsed_response | 
					
						
							|  |  |  |             end | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-11 15:08:54 -03:00
										 |  |  |             response_data = +"" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-27 13:55:32 -03:00
										 |  |  |             begin | 
					
						
							|  |  |  |               cancelled = false | 
					
						
							|  |  |  |               cancel = lambda { cancelled = true } | 
					
						
							|  |  |  |               response_raw = +"" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |               response.read_body do |chunk| | 
					
						
							|  |  |  |                 if cancelled | 
					
						
							|  |  |  |                   http.finish | 
					
						
							|  |  |  |                   return | 
					
						
							|  |  |  |                 end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 response_raw << chunk | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                 chunk | 
					
						
							|  |  |  |                   .split("\n") | 
					
						
							|  |  |  |                   .each do |line| | 
					
						
							| 
									
										
										
										
											2023-08-11 15:08:54 -03:00
										 |  |  |                     data = line.split("data:", 2)[1] | 
					
						
							| 
									
										
										
										
											2023-07-27 13:55:32 -03:00
										 |  |  |                     next if !data || data.squish == "[DONE]" | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                     if !cancelled | 
					
						
							|  |  |  |                       begin | 
					
						
							|  |  |  |                         # partial contains the entire payload till now | 
					
						
							|  |  |  |                         partial = JSON.parse(data, symbolize_names: true) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |                         # this is the last chunk and contains the full response | 
					
						
							|  |  |  |                         next if partial[:token][:special] == true | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-08-11 15:08:54 -03:00
										 |  |  |                         response_data << partial[:token][:text].to_s | 
					
						
							| 
									
										
										
										
											2023-07-27 13:55:32 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |                         yield partial, cancel | 
					
						
							|  |  |  |                       rescue JSON::ParserError | 
					
						
							|  |  |  |                         nil | 
					
						
							|  |  |  |                       end | 
					
						
							|  |  |  |                     end | 
					
						
							|  |  |  |                   end | 
					
						
							|  |  |  |               rescue IOError | 
					
						
							|  |  |  |                 raise if !cancelled | 
					
						
							|  |  |  |               ensure | 
					
						
							|  |  |  |                 log.update!( | 
					
						
							|  |  |  |                   raw_response_payload: response_raw, | 
					
						
							| 
									
										
										
										
											2023-08-02 17:00:00 -03:00
										 |  |  |                   request_tokens: tokenizer.size(prompt), | 
					
						
							|  |  |  |                   response_tokens: tokenizer.size(response_data), | 
					
						
							| 
									
										
										
										
											2023-07-27 13:55:32 -03:00
										 |  |  |                 ) | 
					
						
							|  |  |  |               end | 
					
						
							|  |  |  |             end | 
					
						
							| 
									
										
										
										
											2023-08-11 15:08:54 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |             return response_data | 
					
						
							| 
									
										
										
										
											2023-07-27 13:55:32 -03:00
										 |  |  |           end | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def self.try_parse(data) | 
					
						
							|  |  |  |           JSON.parse(data, symbolize_names: true) | 
					
						
							|  |  |  |         rescue JSON::ParserError | 
					
						
							|  |  |  |           nil | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  |     end | 
					
						
							|  |  |  |   end | 
					
						
							|  |  |  | end |