diff --git a/src/test/java/org/elasticsearch/update/UpdateTests.java b/src/test/java/org/elasticsearch/update/UpdateTests.java index bfbc6eaf68b..423ea9bdb0d 100644 --- a/src/test/java/org/elasticsearch/update/UpdateTests.java +++ b/src/test/java/org/elasticsearch/update/UpdateTests.java @@ -20,11 +20,16 @@ package org.elasticsearch.update; import org.apache.lucene.util.LuceneTestCase.Slow; +import org.elasticsearch.ElasticsearchTimeoutException; +import org.elasticsearch.action.ActionListener; import org.elasticsearch.action.ActionRequestValidationException; +import org.elasticsearch.action.delete.DeleteResponse; import org.elasticsearch.action.get.GetResponse; import org.elasticsearch.action.update.UpdateRequest; +import org.elasticsearch.action.delete.DeleteRequest; import org.elasticsearch.action.update.UpdateRequestBuilder; import org.elasticsearch.action.update.UpdateResponse; +import org.elasticsearch.common.unit.TimeValue; import org.elasticsearch.common.xcontent.XContentFactory; import org.elasticsearch.common.xcontent.XContentHelper; import org.elasticsearch.index.VersionType; @@ -34,11 +39,11 @@ import org.elasticsearch.script.ScriptService; import org.elasticsearch.test.ElasticsearchIntegrationTest; import org.junit.Test; -import java.util.HashMap; -import java.util.List; -import java.util.Map; +import java.util.*; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; +import java.util.concurrent.Semaphore; +import java.util.concurrent.TimeUnit; import static org.elasticsearch.common.xcontent.XContentFactory.jsonBuilder; import static org.elasticsearch.test.hamcrest.ElasticsearchAssertions.assertThrows; @@ -518,4 +523,192 @@ public class UpdateTests extends ElasticsearchIntegrationTest { } } + + @Test + @Slow + public void stressUpdateDeleteConcurrency() throws Exception { + final boolean useBulkApi = randomBoolean(); + createIndex(); + ensureGreen(); + final int numberOfThreads = scaledRandomIntBetween(5,20); + final int numberOfIdsPerThread = scaledRandomIntBetween(3,10); + final int numberOfUpdatesPerId = scaledRandomIntBetween(100,1000); + final int retryOnConflict = randomIntBetween(0,1); + final CountDownLatch latch = new CountDownLatch(numberOfThreads); + final CountDownLatch startLatch = new CountDownLatch(1); + final List failures = new CopyOnWriteArrayList<>(); + + final class UpdateThread extends Thread { + final Map failedMap = new HashMap<>(); + final int numberOfIds; + final int updatesPerId; + final int maxUpdateRequests = numberOfIdsPerThread*numberOfUpdatesPerId; + final int maxDeleteRequests = numberOfIdsPerThread*numberOfUpdatesPerId; + private final Semaphore updateRequestsOutstanding = new Semaphore(maxUpdateRequests); + private final Semaphore deleteRequestsOutstanding = new Semaphore(maxDeleteRequests); + + public UpdateThread(int numberOfIds, int updatesPerId) { + this.numberOfIds = numberOfIds; + this.updatesPerId = updatesPerId; + } + + final class UpdateListener implements ActionListener { + int id; + + public UpdateListener(int id) { + this.id = id; + } + + @Override + public void onResponse(UpdateResponse updateResponse) { + updateRequestsOutstanding.release(1); + } + + @Override + public void onFailure(Throwable e) { + synchronized (failedMap) { + incrementMapValue(id, failedMap); + } + updateRequestsOutstanding.release(1); + } + + } + + final class DeleteListener implements ActionListener { + int id; + + public DeleteListener(int id) { + this.id = id; + } + + @Override + public void onResponse(DeleteResponse deleteResponse) { + deleteRequestsOutstanding.release(1); + } + + @Override + public void onFailure(Throwable e) { + synchronized (failedMap) { + incrementMapValue(id, failedMap); + } + deleteRequestsOutstanding.release(1); + } + } + + @Override + public void run(){ + try { + startLatch.await(); + for (int j = 0; j < numberOfIds; j++) { + for (int k = 0; k < numberOfUpdatesPerId; ++k) { + updateRequestsOutstanding.acquire(); + UpdateRequest ur = client().prepareUpdate("test", "type1", Integer.toString(j)) + .setScript("ctx._source.field += 1", ScriptService.ScriptType.INLINE) + .setRetryOnConflict(retryOnConflict) + .setUpsert(jsonBuilder().startObject().field("field", 1).endObject()) + .setListenerThreaded(false) + .request(); + client().update(ur, new UpdateListener(j) ); + + deleteRequestsOutstanding.acquire(); + DeleteRequest dr = client().prepareDelete("test", "type1", Integer.toString(j)) + .setListenerThreaded(false) + .setOperationThreaded(false).request(); + client().delete(dr, new DeleteListener(j)); + } + } + } catch (Throwable e) { + logger.error("Something went wrong", e); + failures.add(e); + } finally { + try { + waitForOutstandingRequests(TimeValue.timeValueSeconds(60), updateRequestsOutstanding, maxUpdateRequests, "Update"); + waitForOutstandingRequests(TimeValue.timeValueSeconds(60), deleteRequestsOutstanding, maxDeleteRequests, "Delete"); + } catch (ElasticsearchTimeoutException ete) { + failures.add(ete); + } + latch.countDown(); + } + } + + private void incrementMapValue(int j, Map map) { + if (!map.containsKey(j)) { + map.put(j, 0); + } + map.put(j, map.get(j) + 1); + } + + private void waitForOutstandingRequests(TimeValue timeOut, Semaphore requestsOutstanding, int maxRequests, String name) { + long start = System.currentTimeMillis(); + do { + long msRemaining = timeOut.getMillis() - (System.currentTimeMillis() - start); + logger.info("[{}] going to try and aquire [{}] in [{}]ms [{}] available to aquire right now",name, maxRequests,msRemaining, requestsOutstanding.availablePermits()); + try { + requestsOutstanding.tryAcquire(maxRequests, msRemaining, TimeUnit.MILLISECONDS ); + return; + } catch (InterruptedException ie) { + //Just keep swimming + } + } while ((System.currentTimeMillis() - start) < timeOut.getMillis()); + throw new ElasticsearchTimeoutException("Requests were still outstanding after the timeout [" + timeOut + "] for type [" + name + "]" ); + } + } + final List threads = new ArrayList<>(); + + for (int i = 0; i < numberOfThreads; i++) { + UpdateThread ut = new UpdateThread(numberOfIdsPerThread, numberOfUpdatesPerId); + ut.start(); + threads.add(ut); + } + + startLatch.countDown(); + latch.await(); + + for (UpdateThread ut : threads){ + ut.join(); //Threads should have finished because of the latch.await + } + + //If are no errors every request recieved a response otherwise the test would have timedout + //aquiring the request outstanding semaphores. + for (Throwable throwable : failures) { + logger.info("Captured failure on concurrent update:", throwable); + } + + assertThat(failures.size(), equalTo(0)); + + //Upsert all the ids one last time to make sure they are available at get time + //This means that we add 1 to the expected versions and attempts + //All the previous operations should be complete or failed at this point + for (int i = 0; i < numberOfIdsPerThread; ++i) { + UpdateResponse ur = client().prepareUpdate("test", "type1", Integer.toString(i)) + .setScript("ctx._source.field += 1", ScriptService.ScriptType.INLINE) + .setRetryOnConflict(Integer.MAX_VALUE) + .setUpsert(jsonBuilder().startObject().field("field", 1).endObject()) + .execute().actionGet(); + } + + refresh(); + + for (int i = 0; i < numberOfIdsPerThread; ++i) { + int totalFailures = 0; + GetResponse response = client().prepareGet("test", "type1", Integer.toString(i)).execute().actionGet(); + if (response.isExists()) { + assertThat(response.getId(), equalTo(Integer.toString(i))); + int expectedVersion = (numberOfThreads * numberOfUpdatesPerId * 2) + 1; + for (UpdateThread ut : threads) { + if (ut.failedMap.containsKey(i)) { + totalFailures += ut.failedMap.get(i); + } + } + expectedVersion -= totalFailures; + logger.error("Actual version [{}] Expected version [{}] Total failures [{}]", response.getVersion(), expectedVersion, totalFailures); + assertThat(response.getVersion(), equalTo((long) expectedVersion)); + assertThat(response.getVersion() + totalFailures, + equalTo( + (long)((numberOfUpdatesPerId * numberOfThreads * 2) + 1) + )); + } + } + } + }