2023-09-01 21:10:58 -03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								# frozen_string_literal: true
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								module DiscourseAi
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  module AiHelper
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    class SemanticCategorizer
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-02 16:36:56 -03:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								      def initialize(input, user)
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-04 14:30:33 -03:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        @user = user
							 | 
						
					
						
							
								
									
										
										
										
											2023-10-02 16:36:56 -03:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        @text = input[:text]
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								      end
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								      def categories
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return [] if @text.blank?
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return [] unless SiteSetting.ai_embeddings_enabled
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-05 14:15:01 -03:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        candidates = nearest_neighbors(limit: 100)
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        candidate_ids = candidates.map(&:first)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        ::Topic
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .joins(:category)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .where(id: candidate_ids)
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-04 14:30:33 -03:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								          .where("categories.id IN (?)", Category.topic_create_allowed(@user.guardian).pluck(:id))
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .order("array_position(ARRAY#{candidate_ids}, topics.id)")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .pluck("categories.slug")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .map
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .with_index { |category, index| { name: category, score: candidates[index].last } }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .map do |c|
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            c[:score] = 1 / (c[:score] + 1) # inverse of the distance
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            c
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          end
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .group_by { |c| c[:name] }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .map { |name, scores| { name: name, score: scores.sum { |s| s[:score] } } }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .sort_by { |c| -c[:score] }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .take(5)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								      end
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								      def tags
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return [] if @text.blank?
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        return [] unless SiteSetting.ai_embeddings_enabled
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-05 14:15:01 -03:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								        candidates = nearest_neighbors(limit: 100)
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        candidate_ids = candidates.map(&:first)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        ::Topic
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .joins(:topic_tags, :tags)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .where(id: candidate_ids)
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-04 14:30:33 -03:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								          .where("tags.id IN (?)", DiscourseTagging.visible_tags(@user.guardian).pluck(:id))
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .group("topics.id")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .order("array_position(ARRAY#{candidate_ids}, topics.id)")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .pluck("array_agg(tags.name)")
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .map(&:uniq)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .map
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .with_index { |tag_list, index| { tags: tag_list, score: candidates[index].last } }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .flat_map { |c| c[:tags].map { |t| { name: t, score: c[:score] } } }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .map do |c|
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            c[:score] = 1 / (c[:score] + 1) # inverse of the distance
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								            c
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          end
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .group_by { |c| c[:name] }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .map { |name, scores| { name: name, score: scores.sum { |s| s[:score] } } }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .sort_by { |c| -c[:score] }
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          .take(5)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								      end
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-05 14:15:01 -03:00
										 
									 
								 
							 | 
							
								
									
										
									
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								      private
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								      def nearest_neighbors(limit: 100)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        strategy = DiscourseAi::Embeddings::Strategies::Truncation.new
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        vector_rep =
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        raw_vector = vector_rep.vector_from(@text)
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        vector_rep.asymmetric_topics_similarity_search(
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          raw_vector,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          limit: limit,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          offset: 0,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								          return_distance: true,
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								        )
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								      end
							 | 
						
					
						
							
								
									
										
										
										
											2023-09-01 21:10:58 -03:00
										 
									 
								 
							 | 
							
								
							 | 
							
								
							 | 
							
							
								    end
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								  end
							 | 
						
					
						
							| 
								
							 | 
							
								
							 | 
							
								
							 | 
							
							
								end
							 |