From 89bcf9b1f020bbf6977cbe216efd8dde3f314862 Mon Sep 17 00:00:00 2001 From: Roman Rizzi Date: Thu, 10 Jul 2025 17:51:01 -0300 Subject: [PATCH] FIX: Process succesfully generated embeddings even if some failed (#1500) --- lib/embeddings/vector.rb | 3 +- lib/inference/cloudflare_workers_ai.rb | 2 - lib/inference/gemini_embeddings.rb | 2 - lib/inference/open_ai_embeddings.rb | 2 - .../inference/cloudflare_workers_ai_spec.rb | 11 +---- spec/lib/modules/embeddings/vector_spec.rb | 44 ++++++++++++++++--- spec/support/embeddings_generation_stubs.rb | 16 +++---- 7 files changed, 48 insertions(+), 32 deletions(-) diff --git a/lib/embeddings/vector.rb b/lib/embeddings/vector.rb index 4e847f17..91ff4c2e 100644 --- a/lib/embeddings/vector.rb +++ b/lib/embeddings/vector.rb @@ -42,13 +42,14 @@ module DiscourseAi .then_on(pool) do |w_prepared_text| w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text])) end + .rescue { nil } # We log the error during #perform. Skip failed embeddings. end .compact Concurrent::Promises .zip(*promised_embeddings) .value! - .each { |e| schema.store(e[:target], e[:embedding], e[:digest]) } + .each { |e| schema.store(e[:target], e[:embedding], e[:digest]) if e.present? } ensure pool.shutdown pool.wait_for_termination diff --git a/lib/inference/cloudflare_workers_ai.rb b/lib/inference/cloudflare_workers_ai.rb index 360725e7..adb9eb78 100644 --- a/lib/inference/cloudflare_workers_ai.rb +++ b/lib/inference/cloudflare_workers_ai.rb @@ -26,8 +26,6 @@ module ::DiscourseAi case response.status when 200 JSON.parse(response.body, symbolize_names: true).dig(:result, :data).first - when 429 - # TODO add a AdminDashboard Problem? else Rails.logger.warn( "Cloudflare Workers AI Embeddings failed with status: #{response.status} body: #{response.body}", diff --git a/lib/inference/gemini_embeddings.rb b/lib/inference/gemini_embeddings.rb index 6250f5a0..b95cf03d 100644 --- a/lib/inference/gemini_embeddings.rb +++ b/lib/inference/gemini_embeddings.rb @@ -22,8 +22,6 @@ module ::DiscourseAi case response.status when 200 JSON.parse(response.body, symbolize_names: true).dig(:embedding, :values) - when 429 - # TODO add a AdminDashboard Problem? else Rails.logger.warn( "Google Gemini Embeddings failed with status: #{response.status} body: #{response.body}", diff --git a/lib/inference/open_ai_embeddings.rb b/lib/inference/open_ai_embeddings.rb index 5e7820cf..ec1845d9 100644 --- a/lib/inference/open_ai_embeddings.rb +++ b/lib/inference/open_ai_embeddings.rb @@ -30,8 +30,6 @@ module ::DiscourseAi case response.status when 200 JSON.parse(response.body, symbolize_names: true).dig(:data, 0, :embedding) - when 429 - # TODO add a AdminDashboard Problem? else Rails.logger.warn( "OpenAI Embeddings failed with status: #{response.status} body: #{response.body}", diff --git a/spec/lib/inference/cloudflare_workers_ai_spec.rb b/spec/lib/inference/cloudflare_workers_ai_spec.rb index 1199111d..6f6c273a 100644 --- a/spec/lib/inference/cloudflare_workers_ai_spec.rb +++ b/spec/lib/inference/cloudflare_workers_ai_spec.rb @@ -40,16 +40,7 @@ RSpec.describe DiscourseAi::Inference::CloudflareWorkersAi do end end - context "when the response status is 429" do - let(:response_status) { 429 } - let(:response_body) { "" } - - it "doesn't raises a Net::HTTPBadResponse error" do - expect { subject.perform!(content) }.not_to raise_error - end - end - - context "when the response status is not 200 or 429" do + context "when the response status is not 200" do let(:response_status) { 500 } let(:response_body) { "Internal Server Error" } diff --git a/spec/lib/modules/embeddings/vector_spec.rb b/spec/lib/modules/embeddings/vector_spec.rb index d2bc4bbc..e16c7381 100644 --- a/spec/lib/modules/embeddings/vector_spec.rb +++ b/spec/lib/modules/embeddings/vector_spec.rb @@ -83,17 +83,34 @@ RSpec.describe DiscourseAi::Embeddings::Vector do expect(topics_schema.find_by_target(topic).updated_at).to eq_time(original_vector_gen) end + + context "when one of the concurrently generated embeddings fails" do + it "still processes the succesful ones" do + text = vdef.prepare_target_text(topic) + + text2 = vdef.prepare_target_text(topic_2) + + stub_vector_mapping(text, expected_embedding_1) + stub_vector_mapping(text2, expected_embedding_2, result_status: 429) + + vector.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id])) + + expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id) + expect(topics_schema.find_by_target(topic_2)).to be_nil + end + end end end context "with open_ai as the provider" do fab!(:vdef) { Fabricate(:open_ai_embedding_def) } - def stub_vector_mapping(text, expected_embedding) + def stub_vector_mapping(text, expected_embedding, result_status: 200) EmbeddingsGenerationStubs.openai_service( vdef.lookup_custom_param("model_name"), text, expected_embedding, + result_status: result_status, ) end @@ -123,8 +140,12 @@ RSpec.describe DiscourseAi::Embeddings::Vector do context "with hugging_face as the provider" do fab!(:vdef) { Fabricate(:embedding_definition) } - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding) + def stub_vector_mapping(text, expected_embedding, result_status: 200) + EmbeddingsGenerationStubs.hugging_face_service( + text, + expected_embedding, + result_status: result_status, + ) end it_behaves_like "generates and store embeddings using a vector definition" @@ -133,8 +154,13 @@ RSpec.describe DiscourseAi::Embeddings::Vector do context "with google as the provider" do fab!(:vdef) { Fabricate(:gemini_embedding_def) } - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.gemini_service(vdef.api_key, text, expected_embedding) + def stub_vector_mapping(text, expected_embedding, result_status: 200) + EmbeddingsGenerationStubs.gemini_service( + vdef.api_key, + text, + expected_embedding, + result_status: result_status, + ) end it_behaves_like "generates and store embeddings using a vector definition" @@ -143,8 +169,12 @@ RSpec.describe DiscourseAi::Embeddings::Vector do context "with cloudflare as the provider" do fab!(:vdef) { Fabricate(:cloudflare_embedding_def) } - def stub_vector_mapping(text, expected_embedding) - EmbeddingsGenerationStubs.cloudflare_service(text, expected_embedding) + def stub_vector_mapping(text, expected_embedding, result_status: 200) + EmbeddingsGenerationStubs.cloudflare_service( + text, + expected_embedding, + result_status: result_status, + ) end it_behaves_like "generates and store embeddings using a vector definition" diff --git a/spec/support/embeddings_generation_stubs.rb b/spec/support/embeddings_generation_stubs.rb index 48da06eb..50bed368 100644 --- a/spec/support/embeddings_generation_stubs.rb +++ b/spec/support/embeddings_generation_stubs.rb @@ -2,35 +2,35 @@ class EmbeddingsGenerationStubs class << self - def hugging_face_service(string, embedding) + def hugging_face_service(string, embedding, result_status: 200) WebMock .stub_request(:post, "https://test.com/embeddings") .with(body: JSON.dump({ inputs: string, truncate: true })) - .to_return(status: 200, body: JSON.dump([embedding])) + .to_return(status: result_status, body: JSON.dump([embedding])) end - def openai_service(model, string, embedding, extra_args: {}) + def openai_service(model, string, embedding, result_status: 200, extra_args: {}) WebMock .stub_request(:post, "https://api.openai.com/v1/embeddings") .with(body: JSON.dump({ model: model, input: string }.merge(extra_args))) - .to_return(status: 200, body: JSON.dump({ data: [{ embedding: embedding }] })) + .to_return(status: result_status, body: JSON.dump({ data: [{ embedding: embedding }] })) end - def gemini_service(api_key, string, embedding) + def gemini_service(api_key, string, embedding, result_status: 200) WebMock .stub_request( :post, "https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}", ) .with(body: JSON.dump({ content: { parts: [{ text: string }] } })) - .to_return(status: 200, body: JSON.dump({ embedding: { values: embedding } })) + .to_return(status: result_status, body: JSON.dump({ embedding: { values: embedding } })) end - def cloudflare_service(string, embedding) + def cloudflare_service(string, embedding, result_status: 200) WebMock .stub_request(:post, "https://test.com/embeddings") .with(body: JSON.dump({ text: [string] })) - .to_return(status: 200, body: JSON.dump({ result: { data: [embedding] } })) + .to_return(status: result_status, body: JSON.dump({ result: { data: [embedding] } })) end end end