diff --git a/src/main/java/org/apache/hadoop/hbase/io/DataOutputOutputStream.java b/src/main/java/org/apache/hadoop/hbase/io/DataOutputOutputStream.java new file mode 100644 index 00000000000..10700c1d1e7 --- /dev/null +++ b/src/main/java/org/apache/hadoop/hbase/io/DataOutputOutputStream.java @@ -0,0 +1,69 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ +package org.apache.hadoop.hbase.io; + +import java.io.DataOutput; +import java.io.IOException; +import java.io.OutputStream; + +import org.apache.hadoop.classification.InterfaceAudience; +import org.apache.hadoop.classification.InterfaceStability; + +/** + * OutputStream implementation that wraps a DataOutput. + */ +@InterfaceAudience.Private +@InterfaceStability.Unstable +class DataOutputOutputStream extends OutputStream { + + private final DataOutput out; + + /** + * Construct an OutputStream from the given DataOutput. If 'out' + * is already an OutputStream, simply returns it. Otherwise, wraps + * it in an OutputStream. + * @param out the DataOutput to wrap + * @return an OutputStream instance that outputs to 'out' + */ + public static OutputStream constructOutputStream(DataOutput out) { + if (out instanceof OutputStream) { + return (OutputStream)out; + } else { + return new DataOutputOutputStream(out); + } + } + + private DataOutputOutputStream(DataOutput out) { + this.out = out; + } + + @Override + public void write(int b) throws IOException { + out.writeByte(b); + } + + @Override + public void write(byte[] b, int off, int len) throws IOException { + out.write(b, off, len); + } + + @Override + public void write(byte[] b) throws IOException { + out.write(b); + } +} diff --git a/src/main/java/org/apache/hadoop/hbase/io/HbaseObjectWritable.java b/src/main/java/org/apache/hadoop/hbase/io/HbaseObjectWritable.java index 3c7119225ad..5c206e507e4 100644 --- a/src/main/java/org/apache/hadoop/hbase/io/HbaseObjectWritable.java +++ b/src/main/java/org/apache/hadoop/hbase/io/HbaseObjectWritable.java @@ -22,11 +22,14 @@ import java.io.ByteArrayInputStream; import java.io.ByteArrayOutputStream; import java.io.DataInput; import java.io.DataOutput; +import java.io.InputStream; import java.io.IOException; import java.io.ObjectInputStream; import java.io.ObjectOutputStream; import java.io.Serializable; import java.lang.reflect.Array; +import java.lang.reflect.InvocationTargetException; +import java.lang.reflect.Method; import java.util.ArrayList; import java.util.HashMap; import java.util.List; @@ -87,6 +90,7 @@ import org.apache.hadoop.hbase.regionserver.RegionOpeningState; import org.apache.hadoop.hbase.regionserver.wal.HLog; import org.apache.hadoop.hbase.regionserver.wal.HLogKey; import org.apache.hadoop.hbase.util.Bytes; +import org.apache.hadoop.hbase.util.ProtoUtil; import org.apache.hadoop.io.MapWritable; import org.apache.hadoop.io.ObjectWritable; import org.apache.hadoop.io.Text; @@ -94,6 +98,8 @@ import org.apache.hadoop.io.Writable; import org.apache.hadoop.io.WritableFactories; import org.apache.hadoop.io.WritableUtils; +import com.google.protobuf.Message; + /** * This is a customized version of the polymorphic hadoop * {@link ObjectWritable}. It removes UTF8 (HADOOP-414). @@ -253,6 +259,8 @@ public class HbaseObjectWritable implements Writable, WritableWithSize, Configur addToMap(RowMutation.class, code++); + addToMap(Message.class, code++); + //java.lang.reflect.Array is a placeholder for arrays not defined above GENERIC_ARRAY_CODE = code++; addToMap(Array.class, GENERIC_ARRAY_CODE); @@ -353,6 +361,8 @@ public class HbaseObjectWritable implements Writable, WritableWithSize, Configur code = CLASS_TO_CODE.get(Writable.class); } else if (c.isArray()) { code = CLASS_TO_CODE.get(Array.class); + } else if (Message.class.isAssignableFrom(c)) { + code = CLASS_TO_CODE.get(Message.class); } else if (Serializable.class.isAssignableFrom(c)){ code = CLASS_TO_CODE.get(Serializable.class); } @@ -479,6 +489,10 @@ public class HbaseObjectWritable implements Writable, WritableWithSize, Configur } } else if (declClass.isEnum()) { // enum Text.writeString(out, ((Enum)instanceObj).name()); + } else if (Message.class.isAssignableFrom(declaredClass)) { + Text.writeString(out, instanceObj.getClass().getName()); + ((Message)instance).writeDelimitedTo( + DataOutputOutputStream.constructOutputStream(out)); } else if (Writable.class.isAssignableFrom(declClass)) { // Writable Class c = instanceObj.getClass(); Integer code = CLASS_TO_CODE.get(c); @@ -627,6 +641,15 @@ public class HbaseObjectWritable implements Writable, WritableWithSize, Configur } else if (declaredClass.isEnum()) { // enum instance = Enum.valueOf((Class) declaredClass, Text.readString(in)); + } else if (declaredClass == Message.class) { + String className = Text.readString(in); + try { + declaredClass = getClassByName(conf, className); + instance = tryInstantiateProtobuf(declaredClass, in); + } catch (ClassNotFoundException e) { + LOG.error("Can't find class " + className, e); + throw new IOException("Can't find class " + className, e); + } } else { // Writable or Serializable Class instanceClass = null; int b = (byte)WritableUtils.readVInt(in); @@ -681,6 +704,67 @@ public class HbaseObjectWritable implements Writable, WritableWithSize, Configur return instance; } + /** + * Try to instantiate a protocol buffer of the given message class + * from the given input stream. + * + * @param protoClass the class of the generated protocol buffer + * @param dataIn the input stream to read from + * @return the instantiated Message instance + * @throws IOException if an IO problem occurs + */ + private static Message tryInstantiateProtobuf( + Class protoClass, + DataInput dataIn) throws IOException { + + try { + if (dataIn instanceof InputStream) { + // We can use the built-in parseDelimitedFrom and not have to re-copy + // the data + Method parseMethod = getStaticProtobufMethod(protoClass, + "parseDelimitedFrom", InputStream.class); + return (Message)parseMethod.invoke(null, (InputStream)dataIn); + } else { + // Have to read it into a buffer first, since protobuf doesn't deal + // with the DataInput interface directly. + + // Read the size delimiter that writeDelimitedTo writes + int size = ProtoUtil.readRawVarint32(dataIn); + if (size < 0) { + throw new IOException("Invalid size: " + size); + } + + byte[] data = new byte[size]; + dataIn.readFully(data); + Method parseMethod = getStaticProtobufMethod(protoClass, + "parseFrom", byte[].class); + return (Message)parseMethod.invoke(null, data); + } + } catch (InvocationTargetException e) { + + if (e.getCause() instanceof IOException) { + throw (IOException)e.getCause(); + } else { + throw new IOException(e.getCause()); + } + } catch (IllegalAccessException iae) { + throw new AssertionError("Could not access parse method in " + + protoClass); + } + } + + static Method getStaticProtobufMethod(Class declaredClass, String method, + Class ... args) { + + try { + return declaredClass.getMethod(method, args); + } catch (Exception e) { + // This is a bug in Hadoop - protobufs should all have this static method + throw new AssertionError("Protocol buffer class " + declaredClass + + " does not have an accessible parseFrom(InputStream) method!"); + } + } + @SuppressWarnings("unchecked") private static Class getClassByName(Configuration conf, String className) throws ClassNotFoundException { diff --git a/src/main/java/org/apache/hadoop/hbase/util/ProtoUtil.java b/src/main/java/org/apache/hadoop/hbase/util/ProtoUtil.java new file mode 100644 index 00000000000..92129676256 --- /dev/null +++ b/src/main/java/org/apache/hadoop/hbase/util/ProtoUtil.java @@ -0,0 +1,66 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.util; + +import java.io.DataInput; +import java.io.IOException; + +public abstract class ProtoUtil { + + /** + * Read a variable length integer in the same format that ProtoBufs encodes. + * @param in the input stream to read from + * @return the integer + * @throws IOException if it is malformed or EOF. + */ + public static int readRawVarint32(DataInput in) throws IOException { + byte tmp = in.readByte(); + if (tmp >= 0) { + return tmp; + } + int result = tmp & 0x7f; + if ((tmp = in.readByte()) >= 0) { + result |= tmp << 7; + } else { + result |= (tmp & 0x7f) << 7; + if ((tmp = in.readByte()) >= 0) { + result |= tmp << 14; + } else { + result |= (tmp & 0x7f) << 14; + if ((tmp = in.readByte()) >= 0) { + result |= tmp << 21; + } else { + result |= (tmp & 0x7f) << 21; + result |= (tmp = in.readByte()) << 28; + if (tmp < 0) { + // Discard upper 32 bits. + for (int i = 0; i < 5; i++) { + if (in.readByte() >= 0) { + return result; + } + } + throw new IOException("Malformed varint"); + } + } + } + } + return result; + } + +} diff --git a/src/test/java/org/apache/hadoop/hbase/ipc/TestPBOnWritableRpc.java b/src/test/java/org/apache/hadoop/hbase/ipc/TestPBOnWritableRpc.java new file mode 100644 index 00000000000..d5a906850d2 --- /dev/null +++ b/src/test/java/org/apache/hadoop/hbase/ipc/TestPBOnWritableRpc.java @@ -0,0 +1,127 @@ +/** + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.hadoop.hbase.ipc; + +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNotSame; + +import java.io.IOException; +import java.net.InetSocketAddress; + +import org.apache.hadoop.conf.Configuration; +import org.apache.hadoop.io.Text; +import org.apache.hadoop.io.Writable; +import org.junit.Test; + +import com.google.protobuf.DescriptorProtos; +import com.google.protobuf.DescriptorProtos.EnumDescriptorProto; + +/** Unit tests to test PB-based types on WritableRpcEngine. */ +public class TestPBOnWritableRpc { + + private static Configuration conf = new Configuration(); + + public interface TestProtocol extends VersionedProtocol { + public static final long VERSION = 1L; + + String echo(String value) throws IOException; + Writable echo(Writable value) throws IOException; + + DescriptorProtos.EnumDescriptorProto exchangeProto( + DescriptorProtos.EnumDescriptorProto arg); + } + + public static class TestImpl implements TestProtocol { + public long getProtocolVersion(String protocol, long clientVersion) { + return TestProtocol.VERSION; + } + + public ProtocolSignature getProtocolSignature(String protocol, long clientVersion, + int hashcode) { + return new ProtocolSignature(TestProtocol.VERSION, null); + } + + @Override + public String echo(String value) throws IOException { return value; } + + @Override + public Writable echo(Writable writable) { + return writable; + } + + @Override + public EnumDescriptorProto exchangeProto(EnumDescriptorProto arg) { + return arg; + } + } + + @Test(timeout=10000) + public void testCalls() throws Exception { + testCallsInternal(conf); + } + + private void testCallsInternal(Configuration conf) throws Exception { + RpcServer rpcServer = HBaseRPC.getServer(new TestImpl(), + new Class[] {TestProtocol.class}, + "localhost", // BindAddress is IP we got for this server. + 9999, // port number + 2, // number of handlers + 0, // we dont use high priority handlers in master + conf.getBoolean("hbase.rpc.verbose", false), conf, + 0); + TestProtocol proxy = null; + try { + rpcServer.start(); + + InetSocketAddress isa = + new InetSocketAddress("localhost", 9999); + proxy = (TestProtocol) HBaseRPC.waitForProxy( + TestProtocol.class, TestProtocol.VERSION, + isa, conf, -1, 8000, 8000); + + String stringResult = proxy.echo("foo"); + assertEquals(stringResult, "foo"); + + stringResult = proxy.echo((String)null); + assertEquals(stringResult, null); + + Text utf8Result = (Text)proxy.echo(new Text("hello world")); + assertEquals(utf8Result, new Text("hello world")); + + utf8Result = (Text)proxy.echo((Text)null); + assertEquals(utf8Result, null); + + // Test protobufs + EnumDescriptorProto sendProto = + EnumDescriptorProto.newBuilder().setName("test").build(); + EnumDescriptorProto retProto = proxy.exchangeProto(sendProto); + assertEquals(sendProto, retProto); + assertNotSame(sendProto, retProto); + } finally { + rpcServer.stop(); + if(proxy != null) { + HBaseRPC.stopProxy(proxy); + } + } + } + + public static void main(String[] args) throws Exception { + new TestPBOnWritableRpc().testCallsInternal(conf); + } +}