| 
									
										
										
										
											2023-07-13 12:41:36 -03:00
										 |  |  | # frozen_string_literal: true | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  | module DiscourseAi | 
					
						
							|  |  |  |   module Embeddings | 
					
						
							|  |  |  |     module Strategies | 
					
						
							|  |  |  |       class Truncation | 
					
						
							|  |  |  |         attr_reader :processed_target, :digest | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def self.id | 
					
						
							|  |  |  |           1
 | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def id | 
					
						
							|  |  |  |           self.class.id | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def version | 
					
						
							|  |  |  |           1
 | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def initialize(target, model) | 
					
						
							|  |  |  |           @model = model | 
					
						
							|  |  |  |           @target = target | 
					
						
							|  |  |  |           @tokenizer = @model.tokenizer | 
					
						
							| 
									
										
										
										
											2023-08-16 15:09:41 -03:00
										 |  |  |           @max_length = @model.max_sequence_length - 2
 | 
					
						
							|  |  |  |           @processed_target = nil | 
					
						
							| 
									
										
										
										
											2023-07-13 12:41:36 -03:00
										 |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         # Need a better name for this method | 
					
						
							|  |  |  |         def process! | 
					
						
							|  |  |  |           case @target | 
					
						
							|  |  |  |           when Topic | 
					
						
							| 
									
										
										
										
											2023-08-16 15:09:41 -03:00
										 |  |  |             @processed_target = topic_truncation(@target) | 
					
						
							| 
									
										
										
										
											2023-07-13 12:41:36 -03:00
										 |  |  |           when Post | 
					
						
							| 
									
										
										
										
											2023-08-16 15:09:41 -03:00
										 |  |  |             @processed_target = post_truncation(@target) | 
					
						
							| 
									
										
										
										
											2023-07-13 12:41:36 -03:00
										 |  |  |           else | 
					
						
							|  |  |  |             raise ArgumentError, "Invalid target type" | 
					
						
							|  |  |  |           end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           @digest = OpenSSL::Digest::SHA1.hexdigest(@processed_target) | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def topic_truncation(topic) | 
					
						
							| 
									
										
										
										
											2023-08-16 15:09:41 -03:00
										 |  |  |           t = +"" | 
					
						
							| 
									
										
										
										
											2023-07-13 12:41:36 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |           t << topic.title | 
					
						
							|  |  |  |           t << "\n\n" | 
					
						
							|  |  |  |           t << topic.category.name | 
					
						
							|  |  |  |           if SiteSetting.tagging_enabled | 
					
						
							|  |  |  |             t << "\n\n" | 
					
						
							|  |  |  |             t << topic.tags.pluck(:name).join(", ") | 
					
						
							|  |  |  |           end | 
					
						
							|  |  |  |           t << "\n\n" | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-07-13 18:59:25 -03:00
										 |  |  |           topic.posts.find_each do |post| | 
					
						
							| 
									
										
										
										
											2023-07-13 12:41:36 -03:00
										 |  |  |             t << post.raw | 
					
						
							| 
									
										
										
										
											2023-08-16 15:09:41 -03:00
										 |  |  |             break if @tokenizer.size(t) >= @max_length #maybe keep a partial counter to speed this up? | 
					
						
							| 
									
										
										
										
											2023-07-13 12:41:36 -03:00
										 |  |  |             t << "\n\n" | 
					
						
							|  |  |  |           end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           @tokenizer.truncate(t, @max_length) | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         def post_truncation(post) | 
					
						
							| 
									
										
										
										
											2023-08-16 15:09:41 -03:00
										 |  |  |           t = +"" | 
					
						
							| 
									
										
										
										
											2023-07-13 12:41:36 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |           t << post.topic.title | 
					
						
							|  |  |  |           t << "\n\n" | 
					
						
							|  |  |  |           t << post.topic.category.name | 
					
						
							|  |  |  |           if SiteSetting.tagging_enabled | 
					
						
							|  |  |  |             t << "\n\n" | 
					
						
							|  |  |  |             t << post.topic.tags.pluck(:name).join(", ") | 
					
						
							|  |  |  |           end | 
					
						
							|  |  |  |           t << "\n\n" | 
					
						
							|  |  |  |           t << post.raw | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |           @tokenizer.truncate(t, @max_length) | 
					
						
							|  |  |  |         end | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  |     end | 
					
						
							|  |  |  |   end | 
					
						
							|  |  |  | end |