diff --git a/core/src/main/java/org/elasticsearch/action/support/replication/ReplicationResponse.java b/core/src/main/java/org/elasticsearch/action/support/replication/ReplicationResponse.java index 98556494191..e2cd8e3a817 100644 --- a/core/src/main/java/org/elasticsearch/action/support/replication/ReplicationResponse.java +++ b/core/src/main/java/org/elasticsearch/action/support/replication/ReplicationResponse.java @@ -34,6 +34,7 @@ import org.elasticsearch.rest.RestStatus; import java.io.IOException; import java.util.Arrays; +import java.util.Objects; /** * Base class for write action responses. @@ -120,6 +121,25 @@ public class ReplicationResponse extends ActionResponse { return status; } + @Override + public boolean equals(Object that) { + if (this == that) { + return true; + } + if (that == null || getClass() != that.getClass()) { + return false; + } + ShardInfo other = (ShardInfo) that; + return Objects.equals(total, other.total) && + Objects.equals(successful, other.successful) && + Arrays.equals(failures, other.failures); + } + + @Override + public int hashCode() { + return Objects.hash(total, successful, failures); + } + @Override public void readFrom(StreamInput in) throws IOException { total = in.readVInt(); @@ -251,6 +271,27 @@ public class ReplicationResponse extends ActionResponse { return primary; } + @Override + public boolean equals(Object that) { + if (this == that) { + return true; + } + if (that == null || getClass() != that.getClass()) { + return false; + } + Failure failure = (Failure) that; + return Objects.equals(primary, failure.primary) && + Objects.equals(shardId, failure.shardId) && + Objects.equals(nodeId, failure.nodeId) && + Objects.equals(cause, failure.cause) && + Objects.equals(status, failure.status); + } + + @Override + public int hashCode() { + return Objects.hash(shardId, nodeId, cause, status, primary); + } + @Override public void readFrom(StreamInput in) throws IOException { shardId = ShardId.readShardId(in); diff --git a/core/src/test/java/org/elasticsearch/action/support/replication/ReplicationResponseTests.java b/core/src/test/java/org/elasticsearch/action/support/replication/ReplicationResponseTests.java index 3740f8dd5f7..0c805d6d32c 100644 --- a/core/src/test/java/org/elasticsearch/action/support/replication/ReplicationResponseTests.java +++ b/core/src/test/java/org/elasticsearch/action/support/replication/ReplicationResponseTests.java @@ -19,10 +19,21 @@ package org.elasticsearch.action.support.replication; +import org.elasticsearch.ElasticsearchException; +import org.elasticsearch.action.RoutingMissingException; +import org.elasticsearch.index.Index; +import org.elasticsearch.index.shard.IndexShardRecoveringException; +import org.elasticsearch.index.shard.ShardId; +import org.elasticsearch.rest.RestStatus; import org.elasticsearch.test.ESTestCase; +import org.elasticsearch.test.EqualsHashCodeTestUtils; +import java.util.ArrayList; +import java.util.List; import java.util.Locale; +import java.util.function.Supplier; +import static org.elasticsearch.test.EqualsHashCodeTestUtils.checkEqualsAndHashCode; import static org.hamcrest.CoreMatchers.equalTo; public class ReplicationResponseTests extends ESTestCase { @@ -36,4 +47,97 @@ public class ReplicationResponseTests extends ESTestCase { equalTo(String.format(Locale.ROOT, "ShardInfo{total=5, successful=%d, failures=[]}", successful))); } + public void testShardInfoEqualsAndHashcode() { + EqualsHashCodeTestUtils.CopyFunction copy = shardInfo -> + new ReplicationResponse.ShardInfo(shardInfo.getTotal(), shardInfo.getSuccessful(), shardInfo.getFailures()); + + EqualsHashCodeTestUtils.MutateFunction mutate = shardInfo -> { + List> mutations = new ArrayList<>(); + mutations.add(() -> + new ReplicationResponse.ShardInfo(shardInfo.getTotal() + 1, shardInfo.getSuccessful(), shardInfo.getFailures())); + mutations.add(() -> + new ReplicationResponse.ShardInfo(shardInfo.getTotal(), shardInfo.getSuccessful() + 1, shardInfo.getFailures())); + mutations.add(() -> { + int nbFailures = randomIntBetween(1, 5); + return new ReplicationResponse.ShardInfo(shardInfo.getTotal(), shardInfo.getSuccessful(), randomFailures(nbFailures)); + }); + return randomFrom(mutations).get(); + }; + + checkEqualsAndHashCode(randomShardInfo(), copy, mutate); + } + + public void testFailureEqualsAndHashcode() { + EqualsHashCodeTestUtils.CopyFunction copy = failure -> { + Index index = failure.fullShardId().getIndex(); + ShardId shardId = new ShardId(index.getName(), index.getUUID(), failure.shardId()); + Exception cause = (Exception) failure.getCause(); + return new ReplicationResponse.ShardInfo.Failure(shardId, failure.nodeId(), cause, failure.status(), failure.primary()); + }; + + EqualsHashCodeTestUtils.MutateFunction mutate = failure -> { + List> mutations = new ArrayList<>(); + + final Index index = failure.fullShardId().getIndex(); + final ShardId randomIndex = new ShardId(randomUnicodeOfCodepointLength(5), index.getUUID(), failure.shardId()); + mutations.add(() -> new ReplicationResponse.ShardInfo.Failure(randomIndex, failure.nodeId(), (Exception) failure.getCause(), + failure.status(), failure.primary())); + + final ShardId randomUUID = new ShardId(index.getName(), randomUnicodeOfCodepointLength(5), failure.shardId()); + mutations.add(() -> new ReplicationResponse.ShardInfo.Failure(randomUUID, failure.nodeId(), (Exception) failure.getCause(), + failure.status(), failure.primary())); + + final ShardId randomShardId = new ShardId(index.getName(),index.getUUID(), failure.shardId() + randomIntBetween(1, 3)); + mutations.add(() -> new ReplicationResponse.ShardInfo.Failure(randomShardId, failure.nodeId(), (Exception) failure.getCause(), + failure.status(), failure.primary())); + + final String randomNode = randomUnicodeOfLength(3); + mutations.add(() -> new ReplicationResponse.ShardInfo.Failure(failure.fullShardId(), randomNode, (Exception) failure.getCause(), + failure.status(), failure.primary())); + + final Exception randomException = randomFrom(new IllegalStateException("a"), new IllegalArgumentException("b")); + mutations.add(() -> new ReplicationResponse.ShardInfo.Failure(failure.fullShardId(), failure.nodeId(), randomException, + failure.status(), failure.primary())); + + final RestStatus randomStatus = randomFrom(RestStatus.values()); + mutations.add(() -> new ReplicationResponse.ShardInfo.Failure(failure.fullShardId(), failure.nodeId(), + (Exception) failure.getCause(), randomStatus, failure.primary())); + + final boolean randomPrimary = !failure.primary(); + mutations.add(() -> new ReplicationResponse.ShardInfo.Failure(failure.fullShardId(), failure.nodeId(), + (Exception) failure.getCause(), failure.status(), randomPrimary)); + + return randomFrom(mutations).get(); + }; + + checkEqualsAndHashCode(randomFailure(), copy, mutate); + } + + private static ReplicationResponse.ShardInfo randomShardInfo() { + int total = randomIntBetween(1, 10); + int successful = randomIntBetween(0, total); + return new ReplicationResponse.ShardInfo(total, successful, randomFailures(Math.max(0, (total - successful)))); + } + + private static ReplicationResponse.ShardInfo.Failure[] randomFailures(int nbFailures) { + List randomFailures = new ArrayList<>(nbFailures); + for (int i = 0; i < nbFailures; i++) { + randomFailures.add(randomFailure()); + } + return randomFailures.toArray(new ReplicationResponse.ShardInfo.Failure[nbFailures]); + } + + private static ReplicationResponse.ShardInfo.Failure randomFailure() { + return new ReplicationResponse.ShardInfo.Failure( + new ShardId(randomAsciiOfLength(5), randomAsciiOfLength(5), randomIntBetween(0, 5)), + randomAsciiOfLength(3), + randomFrom( + new IndexShardRecoveringException(new ShardId("_test", "_0", 5)), + new ElasticsearchException(new IllegalArgumentException("argument is wrong")), + new RoutingMissingException("_test", "_type", "_id") + ), + randomFrom(RestStatus.values()), + randomBoolean() + ); + } }