From 1ceb25cf09f5057bb9ec23eef90373c8febbc6e2 Mon Sep 17 00:00:00 2001 From: zhangduo Date: Fri, 19 May 2017 22:12:00 +0800 Subject: [PATCH] HBASE-18081 The way we process connection preamble in SimpleRpcServer is broken --- .../hbase/ipc/SimpleServerRpcConnection.java | 54 ++++--- .../ipc/TestRpcServerSlowConnectionSetup.java | 136 ++++++++++++++++++ 2 files changed, 161 insertions(+), 29 deletions(-) create mode 100644 hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcServerSlowConnectionSetup.java diff --git a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleServerRpcConnection.java b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleServerRpcConnection.java index 50a1a6be5dc..b2507d8faed 100644 --- a/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleServerRpcConnection.java +++ b/hbase-server/src/main/java/org/apache/hadoop/hbase/ipc/SimpleServerRpcConnection.java @@ -63,6 +63,7 @@ class SimpleServerRpcConnection extends ServerRpcConnection { final SocketChannel channel; private ByteBuff data; private ByteBuffer dataLengthBuffer; + private ByteBuffer preambleBuffer; protected final ConcurrentLinkedDeque responseQueue = new ConcurrentLinkedDeque<>(); final Lock responseWriteLock = new ReentrantLock(); @@ -130,22 +131,25 @@ class SimpleServerRpcConnection extends ServerRpcConnection { } private int readPreamble() throws IOException { - int count; - // Check for 'HBas' magic. - this.dataLengthBuffer.flip(); - if (!Arrays.equals(HConstants.RPC_HEADER, dataLengthBuffer.array())) { - return doBadPreambleHandling( - "Expected HEADER=" + Bytes.toStringBinary(HConstants.RPC_HEADER) + " but received HEADER=" + - Bytes.toStringBinary(dataLengthBuffer.array()) + " from " + toString()); + if (preambleBuffer == null) { + preambleBuffer = ByteBuffer.allocate(6); } - // Now read the next two bytes, the version and the auth to use. - ByteBuffer versionAndAuthBytes = ByteBuffer.allocate(2); - count = this.rpcServer.channelRead(channel, versionAndAuthBytes); - if (count < 0 || versionAndAuthBytes.remaining() > 0) { + int count = this.rpcServer.channelRead(channel, preambleBuffer); + if (count < 0 || preambleBuffer.remaining() > 0) { return count; } - int version = versionAndAuthBytes.get(0); - byte authbyte = versionAndAuthBytes.get(1); + // Check for 'HBas' magic. + preambleBuffer.flip(); + for (int i = 0; i < HConstants.RPC_HEADER.length; i++) { + if (HConstants.RPC_HEADER[i] != preambleBuffer.get(i)) { + return doBadPreambleHandling("Expected HEADER=" + + Bytes.toStringBinary(HConstants.RPC_HEADER) + " but received HEADER=" + + Bytes.toStringBinary(preambleBuffer.array(), 0, HConstants.RPC_HEADER.length) + + " from " + toString()); + } + } + int version = preambleBuffer.get(HConstants.RPC_HEADER.length); + byte authbyte = preambleBuffer.get(HConstants.RPC_HEADER.length + 1); this.authMethod = AuthMethod.valueOf(authbyte); if (version != SimpleRpcServer.CURRENT_VERSION) { String msg = getFatalConnectionString(version, authbyte); @@ -178,8 +182,7 @@ class SimpleServerRpcConnection extends ServerRpcConnection { if (authMethod != AuthMethod.SIMPLE) { useSasl = true; } - - dataLengthBuffer.clear(); + preambleBuffer = null; // do not need it anymore connectionPreambleRead = true; return count; } @@ -200,26 +203,19 @@ class SimpleServerRpcConnection extends ServerRpcConnection { * @throws InterruptedException */ public int readAndProcess() throws IOException, InterruptedException { - // Try and read in an int. If new connection, the int will hold the 'HBas' HEADER. If it - // does, read in the rest of the connection preamble, the version and the auth method. - // Else it will be length of the data to read (or -1 if a ping). We catch the integer - // length into the 4-byte this.dataLengthBuffer. - int count = read4Bytes(); - if (count < 0 || dataLengthBuffer.remaining() > 0) { - return count; - } - // If we have not read the connection setup preamble, look to see if that is on the wire. if (!connectionPreambleRead) { - count = readPreamble(); + int count = readPreamble(); if (!connectionPreambleRead) { return count; } + } - count = read4Bytes(); - if (count < 0 || dataLengthBuffer.remaining() > 0) { - return count; - } + // Try and read in an int. it will be length of the data to read (or -1 if a ping). We catch the + // integer length into the 4-byte this.dataLengthBuffer. + int count = read4Bytes(); + if (count < 0 || dataLengthBuffer.remaining() > 0) { + return count; } // We have read a length and we have read the preamble. It is either the connection header diff --git a/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcServerSlowConnectionSetup.java b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcServerSlowConnectionSetup.java new file mode 100644 index 00000000000..fba5ca7937e --- /dev/null +++ b/hbase-server/src/test/java/org/apache/hadoop/hbase/ipc/TestRpcServerSlowConnectionSetup.java @@ -0,0 +1,136 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.ipc; + +import static org.apache.hadoop.hbase.ipc.TestProtobufRpcServiceImpl.SERVICE; +import static org.junit.Assert.assertEquals; + +import java.io.BufferedInputStream; +import java.io.DataInputStream; +import java.io.DataOutputStream; +import java.io.IOException; +import java.net.InetSocketAddress; +import java.net.Socket; +import java.util.Arrays; +import java.util.List; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.hbase.HBaseConfiguration; +import org.apache.hadoop.hbase.HConstants; +import org.apache.hadoop.hbase.client.MetricsConnection; +import org.apache.hadoop.hbase.ipc.RpcServer.BlockingServiceAndInterface; +import org.apache.hadoop.hbase.security.AuthMethod; +import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestProtos.EmptyRequestProto; +import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestProtos.EmptyResponseProto; +import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestRpcServiceProtos; +import org.apache.hadoop.hbase.shaded.ipc.protobuf.generated.TestRpcServiceProtos.TestProtobufRpcProto; +import org.apache.hadoop.hbase.shaded.protobuf.ProtobufUtil; +import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.ConnectionHeader; +import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.RequestHeader; +import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.ResponseHeader; +import org.apache.hadoop.hbase.testclassification.MediumTests; +import org.apache.hadoop.hbase.testclassification.RPCTests; +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.experimental.categories.Category; +import org.junit.runner.RunWith; +import org.junit.runners.Parameterized; +import org.junit.runners.Parameterized.Parameter; +import org.junit.runners.Parameterized.Parameters; + +import com.google.common.collect.Lists; + +@RunWith(Parameterized.class) +@Category({ RPCTests.class, MediumTests.class }) +public class TestRpcServerSlowConnectionSetup { + + private RpcServer server; + + private Socket socket; + + @Parameter + public Class rpcServerImpl; + + @Parameters(name = "{index}: rpcServerImpl={0}") + public static List params() { + return Arrays.asList(new Object[] { SimpleRpcServer.class }, + new Object[] { NettyRpcServer.class }); + } + + @Before + public void setUp() throws IOException { + Configuration conf = HBaseConfiguration.create(); + conf.set(RpcServerFactory.CUSTOM_RPC_SERVER_IMPL_CONF_KEY, rpcServerImpl.getName()); + server = RpcServerFactory.createRpcServer(null, "testRpcServer", + Lists.newArrayList(new BlockingServiceAndInterface(SERVICE, null)), + new InetSocketAddress("localhost", 0), conf, new FifoRpcScheduler(conf, 1)); + server.start(); + socket = new Socket("localhost", server.getListenerAddress().getPort()); + } + + @After + public void tearDown() throws IOException { + if (socket != null) { + socket.close(); + } + if (server != null) { + server.stop(); + } + } + + @Test + public void test() throws IOException, InterruptedException { + int rpcHeaderLen = HConstants.RPC_HEADER.length; + byte[] preamble = new byte[rpcHeaderLen + 2]; + System.arraycopy(HConstants.RPC_HEADER, 0, preamble, 0, rpcHeaderLen); + preamble[rpcHeaderLen] = HConstants.RPC_CURRENT_VERSION; + preamble[rpcHeaderLen + 1] = AuthMethod.SIMPLE.code; + socket.getOutputStream().write(preamble, 0, rpcHeaderLen + 1); + socket.getOutputStream().flush(); + Thread.sleep(5000); + socket.getOutputStream().write(preamble, rpcHeaderLen + 1, 1); + socket.getOutputStream().flush(); + + ConnectionHeader header = ConnectionHeader.newBuilder() + .setServiceName(TestRpcServiceProtos.TestProtobufRpcProto.getDescriptor().getFullName()) + .setVersionInfo(ProtobufUtil.getVersionInfo()).build(); + DataOutputStream dos = new DataOutputStream(socket.getOutputStream()); + dos.writeInt(header.getSerializedSize()); + header.writeTo(dos); + dos.flush(); + + int callId = 10; + Call call = new Call(callId, TestProtobufRpcProto.getDescriptor().findMethodByName("ping"), + EmptyRequestProto.getDefaultInstance(), null, EmptyResponseProto.getDefaultInstance(), 1000, + HConstants.NORMAL_QOS, null, MetricsConnection.newCallStats()); + RequestHeader requestHeader = IPCUtil.buildRequestHeader(call, null); + dos.writeInt(IPCUtil.getTotalSizeWhenWrittenDelimited(requestHeader, call.param)); + requestHeader.writeDelimitedTo(dos); + call.param.writeDelimitedTo(dos); + dos.flush(); + + DataInputStream dis = new DataInputStream(new BufferedInputStream(socket.getInputStream())); + int size = dis.readInt(); + ResponseHeader responseHeader = ResponseHeader.parseDelimitedFrom(dis); + assertEquals(callId, responseHeader.getCallId()); + EmptyResponseProto.Builder builder = EmptyResponseProto.newBuilder(); + builder.mergeDelimitedFrom(dis); + assertEquals(size, IPCUtil.getTotalSizeWhenWrittenDelimited(responseHeader, builder.build())); + } +}