diff --git a/shield/src/main/java/org/elasticsearch/shield/transport/netty/ShieldNettyTransport.java b/shield/src/main/java/org/elasticsearch/shield/transport/netty/ShieldNettyTransport.java index 95904f732af..d680bd32a67 100644 --- a/shield/src/main/java/org/elasticsearch/shield/transport/netty/ShieldNettyTransport.java +++ b/shield/src/main/java/org/elasticsearch/shield/transport/netty/ShieldNettyTransport.java @@ -8,6 +8,7 @@ package org.elasticsearch.shield.transport.netty; import org.elasticsearch.Version; import org.elasticsearch.common.inject.Inject; import org.elasticsearch.common.inject.internal.Nullable; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; import org.elasticsearch.common.util.BigArrays; @@ -51,8 +52,8 @@ public class ShieldNettyTransport extends NettyTransport { @Inject public ShieldNettyTransport(Settings settings, ThreadPool threadPool, NetworkService networkService, BigArrays bigArrays, Version version, @Nullable IPFilter authenticator, @Nullable ServerSSLService serverSSLService, ClientSSLService clientSSLService, - ShieldSettingsFilter settingsFilter) { - super(settings, threadPool, networkService, bigArrays, version); + ShieldSettingsFilter settingsFilter, NamedWriteableRegistry namedWriteableRegistry) { + super(settings, threadPool, networkService, bigArrays, version, namedWriteableRegistry); this.authenticator = authenticator; this.ssl = settings.getAsBoolean(TRANSPORT_SSL_SETTING, TRANSPORT_SSL_DEFAULT); this.serverSslService = serverSSLService; diff --git a/shield/src/test/java/org/elasticsearch/shield/transport/netty/ShieldNettyTransportTests.java b/shield/src/test/java/org/elasticsearch/shield/transport/netty/ShieldNettyTransportTests.java index e313adbcb0e..61f5ca0dcb5 100644 --- a/shield/src/test/java/org/elasticsearch/shield/transport/netty/ShieldNettyTransportTests.java +++ b/shield/src/test/java/org/elasticsearch/shield/transport/netty/ShieldNettyTransportTests.java @@ -6,6 +6,7 @@ package org.elasticsearch.shield.transport.netty; import org.elasticsearch.Version; +import org.elasticsearch.common.io.stream.NamedWriteableRegistry; import org.elasticsearch.common.netty.OpenChannelsHandler; import org.elasticsearch.common.network.NetworkService; import org.elasticsearch.common.settings.Settings; @@ -54,7 +55,7 @@ public class ShieldNettyTransportTests extends ESTestCase { @Test public void testThatSSLCanBeDisabledByProfile() throws Exception { Settings settings = settingsBuilder().put("shield.transport.ssl", true).build(); - ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter); + ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter, mock(NamedWriteableRegistry.class)); setOpenChannelsHandlerToMock(transport); ChannelPipelineFactory factory = transport.configureServerChannelPipelineFactory("client", settingsBuilder().put("shield.ssl", false).build()); assertThat(factory.getPipeline().get(SslHandler.class), nullValue()); @@ -63,7 +64,7 @@ public class ShieldNettyTransportTests extends ESTestCase { @Test public void testThatSSLCanBeEnabledByProfile() throws Exception { Settings settings = settingsBuilder().put("shield.transport.ssl", false).build(); - ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter); + ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter, mock(NamedWriteableRegistry.class)); setOpenChannelsHandlerToMock(transport); ChannelPipelineFactory factory = transport.configureServerChannelPipelineFactory("client", settingsBuilder().put("shield.ssl", true).build()); assertThat(factory.getPipeline().get(SslHandler.class), notNullValue()); @@ -72,7 +73,7 @@ public class ShieldNettyTransportTests extends ESTestCase { @Test public void testThatProfileTakesDefaultSSLSetting() throws Exception { Settings settings = settingsBuilder().put("shield.transport.ssl", true).build(); - ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter); + ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter, mock(NamedWriteableRegistry.class)); setOpenChannelsHandlerToMock(transport); ChannelPipelineFactory factory = transport.configureServerChannelPipelineFactory("client", Settings.EMPTY); assertThat(factory.getPipeline().get(SslHandler.class).getEngine(), notNullValue()); @@ -81,7 +82,7 @@ public class ShieldNettyTransportTests extends ESTestCase { @Test public void testDefaultClientAuth() throws Exception { Settings settings = settingsBuilder().put("shield.transport.ssl", true).build(); - ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter); + ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter, mock(NamedWriteableRegistry.class)); setOpenChannelsHandlerToMock(transport); ChannelPipelineFactory factory = transport.configureServerChannelPipelineFactory("client", Settings.EMPTY); assertThat(factory.getPipeline().get(SslHandler.class).getEngine().getNeedClientAuth(), is(true)); @@ -94,7 +95,7 @@ public class ShieldNettyTransportTests extends ESTestCase { Settings settings = settingsBuilder() .put("shield.transport.ssl", true) .put(ShieldNettyTransport.TRANSPORT_CLIENT_AUTH_SETTING, value).build(); - ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter); + ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter, mock(NamedWriteableRegistry.class)); setOpenChannelsHandlerToMock(transport); ChannelPipelineFactory factory = transport.configureServerChannelPipelineFactory("client", Settings.EMPTY); assertThat(factory.getPipeline().get(SslHandler.class).getEngine().getNeedClientAuth(), is(true)); @@ -107,7 +108,7 @@ public class ShieldNettyTransportTests extends ESTestCase { Settings settings = settingsBuilder() .put("shield.transport.ssl", true) .put(ShieldNettyTransport.TRANSPORT_CLIENT_AUTH_SETTING, value).build(); - ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter); + ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter, mock(NamedWriteableRegistry.class)); setOpenChannelsHandlerToMock(transport); ChannelPipelineFactory factory = transport.configureServerChannelPipelineFactory("client", Settings.EMPTY); assertThat(factory.getPipeline().get(SslHandler.class).getEngine().getNeedClientAuth(), is(false)); @@ -120,7 +121,7 @@ public class ShieldNettyTransportTests extends ESTestCase { Settings settings = settingsBuilder() .put("shield.transport.ssl", true) .put(ShieldNettyTransport.TRANSPORT_CLIENT_AUTH_SETTING, value).build(); - ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter); + ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter, mock(NamedWriteableRegistry.class)); setOpenChannelsHandlerToMock(transport); ChannelPipelineFactory factory = transport.configureServerChannelPipelineFactory("client", Settings.EMPTY); assertThat(factory.getPipeline().get(SslHandler.class).getEngine().getNeedClientAuth(), is(false)); @@ -131,7 +132,7 @@ public class ShieldNettyTransportTests extends ESTestCase { public void testProfileRequiredClientAuth() throws Exception { String value = randomFrom(SSLClientAuth.REQUIRED.name(), SSLClientAuth.REQUIRED.name().toLowerCase(Locale.ROOT), "true", "TRUE"); Settings settings = settingsBuilder().put("shield.transport.ssl", true).build(); - ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter); + ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter, mock(NamedWriteableRegistry.class)); setOpenChannelsHandlerToMock(transport); ChannelPipelineFactory factory = transport.configureServerChannelPipelineFactory("client", Settings.builder().put(ShieldNettyTransport.TRANSPORT_PROFILE_CLIENT_AUTH_SETTING, value).build()); assertThat(factory.getPipeline().get(SslHandler.class).getEngine().getNeedClientAuth(), is(true)); @@ -142,7 +143,7 @@ public class ShieldNettyTransportTests extends ESTestCase { public void testProfileNoClientAuth() throws Exception { String value = randomFrom(SSLClientAuth.NO.name(), "false", "FALSE", SSLClientAuth.NO.name().toLowerCase(Locale.ROOT)); Settings settings = settingsBuilder().put("shield.transport.ssl", true).build(); - ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter); + ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter, mock(NamedWriteableRegistry.class)); setOpenChannelsHandlerToMock(transport); ChannelPipelineFactory factory = transport.configureServerChannelPipelineFactory("client", Settings.builder().put(ShieldNettyTransport.TRANSPORT_PROFILE_CLIENT_AUTH_SETTING, value).build()); assertThat(factory.getPipeline().get(SslHandler.class).getEngine().getNeedClientAuth(), is(false)); @@ -153,7 +154,7 @@ public class ShieldNettyTransportTests extends ESTestCase { public void testProfileOptionalClientAuth() throws Exception { String value = randomFrom(SSLClientAuth.OPTIONAL.name(), SSLClientAuth.OPTIONAL.name().toLowerCase(Locale.ROOT)); Settings settings = settingsBuilder().put("shield.transport.ssl", true).build(); - ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter); + ShieldNettyTransport transport = new ShieldNettyTransport(settings, mock(ThreadPool.class), mock(NetworkService.class), mock(BigArrays.class), Version.CURRENT, null, serverSSLService, clientSSLService, settingsFilter, mock(NamedWriteableRegistry.class)); setOpenChannelsHandlerToMock(transport); ChannelPipelineFactory factory = transport.configureServerChannelPipelineFactory("client", Settings.builder().put(ShieldNettyTransport.TRANSPORT_PROFILE_CLIENT_AUTH_SETTING, value).build()); assertThat(factory.getPipeline().get(SslHandler.class).getEngine().getNeedClientAuth(), is(false));