| 
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 |  |  | # frozen_string_literal: true | 
					
						
							|  |  |  | module DiscourseAi | 
					
						
							|  |  |  |   module AiHelper | 
					
						
							|  |  |  |     class SemanticCategorizer | 
					
						
							| 
									
										
										
										
											2023-09-04 14:30:33 -03:00
										 |  |  |       def initialize(text, user) | 
					
						
							|  |  |  |         @user = user | 
					
						
							| 
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 |  |  |         @text = text | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       def categories | 
					
						
							|  |  |  |         return [] if @text.blank? | 
					
						
							|  |  |  |         return [] unless SiteSetting.ai_embeddings_enabled | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-05 14:15:01 -03:00
										 |  |  |         candidates = nearest_neighbors(limit: 100) | 
					
						
							| 
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 |  |  |         candidate_ids = candidates.map(&:first) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ::Topic | 
					
						
							|  |  |  |           .joins(:category) | 
					
						
							|  |  |  |           .where(id: candidate_ids) | 
					
						
							| 
									
										
										
										
											2023-09-04 14:30:33 -03:00
										 |  |  |           .where("categories.id IN (?)", Category.topic_create_allowed(@user.guardian).pluck(:id)) | 
					
						
							| 
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 |  |  |           .order("array_position(ARRAY#{candidate_ids}, topics.id)") | 
					
						
							|  |  |  |           .pluck("categories.slug") | 
					
						
							|  |  |  |           .map | 
					
						
							|  |  |  |           .with_index { |category, index| { name: category, score: candidates[index].last } } | 
					
						
							|  |  |  |           .map do |c| | 
					
						
							|  |  |  |             c[:score] = 1 / (c[:score] + 1) # inverse of the distance | 
					
						
							|  |  |  |             c | 
					
						
							|  |  |  |           end | 
					
						
							|  |  |  |           .group_by { |c| c[:name] } | 
					
						
							|  |  |  |           .map { |name, scores| { name: name, score: scores.sum { |s| s[:score] } } } | 
					
						
							|  |  |  |           .sort_by { |c| -c[:score] } | 
					
						
							|  |  |  |           .take(5) | 
					
						
							|  |  |  |       end | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       def tags | 
					
						
							|  |  |  |         return [] if @text.blank? | 
					
						
							|  |  |  |         return [] unless SiteSetting.ai_embeddings_enabled | 
					
						
							|  |  |  | 
 | 
					
						
							| 
									
										
										
										
											2023-09-05 14:15:01 -03:00
										 |  |  |         candidates = nearest_neighbors(limit: 100) | 
					
						
							| 
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 |  |  |         candidate_ids = candidates.map(&:first) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         ::Topic | 
					
						
							|  |  |  |           .joins(:topic_tags, :tags) | 
					
						
							|  |  |  |           .where(id: candidate_ids) | 
					
						
							| 
									
										
										
										
											2023-09-04 14:30:33 -03:00
										 |  |  |           .where("tags.id IN (?)", DiscourseTagging.visible_tags(@user.guardian).pluck(:id)) | 
					
						
							| 
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 |  |  |           .group("topics.id") | 
					
						
							|  |  |  |           .order("array_position(ARRAY#{candidate_ids}, topics.id)") | 
					
						
							|  |  |  |           .pluck("array_agg(tags.name)") | 
					
						
							|  |  |  |           .map(&:uniq) | 
					
						
							|  |  |  |           .map | 
					
						
							|  |  |  |           .with_index { |tag_list, index| { tags: tag_list, score: candidates[index].last } } | 
					
						
							|  |  |  |           .flat_map { |c| c[:tags].map { |t| { name: t, score: c[:score] } } } | 
					
						
							|  |  |  |           .map do |c| | 
					
						
							|  |  |  |             c[:score] = 1 / (c[:score] + 1) # inverse of the distance | 
					
						
							|  |  |  |             c | 
					
						
							|  |  |  |           end | 
					
						
							|  |  |  |           .group_by { |c| c[:name] } | 
					
						
							|  |  |  |           .map { |name, scores| { name: name, score: scores.sum { |s| s[:score] } } } | 
					
						
							|  |  |  |           .sort_by { |c| -c[:score] } | 
					
						
							|  |  |  |           .take(5) | 
					
						
							|  |  |  |       end | 
					
						
							| 
									
										
										
										
											2023-09-05 14:15:01 -03:00
										 |  |  | 
 | 
					
						
							|  |  |  |       private | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |       def nearest_neighbors(limit: 100) | 
					
						
							|  |  |  |         strategy = DiscourseAi::Embeddings::Strategies::Truncation.new | 
					
						
							|  |  |  |         vector_rep = | 
					
						
							|  |  |  |           DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         raw_vector = vector_rep.vector_from(@text) | 
					
						
							|  |  |  | 
 | 
					
						
							|  |  |  |         vector_rep.asymmetric_topics_similarity_search( | 
					
						
							|  |  |  |           raw_vector, | 
					
						
							|  |  |  |           limit: limit, | 
					
						
							|  |  |  |           offset: 0, | 
					
						
							|  |  |  |           return_distance: true, | 
					
						
							|  |  |  |         ) | 
					
						
							|  |  |  |       end | 
					
						
							| 
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 |  |  |     end | 
					
						
							|  |  |  |   end | 
					
						
							|  |  |  | end |