diff --git a/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java b/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java index 80cd4165f65..8316918893a 100644 --- a/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java +++ b/server/src/main/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseFactory.java @@ -27,7 +27,6 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.ObjectMapper; import com.fasterxml.jackson.jaxrs.smile.SmileMediaTypes; import com.google.common.annotations.VisibleForTesting; -import com.google.common.base.Optional; import com.google.common.base.Preconditions; import com.google.common.base.Throwables; import com.google.common.collect.ImmutableMap; @@ -69,8 +68,11 @@ import java.util.Collection; import java.util.List; import java.util.Map; import java.util.NoSuchElementException; +import java.util.Optional; import java.util.concurrent.ArrayBlockingQueue; import java.util.concurrent.BlockingQueue; +import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.TimeUnit; import java.util.concurrent.atomic.AtomicLong; @@ -81,8 +83,10 @@ import java.util.concurrent.atomic.AtomicLong; */ public class EventReceiverFirehoseFactory implements FirehoseFactory { + public static final int MAX_FIREHOSE_PRODUCERS = 10_000; + private static final EmittingLogger log = new EmittingLogger(EventReceiverFirehoseFactory.class); - private static final int DEFAULT_BUFFER_SIZE = 100000; + private static final int DEFAULT_BUFFER_SIZE = 100_000; private final String serviceName; private final int bufferSize; @@ -107,7 +111,7 @@ public class EventReceiverFirehoseFactory implements FirehoseFactory producerSequences = new ConcurrentHashMap<>(); public EventReceiverFirehose(MapInputRowParser parser) { @@ -173,7 +178,7 @@ public class EventReceiverFirehoseFactory implements FirehoseFactory producerSequenceResponse = checkProducerSequence(req, reqContentType, objectMapper); + if (producerSequenceResponse.isPresent()) { + return producerSequenceResponse.get(); + } + CountingInputStream countingInputStream = new CountingInputStream(in); Collection> events = null; try { @@ -393,5 +404,81 @@ public class EventReceiverFirehoseFactory implements FirehoseFactory checkProducerSequence( + final HttpServletRequest req, + final String responseContentType, + final ObjectMapper responseMapper + ) + { + final String producerId = req.getHeader("X-Firehose-Producer-Id"); + + if (producerId == null) { + return Optional.empty(); + } + + final String sequenceValue = req.getHeader("X-Firehose-Producer-Seq"); + + if (sequenceValue == null) { + return Optional.of( + Response.status(Response.Status.BAD_REQUEST) + .entity(ImmutableMap.of("error", "Producer sequence value is missing")) + .build() + ); + } + + Long producerSequence = producerSequences.computeIfAbsent(producerId, key -> Long.MIN_VALUE); + + if (producerSequences.size() >= MAX_FIREHOSE_PRODUCERS) { + return Optional.of( + Response.status(Response.Status.FORBIDDEN) + .entity( + ImmutableMap.of( + "error", + "Too many individual producer IDs for this firehose. Max is " + MAX_FIREHOSE_PRODUCERS + ) + ) + .build() + ); + } + + try { + Long newSequence = Long.parseLong(sequenceValue); + if (newSequence <= producerSequence) { + return Optional.of( + Response.ok( + responseMapper.writeValueAsString( + ImmutableMap.of("eventCount", 0, "skipped", true) + ), + responseContentType + ).build() + ); + } + + producerSequences.put(producerId, newSequence); + } + catch (JsonProcessingException ex) { + throw Throwables.propagate(ex); + } + catch (NumberFormatException ex) { + return Optional.of( + Response.status(Response.Status.BAD_REQUEST) + .entity(ImmutableMap.of("error", "Producer sequence must be a number")) + .build() + ); + } + + return Optional.empty(); + } } } diff --git a/server/src/test/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseTest.java b/server/src/test/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseTest.java index a2329cb44f7..27b2ec18d0d 100644 --- a/server/src/test/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseTest.java +++ b/server/src/test/java/io/druid/segment/realtime/firehose/EventReceiverFirehoseTest.java @@ -41,8 +41,10 @@ import org.junit.Before; import org.junit.Test; import javax.servlet.http.HttpServletRequest; +import javax.ws.rs.core.Response; import java.io.IOException; import java.io.InputStream; +import java.nio.charset.StandardCharsets; import java.util.Map; import java.util.concurrent.Callable; import java.util.concurrent.ExecutionException; @@ -99,19 +101,9 @@ public class EventReceiverFirehoseTest @Test public void testSingleThread() throws IOException { - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) - .andReturn(null) - .anyTimes(); - EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) - .andReturn(AllowAllAuthenticator.ALLOW_ALL_RESULT) - .anyTimes(); - req.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); - EasyMock.expectLastCall().anyTimes(); - EasyMock.expect(req.getContentType()).andReturn("application/json").times(NUM_EVENTS); - EasyMock.replay(req); - for (int i = 0; i < NUM_EVENTS; ++i) { - final InputStream inputStream = IOUtils.toInputStream(inputRow); + setUpRequestExpectations(null, null); + final InputStream inputStream = IOUtils.toInputStream(inputRow, StandardCharsets.UTF_8); firehose.addAll(inputStream, req); Assert.assertEquals(i + 1, firehose.getCurrentBufferSize()); inputStream.close(); @@ -159,6 +151,7 @@ public class EventReceiverFirehoseTest EasyMock.expectLastCall().anyTimes(); EasyMock.expect(req.getContentType()).andReturn("application/json").times(2 * NUM_EVENTS); + EasyMock.expect(req.getHeader("X-Firehose-Producer-Id")).andReturn(null).times(2 * NUM_EVENTS); EasyMock.replay(req); final ExecutorService executorService = Execs.singleThreaded("single_thread"); @@ -169,7 +162,7 @@ public class EventReceiverFirehoseTest public Boolean call() throws Exception { for (int i = 0; i < NUM_EVENTS; ++i) { - final InputStream inputStream = IOUtils.toInputStream(inputRow); + final InputStream inputStream = IOUtils.toInputStream(inputRow, StandardCharsets.UTF_8); firehose.addAll(inputStream, req); inputStream.close(); } @@ -179,7 +172,7 @@ public class EventReceiverFirehoseTest ); for (int i = 0; i < NUM_EVENTS; ++i) { - final InputStream inputStream = IOUtils.toInputStream(inputRow); + final InputStream inputStream = IOUtils.toInputStream(inputRow, StandardCharsets.UTF_8); firehose.addAll(inputStream, req); inputStream.close(); } @@ -284,4 +277,145 @@ public class EventReceiverFirehoseTest Thread.sleep(50); } } + + @Test + public void testProducerSequence() throws IOException + { + for (int i = 0; i < NUM_EVENTS; ++i) { + setUpRequestExpectations("producer", String.valueOf(i)); + + final InputStream inputStream = IOUtils.toInputStream(inputRow, StandardCharsets.UTF_8); + firehose.addAll(inputStream, req); + Assert.assertEquals(i + 1, firehose.getCurrentBufferSize()); + inputStream.close(); + } + + EasyMock.verify(req); + + final Iterable> metrics = register.getMetrics(); + Assert.assertEquals(1, Iterables.size(metrics)); + + final Map.Entry entry = Iterables.getLast(metrics); + Assert.assertEquals(SERVICE_NAME, entry.getKey()); + Assert.assertEquals(CAPACITY, entry.getValue().getCapacity()); + Assert.assertEquals(CAPACITY, firehose.getCapacity()); + Assert.assertEquals(NUM_EVENTS, entry.getValue().getCurrentBufferSize()); + Assert.assertEquals(NUM_EVENTS, firehose.getCurrentBufferSize()); + + for (int i = NUM_EVENTS - 1; i >= 0; --i) { + Assert.assertTrue(firehose.hasMore()); + Assert.assertNotNull(firehose.nextRow()); + Assert.assertEquals(i, firehose.getCurrentBufferSize()); + } + + Assert.assertEquals(CAPACITY, entry.getValue().getCapacity()); + Assert.assertEquals(CAPACITY, firehose.getCapacity()); + Assert.assertEquals(0, entry.getValue().getCurrentBufferSize()); + Assert.assertEquals(0, firehose.getCurrentBufferSize()); + + firehose.close(); + Assert.assertFalse(firehose.hasMore()); + Assert.assertEquals(0, Iterables.size(register.getMetrics())); + + } + + @Test + public void testLowProducerSequence() throws IOException + { + for (int i = 0; i < NUM_EVENTS; ++i) { + setUpRequestExpectations("producer", "1"); + + final InputStream inputStream = IOUtils.toInputStream(inputRow, StandardCharsets.UTF_8); + final Response response = firehose.addAll(inputStream, req); + Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + Assert.assertEquals(1, firehose.getCurrentBufferSize()); + inputStream.close(); + } + + EasyMock.verify(req); + + firehose.close(); + } + + @Test + public void testMissingProducerSequence() throws IOException + { + setUpRequestExpectations("producer", null); + + final InputStream inputStream = IOUtils.toInputStream(inputRow, StandardCharsets.UTF_8); + final Response response = firehose.addAll(inputStream, req); + + Assert.assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), response.getStatus()); + + inputStream.close(); + + EasyMock.verify(req); + + firehose.close(); + } + + @Test + public void testTooManyProducerIds() throws IOException + { + for (int i = 0; i < EventReceiverFirehoseFactory.MAX_FIREHOSE_PRODUCERS - 1; i++) { + setUpRequestExpectations("producer-" + i, "0"); + + final InputStream inputStream = IOUtils.toInputStream(inputRow, StandardCharsets.UTF_8); + final Response response = firehose.addAll(inputStream, req); + Assert.assertEquals(Response.Status.OK.getStatusCode(), response.getStatus()); + inputStream.close(); + Assert.assertTrue(firehose.hasMore()); + Assert.assertNotNull(firehose.nextRow()); + } + + setUpRequestExpectations("toomany", "0"); + + final InputStream inputStream = IOUtils.toInputStream(inputRow, StandardCharsets.UTF_8); + final Response response = firehose.addAll(inputStream, req); + Assert.assertEquals(Response.Status.FORBIDDEN.getStatusCode(), response.getStatus()); + inputStream.close(); + + EasyMock.verify(req); + + firehose.close(); + } + + @Test + public void testNaNProducerSequence() throws IOException + { + setUpRequestExpectations("producer", "foo"); + + final InputStream inputStream = IOUtils.toInputStream(inputRow, StandardCharsets.UTF_8); + final Response response = firehose.addAll(inputStream, req); + + Assert.assertEquals(Response.Status.BAD_REQUEST.getStatusCode(), response.getStatus()); + + inputStream.close(); + + EasyMock.verify(req); + + firehose.close(); + } + + private void setUpRequestExpectations(String producerId, String producerSequenceValue) + { + EasyMock.reset(req); + EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED)) + .andReturn(null) + .anyTimes(); + EasyMock.expect(req.getAttribute(AuthConfig.DRUID_AUTHENTICATION_RESULT)) + .andReturn(AllowAllAuthenticator.ALLOW_ALL_RESULT) + .anyTimes(); + req.setAttribute(AuthConfig.DRUID_AUTHORIZATION_CHECKED, true); + EasyMock.expectLastCall().anyTimes(); + + EasyMock.expect(req.getContentType()).andReturn("application/json"); + EasyMock.expect(req.getHeader("X-Firehose-Producer-Id")).andReturn(producerId); + + if (producerId != null) { + EasyMock.expect(req.getHeader("X-Firehose-Producer-Seq")).andReturn(producerSequenceValue); + } + + EasyMock.replay(req); + } }