Drop mocksocket in favour of custom security manager checks (tests only) (#1205)

* Drop mocksocket in favour of custom security manager checks (tests only)

Signed-off-by: Andriy Redko <andriy.redko@aiven.io>

* Slightly relaxed host checks to allow all local addresses

Signed-off-by: Andriy Redko <andriy.redko@aiven.io>
This commit is contained in:
Andriy Redko 2021-09-16 17:21:47 -04:00 committed by GitHub
parent cbbf967d76
commit b6c8bdf872
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
18 changed files with 83 additions and 38 deletions

View File

@ -37,7 +37,6 @@ commonslogging = 1.1.3
commonscodec = 1.13
hamcrest = 2.1
securemock = 1.2
mocksocket = 1.2
mockito = 1.9.5
objenesis = 1.0

View File

@ -35,7 +35,9 @@ package org.opensearch.secure_sm;
import java.security.AccessController;
import java.security.Permission;
import java.security.PrivilegedAction;
import java.util.Arrays;
import java.util.Objects;
import java.util.Set;
/**
* Extension of SecurityManager that works around a few design flaws in Java Security.
@ -105,10 +107,48 @@ public class SecureSM extends SecurityManager {
* <li><code>com.intellij.rt.execution.junit.</code></li>
* </ul>
*
* For testing purposes, the security manager grants network permissions "connect, accept"
* to following classes, granted they only access local network interfaces.
*
* <ul>
* <li><code>sun.net.httpserver.ServerImpl</code></li>
* <li><code>java.net.ServerSocket"</code></li>
* <li><code>java.net.Socket</code></li>
* </ul>
*
* @return an instance of SecureSM where test packages can halt or exit the virtual machine
*/
public static SecureSM createTestSecureSM() {
return new SecureSM(TEST_RUNNER_PACKAGES);
public static SecureSM createTestSecureSM(final Set<String> trustedHosts) {
return new SecureSM(TEST_RUNNER_PACKAGES) {
// Trust these callers inside the test suite only
final String[] TRUSTED_CALLERS = new String[] {
"sun.net.httpserver.ServerImpl",
"java.net.ServerSocket",
"java.net.Socket"
};
@Override
public void checkConnect(String host, int port) {
// Allow to connect from selected trusted classes to local addresses only
if (!hasTrustedCallerChain() || !trustedHosts.contains(host)) {
super.checkConnect(host, port);
}
}
@Override
public void checkAccept(String host, int port) {
// Allow to accept connections from selected trusted classes to local addresses only
if (!hasTrustedCallerChain() || !trustedHosts.contains(host)) {
super.checkAccept(host, port);
}
}
private boolean hasTrustedCallerChain() {
return Arrays
.stream(getClassContext())
.anyMatch(c -> Arrays.stream(TRUSTED_CALLERS).anyMatch(t -> c.getName().startsWith(t)));
}
};
}
static final String[] TEST_RUNNER_PACKAGES = new String[] {

View File

@ -37,6 +37,7 @@ import junit.framework.TestCase;
import java.security.Permission;
import java.security.Policy;
import java.security.ProtectionDomain;
import java.util.Collections;
import java.util.concurrent.atomic.AtomicBoolean;
/** Simple tests for SecureSM */
@ -57,7 +58,7 @@ public class SecureSMTests extends TestCase {
return true;
}
});
System.setSecurityManager(SecureSM.createTestSecureSM());
System.setSecurityManager(SecureSM.createTestSecureSM(Collections.emptySet()));
}
@SuppressForbidden(reason = "testing that System#exit is blocked")

View File

