mirror of
				https://github.com/discourse/discourse-ai.git
				synced 2025-11-04 08:28:46 +00:00 
			
		
		
		
	
		
			
	
	
		
			84 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Ruby
		
	
	
	
	
	
		
		
			
		
	
	
			84 lines
		
	
	
		
			2.5 KiB
		
	
	
	
		
			Ruby
		
	
	
	
	
	
| 
								 | 
							
								# frozen_string_literal: true
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								module DiscourseAi
							 | 
						||
| 
								 | 
							
								  module Embeddings
							 | 
						||
| 
								 | 
							
								    class Model
							 | 
						||
| 
								 | 
							
								      AVAILABLE_MODELS_TEMPLATES = {
							 | 
						||
| 
								 | 
							
								        "all-mpnet-base-v2" => [768, 384, %i[dot cosine euclidean], %i[symmetric], "discourse"],
							 | 
						||
| 
								 | 
							
								        "all-distilroberta-v1" => [768, 512, %i[dot cosine euclidean], %i[symmetric], "discourse"],
							 | 
						||
| 
								 | 
							
								        "multi-qa-mpnet-base-dot-v1" => [768, 512, %i[dot], %i[symmetric], "discourse"],
							 | 
						||
| 
								 | 
							
								        "paraphrase-multilingual-mpnet-base-v2" => [
							 | 
						||
| 
								 | 
							
								          768,
							 | 
						||
| 
								 | 
							
								          128,
							 | 
						||
| 
								 | 
							
								          %i[cosine],
							 | 
						||
| 
								 | 
							
								          %i[symmetric],
							 | 
						||
| 
								 | 
							
								          "discourse",
							 | 
						||
| 
								 | 
							
								        ],
							 | 
						||
| 
								 | 
							
								        "msmarco-distilbert-base-v4" => [768, 512, %i[cosine], %i[asymmetric], "discourse"],
							 | 
						||
| 
								 | 
							
								        "msmarco-distilbert-base-tas-b" => [768, 512, %i[dot], %i[asymmetric], "discourse"],
							 | 
						||
| 
								 | 
							
								        "text-embedding-ada-002" => [1536, 2048, %i[cosine], %i[symmetric asymmetric], "openai"],
							 | 
						||
| 
								 | 
							
								      }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      SEARCH_FUNCTION_TO_PG_INDEX = {
							 | 
						||
| 
								 | 
							
								        dot: "vector_ip_ops",
							 | 
						||
| 
								 | 
							
								        cosine: "vector_cosine_ops",
							 | 
						||
| 
								 | 
							
								        euclidean: "vector_l2_ops",
							 | 
						||
| 
								 | 
							
								      }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      class << self
							 | 
						||
| 
								 | 
							
								        def instantiate(model_name)
							 | 
						||
| 
								 | 
							
								          new(model_name, *AVAILABLE_MODELS_TEMPLATES[model_name])
							 | 
						||
| 
								 | 
							
								        end
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								        def enabled_models
							 | 
						||
| 
								 | 
							
								          SiteSetting
							 | 
						||
| 
								 | 
							
								            .ai_embeddings_models
							 | 
						||
| 
								 | 
							
								            .split("|")
							 | 
						||
| 
								 | 
							
								            .map { |model_name| instantiate(model_name.strip) }
							 | 
						||
| 
								 | 
							
								        end
							 | 
						||
| 
								 | 
							
								      end
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      def initialize(name, dimensions, max_sequence_lenght, functions, type, provider)
							 | 
						||
| 
								 | 
							
								        @name = name
							 | 
						||
| 
								 | 
							
								        @dimensions = dimensions
							 | 
						||
| 
								 | 
							
								        @max_sequence_lenght = max_sequence_lenght
							 | 
						||
| 
								 | 
							
								        @functions = functions
							 | 
						||
| 
								 | 
							
								        @type = type
							 | 
						||
| 
								 | 
							
								        @provider = provider
							 | 
						||
| 
								 | 
							
								      end
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      def generate_embedding(input)
							 | 
						||
| 
								 | 
							
								        send("#{provider}_embeddings", input)
							 | 
						||
| 
								 | 
							
								      end
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      def pg_function
							 | 
						||
| 
								 | 
							
								        SEARCH_FUNCTION_TO_PG_FUNCTION[functions.first]
							 | 
						||
| 
								 | 
							
								      end
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      def pg_index
							 | 
						||
| 
								 | 
							
								        SEARCH_FUNCTION_TO_PG_INDEX[functions.first]
							 | 
						||
| 
								 | 
							
								      end
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      attr_reader :name, :dimensions, :max_sequence_lenght, :functions, :type, :provider
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      private
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      def discourse_embeddings(input)
							 | 
						||
| 
								 | 
							
								        DiscourseAi::Inference::DiscourseClassifier.perform!(
							 | 
						||
| 
								 | 
							
								          "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
							 | 
						||
| 
								 | 
							
								          name.to_s,
							 | 
						||
| 
								 | 
							
								          input,
							 | 
						||
| 
								 | 
							
								          SiteSetting.ai_embeddings_discourse_service_api_key,
							 | 
						||
| 
								 | 
							
								        )
							 | 
						||
| 
								 | 
							
								      end
							 | 
						||
| 
								 | 
							
								
							 | 
						||
| 
								 | 
							
								      def openai_embeddings(input)
							 | 
						||
| 
								 | 
							
								        response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(input)
							 | 
						||
| 
								 | 
							
								        response[:data].first[:embedding]
							 | 
						||
| 
								 | 
							
								      end
							 | 
						||
| 
								 | 
							
								    end
							 | 
						||
| 
								 | 
							
								  end
							 | 
						||
| 
								 | 
							
								end
							 |