FIX: Process succesfully generated embeddings even if some failed (#1500)

This commit is contained in:
Roman Rizzi 2025-07-10 17:51:01 -03:00 committed by GitHub
parent 6b5ea38644
commit 89bcf9b1f0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
7 changed files with 48 additions and 32 deletions

View File

@ -42,13 +42,14 @@ module DiscourseAi
.then_on(pool) do |w_prepared_text|
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
end
.rescue { nil } # We log the error during #perform. Skip failed embeddings.
end
.compact
Concurrent::Promises
.zip(*promised_embeddings)
.value!
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) if e.present? }
ensure
pool.shutdown
pool.wait_for_termination

View File

@ -26,8 +26,6 @@ module ::DiscourseAi
case response.status
when 200
JSON.parse(response.body, symbolize_names: true).dig(:result, :data).first
when 429
# TODO add a AdminDashboard Problem?
else
Rails.logger.warn(
"Cloudflare Workers AI Embeddings failed with status: #{response.status} body: #{response.body}",

View File

@ -22,8 +22,6 @@ module ::DiscourseAi
case response.status
when 200
JSON.parse(response.body, symbolize_names: true).dig(:embedding, :values)
when 429
# TODO add a AdminDashboard Problem?
else
Rails.logger.warn(
"Google Gemini Embeddings failed with status: #{response.status} body: #{response.body}",

View File

@ -30,8 +30,6 @@ module ::DiscourseAi
case response.status
when 200
JSON.parse(response.body, symbolize_names: true).dig(:data, 0, :embedding)
when 429
# TODO add a AdminDashboard Problem?
else
Rails.logger.warn(
"OpenAI Embeddings failed with status: #{response.status} body: #{response.body}",

View File

@ -40,16 +40,7 @@ RSpec.describe DiscourseAi::Inference::CloudflareWorkersAi do
end
end
context "when the response status is 429" do
let(:response_status) { 429 }
let(:response_body) { "" }
it "doesn't raises a Net::HTTPBadResponse error" do
expect { subject.perform!(content) }.not_to raise_error
end
end
context "when the response status is not 200 or 429" do
context "when the response status is not 200" do
let(:response_status) { 500 }
let(:response_body) { "Internal Server Error" }

View File

@ -83,17 +83,34 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
expect(topics_schema.find_by_target(topic).updated_at).to eq_time(original_vector_gen)
end
context "when one of the concurrently generated embeddings fails" do
it "still processes the succesful ones" do
text = vdef.prepare_target_text(topic)
text2 = vdef.prepare_target_text(topic_2)
stub_vector_mapping(text, expected_embedding_1)
stub_vector_mapping(text2, expected_embedding_2, result_status: 429)
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
expect(topics_schema.find_by_target(topic_2)).to be_nil
end
end
end
end
context "with open_ai as the provider" do
fab!(:vdef) { Fabricate(:open_ai_embedding_def) }
def stub_vector_mapping(text, expected_embedding)
def stub_vector_mapping(text, expected_embedding, result_status: 200)
EmbeddingsGenerationStubs.openai_service(
vdef.lookup_custom_param("model_name"),
text,
expected_embedding,
result_status: result_status,
)
end
@ -123,8 +140,12 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
context "with hugging_face as the provider" do
fab!(:vdef) { Fabricate(:embedding_definition) }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding)
def stub_vector_mapping(text, expected_embedding, result_status: 200)
EmbeddingsGenerationStubs.hugging_face_service(
text,
expected_embedding,
result_status: result_status,
)
end
it_behaves_like "generates and store embeddings using a vector definition"
@ -133,8 +154,13 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
context "with google as the provider" do
fab!(:vdef) { Fabricate(:gemini_embedding_def) }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.gemini_service(vdef.api_key, text, expected_embedding)
def stub_vector_mapping(text, expected_embedding, result_status: 200)
EmbeddingsGenerationStubs.gemini_service(
vdef.api_key,
text,
expected_embedding,
result_status: result_status,
)
end
it_behaves_like "generates and store embeddings using a vector definition"
@ -143,8 +169,12 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
context "with cloudflare as the provider" do
fab!(:vdef) { Fabricate(:cloudflare_embedding_def) }
def stub_vector_mapping(text, expected_embedding)
EmbeddingsGenerationStubs.cloudflare_service(text, expected_embedding)
def stub_vector_mapping(text, expected_embedding, result_status: 200)
EmbeddingsGenerationStubs.cloudflare_service(
text,
expected_embedding,
result_status: result_status,
)
end
it_behaves_like "generates and store embeddings using a vector definition"

View File

@ -2,35 +2,35 @@
class EmbeddingsGenerationStubs
class << self
def hugging_face_service(string, embedding)
def hugging_face_service(string, embedding, result_status: 200)
WebMock
.stub_request(:post, "https://test.com/embeddings")
.with(body: JSON.dump({ inputs: string, truncate: true }))
.to_return(status: 200, body: JSON.dump([embedding]))
.to_return(status: result_status, body: JSON.dump([embedding]))
end
def openai_service(model, string, embedding, extra_args: {})
def openai_service(model, string, embedding, result_status: 200, extra_args: {})
WebMock
.stub_request(:post, "https://api.openai.com/v1/embeddings")
.with(body: JSON.dump({ model: model, input: string }.merge(extra_args)))
.to_return(status: 200, body: JSON.dump({ data: [{ embedding: embedding }] }))
.to_return(status: result_status, body: JSON.dump({ data: [{ embedding: embedding }] }))
end
def gemini_service(api_key, string, embedding)
def gemini_service(api_key, string, embedding, result_status: 200)
WebMock
.stub_request(
:post,
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}",
)
.with(body: JSON.dump({ content: { parts: [{ text: string }] } }))
.to_return(status: 200, body: JSON.dump({ embedding: { values: embedding } }))
.to_return(status: result_status, body: JSON.dump({ embedding: { values: embedding } }))
end
def cloudflare_service(string, embedding)
def cloudflare_service(string, embedding, result_status: 200)
WebMock
.stub_request(:post, "https://test.com/embeddings")
.with(body: JSON.dump({ text: [string] }))
.to_return(status: 200, body: JSON.dump({ result: { data: [embedding] } }))
.to_return(status: result_status, body: JSON.dump({ result: { data: [embedding] } }))
end
end
end