diff --git a/lib/ai_bot/personas/persona.rb b/lib/ai_bot/personas/persona.rb index 63d566ea..d07f5bf4 100644 --- a/lib/ai_bot/personas/persona.rb +++ b/lib/ai_bot/personas/persona.rb @@ -313,7 +313,7 @@ module DiscourseAi strategy = DiscourseAi::Embeddings::Strategies::Truncation.new vector_rep = DiscourseAi::Embeddings::VectorRepresentations::Base.current_representation(strategy) - reranker = DiscourseAi::Inference::HuggingFaceTextEmbeddings + reranker = DiscourseAi::Inference::HuggingFaceText interactions_vector = vector_rep.vector_from(consolidated_question) @@ -344,7 +344,7 @@ module DiscourseAi if reranker.reranker_configured? guidance = fragments.map { |fragment, _metadata| fragment } ranks = - DiscourseAi::Inference::HuggingFaceTextEmbeddings + DiscourseAi::Inference::HuggingFaceText .rerank(conversation_context.last[:content], guidance) .to_a .take(rag_conversation_chunks) diff --git a/lib/embeddings/semantic_search.rb b/lib/embeddings/semantic_search.rb index cae93958..d4d78cf6 100644 --- a/lib/embeddings/semantic_search.rb +++ b/lib/embeddings/semantic_search.rb @@ -159,10 +159,7 @@ module DiscourseAi .map { _1.truncate(2000, omission: "") } reranked_results = - DiscourseAi::Inference::HuggingFaceTextEmbeddings.rerank( - search_term, - rerank_posts_payload, - ) + DiscourseAi::Inference::HuggingFaceText.rerank(search_term, rerank_posts_payload) reordered_ids = reranked_results.map { _1[:index] }.map { filtered_results[_1].id }.take(5) diff --git a/lib/embeddings/vector_representations/all_mpnet_base_v2.rb b/lib/embeddings/vector_representations/all_mpnet_base_v2.rb index 1a4b8002..7e4a2ad7 100644 --- a/lib/embeddings/vector_representations/all_mpnet_base_v2.rb +++ b/lib/embeddings/vector_representations/all_mpnet_base_v2.rb @@ -24,12 +24,7 @@ module DiscourseAi end def vector_from(text, asymetric: false) - DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{discourse_embeddings_endpoint}/api/v1/classify", - self.class.name, - text, - SiteSetting.ai_embeddings_discourse_service_api_key, - ) + inference_client.perform!(text) end def dimensions @@ -59,6 +54,10 @@ module DiscourseAi def tokenizer DiscourseAi::Tokenizer::AllMpnetBaseV2Tokenizer end + + def inference_client + DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name) + end end end end diff --git a/lib/embeddings/vector_representations/base.rb b/lib/embeddings/vector_representations/base.rb index d5c23d21..be6b46b5 100644 --- a/lib/embeddings/vector_representations/base.rb +++ b/lib/embeddings/vector_representations/base.rb @@ -426,16 +426,8 @@ module DiscourseAi end end - def discourse_embeddings_endpoint - if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? - service = - DiscourseAi::Utils::DnsSrv.lookup( - SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv, - ) - "https://#{service.target}:#{service.port}" - else - SiteSetting.ai_embeddings_discourse_service_api_endpoint - end + def inference_client + raise NotImplementedError end end end diff --git a/lib/embeddings/vector_representations/bge_large_en.rb b/lib/embeddings/vector_representations/bge_large_en.rb index 601c85a1..8c974b89 100644 --- a/lib/embeddings/vector_representations/bge_large_en.rb +++ b/lib/embeddings/vector_representations/bge_large_en.rb @@ -11,7 +11,7 @@ module DiscourseAi def correctly_configured? SiteSetting.ai_cloudflare_workers_api_token.present? || - DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? || + DiscourseAi::Inference::TeiEmbeddings.configured? || ( SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? @@ -33,24 +33,12 @@ module DiscourseAi def vector_from(text, asymetric: false) text = "#{asymmetric_query_prefix} #{text}" if asymetric - if SiteSetting.ai_cloudflare_workers_api_token.present? - DiscourseAi::Inference::CloudflareWorkersAi - .perform!(inference_model_name, { text: text }) - .dig(:result, :data) - .first - elsif DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? - truncated_text = tokenizer.truncate(text, max_sequence_length - 2) - DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first - elsif discourse_embeddings_endpoint.present? - DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{discourse_embeddings_endpoint}/api/v1/classify", - inference_model_name.split("/").last, - text, - SiteSetting.ai_embeddings_discourse_service_api_key, - ) - else - raise "No inference endpoint configured" - end + client = inference_client + + needs_truncation = client.class.name.include?("TeiEmbeddings") + text = tokenizer.truncate(text, max_sequence_length - 2) if needs_truncation + + inference_client.perform!(text) end def inference_model_name @@ -88,6 +76,21 @@ module DiscourseAi def asymmetric_query_prefix "Represent this sentence for searching relevant passages:" end + + def inference_client + if SiteSetting.ai_cloudflare_workers_api_token.present? + DiscourseAi::Inference::CloudflareWorkersAi.instance(inference_model_name) + elsif DiscourseAi::Inference::TeiEmbeddings.configured? + DiscourseAi::Inference::TeiEmbeddings.instance + elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || + SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? + DiscourseAi::Inference::DiscourseClassifier.instance( + inference_model_name.split("/").last, + ) + else + raise "No inference endpoint configured" + end + end end end end diff --git a/lib/embeddings/vector_representations/bge_m3.rb b/lib/embeddings/vector_representations/bge_m3.rb index c220cf75..98247730 100644 --- a/lib/embeddings/vector_representations/bge_m3.rb +++ b/lib/embeddings/vector_representations/bge_m3.rb @@ -10,7 +10,7 @@ module DiscourseAi end def correctly_configured? - DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? + DiscourseAi::Inference::TeiEmbeddings.configured? end def dependant_setting_names @@ -20,7 +20,7 @@ module DiscourseAi def vector_from(text, asymetric: false) truncated_text = tokenizer.truncate(text, max_sequence_length - 2) - DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first + inference_client.perform!(truncated_text) end def dimensions @@ -50,6 +50,10 @@ module DiscourseAi def tokenizer DiscourseAi::Tokenizer::BgeM3Tokenizer end + + def inference_client + DiscourseAi::Inference::TeiEmbeddings.instance + end end end end diff --git a/lib/embeddings/vector_representations/gemini.rb b/lib/embeddings/vector_representations/gemini.rb index 86b7afae..a693849d 100644 --- a/lib/embeddings/vector_representations/gemini.rb +++ b/lib/embeddings/vector_representations/gemini.rb @@ -43,8 +43,7 @@ module DiscourseAi end def vector_from(text, asymetric: false) - response = DiscourseAi::Inference::GeminiEmbeddings.perform!(text) - response[:embedding][:values] + inference_client.perform!(text).dig(:embedding, :values) end # There is no public tokenizer for Gemini, and from the ones we already ship in the plugin @@ -53,6 +52,10 @@ module DiscourseAi def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end + + def inference_client + DiscourseAi::Inference::GeminiEmbeddings.instance + end end end end diff --git a/lib/embeddings/vector_representations/multilingual_e5_large.rb b/lib/embeddings/vector_representations/multilingual_e5_large.rb index 8267f938..166c546d 100644 --- a/lib/embeddings/vector_representations/multilingual_e5_large.rb +++ b/lib/embeddings/vector_representations/multilingual_e5_large.rb @@ -10,7 +10,7 @@ module DiscourseAi end def correctly_configured? - DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? || + DiscourseAi::Inference::TeiEmbeddings.configured? || ( SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? @@ -29,19 +29,16 @@ module DiscourseAi end def vector_from(text, asymetric: false) - if DiscourseAi::Inference::HuggingFaceTextEmbeddings.configured? - truncated_text = tokenizer.truncate(text, max_sequence_length - 2) - DiscourseAi::Inference::HuggingFaceTextEmbeddings.perform!(truncated_text).first - elsif discourse_embeddings_endpoint.present? - DiscourseAi::Inference::DiscourseClassifier.perform!( - "#{discourse_embeddings_endpoint}/api/v1/classify", - self.class.name, - "query: #{text}", - SiteSetting.ai_embeddings_discourse_service_api_key, - ) + client = inference_client + + needs_truncation = client.class.name.include?("TeiEmbeddings") + if needs_truncation + text = tokenizer.truncate(text, max_sequence_length - 2) else - raise "No inference endpoint configured" + text = "query: #{text}" end + + client.perform!(text) end def id @@ -71,6 +68,18 @@ module DiscourseAi def tokenizer DiscourseAi::Tokenizer::MultilingualE5LargeTokenizer end + + def inference_client + if DiscourseAi::Inference::TeiEmbeddings.configured? + truncated_text = tokenizer.truncate(text, max_sequence_length - 2) + DiscourseAi::Inference::TeiEmbeddings.instance + elsif SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? || + SiteSetting.ai_embeddings_discourse_service_api_endpoint.present? + DiscourseAi::Inference::DiscourseClassifier.instance(self.class.name) + else + raise "No inference endpoint configured" + end + end end end end diff --git a/lib/embeddings/vector_representations/text_embedding_3_large.rb b/lib/embeddings/vector_representations/text_embedding_3_large.rb index 626428c9..202d66de 100644 --- a/lib/embeddings/vector_representations/text_embedding_3_large.rb +++ b/lib/embeddings/vector_representations/text_embedding_3_large.rb @@ -45,18 +45,19 @@ module DiscourseAi end def vector_from(text, asymetric: false) - response = - DiscourseAi::Inference::OpenAiEmbeddings.perform!( - text, - model: self.class.name, - dimensions: dimensions, - ) - response[:data].first[:embedding] + inference_client.perform!(text) end def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end + + def inference_client + DiscourseAi::Inference::OpenAiEmbeddings.instance( + model: self.class.name, + dimensions: dimensions, + ) + end end end end diff --git a/lib/embeddings/vector_representations/text_embedding_3_small.rb b/lib/embeddings/vector_representations/text_embedding_3_small.rb index fbac4bc7..87f31185 100644 --- a/lib/embeddings/vector_representations/text_embedding_3_small.rb +++ b/lib/embeddings/vector_representations/text_embedding_3_small.rb @@ -43,13 +43,16 @@ module DiscourseAi end def vector_from(text, asymetric: false) - response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name) - response[:data].first[:embedding] + inference_client.perform!(text) end def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end + + def inference_client + DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name) + end end end end diff --git a/lib/embeddings/vector_representations/text_embedding_ada_002.rb b/lib/embeddings/vector_representations/text_embedding_ada_002.rb index 2079e028..1e570b98 100644 --- a/lib/embeddings/vector_representations/text_embedding_ada_002.rb +++ b/lib/embeddings/vector_representations/text_embedding_ada_002.rb @@ -43,13 +43,16 @@ module DiscourseAi end def vector_from(text, asymetric: false) - response = DiscourseAi::Inference::OpenAiEmbeddings.perform!(text, model: self.class.name) - response[:data].first[:embedding] + inference_client.perform!(text) end def tokenizer DiscourseAi::Tokenizer::OpenAiTokenizer end + + def inference_client + DiscourseAi::Inference::OpenAiEmbeddings.instance(model: self.class.name) + end end end end diff --git a/lib/inference/cloudflare_workers_ai.rb b/lib/inference/cloudflare_workers_ai.rb index 099ae5be..b0cd5926 100644 --- a/lib/inference/cloudflare_workers_ai.rb +++ b/lib/inference/cloudflare_workers_ai.rb @@ -3,25 +3,38 @@ module ::DiscourseAi module Inference class CloudflareWorkersAi - def self.perform!(model, content) - headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + def initialize(account_id, api_token, model, referer = Discourse.base_url) + @account_id = account_id + @api_token = api_token + @model = model + @referer = referer + end - account_id = SiteSetting.ai_cloudflare_workers_account_id - token = SiteSetting.ai_cloudflare_workers_api_token + def self.instance(model) + new( + SiteSetting.ai_cloudflare_workers_account_id, + SiteSetting.ai_cloudflare_workers_api_token, + model, + ) + end - base_url = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/" - headers["Authorization"] = "Bearer #{token}" + attr_reader :account_id, :api_token, :model, :referer - endpoint = "#{base_url}#{model}" + def perform!(content) + headers = { + "Referer" => Discourse.base_url, + "Content-Type" => "application/json", + "Authorization" => "Bearer #{api_token}", + } + + endpoint = "https://api.cloudflare.com/client/v4/accounts/#{account_id}/ai/run/@cf/#{model}" conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } response = conn.post(endpoint, content.to_json, headers) - raise Net::HTTPBadResponse if ![200].include?(response.status) - case response.status when 200 - JSON.parse(response.body, symbolize_names: true) + JSON.parse(response.body, symbolize_names: true).dig(:result, :data).first when 429 # TODO add a AdminDashboard Problem? else diff --git a/lib/inference/discourse_classifier.rb b/lib/inference/discourse_classifier.rb index 3784a190..46f912dd 100644 --- a/lib/inference/discourse_classifier.rb +++ b/lib/inference/discourse_classifier.rb @@ -3,9 +3,36 @@ module ::DiscourseAi module Inference class DiscourseClassifier - def self.perform!(endpoint, model, content, api_key) - headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + def initialize(endpoint, api_key, model, referer = Discourse.base_url) + @endpoint = endpoint + @api_key = api_key + @model = model + @referer = referer + end + def self.instance(model) + endpoint = + if SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv.present? + service = + DiscourseAi::Utils::DnsSrv.lookup( + SiteSetting.ai_embeddings_discourse_service_api_endpoint_srv, + ) + "https://#{service.target}:#{service.port}" + else + SiteSetting.ai_embeddings_discourse_service_api_endpoint + end + + new( + "#{endpoint}/api/v1/classify", + SiteSetting.ai_embeddings_discourse_service_api_key, + model, + ) + end + + attr_reader :endpoint, :api_key, :model, :referer + + def perform!(content) + headers = { "Referer" => referer, "Content-Type" => "application/json" } headers["X-API-KEY"] = api_key if api_key.present? conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } diff --git a/lib/inference/gemini_embeddings.rb b/lib/inference/gemini_embeddings.rb index cedda24c..13fb62c5 100644 --- a/lib/inference/gemini_embeddings.rb +++ b/lib/inference/gemini_embeddings.rb @@ -3,12 +3,17 @@ module ::DiscourseAi module Inference class GeminiEmbeddings - def self.perform!(content) - headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } + def initialize(api_key, referer = Discourse.base_url) + @api_key = api_key + @referer = referer + end + attr_reader :api_key, :referer + + def perform!(content) + headers = { "Referer" => referer, "Content-Type" => "application/json" } url = - "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{SiteSetting.ai_gemini_api_key}" - + "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}" body = { content: { parts: [{ text: content }] } } conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } diff --git a/lib/inference/hugging_face_text_embeddings.rb b/lib/inference/hugging_face_text.rb similarity index 66% rename from lib/inference/hugging_face_text_embeddings.rb rename to lib/inference/hugging_face_text.rb index 0e904a94..47c83c0a 100644 --- a/lib/inference/hugging_face_text_embeddings.rb +++ b/lib/inference/hugging_face_text.rb @@ -2,33 +2,8 @@ module ::DiscourseAi module Inference - class HuggingFaceTextEmbeddings + class HuggingFaceText class << self - def perform!(content) - headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } - body = { inputs: content, truncate: true }.to_json - - if SiteSetting.ai_hugging_face_tei_endpoint_srv.present? - service = - DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv) - api_endpoint = "https://#{service.target}:#{service.port}" - else - api_endpoint = SiteSetting.ai_hugging_face_tei_endpoint - end - - if SiteSetting.ai_hugging_face_tei_api_key.present? - headers["X-API-KEY"] = SiteSetting.ai_hugging_face_tei_api_key - headers["Authorization"] = "Bearer #{SiteSetting.ai_hugging_face_tei_api_key}" - end - - conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } - response = conn.post(api_endpoint, body, headers) - - raise Net::HTTPBadResponse if ![200].include?(response.status) - - JSON.parse(response.body, symbolize_names: true) - end - def rerank(content, candidates) headers = { "Referer" => Discourse.base_url, "Content-Type" => "application/json" } body = { query: content, texts: candidates, truncate: true }.to_json @@ -85,11 +60,6 @@ module ::DiscourseAi SiteSetting.ai_hugging_face_tei_reranker_endpoint.present? || SiteSetting.ai_hugging_face_tei_reranker_endpoint_srv.present? end - - def configured? - SiteSetting.ai_hugging_face_tei_endpoint.present? || - SiteSetting.ai_hugging_face_tei_endpoint_srv.present? - end end end end diff --git a/lib/inference/open_ai_embeddings.rb b/lib/inference/open_ai_embeddings.rb index 9ffcaa49..e3e6551c 100644 --- a/lib/inference/open_ai_embeddings.rb +++ b/lib/inference/open_ai_embeddings.rb @@ -3,13 +3,26 @@ module ::DiscourseAi module Inference class OpenAiEmbeddings - def self.perform!(content, model:, dimensions: nil) + def initialize(endpoint, api_key, model, dimensions) + @endpoint = endpoint + @api_key = api_key + @model = model + @dimensions = dimensions + end + + attr_reader :endpoint, :api_key, :model, :dimensions + + def self.instance(model:, dimensions: nil) + new(SiteSetting.ai_openai_embeddings_url, SiteSetting.ai_openai_api_key, model, dimensions) + end + + def perform!(content) headers = { "Content-Type" => "application/json" } - if SiteSetting.ai_openai_embeddings_url.include?("azure") - headers["api-key"] = SiteSetting.ai_openai_api_key + if endpoint.include?("azure") + headers["api-key"] = api_key else - headers["Authorization"] = "Bearer #{SiteSetting.ai_openai_api_key}" + headers["Authorization"] = "Bearer #{api_key}" end payload = { model: model, input: content } @@ -20,7 +33,7 @@ module ::DiscourseAi case response.status when 200 - JSON.parse(response.body, symbolize_names: true) + JSON.parse(response.body, symbolize_names: true).dig(:data, 0, :embedding) when 429 # TODO add a AdminDashboard Problem? else diff --git a/lib/inference/tei_embeddings.rb b/lib/inference/tei_embeddings.rb new file mode 100644 index 00000000..2d59fab0 --- /dev/null +++ b/lib/inference/tei_embeddings.rb @@ -0,0 +1,51 @@ +# frozen_string_literal: true + +module DiscourseAi + module Inference + class TeiEmbeddings + def initialize(endpoint, key, referer = Discourse.base_url) + @endpoint = endpoint + @key = key + @referer = referer + end + + attr_reader :endpoint, :key, :referer + + class << self + def configured? + SiteSetting.ai_hugging_face_tei_endpoint.present? || + SiteSetting.ai_hugging_face_tei_endpoint_srv.present? + end + + def instance + if SiteSetting.ai_hugging_face_tei_endpoint_srv.present? + service = + DiscourseAi::Utils::DnsSrv.lookup(SiteSetting.ai_hugging_face_tei_endpoint_srv) + endpoint = "https://#{service.target}:#{service.port}" + else + endpoint = SiteSetting.ai_hugging_face_tei_endpoint + end + + new(endpoint, SiteSetting.ai_hugging_face_tei_api_key) + end + end + + def perform!(content) + headers = { "Referer" => referer, "Content-Type" => "application/json" } + body = { inputs: content, truncate: true }.to_json + + if key.present? + headers["X-API-KEY"] = key + headers["Authorization"] = "Bearer #{key}" + end + + conn = Faraday.new { |f| f.adapter FinalDestination::FaradayAdapter } + response = conn.post(endpoint, body, headers) + + raise Net::HTTPBadResponse if ![200].include?(response.status) + + JSON.parse(response.body, symbolize_names: true).first + end + end + end +end diff --git a/lib/nsfw/classification.rb b/lib/nsfw/classification.rb index c87ba8d1..a6f99439 100644 --- a/lib/nsfw/classification.rb +++ b/lib/nsfw/classification.rb @@ -54,12 +54,11 @@ module DiscourseAi upload_url = Discourse.store.cdn_url(upload.url) upload_url = "#{Discourse.base_url_no_prefix}#{upload_url}" if upload_url.starts_with?("/") - DiscourseAi::Inference::DiscourseClassifier.perform!( + DiscourseAi::Inference::DiscourseClassifier.new( "#{endpoint}/api/v1/classify", - model, - upload_url, SiteSetting.ai_nsfw_inference_service_api_key, - ) + model, + ).perform!(upload_url) end def available_models diff --git a/lib/sentiment/sentiment_classification.rb b/lib/sentiment/sentiment_classification.rb index f73447ca..5d71a6c1 100644 --- a/lib/sentiment/sentiment_classification.rb +++ b/lib/sentiment/sentiment_classification.rb @@ -45,7 +45,7 @@ module DiscourseAi private def request_with(content, model_config) - result = ::DiscourseAi::Inference::HuggingFaceTextEmbeddings.classify(content, model_config) + result = ::DiscourseAi::Inference::HuggingFaceText.classify(content, model_config) transform_result(result) end diff --git a/lib/toxicity/toxicity_classification.rb b/lib/toxicity/toxicity_classification.rb index c178d2e1..1756d3b2 100644 --- a/lib/toxicity/toxicity_classification.rb +++ b/lib/toxicity/toxicity_classification.rb @@ -42,12 +42,11 @@ module DiscourseAi def request(target_to_classify) data = - ::DiscourseAi::Inference::DiscourseClassifier.perform!( + ::DiscourseAi::Inference::DiscourseClassifier.new( "#{endpoint}/api/v1/classify", - SiteSetting.ai_toxicity_inference_service_api_model, - content_of(target_to_classify), SiteSetting.ai_toxicity_inference_service_api_key, - ) + SiteSetting.ai_toxicity_inference_service_api_model, + ).perform!(content_of(target_to_classify)) { available_model => data } end diff --git a/spec/shared/inference/openai_embeddings_spec.rb b/spec/shared/inference/openai_embeddings_spec.rb index e938b6a4..7db19a7e 100644 --- a/spec/shared/inference/openai_embeddings_spec.rb +++ b/spec/shared/inference/openai_embeddings_spec.rb @@ -26,10 +26,11 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do ).to_return(status: 200, body: body_json, headers: {}) result = - DiscourseAi::Inference::OpenAiEmbeddings.perform!("hello", model: "text-embedding-ada-002") + DiscourseAi::Inference::OpenAiEmbeddings.instance(model: "text-embedding-ada-002").perform!( + "hello", + ) - expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 }) - expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] }) + expect(result).to eq([0.0, 0.1]) end it "supports openai embeddings" do @@ -54,13 +55,11 @@ describe DiscourseAi::Inference::OpenAiEmbeddings do ).to_return(status: 200, body: body_json, headers: {}) result = - DiscourseAi::Inference::OpenAiEmbeddings.perform!( - "hello", + DiscourseAi::Inference::OpenAiEmbeddings.instance( model: "text-embedding-ada-002", dimensions: 1000, - ) + ).perform!("hello") - expect(result[:usage]).to eq({ prompt_tokens: 1, total_tokens: 1 }) - expect(result[:data].first).to eq({ object: "embedding", embedding: [0.0, 0.1] }) + expect(result).to eq([0.0, 0.1]) end end