mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-07-13 01:23:29 +00:00
FIX: Process succesfully generated embeddings even if some failed (#1500)
This commit is contained in:
parent
6b5ea38644
commit
89bcf9b1f0
@ -42,13 +42,14 @@ module DiscourseAi
|
|||||||
.then_on(pool) do |w_prepared_text|
|
.then_on(pool) do |w_prepared_text|
|
||||||
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
|
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
|
||||||
end
|
end
|
||||||
|
.rescue { nil } # We log the error during #perform. Skip failed embeddings.
|
||||||
end
|
end
|
||||||
.compact
|
.compact
|
||||||
|
|
||||||
Concurrent::Promises
|
Concurrent::Promises
|
||||||
.zip(*promised_embeddings)
|
.zip(*promised_embeddings)
|
||||||
.value!
|
.value!
|
||||||
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
|
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) if e.present? }
|
||||||
ensure
|
ensure
|
||||||
pool.shutdown
|
pool.shutdown
|
||||||
pool.wait_for_termination
|
pool.wait_for_termination
|
||||||
|
@ -26,8 +26,6 @@ module ::DiscourseAi
|
|||||||
case response.status
|
case response.status
|
||||||
when 200
|
when 200
|
||||||
JSON.parse(response.body, symbolize_names: true).dig(:result, :data).first
|
JSON.parse(response.body, symbolize_names: true).dig(:result, :data).first
|
||||||
when 429
|
|
||||||
# TODO add a AdminDashboard Problem?
|
|
||||||
else
|
else
|
||||||
Rails.logger.warn(
|
Rails.logger.warn(
|
||||||
"Cloudflare Workers AI Embeddings failed with status: #{response.status} body: #{response.body}",
|
"Cloudflare Workers AI Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||||
|
@ -22,8 +22,6 @@ module ::DiscourseAi
|
|||||||
case response.status
|
case response.status
|
||||||
when 200
|
when 200
|
||||||
JSON.parse(response.body, symbolize_names: true).dig(:embedding, :values)
|
JSON.parse(response.body, symbolize_names: true).dig(:embedding, :values)
|
||||||
when 429
|
|
||||||
# TODO add a AdminDashboard Problem?
|
|
||||||
else
|
else
|
||||||
Rails.logger.warn(
|
Rails.logger.warn(
|
||||||
"Google Gemini Embeddings failed with status: #{response.status} body: #{response.body}",
|
"Google Gemini Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||||
|
@ -30,8 +30,6 @@ module ::DiscourseAi
|
|||||||
case response.status
|
case response.status
|
||||||
when 200
|
when 200
|
||||||
JSON.parse(response.body, symbolize_names: true).dig(:data, 0, :embedding)
|
JSON.parse(response.body, symbolize_names: true).dig(:data, 0, :embedding)
|
||||||
when 429
|
|
||||||
# TODO add a AdminDashboard Problem?
|
|
||||||
else
|
else
|
||||||
Rails.logger.warn(
|
Rails.logger.warn(
|
||||||
"OpenAI Embeddings failed with status: #{response.status} body: #{response.body}",
|
"OpenAI Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||||
|
@ -40,16 +40,7 @@ RSpec.describe DiscourseAi::Inference::CloudflareWorkersAi do
|
|||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
context "when the response status is 429" do
|
context "when the response status is not 200" do
|
||||||
let(:response_status) { 429 }
|
|
||||||
let(:response_body) { "" }
|
|
||||||
|
|
||||||
it "doesn't raises a Net::HTTPBadResponse error" do
|
|
||||||
expect { subject.perform!(content) }.not_to raise_error
|
|
||||||
end
|
|
||||||
end
|
|
||||||
|
|
||||||
context "when the response status is not 200 or 429" do
|
|
||||||
let(:response_status) { 500 }
|
let(:response_status) { 500 }
|
||||||
let(:response_body) { "Internal Server Error" }
|
let(:response_body) { "Internal Server Error" }
|
||||||
|
|
||||||
|
@ -83,17 +83,34 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
|||||||
|
|
||||||
expect(topics_schema.find_by_target(topic).updated_at).to eq_time(original_vector_gen)
|
expect(topics_schema.find_by_target(topic).updated_at).to eq_time(original_vector_gen)
|
||||||
end
|
end
|
||||||
|
|
||||||
|
context "when one of the concurrently generated embeddings fails" do
|
||||||
|
it "still processes the succesful ones" do
|
||||||
|
text = vdef.prepare_target_text(topic)
|
||||||
|
|
||||||
|
text2 = vdef.prepare_target_text(topic_2)
|
||||||
|
|
||||||
|
stub_vector_mapping(text, expected_embedding_1)
|
||||||
|
stub_vector_mapping(text2, expected_embedding_2, result_status: 429)
|
||||||
|
|
||||||
|
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
|
||||||
|
|
||||||
|
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||||
|
expect(topics_schema.find_by_target(topic_2)).to be_nil
|
||||||
|
end
|
||||||
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
|
||||||
context "with open_ai as the provider" do
|
context "with open_ai as the provider" do
|
||||||
fab!(:vdef) { Fabricate(:open_ai_embedding_def) }
|
fab!(:vdef) { Fabricate(:open_ai_embedding_def) }
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
def stub_vector_mapping(text, expected_embedding, result_status: 200)
|
||||||
EmbeddingsGenerationStubs.openai_service(
|
EmbeddingsGenerationStubs.openai_service(
|
||||||
vdef.lookup_custom_param("model_name"),
|
vdef.lookup_custom_param("model_name"),
|
||||||
text,
|
text,
|
||||||
expected_embedding,
|
expected_embedding,
|
||||||
|
result_status: result_status,
|
||||||
)
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
@ -123,8 +140,12 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
|||||||
context "with hugging_face as the provider" do
|
context "with hugging_face as the provider" do
|
||||||
fab!(:vdef) { Fabricate(:embedding_definition) }
|
fab!(:vdef) { Fabricate(:embedding_definition) }
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
def stub_vector_mapping(text, expected_embedding, result_status: 200)
|
||||||
EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding)
|
EmbeddingsGenerationStubs.hugging_face_service(
|
||||||
|
text,
|
||||||
|
expected_embedding,
|
||||||
|
result_status: result_status,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
it_behaves_like "generates and store embeddings using a vector definition"
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
@ -133,8 +154,13 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
|||||||
context "with google as the provider" do
|
context "with google as the provider" do
|
||||||
fab!(:vdef) { Fabricate(:gemini_embedding_def) }
|
fab!(:vdef) { Fabricate(:gemini_embedding_def) }
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
def stub_vector_mapping(text, expected_embedding, result_status: 200)
|
||||||
EmbeddingsGenerationStubs.gemini_service(vdef.api_key, text, expected_embedding)
|
EmbeddingsGenerationStubs.gemini_service(
|
||||||
|
vdef.api_key,
|
||||||
|
text,
|
||||||
|
expected_embedding,
|
||||||
|
result_status: result_status,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
it_behaves_like "generates and store embeddings using a vector definition"
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
@ -143,8 +169,12 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
|||||||
context "with cloudflare as the provider" do
|
context "with cloudflare as the provider" do
|
||||||
fab!(:vdef) { Fabricate(:cloudflare_embedding_def) }
|
fab!(:vdef) { Fabricate(:cloudflare_embedding_def) }
|
||||||
|
|
||||||
def stub_vector_mapping(text, expected_embedding)
|
def stub_vector_mapping(text, expected_embedding, result_status: 200)
|
||||||
EmbeddingsGenerationStubs.cloudflare_service(text, expected_embedding)
|
EmbeddingsGenerationStubs.cloudflare_service(
|
||||||
|
text,
|
||||||
|
expected_embedding,
|
||||||
|
result_status: result_status,
|
||||||
|
)
|
||||||
end
|
end
|
||||||
|
|
||||||
it_behaves_like "generates and store embeddings using a vector definition"
|
it_behaves_like "generates and store embeddings using a vector definition"
|
||||||
|
@ -2,35 +2,35 @@
|
|||||||
|
|
||||||
class EmbeddingsGenerationStubs
|
class EmbeddingsGenerationStubs
|
||||||
class << self
|
class << self
|
||||||
def hugging_face_service(string, embedding)
|
def hugging_face_service(string, embedding, result_status: 200)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://test.com/embeddings")
|
.stub_request(:post, "https://test.com/embeddings")
|
||||||
.with(body: JSON.dump({ inputs: string, truncate: true }))
|
.with(body: JSON.dump({ inputs: string, truncate: true }))
|
||||||
.to_return(status: 200, body: JSON.dump([embedding]))
|
.to_return(status: result_status, body: JSON.dump([embedding]))
|
||||||
end
|
end
|
||||||
|
|
||||||
def openai_service(model, string, embedding, extra_args: {})
|
def openai_service(model, string, embedding, result_status: 200, extra_args: {})
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://api.openai.com/v1/embeddings")
|
.stub_request(:post, "https://api.openai.com/v1/embeddings")
|
||||||
.with(body: JSON.dump({ model: model, input: string }.merge(extra_args)))
|
.with(body: JSON.dump({ model: model, input: string }.merge(extra_args)))
|
||||||
.to_return(status: 200, body: JSON.dump({ data: [{ embedding: embedding }] }))
|
.to_return(status: result_status, body: JSON.dump({ data: [{ embedding: embedding }] }))
|
||||||
end
|
end
|
||||||
|
|
||||||
def gemini_service(api_key, string, embedding)
|
def gemini_service(api_key, string, embedding, result_status: 200)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(
|
.stub_request(
|
||||||
:post,
|
:post,
|
||||||
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}",
|
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}",
|
||||||
)
|
)
|
||||||
.with(body: JSON.dump({ content: { parts: [{ text: string }] } }))
|
.with(body: JSON.dump({ content: { parts: [{ text: string }] } }))
|
||||||
.to_return(status: 200, body: JSON.dump({ embedding: { values: embedding } }))
|
.to_return(status: result_status, body: JSON.dump({ embedding: { values: embedding } }))
|
||||||
end
|
end
|
||||||
|
|
||||||
def cloudflare_service(string, embedding)
|
def cloudflare_service(string, embedding, result_status: 200)
|
||||||
WebMock
|
WebMock
|
||||||
.stub_request(:post, "https://test.com/embeddings")
|
.stub_request(:post, "https://test.com/embeddings")
|
||||||
.with(body: JSON.dump({ text: [string] }))
|
.with(body: JSON.dump({ text: [string] }))
|
||||||
.to_return(status: 200, body: JSON.dump({ result: { data: [embedding] } }))
|
.to_return(status: result_status, body: JSON.dump({ result: { data: [embedding] } }))
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
end
|
end
|
||||||
|
Loading…
x
Reference in New Issue
Block a user