@ -47,7 +47,6 @@ import org.opensearch.common.ssl.PemKeyConfig;
import org.opensearch.common.ssl.PemTrustConfig;
import org.opensearch.env.Environment;
import org.opensearch.env.TestEnvironment;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.watcher.ResourceWatcherService;
import org.hamcrest.Matchers;
@ -91,7 +90,7 @@ public class ReindexRestClientSslTests extends OpenSearchTestCase {
public static void setupHttpServer() throws Exception {
InetSocketAddress address = new InetSocketAddress("localhost", 0);
SSLContext sslContext = buildServerSslContext();
server = MockHttpServer.createHttps(address, 0);
server = HttpsServer.create(address, 0);
server.setHttpsConfigurator(new ClientAuthHttpsConfigurator(sslContext));
server.start();
server.createContext("/", http -> {

View File

@ -37,7 +37,6 @@ import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.blobstore.BlobContainer;
import org.opensearch.common.blobstore.BlobPath;
import org.opensearch.common.settings.Settings;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.test.OpenSearchTestCase;
import org.junit.AfterClass;
import org.junit.Before;
@ -67,7 +66,7 @@ public class URLBlobStoreTests extends OpenSearchTestCase {
}
blobName = randomAlphaOfLength(8);
httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 6001), 0);
httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 6001), 0);
httpServer.createContext("/indices/" + blobName, (s) -> {
s.sendResponseHeaders(200, message.length);

View File

@ -40,7 +40,6 @@ import org.opensearch.common.transport.TransportAddress;
import org.opensearch.common.util.MockPageCacheRecycler;
import org.opensearch.common.util.PageCacheRecycler;
import org.opensearch.indices.breaker.NoneCircuitBreakerService;
import org.elasticsearch.mocksocket.MockSocket;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.threadpool.ThreadPool;
import org.opensearch.transport.SharedGroupFactory;
@ -100,7 +99,7 @@ public class Netty4SizeHeaderFrameDecoderTests extends OpenSearchTestCase {
String randomMethod = randomFrom("GET", "POST", "PUT", "DELETE", "HEAD", "OPTIONS", "PATCH");
String data = randomMethod + " / HTTP/1.1";
try (Socket socket = new MockSocket(host, port)) {
try (Socket socket = new Socket(host, port)) {
socket.getOutputStream().write(data.getBytes(StandardCharsets.UTF_8));
socket.getOutputStream().flush();
@ -111,7 +110,7 @@ public class Netty4SizeHeaderFrameDecoderTests extends OpenSearchTestCase {
}
public void testThatNothingIsReturnedForOtherInvalidPackets() throws Exception {
try (Socket socket = new MockSocket(host, port)) {
try (Socket socket = new Socket(host, port)) {
socket.getOutputStream().write("FOOBAR".getBytes(StandardCharsets.UTF_8));
socket.getOutputStream().flush();

View File

@ -46,7 +46,6 @@ import org.opensearch.common.settings.Setting;
import org.opensearch.common.settings.Settings;
import org.opensearch.discovery.DiscoveryModule;
import org.opensearch.env.Environment;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.node.Node;
import org.opensearch.plugin.discovery.azure.classic.AzureDiscoveryPlugin;
import org.opensearch.plugins.Plugin;
@ -163,7 +162,7 @@ public class AzureDiscoveryClusterFormationTests extends OpenSearchIntegTestCase
public static void startHttpd() throws Exception {
logDir = createTempDir();
SSLContext sslContext = getSSLContext();
httpsServer = MockHttpServer.createHttps(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 0), 0);
httpsServer = HttpsServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 0), 0);
httpsServer.setHttpsConfigurator(new HttpsConfigurator(sslContext));
httpsServer.createContext("/subscription/services/hostedservices/myservice", (s) -> {
Headers headers = s.getResponseHeaders();

View File

@ -40,7 +40,6 @@ import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.MockSecureSettings;
import org.opensearch.common.settings.Settings;
import org.opensearch.core.internal.io.IOUtils;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.test.OpenSearchTestCase;
import org.opensearch.test.transport.MockTransportService;
import org.opensearch.threadpool.TestThreadPool;
@ -74,7 +73,7 @@ public abstract class AbstractEC2MockAPITestCase extends OpenSearchTestCase {
@Before
public void setUp() throws Exception {
httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer.start();
threadPool = new TestThreadPool(EC2RetriesTests.class.getName());
transportService = createTransportService();

View File

@ -38,7 +38,6 @@ import org.opensearch.common.Strings;
import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.network.NetworkService;
import org.opensearch.common.settings.Settings;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.rest.RestStatus;
import org.opensearch.test.OpenSearchTestCase;
@ -74,7 +73,7 @@ public class Ec2NetworkTests extends OpenSearchTestCase {
@BeforeClass
public static void startHttp() throws Exception {
httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 0), 0);
httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress().getHostAddress(), 0), 0);
BiConsumer<String, String> registerContext = (path, v) ->{
final byte[] message = v.getBytes(UTF_8);

View File

@ -32,7 +32,6 @@
package org.opensearch.example.resthandler;
import org.elasticsearch.mocksocket.MockSocket;
import org.opensearch.test.OpenSearchTestCase;
import java.io.BufferedReader;
@ -57,7 +56,7 @@ public class ExampleFixtureIT extends OpenSearchTestCase {
final URL url = new URL("http://" + externalAddress);
final InetAddress address = InetAddress.getByName(url.getHost());
try (
Socket socket = new MockSocket(address, url.getPort());
Socket socket = new Socket(address, url.getPort());
BufferedWriter writer = new BufferedWriter(new OutputStreamWriter(socket.getOutputStream(), StandardCharsets.UTF_8));
BufferedReader reader = new BufferedReader(new InputStreamReader(socket.getInputStream(), StandardCharsets.UTF_8))
) {

View File

@ -56,7 +56,6 @@ import org.opensearch.common.settings.Settings;
import org.opensearch.common.unit.ByteSizeUnit;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.CountDown;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.rest.RestStatus;
import org.opensearch.rest.RestUtils;
import org.opensearch.test.OpenSearchTestCase;
@ -118,7 +117,7 @@ public class AzureBlobContainerRetriesTests extends OpenSearchTestCase {
@Before
public void setUp() throws Exception {
threadPool = new TestThreadPool(getTestClass().getName(), AzureRepositoryPlugin.executorBuilder());
httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer.start();
super.setUp();
}

View File

@ -73,11 +73,6 @@ grant codeBase "${codebase.junit}" {
permission java.lang.reflect.ReflectPermission "suppressAccessChecks";
};
grant codeBase "${codebase.mocksocket}" {
// mocksocket makes and accepts socket connections
permission java.net.SocketPermission "*", "accept,connect";
};
grant codeBase "${codebase.opensearch-nio}" {
// opensearch-nio makes and accepts socket connections
permission java.net.SocketPermission "*", "accept,connect";

View File

@ -63,7 +63,6 @@ import org.opensearch.common.xcontent.XContentBuilder;
import org.opensearch.common.xcontent.XContentFactory;
import org.opensearch.core.internal.io.IOUtils;
import org.opensearch.index.IndexNotFoundException;
import org.elasticsearch.mocksocket.MockServerSocket;
import org.opensearch.search.SearchHit;
import org.opensearch.search.SearchHits;
import org.opensearch.search.aggregations.InternalAggregations;
@ -180,7 +179,7 @@ public class RemoteClusterConnectionTests extends OpenSearchTestCase {
@SuppressForbidden(reason = "calls getLocalHost here but it's fine in this case")
public void testSlowNodeCanBeCancelled() throws IOException, InterruptedException {
try (ServerSocket socket = new MockServerSocket()) {
try (ServerSocket socket = new ServerSocket()) {
socket.bind(new InetSocketAddress(InetAddress.getLocalHost(), 0), 1);
socket.setReuseAddress(true);
DiscoveryNode seedNode = new DiscoveryNode("TEST", new TransportAddress(socket.getInetAddress(),

View File

@ -46,7 +46,6 @@ dependencies {
api "commons-logging:commons-logging:${versions.commonslogging}"
api "commons-codec:commons-codec:${versions.commonscodec}"
api "org.elasticsearch:securemock:${versions.securemock}"
api "org.elasticsearch:mocksocket:${versions.mocksocket}"
}
compileJava.options.compilerArgs -= '-Xlint:cast'

View File

@ -42,12 +42,15 @@ import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.io.FileSystemUtils;
import org.opensearch.common.io.PathUtils;
import org.opensearch.common.network.IfConfig;
import org.opensearch.common.network.NetworkAddress;
import org.opensearch.common.settings.Settings;
import org.opensearch.plugins.PluginInfo;
import org.opensearch.secure_sm.SecureSM;
import org.junit.Assert;
import java.io.InputStream;
import java.net.NetworkInterface;
import java.net.SocketException;
import java.net.SocketPermission;
import java.net.URL;
import java.nio.file.Files;
@ -66,6 +69,7 @@ import java.util.Map;
import java.util.Objects;
import java.util.Properties;
import java.util.Set;
import java.util.stream.Collectors;
import static com.carrotsearch.randomizedtesting.RandomizedTest.systemPropertyAsBoolean;
@ -161,7 +165,7 @@ public class BootstrapForTesting {
return opensearchPolicy.implies(domain, permission) || testFramework.implies(domain, permission);
}
});
System.setSecurityManager(SecureSM.createTestSecureSM());
System.setSecurityManager(SecureSM.createTestSecureSM(getTrustedHosts()));
Security.selfTest();
// guarantee plugin classes are initialized first, in case they have one-time hacks.
@ -272,6 +276,25 @@ public class BootstrapForTesting {
return raw;
}
/**
* Collect host addresses of all local interfaces so we could check
* if the network connection is being made only on those.
* @return host names and addresses of all local interfaces
*/
private static Set<String> getTrustedHosts() {
//
try {
return Collections
.list(NetworkInterface.getNetworkInterfaces())
.stream()
.flatMap(iface -> Collections.list(iface.getInetAddresses()).stream())
.map(address -> NetworkAddress.format(address))
.collect(Collectors.toSet());
} catch (final SocketException e) {
return Collections.emptySet();
}
}
// does nothing, just easy way to make sure the class is loaded.
public static void ensureInitialized() {}
}

View File

@ -45,7 +45,6 @@ import org.opensearch.common.io.Streams;
import org.opensearch.common.unit.ByteSizeValue;
import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.CountDown;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.test.OpenSearchTestCase;
import org.junit.After;
import org.junit.Before;
@ -81,7 +80,7 @@ public abstract class AbstractBlobContainerRetriesTestCase extends OpenSearchTes
@Before
public void setUp() throws Exception {
httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer.start();
super.setUp();
}

View File

@ -47,7 +47,6 @@ import org.opensearch.common.SuppressForbidden;
import org.opensearch.common.bytes.BytesReference;
import org.opensearch.common.network.InetAddresses;
import org.opensearch.common.settings.Settings;
import org.elasticsearch.mocksocket.MockHttpServer;
import org.opensearch.repositories.RepositoriesService;
import org.opensearch.repositories.Repository;
import org.opensearch.repositories.RepositoryMissingException;
@ -102,7 +101,7 @@ public abstract class OpenSearchMockAPIBasedRepositoryIntegTestCase extends Open
@BeforeClass
public static void startHttpServer() throws Exception {
httpServer = MockHttpServer.createHttp(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer = HttpServer.create(new InetSocketAddress(InetAddress.getLoopbackAddress(), 0), 0);
httpServer.setExecutor(r -> {
try {
r.run();

View File

@ -64,7 +64,6 @@ import org.opensearch.common.unit.TimeValue;
import org.opensearch.common.util.concurrent.AbstractRunnable;
import org.opensearch.common.util.concurrent.ConcurrentCollections;
import org.opensearch.core.internal.io.IOUtils;
import org.elasticsearch.mocksocket.MockServerSocket;
import org.opensearch.node.Node;
import org.opensearch.tasks.Task;
import org.opensearch.test.OpenSearchTestCase;
@ -1938,7 +1937,7 @@ public abstract class AbstractSimpleTransportTestCase extends OpenSearchTestCase
public void testTimeoutPerConnection() throws IOException {
assumeTrue("Works only on BSD network stacks", Constants.MAC_OS_X || Constants.FREE_BSD);
try (ServerSocket socket = new MockServerSocket()) {
try (ServerSocket socket = new ServerSocket()) {
// note - this test uses backlog=1 which is implementation specific ie. it might not work on some TCP/IP stacks
// on linux (at least newer ones) the listen(addr, backlog=1) should just ignore new connections if the queue is full which
// means that once we received an ACK from the client we just drop the packet on the floor (which is what we want) and we run
@ -2057,7 +2056,7 @@ public abstract class AbstractSimpleTransportTestCase extends OpenSearchTestCase
}
public void testTcpHandshakeTimeout() throws IOException {
try (ServerSocket socket = new MockServerSocket()) {
try (ServerSocket socket = new ServerSocket()) {
socket.bind(getLocalEphemeral(), 1);
socket.setReuseAddress(true);
DiscoveryNode dummy = new DiscoveryNode("TEST", new TransportAddress(socket.getInetAddress(),
@ -2078,7 +2077,7 @@ public abstract class AbstractSimpleTransportTestCase extends OpenSearchTestCase
}
public void testTcpHandshakeConnectionReset() throws IOException, InterruptedException {
try (ServerSocket socket = new MockServerSocket()) {
try (ServerSocket socket = new ServerSocket()) {
socket.bind(getLocalEphemeral(), 1);
socket.setReuseAddress(true);
DiscoveryNode dummy = new DiscoveryNode("TEST", new TransportAddress(socket.getInetAddress(),