diff --git a/libs/core/src/test/java/org/opensearch/common/util/concurrent/RefCountedTests.java b/libs/core/src/test/java/org/opensearch/common/util/concurrent/RefCountedTests.java index 47cf49b3e32..f784ef9d164 100644 --- a/libs/core/src/test/java/org/opensearch/common/util/concurrent/RefCountedTests.java +++ b/libs/core/src/test/java/org/opensearch/common/util/concurrent/RefCountedTests.java @@ -31,13 +31,13 @@ package org.opensearch.common.util.concurrent; +import org.opensearch.common.concurrent.OneWayGate; import org.opensearch.test.OpenSearchTestCase; import org.hamcrest.Matchers; import java.io.IOException; import java.util.concurrent.CopyOnWriteArrayList; import java.util.concurrent.CountDownLatch; -import java.util.concurrent.atomic.AtomicBoolean; import static org.hamcrest.Matchers.equalTo; import static org.hamcrest.Matchers.is; @@ -138,7 +138,7 @@ public class RefCountedTests extends OpenSearchTestCase { private final class MyRefCounted extends AbstractRefCounted { - private final AtomicBoolean closed = new AtomicBoolean(false); + private final OneWayGate gate = new OneWayGate(); MyRefCounted() { super("test"); @@ -146,11 +146,11 @@ public class RefCountedTests extends OpenSearchTestCase { @Override protected void closeInternal() { - this.closed.set(true); + gate.close(); } public void ensureOpen() { - if (closed.get()) { + if (gate.isClosed()) { assert this.refCount() == 0; throw new IllegalStateException("closed"); } diff --git a/plugins/discovery-ec2/src/main/java/org/opensearch/discovery/ec2/AmazonEc2Reference.java b/plugins/discovery-ec2/src/main/java/org/opensearch/discovery/ec2/AmazonEc2Reference.java index eac46356d91..2686c376213 100644 --- a/plugins/discovery-ec2/src/main/java/org/opensearch/discovery/ec2/AmazonEc2Reference.java +++ b/plugins/discovery-ec2/src/main/java/org/opensearch/discovery/ec2/AmazonEc2Reference.java @@ -33,42 +33,15 @@ package org.opensearch.discovery.ec2; import com.amazonaws.services.ec2.AmazonEC2; - -import org.opensearch.common.lease.Releasable; -import org.opensearch.common.util.concurrent.AbstractRefCounted; +import org.opensearch.common.concurrent.RefCountedReleasable; /** * Handles the shutdown of the wrapped {@link AmazonEC2} using reference * counting. */ -public class AmazonEc2Reference extends AbstractRefCounted implements Releasable { - - private final AmazonEC2 client; +public class AmazonEc2Reference extends RefCountedReleasable { AmazonEc2Reference(AmazonEC2 client) { - super("AWS_EC2_CLIENT"); - this.client = client; + super("AWS_EC2_CLIENT", client, client::shutdown); } - - /** - * Call when the client is not needed anymore. - */ - @Override - public void close() { - decRef(); - } - - /** - * Returns the underlying `AmazonEC2` client. All method calls are permitted BUT - * NOT shutdown. Shutdown is called when reference count reaches 0. - */ - public AmazonEC2 client() { - return client; - } - - @Override - protected void closeInternal() { - client.shutdown(); - } - } diff --git a/plugins/discovery-ec2/src/main/java/org/opensearch/discovery/ec2/AwsEc2SeedHostsProvider.java b/plugins/discovery-ec2/src/main/java/org/opensearch/discovery/ec2/AwsEc2SeedHostsProvider.java index 4b36a60bb27..f26ecfab501 100644 --- a/plugins/discovery-ec2/src/main/java/org/opensearch/discovery/ec2/AwsEc2SeedHostsProvider.java +++ b/plugins/discovery-ec2/src/main/java/org/opensearch/discovery/ec2/AwsEc2SeedHostsProvider.java @@ -129,7 +129,7 @@ class AwsEc2SeedHostsProvider implements SeedHostsProvider { // NOTE: we don't filter by security group during the describe instances request for two reasons: // 1. differences in VPCs require different parameters during query (ID vs Name) // 2. We want to use two different strategies: (all security groups vs. any security groups) - descInstances = SocketAccess.doPrivileged(() -> clientReference.client().describeInstances(buildDescribeInstancesRequest())); + descInstances = SocketAccess.doPrivileged(() -> clientReference.get().describeInstances(buildDescribeInstancesRequest())); } catch (final AmazonClientException e) { logger.info("Exception while retrieving instance list from AWS API: {}", e.getMessage()); logger.debug("Full exception:", e); diff --git a/plugins/discovery-ec2/src/test/java/org/opensearch/discovery/ec2/Ec2DiscoveryPluginTests.java b/plugins/discovery-ec2/src/test/java/org/opensearch/discovery/ec2/Ec2DiscoveryPluginTests.java index be6261583bd..cb19c0d4255 100644 --- a/plugins/discovery-ec2/src/test/java/org/opensearch/discovery/ec2/Ec2DiscoveryPluginTests.java +++ b/plugins/discovery-ec2/src/test/java/org/opensearch/discovery/ec2/Ec2DiscoveryPluginTests.java @@ -103,7 +103,7 @@ public class Ec2DiscoveryPluginTests extends OpenSearchTestCase { public void testDefaultEndpoint() throws IOException { try (Ec2DiscoveryPluginMock plugin = new Ec2DiscoveryPluginMock(Settings.EMPTY)) { - final String endpoint = ((AmazonEC2Mock) plugin.ec2Service.client().client()).endpoint; + final String endpoint = ((AmazonEC2Mock) plugin.ec2Service.client().get()).endpoint; assertThat(endpoint, is("")); } } @@ -111,7 +111,7 @@ public class Ec2DiscoveryPluginTests extends OpenSearchTestCase { public void testSpecificEndpoint() throws IOException { final Settings settings = Settings.builder().put(Ec2ClientSettings.ENDPOINT_SETTING.getKey(), "ec2.endpoint").build(); try (Ec2DiscoveryPluginMock plugin = new Ec2DiscoveryPluginMock(settings)) { - final String endpoint = ((AmazonEC2Mock) plugin.ec2Service.client().client()).endpoint; + final String endpoint = ((AmazonEC2Mock) plugin.ec2Service.client().get()).endpoint; assertThat(endpoint, is("ec2.endpoint")); } } @@ -150,7 +150,7 @@ public class Ec2DiscoveryPluginTests extends OpenSearchTestCase { try (Ec2DiscoveryPluginMock plugin = new Ec2DiscoveryPluginMock(settings1)) { try (AmazonEc2Reference clientReference = plugin.ec2Service.client()) { { - final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.client()).credentials.getCredentials(); + final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.get()).credentials.getCredentials(); assertThat(credentials.getAWSAccessKeyId(), is("ec2_access_1")); assertThat(credentials.getAWSSecretKey(), is("ec2_secret_1")); if (mockSecure1HasSessionToken) { @@ -159,32 +159,32 @@ public class Ec2DiscoveryPluginTests extends OpenSearchTestCase { } else { assertThat(credentials, instanceOf(BasicAWSCredentials.class)); } - assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyUsername(), is("proxy_username_1")); - assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPassword(), is("proxy_password_1")); - assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyHost(), is("proxy_host_1")); - assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPort(), is(881)); - assertThat(((AmazonEC2Mock) clientReference.client()).endpoint, is("ec2_endpoint_1")); + assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyUsername(), is("proxy_username_1")); + assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPassword(), is("proxy_password_1")); + assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyHost(), is("proxy_host_1")); + assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPort(), is(881)); + assertThat(((AmazonEC2Mock) clientReference.get()).endpoint, is("ec2_endpoint_1")); } // reload secure settings2 plugin.reload(settings2); // client is not released, it is still using the old settings { - final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.client()).credentials.getCredentials(); + final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.get()).credentials.getCredentials(); if (mockSecure1HasSessionToken) { assertThat(credentials, instanceOf(BasicSessionCredentials.class)); assertThat(((BasicSessionCredentials) credentials).getSessionToken(), is("ec2_session_token_1")); } else { assertThat(credentials, instanceOf(BasicAWSCredentials.class)); } - assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyUsername(), is("proxy_username_1")); - assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPassword(), is("proxy_password_1")); - assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyHost(), is("proxy_host_1")); - assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPort(), is(881)); - assertThat(((AmazonEC2Mock) clientReference.client()).endpoint, is("ec2_endpoint_1")); + assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyUsername(), is("proxy_username_1")); + assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPassword(), is("proxy_password_1")); + assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyHost(), is("proxy_host_1")); + assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPort(), is(881)); + assertThat(((AmazonEC2Mock) clientReference.get()).endpoint, is("ec2_endpoint_1")); } } try (AmazonEc2Reference clientReference = plugin.ec2Service.client()) { - final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.client()).credentials.getCredentials(); + final AWSCredentials credentials = ((AmazonEC2Mock) clientReference.get()).credentials.getCredentials(); assertThat(credentials.getAWSAccessKeyId(), is("ec2_access_2")); assertThat(credentials.getAWSSecretKey(), is("ec2_secret_2")); if (mockSecure2HasSessionToken) { @@ -193,11 +193,11 @@ public class Ec2DiscoveryPluginTests extends OpenSearchTestCase { } else { assertThat(credentials, instanceOf(BasicAWSCredentials.class)); } - assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyUsername(), is("proxy_username_2")); - assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPassword(), is("proxy_password_2")); - assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyHost(), is("proxy_host_2")); - assertThat(((AmazonEC2Mock) clientReference.client()).configuration.getProxyPort(), is(882)); - assertThat(((AmazonEC2Mock) clientReference.client()).endpoint, is("ec2_endpoint_2")); + assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyUsername(), is("proxy_username_2")); + assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPassword(), is("proxy_password_2")); + assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyHost(), is("proxy_host_2")); + assertThat(((AmazonEC2Mock) clientReference.get()).configuration.getProxyPort(), is(882)); + assertThat(((AmazonEC2Mock) clientReference.get()).endpoint, is("ec2_endpoint_2")); } } } diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonS3Reference.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonS3Reference.java index 239918206f3..62e415705a0 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonS3Reference.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/AmazonS3Reference.java @@ -32,45 +32,17 @@ package org.opensearch.repositories.s3; -import org.opensearch.common.util.concurrent.AbstractRefCounted; - import com.amazonaws.services.s3.AmazonS3; import com.amazonaws.services.s3.AmazonS3Client; - -import org.opensearch.common.lease.Releasable; +import org.opensearch.common.concurrent.RefCountedReleasable; /** * Handles the shutdown of the wrapped {@link AmazonS3Client} using reference * counting. */ -public class AmazonS3Reference extends AbstractRefCounted implements Releasable { - - private final AmazonS3 client; +public class AmazonS3Reference extends RefCountedReleasable { AmazonS3Reference(AmazonS3 client) { - super("AWS_S3_CLIENT"); - this.client = client; + super("AWS_S3_CLIENT", client, client::shutdown); } - - /** - * Call when the client is not needed anymore. - */ - @Override - public void close() { - decRef(); - } - - /** - * Returns the underlying `AmazonS3` client. All method calls are permitted BUT - * NOT shutdown. Shutdown is called when reference count reaches 0. - */ - public AmazonS3 client() { - return client; - } - - @Override - protected void closeInternal() { - client.shutdown(); - } - } diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java index 5a9c03c0b2a..678be7c6f13 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3BlobContainer.java @@ -101,7 +101,7 @@ class S3BlobContainer extends AbstractBlobContainer { @Override public boolean blobExists(String blobName) { try (AmazonS3Reference clientReference = blobStore.clientReference()) { - return SocketAccess.doPrivileged(() -> clientReference.client().doesObjectExist(blobStore.bucket(), buildKey(blobName))); + return SocketAccess.doPrivileged(() -> clientReference.get().doesObjectExist(blobStore.bucket(), buildKey(blobName))); } catch (final Exception e) { throw new BlobStoreException("Failed to check if blob [" + blobName + "] exists", e); } @@ -169,13 +169,13 @@ class S3BlobContainer extends AbstractBlobContainer { ObjectListing list; if (prevListing != null) { final ObjectListing finalPrevListing = prevListing; - list = SocketAccess.doPrivileged(() -> clientReference.client().listNextBatchOfObjects(finalPrevListing)); + list = SocketAccess.doPrivileged(() -> clientReference.get().listNextBatchOfObjects(finalPrevListing)); } else { final ListObjectsRequest listObjectsRequest = new ListObjectsRequest(); listObjectsRequest.setBucketName(blobStore.bucket()); listObjectsRequest.setPrefix(keyPath); listObjectsRequest.setRequestMetricCollector(blobStore.listMetricCollector); - list = SocketAccess.doPrivileged(() -> clientReference.client().listObjects(listObjectsRequest)); + list = SocketAccess.doPrivileged(() -> clientReference.get().listObjects(listObjectsRequest)); } final List blobsToDelete = new ArrayList<>(); list.getObjectSummaries().forEach(s3ObjectSummary -> { @@ -236,7 +236,7 @@ class S3BlobContainer extends AbstractBlobContainer { .map(DeleteObjectsRequest.KeyVersion::getKey) .collect(Collectors.toList()); try { - clientReference.client().deleteObjects(deleteRequest); + clientReference.get().deleteObjects(deleteRequest); outstanding.removeAll(keysInRequest); } catch (MultiObjectDeleteException e) { // We are sending quiet mode requests so we can't use the deleted keys entry on the exception and instead @@ -324,9 +324,9 @@ class S3BlobContainer extends AbstractBlobContainer { ObjectListing list; if (prevListing != null) { final ObjectListing finalPrevListing = prevListing; - list = SocketAccess.doPrivileged(() -> clientReference.client().listNextBatchOfObjects(finalPrevListing)); + list = SocketAccess.doPrivileged(() -> clientReference.get().listNextBatchOfObjects(finalPrevListing)); } else { - list = SocketAccess.doPrivileged(() -> clientReference.client().listObjects(listObjectsRequest)); + list = SocketAccess.doPrivileged(() -> clientReference.get().listObjects(listObjectsRequest)); } results.add(list); if (list.isTruncated()) { @@ -374,7 +374,7 @@ class S3BlobContainer extends AbstractBlobContainer { putRequest.setRequestMetricCollector(blobStore.putMetricCollector); try (AmazonS3Reference clientReference = blobStore.clientReference()) { - SocketAccess.doPrivilegedVoid(() -> { clientReference.client().putObject(putRequest); }); + SocketAccess.doPrivilegedVoid(() -> { clientReference.get().putObject(putRequest); }); } catch (final AmazonClientException e) { throw new IOException("Unable to upload object [" + blobName + "] using a single upload", e); } @@ -413,7 +413,7 @@ class S3BlobContainer extends AbstractBlobContainer { } try (AmazonS3Reference clientReference = blobStore.clientReference()) { - uploadId.set(SocketAccess.doPrivileged(() -> clientReference.client().initiateMultipartUpload(initRequest).getUploadId())); + uploadId.set(SocketAccess.doPrivileged(() -> clientReference.get().initiateMultipartUpload(initRequest).getUploadId())); if (Strings.isEmpty(uploadId.get())) { throw new IOException("Failed to initialize multipart upload " + blobName); } @@ -439,7 +439,7 @@ class S3BlobContainer extends AbstractBlobContainer { } bytesCount += uploadRequest.getPartSize(); - final UploadPartResult uploadResponse = SocketAccess.doPrivileged(() -> clientReference.client().uploadPart(uploadRequest)); + final UploadPartResult uploadResponse = SocketAccess.doPrivileged(() -> clientReference.get().uploadPart(uploadRequest)); parts.add(uploadResponse.getPartETag()); } @@ -456,7 +456,7 @@ class S3BlobContainer extends AbstractBlobContainer { parts ); complRequest.setRequestMetricCollector(blobStore.multiPartUploadMetricCollector); - SocketAccess.doPrivilegedVoid(() -> clientReference.client().completeMultipartUpload(complRequest)); + SocketAccess.doPrivilegedVoid(() -> clientReference.get().completeMultipartUpload(complRequest)); success = true; } catch (final AmazonClientException e) { @@ -465,7 +465,7 @@ class S3BlobContainer extends AbstractBlobContainer { if ((success == false) && Strings.hasLength(uploadId.get())) { final AbortMultipartUploadRequest abortRequest = new AbortMultipartUploadRequest(bucketName, blobName, uploadId.get()); try (AmazonS3Reference clientReference = blobStore.clientReference()) { - SocketAccess.doPrivilegedVoid(() -> clientReference.client().abortMultipartUpload(abortRequest)); + SocketAccess.doPrivilegedVoid(() -> clientReference.get().abortMultipartUpload(abortRequest)); } } } diff --git a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3RetryingInputStream.java b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3RetryingInputStream.java index 82c3367679c..388f5b8d74a 100644 --- a/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3RetryingInputStream.java +++ b/plugins/repository-s3/src/main/java/org/opensearch/repositories/s3/S3RetryingInputStream.java @@ -110,7 +110,7 @@ class S3RetryingInputStream extends InputStream { + end; getObjectRequest.setRange(Math.addExact(start, currentOffset), end); } - final S3Object s3Object = SocketAccess.doPrivileged(() -> clientReference.client().getObject(getObjectRequest)); + final S3Object s3Object = SocketAccess.doPrivileged(() -> clientReference.get().getObject(getObjectRequest)); this.currentStreamLastOffset = Math.addExact(Math.addExact(start, currentOffset), getStreamLength(s3Object)); this.currentStream = s3Object.getObjectContent(); } catch (final AmazonClientException e) { diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/RepositoryCredentialsTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/RepositoryCredentialsTests.java index 645fe5cf1d1..9c359d67db8 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/RepositoryCredentialsTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/RepositoryCredentialsTests.java @@ -123,7 +123,7 @@ public class RepositoryCredentialsTests extends OpenSearchSingleNodeTestCase { assertThat(repositories.repository(repositoryName), instanceOf(S3Repository.class)); final S3Repository repository = (S3Repository) repositories.repository(repositoryName); - final AmazonS3 client = repository.createBlobStore().clientReference().client(); + final AmazonS3 client = repository.createBlobStore().clientReference().get(); assertThat(client, instanceOf(ProxyS3RepositoryPlugin.ClientAndCredentials.class)); final AWSCredentials credentials = ((ProxyS3RepositoryPlugin.ClientAndCredentials) client).credentials.getCredentials(); @@ -162,7 +162,7 @@ public class RepositoryCredentialsTests extends OpenSearchSingleNodeTestCase { final S3Repository repository = (S3Repository) repositories.repository(repositoryName); try (AmazonS3Reference clientReference = ((S3BlobStore) repository.blobStore()).clientReference()) { - final AmazonS3 client = clientReference.client(); + final AmazonS3 client = clientReference.get(); assertThat(client, instanceOf(ProxyS3RepositoryPlugin.ClientAndCredentials.class)); final AWSCredentials credentials = ((ProxyS3RepositoryPlugin.ClientAndCredentials) client).credentials.getCredentials(); @@ -202,7 +202,7 @@ public class RepositoryCredentialsTests extends OpenSearchSingleNodeTestCase { // check credentials have been updated try (AmazonS3Reference clientReference = ((S3BlobStore) repository.blobStore()).clientReference()) { - final AmazonS3 client = clientReference.client(); + final AmazonS3 client = clientReference.get(); assertThat(client, instanceOf(ProxyS3RepositoryPlugin.ClientAndCredentials.class)); final AWSCredentials newCredentials = ((ProxyS3RepositoryPlugin.ClientAndCredentials) client).credentials.getCredentials(); diff --git a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3RetryingInputStreamTests.java b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3RetryingInputStreamTests.java index c7d1cb43bd2..0f40a7b3392 100644 --- a/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3RetryingInputStreamTests.java +++ b/plugins/repository-s3/src/test/java/org/opensearch/repositories/s3/S3RetryingInputStreamTests.java @@ -109,7 +109,7 @@ public class S3RetryingInputStreamTests extends OpenSearchTestCase { final AmazonS3 client = mock(AmazonS3.class); when(client.getObject(any(GetObjectRequest.class))).thenReturn(s3Object); final AmazonS3Reference clientReference = mock(AmazonS3Reference.class); - when(clientReference.client()).thenReturn(client); + when(clientReference.get()).thenReturn(client); final S3BlobStore blobStore = mock(S3BlobStore.class); when(blobStore.clientReference()).thenReturn(clientReference); diff --git a/server/src/internalClusterTest/java/org/opensearch/action/admin/indices/forcemerge/ForceMergeIT.java b/server/src/internalClusterTest/java/org/opensearch/action/admin/indices/forcemerge/ForceMergeIT.java index a31976c969a..5c5bb6c6224 100644 --- a/server/src/internalClusterTest/java/org/opensearch/action/admin/indices/forcemerge/ForceMergeIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/action/admin/indices/forcemerge/ForceMergeIT.java @@ -100,7 +100,7 @@ public class ForceMergeIT extends OpenSearchIntegTestCase { private static String getForceMergeUUID(IndexShard indexShard) throws IOException { try (Engine.IndexCommitRef indexCommitRef = indexShard.acquireLastIndexCommit(true)) { - return indexCommitRef.getIndexCommit().getUserData().get(Engine.FORCE_MERGE_UUID_KEY); + return indexCommitRef.get().getUserData().get(Engine.FORCE_MERGE_UUID_KEY); } } } diff --git a/server/src/internalClusterTest/java/org/opensearch/indices/recovery/IndexRecoveryIT.java b/server/src/internalClusterTest/java/org/opensearch/indices/recovery/IndexRecoveryIT.java index 042b98c3368..17e457bba64 100644 --- a/server/src/internalClusterTest/java/org/opensearch/indices/recovery/IndexRecoveryIT.java +++ b/server/src/internalClusterTest/java/org/opensearch/indices/recovery/IndexRecoveryIT.java @@ -1601,7 +1601,7 @@ public class IndexRecoveryIT extends OpenSearchIntegTestCase { final long localCheckpointOfSafeCommit; try (Engine.IndexCommitRef safeCommitRef = shard.acquireSafeIndexCommit()) { localCheckpointOfSafeCommit = SequenceNumbers.loadSeqNoInfoFromLuceneCommit( - safeCommitRef.getIndexCommit().getUserData().entrySet() + safeCommitRef.get().getUserData().entrySet() ).localCheckpoint; } final long maxSeqNo = shard.seqNoStats().getMaxSeqNo(); diff --git a/server/src/main/java/org/opensearch/common/bytes/ReleasableBytesReference.java b/server/src/main/java/org/opensearch/common/bytes/ReleasableBytesReference.java index e9466b47c3d..9ed47ef6cbf 100644 --- a/server/src/main/java/org/opensearch/common/bytes/ReleasableBytesReference.java +++ b/server/src/main/java/org/opensearch/common/bytes/ReleasableBytesReference.java @@ -34,9 +34,9 @@ package org.opensearch.common.bytes; import org.apache.lucene.util.BytesRef; import org.apache.lucene.util.BytesRefIterator; +import org.opensearch.common.concurrent.RefCountedReleasable; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.lease.Releasable; -import org.opensearch.common.util.concurrent.AbstractRefCounted; import org.opensearch.common.xcontent.XContentBuilder; import java.io.IOException; @@ -50,14 +50,14 @@ public final class ReleasableBytesReference implements Releasable, BytesReferenc public static final Releasable NO_OP = () -> {}; private final BytesReference delegate; - private final AbstractRefCounted refCounted; + private final RefCountedReleasable refCounted; public ReleasableBytesReference(BytesReference delegate, Releasable releasable) { this.delegate = delegate; - this.refCounted = new RefCountedReleasable(releasable); + this.refCounted = new RefCountedReleasable<>("bytes-reference", releasable, releasable::close); } - private ReleasableBytesReference(BytesReference delegate, AbstractRefCounted refCounted) { + private ReleasableBytesReference(BytesReference delegate, RefCountedReleasable refCounted) { this.delegate = delegate; this.refCounted = refCounted; refCounted.incRef(); @@ -82,7 +82,7 @@ public final class ReleasableBytesReference implements Releasable, BytesReferenc @Override public void close() { - refCounted.decRef(); + refCounted.close(); } @Override @@ -164,19 +164,4 @@ public final class ReleasableBytesReference implements Releasable, BytesReferenc public int hashCode() { return delegate.hashCode(); } - - private static final class RefCountedReleasable extends AbstractRefCounted { - - private final Releasable releasable; - - RefCountedReleasable(Releasable releasable) { - super("bytes-reference"); - this.releasable = releasable; - } - - @Override - protected void closeInternal() { - releasable.close(); - } - } } diff --git a/server/src/main/java/org/opensearch/common/concurrent/GatedAutoCloseable.java b/server/src/main/java/org/opensearch/common/concurrent/GatedAutoCloseable.java new file mode 100644 index 00000000000..cb819c0320e --- /dev/null +++ b/server/src/main/java/org/opensearch/common/concurrent/GatedAutoCloseable.java @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.concurrent; + +/** + * Decorator class that wraps an object reference with a {@link Runnable} that is + * invoked when {@link #close()} is called. The internal {@link OneWayGate} instance ensures + * that this is invoked only once. See also {@link GatedCloseable} + */ +public class GatedAutoCloseable implements AutoCloseable { + + private final T ref; + private final Runnable onClose; + private final OneWayGate gate; + + public GatedAutoCloseable(T ref, Runnable onClose) { + this.ref = ref; + this.onClose = onClose; + gate = new OneWayGate(); + } + + public T get() { + return ref; + } + + @Override + public void close() { + if (gate.close()) { + onClose.run(); + } + } +} diff --git a/server/src/main/java/org/opensearch/common/concurrent/GatedCloseable.java b/server/src/main/java/org/opensearch/common/concurrent/GatedCloseable.java new file mode 100644 index 00000000000..d98e4cca8d5 --- /dev/null +++ b/server/src/main/java/org/opensearch/common/concurrent/GatedCloseable.java @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.concurrent; + +import org.opensearch.common.CheckedRunnable; + +import java.io.Closeable; +import java.io.IOException; + +/** + * Decorator class that wraps an object reference with a {@link CheckedRunnable} that is + * invoked when {@link #close()} is called. The internal {@link OneWayGate} instance ensures + * that this is invoked only once. See also {@link GatedAutoCloseable} + */ +public class GatedCloseable implements Closeable { + + private final T ref; + private final CheckedRunnable onClose; + private final OneWayGate gate; + + public GatedCloseable(T ref, CheckedRunnable onClose) { + this.ref = ref; + this.onClose = onClose; + gate = new OneWayGate(); + } + + public T get() { + return ref; + } + + @Override + public void close() throws IOException { + if (gate.close()) { + onClose.run(); + } + } +} diff --git a/server/src/main/java/org/opensearch/common/concurrent/OneWayGate.java b/server/src/main/java/org/opensearch/common/concurrent/OneWayGate.java new file mode 100644 index 00000000000..76625094f3c --- /dev/null +++ b/server/src/main/java/org/opensearch/common/concurrent/OneWayGate.java @@ -0,0 +1,43 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.concurrent; + +import java.util.concurrent.atomic.AtomicBoolean; + +/** + * Encapsulates logic for a one-way gate. Guarantees idempotency via the {@link AtomicBoolean} instance + * and the return value of the {@link #close()} function. + */ +public class OneWayGate { + + private final AtomicBoolean closed = new AtomicBoolean(); + + /** + * Closes the gate and sets the internal boolean value in an idempotent + * fashion. This is a one-way operation and cannot be reset. + * @return true if the gate was closed in this invocation, + * false if the gate was already closed + */ + public boolean close() { + return closed.compareAndSet(false, true); + } + + /** + * Indicates if the gate has been closed. + * @return true if the gate is closed, false otherwise + */ + public boolean isClosed() { + return closed.get(); + } +} diff --git a/server/src/main/java/org/opensearch/common/concurrent/RefCountedReleasable.java b/server/src/main/java/org/opensearch/common/concurrent/RefCountedReleasable.java new file mode 100644 index 00000000000..975f2295d7c --- /dev/null +++ b/server/src/main/java/org/opensearch/common/concurrent/RefCountedReleasable.java @@ -0,0 +1,48 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.concurrent; + +import org.opensearch.common.lease.Releasable; +import org.opensearch.common.util.concurrent.AbstractRefCounted; + +/** + * Decorator class that wraps an object reference as a {@link AbstractRefCounted} instance. + * In addition to a {@link String} name, it accepts a {@link Runnable} shutdown hook that is + * invoked when the reference count reaches zero i.e. on {@link #closeInternal()}. + */ +public class RefCountedReleasable extends AbstractRefCounted implements Releasable { + + private final T ref; + private final Runnable shutdownRunnable; + + public RefCountedReleasable(String name, T ref, Runnable shutdownRunnable) { + super(name); + this.ref = ref; + this.shutdownRunnable = shutdownRunnable; + } + + @Override + public void close() { + decRef(); + } + + public T get() { + return ref; + } + + @Override + protected void closeInternal() { + shutdownRunnable.run(); + } +} diff --git a/server/src/main/java/org/opensearch/index/engine/Engine.java b/server/src/main/java/org/opensearch/index/engine/Engine.java index 2d9cba2ee09..cbaf43b14c7 100644 --- a/server/src/main/java/org/opensearch/index/engine/Engine.java +++ b/server/src/main/java/org/opensearch/index/engine/Engine.java @@ -59,6 +59,7 @@ import org.opensearch.common.CheckedRunnable; import org.opensearch.common.Nullable; import org.opensearch.common.bytes.BytesReference; import org.opensearch.common.collect.ImmutableOpenMap; +import org.opensearch.common.concurrent.GatedCloseable; import org.opensearch.common.lease.Releasable; import org.opensearch.common.lease.Releasables; import org.opensearch.common.logging.Loggers; @@ -1828,25 +1829,9 @@ public abstract class Engine implements Closeable { } } - public static class IndexCommitRef implements Closeable { - private final AtomicBoolean closed = new AtomicBoolean(); - private final CheckedRunnable onClose; - private final IndexCommit indexCommit; - + public static class IndexCommitRef extends GatedCloseable { public IndexCommitRef(IndexCommit indexCommit, CheckedRunnable onClose) { - this.indexCommit = indexCommit; - this.onClose = onClose; - } - - @Override - public void close() throws IOException { - if (closed.compareAndSet(false, true)) { - onClose.run(); - } - } - - public IndexCommit getIndexCommit() { - return indexCommit; + super(indexCommit, onClose); } } diff --git a/server/src/main/java/org/opensearch/index/shard/IndexShard.java b/server/src/main/java/org/opensearch/index/shard/IndexShard.java index df0edd02d4f..863c2684142 100644 --- a/server/src/main/java/org/opensearch/index/shard/IndexShard.java +++ b/server/src/main/java/org/opensearch/index/shard/IndexShard.java @@ -1462,7 +1462,7 @@ public class IndexShard extends AbstractIndexShardComponent implements IndicesCl return store.getMetadata(null, true); } } - return store.getMetadata(indexCommit.getIndexCommit()); + return store.getMetadata(indexCommit.get()); } finally { store.decRef(); IOUtils.close(indexCommit); diff --git a/server/src/main/java/org/opensearch/index/shard/LocalShardSnapshot.java b/server/src/main/java/org/opensearch/index/shard/LocalShardSnapshot.java index 148c39df070..d62d0358eb7 100644 --- a/server/src/main/java/org/opensearch/index/shard/LocalShardSnapshot.java +++ b/server/src/main/java/org/opensearch/index/shard/LocalShardSnapshot.java @@ -88,7 +88,7 @@ final class LocalShardSnapshot implements Closeable { return new FilterDirectory(store.directory()) { @Override public String[] listAll() throws IOException { - Collection fileNames = indexCommit.getIndexCommit().getFileNames(); + Collection fileNames = indexCommit.get().getFileNames(); final String[] fileNameArray = fileNames.toArray(new String[fileNames.size()]); return fileNameArray; } diff --git a/server/src/main/java/org/opensearch/indices/recovery/PeerRecoveryTargetService.java b/server/src/main/java/org/opensearch/indices/recovery/PeerRecoveryTargetService.java index 81a6b0f2c38..684c4017168 100644 --- a/server/src/main/java/org/opensearch/indices/recovery/PeerRecoveryTargetService.java +++ b/server/src/main/java/org/opensearch/indices/recovery/PeerRecoveryTargetService.java @@ -222,7 +222,7 @@ public class PeerRecoveryTargetService implements IndexEventListener { logger.trace("not running recovery with id [{}] - can not find it (probably finished)", recoveryId); return; } - final RecoveryTarget recoveryTarget = recoveryRef.target(); + final RecoveryTarget recoveryTarget = recoveryRef.get(); timer = recoveryTarget.state().getTimer(); cancellableThreads = recoveryTarget.cancellableThreads(); if (preExistingRequest == null) { @@ -363,7 +363,7 @@ public class PeerRecoveryTargetService implements IndexEventListener { return; } - recoveryRef.target().prepareForTranslogOperations(request.totalTranslogOps(), listener); + recoveryRef.get().prepareForTranslogOperations(request.totalTranslogOps(), listener); } } } @@ -378,7 +378,7 @@ public class PeerRecoveryTargetService implements IndexEventListener { return; } - recoveryRef.target().finalizeRecovery(request.globalCheckpoint(), request.trimAboveSeqNo(), listener); + recoveryRef.get().finalizeRecovery(request.globalCheckpoint(), request.trimAboveSeqNo(), listener); } } } @@ -389,7 +389,7 @@ public class PeerRecoveryTargetService implements IndexEventListener { public void messageReceived(final RecoveryHandoffPrimaryContextRequest request, final TransportChannel channel, Task task) throws Exception { try (RecoveryRef recoveryRef = onGoingRecoveries.getRecoverySafe(request.recoveryId(), request.shardId())) { - recoveryRef.target().handoffPrimaryContext(request.primaryContext()); + recoveryRef.get().handoffPrimaryContext(request.primaryContext()); } channel.sendResponse(TransportResponse.Empty.INSTANCE); } @@ -402,7 +402,7 @@ public class PeerRecoveryTargetService implements IndexEventListener { public void messageReceived(final RecoveryTranslogOperationsRequest request, final TransportChannel channel, Task task) throws IOException { try (RecoveryRef recoveryRef = onGoingRecoveries.getRecoverySafe(request.recoveryId(), request.shardId())) { - final RecoveryTarget recoveryTarget = recoveryRef.target(); + final RecoveryTarget recoveryTarget = recoveryRef.get(); final ActionListener listener = createOrFinishListener( recoveryRef, channel, @@ -423,7 +423,7 @@ public class PeerRecoveryTargetService implements IndexEventListener { final ActionListener listener, final RecoveryRef recoveryRef ) { - final RecoveryTarget recoveryTarget = recoveryRef.target(); + final RecoveryTarget recoveryTarget = recoveryRef.get(); final ClusterStateObserver observer = new ClusterStateObserver(clusterService, null, logger, threadPool.getThreadContext()); final Consumer retryOnMappingException = exception -> { @@ -488,7 +488,7 @@ public class PeerRecoveryTargetService implements IndexEventListener { return; } - recoveryRef.target() + recoveryRef.get() .receiveFileInfo( request.phase1FileNames, request.phase1FileSizes, @@ -511,7 +511,7 @@ public class PeerRecoveryTargetService implements IndexEventListener { return; } - recoveryRef.target() + recoveryRef.get() .cleanFiles(request.totalTranslogOps(), request.getGlobalCheckpoint(), request.sourceMetaSnapshot(), listener); } } @@ -525,7 +525,7 @@ public class PeerRecoveryTargetService implements IndexEventListener { @Override public void messageReceived(final RecoveryFileChunkRequest request, TransportChannel channel, Task task) throws Exception { try (RecoveryRef recoveryRef = onGoingRecoveries.getRecoverySafe(request.recoveryId(), request.shardId())) { - final RecoveryTarget recoveryTarget = recoveryRef.target(); + final RecoveryTarget recoveryTarget = recoveryRef.get(); final ActionListener listener = createOrFinishListener(recoveryRef, channel, Actions.FILE_CHUNK, request); if (listener == null) { return; @@ -575,7 +575,7 @@ public class PeerRecoveryTargetService implements IndexEventListener { final RecoveryTransportRequest request, final CheckedFunction responseFn ) { - final RecoveryTarget recoveryTarget = recoveryRef.target(); + final RecoveryTarget recoveryTarget = recoveryRef.get(); final ActionListener channelListener = new ChannelActionListener<>(channel, action, request); final ActionListener voidListener = ActionListener.map(channelListener, responseFn); @@ -611,7 +611,7 @@ public class PeerRecoveryTargetService implements IndexEventListener { logger.error(() -> new ParameterizedMessage("unexpected error during recovery [{}], failing shard", recoveryId), e); onGoingRecoveries.failRecovery( recoveryId, - new RecoveryFailedException(recoveryRef.target().state(), "unexpected error", e), + new RecoveryFailedException(recoveryRef.get().state(), "unexpected error", e), true // be safe ); } else { diff --git a/server/src/main/java/org/opensearch/indices/recovery/RecoveriesCollection.java b/server/src/main/java/org/opensearch/indices/recovery/RecoveriesCollection.java index 0fa2bc29c09..3c197a8e33e 100644 --- a/server/src/main/java/org/opensearch/indices/recovery/RecoveriesCollection.java +++ b/server/src/main/java/org/opensearch/indices/recovery/RecoveriesCollection.java @@ -36,6 +36,7 @@ import org.apache.logging.log4j.Logger; import org.apache.logging.log4j.message.ParameterizedMessage; import org.opensearch.OpenSearchTimeoutException; import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.concurrent.GatedAutoCloseable; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.util.concurrent.AbstractRunnable; import org.opensearch.common.util.concurrent.ConcurrentCollections; @@ -48,7 +49,6 @@ import java.util.ArrayList; import java.util.Iterator; import java.util.List; import java.util.concurrent.ConcurrentMap; -import java.util.concurrent.atomic.AtomicBoolean; /** * This class holds a collection of all on going recoveries on the current node (i.e., the node is the target node @@ -178,7 +178,7 @@ public class RecoveriesCollection { if (recoveryRef == null) { throw new IndexShardClosedException(shardId); } - assert recoveryRef.target().shardId().equals(shardId); + assert recoveryRef.get().shardId().equals(shardId); return recoveryRef; } @@ -273,29 +273,15 @@ public class RecoveriesCollection { * causes {@link RecoveryTarget#decRef()} to be called. This makes sure that the underlying resources * will not be freed until {@link RecoveryRef#close()} is called. */ - public static class RecoveryRef implements AutoCloseable { - - private final RecoveryTarget status; - private final AtomicBoolean closed = new AtomicBoolean(false); + public static class RecoveryRef extends GatedAutoCloseable { /** * Important: {@link RecoveryTarget#tryIncRef()} should * be *successfully* called on status before */ public RecoveryRef(RecoveryTarget status) { - this.status = status; - this.status.setLastAccessTime(); - } - - @Override - public void close() { - if (closed.compareAndSet(false, true)) { - status.decRef(); - } - } - - public RecoveryTarget target() { - return status; + super(status, status::decRef); + status.setLastAccessTime(); } } diff --git a/server/src/main/java/org/opensearch/indices/recovery/RecoverySourceHandler.java b/server/src/main/java/org/opensearch/indices/recovery/RecoverySourceHandler.java index dcb7024ae8c..710b01a6709 100644 --- a/server/src/main/java/org/opensearch/indices/recovery/RecoverySourceHandler.java +++ b/server/src/main/java/org/opensearch/indices/recovery/RecoverySourceHandler.java @@ -268,7 +268,7 @@ public class RecoverySourceHandler { // advances and not when creating a new safe commit. In any case this is a best-effort thing since future recoveries can // always fall back to file-based ones, and only really presents a problem if this primary fails before things have settled // down. - startingSeqNo = Long.parseLong(safeCommitRef.getIndexCommit().getUserData().get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)) + 1L; + startingSeqNo = Long.parseLong(safeCommitRef.get().getUserData().get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)) + 1L; logger.trace("performing file-based recovery followed by history replay starting at [{}]", startingSeqNo); try { @@ -307,7 +307,7 @@ public class RecoverySourceHandler { deleteRetentionLeaseStep.whenComplete(ignored -> { assert Transports.assertNotTransportThread(RecoverySourceHandler.this + "[phase1]"); - phase1(safeCommitRef.getIndexCommit(), startingSeqNo, () -> estimateNumOps, sendFileStep); + phase1(safeCommitRef.get(), startingSeqNo, () -> estimateNumOps, sendFileStep); }, onFailure); } catch (final Exception e) { @@ -470,7 +470,7 @@ public class RecoverySourceHandler { private Engine.IndexCommitRef acquireSafeCommit(IndexShard shard) { final Engine.IndexCommitRef commitRef = shard.acquireSafeIndexCommit(); final AtomicBoolean closed = new AtomicBoolean(false); - return new Engine.IndexCommitRef(commitRef.getIndexCommit(), () -> { + return new Engine.IndexCommitRef(commitRef.get(), () -> { if (closed.compareAndSet(false, true)) { runWithGenericThreadPool(commitRef::close); } diff --git a/server/src/main/java/org/opensearch/snapshots/SnapshotShardsService.java b/server/src/main/java/org/opensearch/snapshots/SnapshotShardsService.java index 3b765cf1798..06b17c679cb 100644 --- a/server/src/main/java/org/opensearch/snapshots/SnapshotShardsService.java +++ b/server/src/main/java/org/opensearch/snapshots/SnapshotShardsService.java @@ -372,13 +372,13 @@ public class SnapshotShardsService extends AbstractLifecycleComponent implements try { // we flush first to make sure we get the latest writes snapshotted snapshotRef = indexShard.acquireLastIndexCommit(true); - final IndexCommit snapshotIndexCommit = snapshotRef.getIndexCommit(); + final IndexCommit snapshotIndexCommit = snapshotRef.get(); repository.snapshotShard( indexShard.store(), indexShard.mapperService(), snapshot.getSnapshotId(), indexId, - snapshotRef.getIndexCommit(), + snapshotRef.get(), getShardStateId(indexShard, snapshotIndexCommit), snapshotStatus, version, diff --git a/server/src/test/java/org/opensearch/common/concurrent/GatedAutoCloseableTests.java b/server/src/test/java/org/opensearch/common/concurrent/GatedAutoCloseableTests.java new file mode 100644 index 00000000000..63058da8f16 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/concurrent/GatedAutoCloseableTests.java @@ -0,0 +1,46 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.concurrent; + +import org.junit.Before; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.concurrent.atomic.AtomicInteger; + +public class GatedAutoCloseableTests extends OpenSearchTestCase { + + private AtomicInteger testRef; + private GatedAutoCloseable testObject; + + @Before + public void setup() { + testRef = new AtomicInteger(0); + testObject = new GatedAutoCloseable<>(testRef, testRef::incrementAndGet); + } + + public void testGet() { + assertEquals(0, testObject.get().get()); + } + + public void testClose() { + testObject.close(); + assertEquals(1, testObject.get().get()); + } + + public void testIdempotent() { + testObject.close(); + testObject.close(); + assertEquals(1, testObject.get().get()); + } +} diff --git a/server/src/test/java/org/opensearch/common/concurrent/GatedCloseableTests.java b/server/src/test/java/org/opensearch/common/concurrent/GatedCloseableTests.java new file mode 100644 index 00000000000..0645f971b8d --- /dev/null +++ b/server/src/test/java/org/opensearch/common/concurrent/GatedCloseableTests.java @@ -0,0 +1,60 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.concurrent; + +import org.junit.Before; +import org.opensearch.test.OpenSearchTestCase; + +import java.io.IOException; +import java.nio.file.FileSystem; + +import static org.mockito.Mockito.atMostOnce; +import static org.mockito.Mockito.doThrow; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.never; +import static org.mockito.Mockito.verify; + +public class GatedCloseableTests extends OpenSearchTestCase { + + private FileSystem testRef; + GatedCloseable testObject; + + @Before + public void setup() { + testRef = mock(FileSystem.class); + testObject = new GatedCloseable<>(testRef, testRef::close); + } + + public void testGet() throws Exception { + assertNotNull(testObject.get()); + assertEquals(testRef, testObject.get()); + verify(testRef, never()).close(); + } + + public void testClose() throws IOException { + testObject.close(); + verify(testRef, atMostOnce()).close(); + } + + public void testIdempotent() throws IOException { + testObject.close(); + testObject.close(); + verify(testRef, atMostOnce()).close(); + } + + public void testException() throws IOException { + doThrow(new IOException()).when(testRef).close(); + assertThrows(IOException.class, () -> testObject.close()); + } +} diff --git a/server/src/test/java/org/opensearch/common/concurrent/OneWayGateTests.java b/server/src/test/java/org/opensearch/common/concurrent/OneWayGateTests.java new file mode 100644 index 00000000000..357bf3ae321 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/concurrent/OneWayGateTests.java @@ -0,0 +1,41 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.concurrent; + +import org.junit.Before; +import org.opensearch.test.OpenSearchTestCase; + +public class OneWayGateTests extends OpenSearchTestCase { + + private OneWayGate testGate; + + @Before + public void setup() { + testGate = new OneWayGate(); + } + + public void testGateOpen() { + assertFalse(testGate.isClosed()); + } + + public void testGateClosed() { + testGate.close(); + assertTrue(testGate.isClosed()); + } + + public void testGateIdempotent() { + assertTrue(testGate.close()); + assertFalse(testGate.close()); + } +} diff --git a/server/src/test/java/org/opensearch/common/concurrent/RefCountedReleasableTests.java b/server/src/test/java/org/opensearch/common/concurrent/RefCountedReleasableTests.java new file mode 100644 index 00000000000..63c0873f159 --- /dev/null +++ b/server/src/test/java/org/opensearch/common/concurrent/RefCountedReleasableTests.java @@ -0,0 +1,68 @@ +/* + * SPDX-License-Identifier: Apache-2.0 + * + * The OpenSearch Contributors require contributions made to + * this file be licensed under the Apache-2.0 license or a + * compatible open source license. + */ + +/* + * Modifications Copyright OpenSearch Contributors. See + * GitHub history for details. + */ + +package org.opensearch.common.concurrent; + +import org.junit.Before; +import org.opensearch.test.OpenSearchTestCase; + +import java.util.concurrent.atomic.AtomicInteger; + +public class RefCountedReleasableTests extends OpenSearchTestCase { + + private AtomicInteger testRef; + private RefCountedReleasable testObject; + + @Before + public void setup() { + testRef = new AtomicInteger(0); + testObject = new RefCountedReleasable<>("test", testRef, testRef::incrementAndGet); + } + + public void testInitialState() { + assertEquals("test", testObject.getName()); + assertEquals(testRef, testObject.get()); + assertEquals(testRef, testObject.get()); + assertEquals(0, testObject.get().get()); + assertEquals(1, testObject.refCount()); + } + + public void testIncRef() { + testObject.incRef(); + assertEquals(2, testObject.refCount()); + assertEquals(0, testObject.get().get()); + } + + public void testCloseWithoutInternal() { + testObject.incRef(); + assertEquals(2, testObject.refCount()); + testObject.close(); + assertEquals(1, testObject.refCount()); + assertEquals(0, testObject.get().get()); + } + + public void testCloseWithInternal() { + assertEquals(1, testObject.refCount()); + testObject.close(); + assertEquals(0, testObject.refCount()); + assertEquals(1, testObject.get().get()); + } + + public void testIncRefAfterClose() { + assertEquals(1, testObject.refCount()); + testObject.close(); + assertEquals(0, testObject.refCount()); + assertEquals(1, testObject.get().get()); + assertThrows(IllegalStateException.class, () -> testObject.incRef()); + } +} diff --git a/server/src/test/java/org/opensearch/index/engine/InternalEngineTests.java b/server/src/test/java/org/opensearch/index/engine/InternalEngineTests.java index 0bd47902c89..745508135c6 100644 --- a/server/src/test/java/org/opensearch/index/engine/InternalEngineTests.java +++ b/server/src/test/java/org/opensearch/index/engine/InternalEngineTests.java @@ -1088,7 +1088,7 @@ public class InternalEngineTests extends EngineTestCase { assertThat(engine.getLastSyncedGlobalCheckpoint(), equalTo(globalCheckpoint.get())); try (Engine.IndexCommitRef safeCommit = engine.acquireSafeIndexCommit()) { SequenceNumbers.CommitInfo commitInfo = SequenceNumbers.loadSeqNoInfoFromLuceneCommit( - safeCommit.getIndexCommit().getUserData().entrySet() + safeCommit.get().getUserData().entrySet() ); assertThat(commitInfo.localCheckpoint, equalTo(engine.getProcessedLocalCheckpoint())); } @@ -1505,7 +1505,7 @@ public class InternalEngineTests extends EngineTestCase { engine.syncTranslog(); final long safeCommitCheckpoint; try (Engine.IndexCommitRef safeCommit = engine.acquireSafeIndexCommit()) { - safeCommitCheckpoint = Long.parseLong(safeCommit.getIndexCommit().getUserData().get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)); + safeCommitCheckpoint = Long.parseLong(safeCommit.get().getUserData().get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)); } engine.forceMerge(true, 1, false, false, false, UUIDs.randomBase64UUID()); assertConsistentHistoryBetweenTranslogAndLuceneIndex(engine, mapperService); @@ -1595,9 +1595,7 @@ public class InternalEngineTests extends EngineTestCase { engine.syncTranslog(); final long minSeqNoToRetain; try (Engine.IndexCommitRef safeCommit = engine.acquireSafeIndexCommit()) { - long safeCommitLocalCheckpoint = Long.parseLong( - safeCommit.getIndexCommit().getUserData().get(SequenceNumbers.LOCAL_CHECKPOINT_KEY) - ); + long safeCommitLocalCheckpoint = Long.parseLong(safeCommit.get().getUserData().get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)); minSeqNoToRetain = Math.min(globalCheckpoint.get() + 1 - retainedExtraOps, safeCommitLocalCheckpoint + 1); } engine.forceMerge(true, 1, false, false, false, UUIDs.randomBase64UUID()); @@ -2671,7 +2669,7 @@ public class InternalEngineTests extends EngineTestCase { long prevLocalCheckpoint = SequenceNumbers.NO_OPS_PERFORMED; long prevMaxSeqNo = SequenceNumbers.NO_OPS_PERFORMED; for (Engine.IndexCommitRef commitRef : commits) { - final IndexCommit commit = commitRef.getIndexCommit(); + final IndexCommit commit = commitRef.get(); Map userData = commit.getUserData(); long localCheckpoint = userData.containsKey(SequenceNumbers.LOCAL_CHECKPOINT_KEY) ? Long.parseLong(userData.get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)) @@ -5643,7 +5641,7 @@ public class InternalEngineTests extends EngineTestCase { globalCheckpoint.set(numDocs + moreDocs - 1); engine.flush(); // check that we can still read the commit that we captured - try (IndexReader reader = DirectoryReader.open(snapshot.getIndexCommit())) { + try (IndexReader reader = DirectoryReader.open(snapshot.get())) { assertThat(reader.numDocs(), equalTo(flushFirst && safeCommit == false ? numDocs : 0)); } assertThat(DirectoryReader.listCommits(engine.store.directory()), hasSize(2)); @@ -6325,7 +6323,7 @@ public class InternalEngineTests extends EngineTestCase { assertThat(actualOps, containsInAnyOrder(expectedOps)); } try (Engine.IndexCommitRef commitRef = engine.acquireSafeIndexCommit()) { - IndexCommit safeCommit = commitRef.getIndexCommit(); + IndexCommit safeCommit = commitRef.get(); if (safeCommit.getUserData().containsKey(Engine.MIN_RETAINED_SEQNO)) { lastMinRetainedSeqNo = Long.parseLong(safeCommit.getUserData().get(Engine.MIN_RETAINED_SEQNO)); } diff --git a/server/src/test/java/org/opensearch/index/engine/NoOpEngineTests.java b/server/src/test/java/org/opensearch/index/engine/NoOpEngineTests.java index 65b8a81b029..772cda9efa5 100644 --- a/server/src/test/java/org/opensearch/index/engine/NoOpEngineTests.java +++ b/server/src/test/java/org/opensearch/index/engine/NoOpEngineTests.java @@ -115,7 +115,7 @@ public class NoOpEngineTests extends EngineTestCase { assertThat(noOpEngine.getPersistedLocalCheckpoint(), equalTo(localCheckpoint)); assertThat(noOpEngine.getSeqNoStats(100L).getMaxSeqNo(), equalTo(maxSeqNo)); try (Engine.IndexCommitRef ref = noOpEngine.acquireLastIndexCommit(false)) { - try (IndexReader reader = DirectoryReader.open(ref.getIndexCommit())) { + try (IndexReader reader = DirectoryReader.open(ref.get())) { assertThat(reader.numDocs(), equalTo(docs)); } } diff --git a/server/src/test/java/org/opensearch/index/shard/IndexShardTests.java b/server/src/test/java/org/opensearch/index/shard/IndexShardTests.java index e08786e2c45..6485861f175 100644 --- a/server/src/test/java/org/opensearch/index/shard/IndexShardTests.java +++ b/server/src/test/java/org/opensearch/index/shard/IndexShardTests.java @@ -4127,10 +4127,10 @@ public class IndexShardTests extends IndexShardTestCase { readyToSnapshotLatch.await(); shard.snapshotStoreMetadata(); try (Engine.IndexCommitRef indexCommitRef = shard.acquireLastIndexCommit(randomBoolean())) { - shard.store().getMetadata(indexCommitRef.getIndexCommit()); + shard.store().getMetadata(indexCommitRef.get()); } try (Engine.IndexCommitRef indexCommitRef = shard.acquireSafeIndexCommit()) { - shard.store().getMetadata(indexCommitRef.getIndexCommit()); + shard.store().getMetadata(indexCommitRef.get()); } } catch (InterruptedException | IOException e) { throw new AssertionError(e); diff --git a/server/src/test/java/org/opensearch/recovery/RecoveriesCollectionTests.java b/server/src/test/java/org/opensearch/recovery/RecoveriesCollectionTests.java index 69923e4390e..6a08f5115d1 100644 --- a/server/src/test/java/org/opensearch/recovery/RecoveriesCollectionTests.java +++ b/server/src/test/java/org/opensearch/recovery/RecoveriesCollectionTests.java @@ -69,10 +69,10 @@ public class RecoveriesCollectionTests extends OpenSearchIndexLevelReplicationTe final RecoveriesCollection collection = new RecoveriesCollection(logger, threadPool); final long recoveryId = startRecovery(collection, shards.getPrimaryNode(), shards.addReplica()); try (RecoveriesCollection.RecoveryRef status = collection.getRecovery(recoveryId)) { - final long lastSeenTime = status.target().lastAccessTime(); + final long lastSeenTime = status.get().lastAccessTime(); assertBusy(() -> { try (RecoveriesCollection.RecoveryRef currentStatus = collection.getRecovery(recoveryId)) { - assertThat("access time failed to update", lastSeenTime, lessThan(currentStatus.target().lastAccessTime())); + assertThat("access time failed to update", lastSeenTime, lessThan(currentStatus.get().lastAccessTime())); } }); } finally { @@ -120,7 +120,7 @@ public class RecoveriesCollectionTests extends OpenSearchIndexLevelReplicationTe final long recoveryId = startRecovery(collection, shards.getPrimaryNode(), shards.addReplica()); final long recoveryId2 = startRecovery(collection, shards.getPrimaryNode(), shards.addReplica()); try (RecoveriesCollection.RecoveryRef recoveryRef = collection.getRecovery(recoveryId)) { - ShardId shardId = recoveryRef.target().shardId(); + ShardId shardId = recoveryRef.get().shardId(); assertTrue("failed to cancel recoveries", collection.cancelRecoveriesForShard(shardId, "test")); assertThat("all recoveries should be cancelled", collection.size(), equalTo(0)); } finally { @@ -160,8 +160,8 @@ public class RecoveriesCollectionTests extends OpenSearchIndexLevelReplicationTe assertEquals(currentAsTarget, shard.recoveryStats().currentAsTarget()); try (RecoveriesCollection.RecoveryRef newRecoveryRef = collection.getRecovery(resetRecoveryId)) { shards.recoverReplica(shard, (s, n) -> { - assertSame(s, newRecoveryRef.target().indexShard()); - return newRecoveryRef.target(); + assertSame(s, newRecoveryRef.get().indexShard()); + return newRecoveryRef.get(); }, false); } shards.assertAllEqual(numDocs); diff --git a/test/framework/src/main/java/org/opensearch/index/engine/EngineTestCase.java b/test/framework/src/main/java/org/opensearch/index/engine/EngineTestCase.java index 24d24cd9f1a..97d3490db4a 100644 --- a/test/framework/src/main/java/org/opensearch/index/engine/EngineTestCase.java +++ b/test/framework/src/main/java/org/opensearch/index/engine/EngineTestCase.java @@ -1389,7 +1389,7 @@ public abstract class EngineTestCase extends OpenSearchTestCase { final long seqNoForRecovery; if (engine.config().getIndexSettings().isSoftDeleteEnabled()) { try (Engine.IndexCommitRef safeCommit = engine.acquireSafeIndexCommit()) { - seqNoForRecovery = Long.parseLong(safeCommit.getIndexCommit().getUserData().get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)) + 1; + seqNoForRecovery = Long.parseLong(safeCommit.get().getUserData().get(SequenceNumbers.LOCAL_CHECKPOINT_KEY)) + 1; } } else { seqNoForRecovery = engine.getMinRetainedSeqNo(); diff --git a/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java b/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java index 54b3ffbfd3a..b388ab8835a 100644 --- a/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java +++ b/test/framework/src/main/java/org/opensearch/index/shard/IndexShardTestCase.java @@ -1036,7 +1036,7 @@ public abstract class IndexShardTestCase extends OpenSearchTestCase { shard.mapperService(), snapshot.getSnapshotId(), indexId, - indexCommitRef.getIndexCommit(), + indexCommitRef.get(), null, snapshotStatus, Version.CURRENT,