mirror of
				https://github.com/discourse/discourse-ai.git
				synced 2025-10-31 14:38:37 +00:00 
			
		
		
		
	
		
			
				
	
	
		
			95 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Ruby
		
	
	
	
	
	
			
		
		
	
	
			95 lines
		
	
	
		
			3.0 KiB
		
	
	
	
		
			Ruby
		
	
	
	
	
	
| # frozen_string_literal: true
 | |
| 
 | |
| module DiscourseAi
 | |
|   module Embeddings
 | |
|     class Model
 | |
|       AVAILABLE_MODELS_TEMPLATES = {
 | |
|         "all-mpnet-base-v2" => [768, 384, %i[dot cosine euclidean], %i[symmetric], "discourse"],
 | |
|         "all-distilroberta-v1" => [768, 512, %i[dot cosine euclidean], %i[symmetric], "discourse"],
 | |
|         "multi-qa-mpnet-base-dot-v1" => [768, 512, %i[dot], %i[symmetric], "discourse"],
 | |
|         "paraphrase-multilingual-mpnet-base-v2" => [
 | |
|           768,
 | |
|           128,
 | |
|           %i[cosine],
 | |
|           %i[symmetric],
 | |
|           "discourse",
 | |
|         ],
 | |
|         "msmarco-distilbert-base-tas-b" => [768, 512, %i[dot], %i[asymmetric], "discourse"],
 | |
|         "msmarco-distilbert-base-v4" => [768, 512, %i[cosine], %i[asymmetric], "discourse"],
 | |
|         "instructor-xl" => [768, 512, %i[cosine], %i[symmetric asymmetric], "discourse"],
 | |
|         "text-embedding-ada-002" => [1536, 2048, %i[cosine], %i[symmetric asymmetric], "openai"],
 | |
|       }
 | |
| 
 | |
|       SEARCH_FUNCTION_TO_PG_INDEX = {
 | |
|         dot: "vector_ip_ops",
 | |
|         cosine: "vector_cosine_ops",
 | |
|         euclidean: "vector_l2_ops",
 | |
|       }
 | |
| 
 | |
|       SEARCH_FUNCTION_TO_PG_FUNCTION = { dot: "<#>", cosine: "<=>", euclidean: "<->" }
 | |
| 
 | |
|       class << self
 | |
|         def instantiate(model_name)
 | |
|           new(model_name, *AVAILABLE_MODELS_TEMPLATES[model_name])
 | |
|         end
 | |
| 
 | |
|         def enabled_models
 | |
|           SiteSetting
 | |
|             .ai_embeddings_models
 | |
|             .split("|")
 | |
|             .map { |model_name| instantiate(model_name.strip) }
 | |
|         end
 | |
|       end
 | |
| 
 | |
|       def initialize(name, dimensions, max_sequence_lenght, functions, type, provider)
 | |
|         @name = name
 | |
|         @dimensions = dimensions
 | |
|         @max_sequence_lenght = max_sequence_lenght
 | |
|         @functions = functions
 | |
|         @type = type
 | |
|         @provider = provider
 | |
|       end
 | |
| 
 | |
|       def generate_embedding(input)
 | |
|         send("#{provider}_embeddings", input)
 | |
|       end
 | |
| 
 | |
|       def pg_function
 | |
|         SEARCH_FUNCTION_TO_PG_FUNCTION[functions.first]
 | |
|       end
 | |
| 
 | |
|       def pg_index
 | |
|         SEARCH_FUNCTION_TO_PG_INDEX[functions.first]
 | |
|       end
 | |
| 
 | |
|       attr_reader :name, :dimensions, :max_sequence_lenght, :functions, :type, :provider
 | |
| 
 | |
|       private
 | |
| 
 | |
|       def discourse_embeddings(input)
 | |
|         truncated_input = DiscourseAi::Tokenizer::BertTokenizer.truncate(input, max_sequence_lenght)
 | |
| 
 | |
|         if name.start_with?("instructor")
 | |
|           truncated_input = [
 | |
|             [SiteSetting.ai_embeddings_semantic_related_instruction, truncated_input],
 | |
|           ]
 | |
|         end
 | |
| 
 | |
|         DiscourseAi::Inference::DiscourseClassifier.perform!(
 | |
|           "#{SiteSetting.ai_embeddings_discourse_service_api_endpoint}/api/v1/classify",
 | |
|           name.to_s,
 | |
|           truncated_input,
 | |
|           SiteSetting.ai_embeddings_discourse_service_api_key,
 | |
|         )
 | |
|       end
 | |
| 
 | |
|       def openai_embeddings(input)
 | |
|         truncated_input =
 | |
|           DiscourseAi::Tokenizer::OpenAiTokenizer.truncate(input, max_sequence_lenght)
 | |
|         response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(truncated_input)
 | |
|         response[:data].first[:embedding]
 | |
|       end
 | |
|     end
 | |
|   end
 | |
| end
 |