mirror of
https://github.com/discourse/discourse-ai.git
synced 2025-07-12 17:13:29 +00:00
FIX: Process succesfully generated embeddings even if some failed (#1500)
This commit is contained in:
parent
6b5ea38644
commit
89bcf9b1f0
@ -42,13 +42,14 @@ module DiscourseAi
|
||||
.then_on(pool) do |w_prepared_text|
|
||||
w_prepared_text.merge(embedding: embedding_gen.perform!(w_prepared_text[:text]))
|
||||
end
|
||||
.rescue { nil } # We log the error during #perform. Skip failed embeddings.
|
||||
end
|
||||
.compact
|
||||
|
||||
Concurrent::Promises
|
||||
.zip(*promised_embeddings)
|
||||
.value!
|
||||
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) }
|
||||
.each { |e| schema.store(e[:target], e[:embedding], e[:digest]) if e.present? }
|
||||
ensure
|
||||
pool.shutdown
|
||||
pool.wait_for_termination
|
||||
|
@ -26,8 +26,6 @@ module ::DiscourseAi
|
||||
case response.status
|
||||
when 200
|
||||
JSON.parse(response.body, symbolize_names: true).dig(:result, :data).first
|
||||
when 429
|
||||
# TODO add a AdminDashboard Problem?
|
||||
else
|
||||
Rails.logger.warn(
|
||||
"Cloudflare Workers AI Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||
|
@ -22,8 +22,6 @@ module ::DiscourseAi
|
||||
case response.status
|
||||
when 200
|
||||
JSON.parse(response.body, symbolize_names: true).dig(:embedding, :values)
|
||||
when 429
|
||||
# TODO add a AdminDashboard Problem?
|
||||
else
|
||||
Rails.logger.warn(
|
||||
"Google Gemini Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||
|
@ -30,8 +30,6 @@ module ::DiscourseAi
|
||||
case response.status
|
||||
when 200
|
||||
JSON.parse(response.body, symbolize_names: true).dig(:data, 0, :embedding)
|
||||
when 429
|
||||
# TODO add a AdminDashboard Problem?
|
||||
else
|
||||
Rails.logger.warn(
|
||||
"OpenAI Embeddings failed with status: #{response.status} body: #{response.body}",
|
||||
|
@ -40,16 +40,7 @@ RSpec.describe DiscourseAi::Inference::CloudflareWorkersAi do
|
||||
end
|
||||
end
|
||||
|
||||
context "when the response status is 429" do
|
||||
let(:response_status) { 429 }
|
||||
let(:response_body) { "" }
|
||||
|
||||
it "doesn't raises a Net::HTTPBadResponse error" do
|
||||
expect { subject.perform!(content) }.not_to raise_error
|
||||
end
|
||||
end
|
||||
|
||||
context "when the response status is not 200 or 429" do
|
||||
context "when the response status is not 200" do
|
||||
let(:response_status) { 500 }
|
||||
let(:response_body) { "Internal Server Error" }
|
||||
|
||||
|
@ -83,17 +83,34 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
||||
|
||||
expect(topics_schema.find_by_target(topic).updated_at).to eq_time(original_vector_gen)
|
||||
end
|
||||
|
||||
context "when one of the concurrently generated embeddings fails" do
|
||||
it "still processes the succesful ones" do
|
||||
text = vdef.prepare_target_text(topic)
|
||||
|
||||
text2 = vdef.prepare_target_text(topic_2)
|
||||
|
||||
stub_vector_mapping(text, expected_embedding_1)
|
||||
stub_vector_mapping(text2, expected_embedding_2, result_status: 429)
|
||||
|
||||
vector.gen_bulk_reprensentations(Topic.where(id: [topic.id, topic_2.id]))
|
||||
|
||||
expect(topics_schema.find_by_embedding(expected_embedding_1).topic_id).to eq(topic.id)
|
||||
expect(topics_schema.find_by_target(topic_2)).to be_nil
|
||||
end
|
||||
end
|
||||
end
|
||||
end
|
||||
|
||||
context "with open_ai as the provider" do
|
||||
fab!(:vdef) { Fabricate(:open_ai_embedding_def) }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
def stub_vector_mapping(text, expected_embedding, result_status: 200)
|
||||
EmbeddingsGenerationStubs.openai_service(
|
||||
vdef.lookup_custom_param("model_name"),
|
||||
text,
|
||||
expected_embedding,
|
||||
result_status: result_status,
|
||||
)
|
||||
end
|
||||
|
||||
@ -123,8 +140,12 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
||||
context "with hugging_face as the provider" do
|
||||
fab!(:vdef) { Fabricate(:embedding_definition) }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.hugging_face_service(text, expected_embedding)
|
||||
def stub_vector_mapping(text, expected_embedding, result_status: 200)
|
||||
EmbeddingsGenerationStubs.hugging_face_service(
|
||||
text,
|
||||
expected_embedding,
|
||||
result_status: result_status,
|
||||
)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
@ -133,8 +154,13 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
||||
context "with google as the provider" do
|
||||
fab!(:vdef) { Fabricate(:gemini_embedding_def) }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.gemini_service(vdef.api_key, text, expected_embedding)
|
||||
def stub_vector_mapping(text, expected_embedding, result_status: 200)
|
||||
EmbeddingsGenerationStubs.gemini_service(
|
||||
vdef.api_key,
|
||||
text,
|
||||
expected_embedding,
|
||||
result_status: result_status,
|
||||
)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
@ -143,8 +169,12 @@ RSpec.describe DiscourseAi::Embeddings::Vector do
|
||||
context "with cloudflare as the provider" do
|
||||
fab!(:vdef) { Fabricate(:cloudflare_embedding_def) }
|
||||
|
||||
def stub_vector_mapping(text, expected_embedding)
|
||||
EmbeddingsGenerationStubs.cloudflare_service(text, expected_embedding)
|
||||
def stub_vector_mapping(text, expected_embedding, result_status: 200)
|
||||
EmbeddingsGenerationStubs.cloudflare_service(
|
||||
text,
|
||||
expected_embedding,
|
||||
result_status: result_status,
|
||||
)
|
||||
end
|
||||
|
||||
it_behaves_like "generates and store embeddings using a vector definition"
|
||||
|
@ -2,35 +2,35 @@
|
||||
|
||||
class EmbeddingsGenerationStubs
|
||||
class << self
|
||||
def hugging_face_service(string, embedding)
|
||||
def hugging_face_service(string, embedding, result_status: 200)
|
||||
WebMock
|
||||
.stub_request(:post, "https://test.com/embeddings")
|
||||
.with(body: JSON.dump({ inputs: string, truncate: true }))
|
||||
.to_return(status: 200, body: JSON.dump([embedding]))
|
||||
.to_return(status: result_status, body: JSON.dump([embedding]))
|
||||
end
|
||||
|
||||
def openai_service(model, string, embedding, extra_args: {})
|
||||
def openai_service(model, string, embedding, result_status: 200, extra_args: {})
|
||||
WebMock
|
||||
.stub_request(:post, "https://api.openai.com/v1/embeddings")
|
||||
.with(body: JSON.dump({ model: model, input: string }.merge(extra_args)))
|
||||
.to_return(status: 200, body: JSON.dump({ data: [{ embedding: embedding }] }))
|
||||
.to_return(status: result_status, body: JSON.dump({ data: [{ embedding: embedding }] }))
|
||||
end
|
||||
|
||||
def gemini_service(api_key, string, embedding)
|
||||
def gemini_service(api_key, string, embedding, result_status: 200)
|
||||
WebMock
|
||||
.stub_request(
|
||||
:post,
|
||||
"https://generativelanguage.googleapis.com/v1beta/models/embedding-001:embedContent\?key\=#{api_key}",
|
||||
)
|
||||
.with(body: JSON.dump({ content: { parts: [{ text: string }] } }))
|
||||
.to_return(status: 200, body: JSON.dump({ embedding: { values: embedding } }))
|
||||
.to_return(status: result_status, body: JSON.dump({ embedding: { values: embedding } }))
|
||||
end
|
||||
|
||||
def cloudflare_service(string, embedding)
|
||||
def cloudflare_service(string, embedding, result_status: 200)
|
||||
WebMock
|
||||
.stub_request(:post, "https://test.com/embeddings")
|
||||
.with(body: JSON.dump({ text: [string] }))
|
||||
.to_return(status: 200, body: JSON.dump({ result: { data: [embedding] } }))
|
||||
.to_return(status: result_status, body: JSON.dump({ result: { data: [embedding] } }))
|
||||
end
|
||||
end
|
||||
end
|
||||
|
Loading…
x
Reference in New Issue
Block a user