Merge pull request #2047 from binlijin/master

optimize InputRowSerde
This commit is contained in:
Himanshu 2015-12-09 15:14:07 -06:00
commit f29c25b826
1 changed files with 99 additions and 77 deletions

View File

@ -19,13 +19,14 @@
package io.druid.indexer;
import com.google.common.base.Charsets;
import com.google.common.base.Supplier;
import com.google.common.base.Throwables;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
import com.google.common.io.ByteArrayDataOutput;
import com.google.common.io.ByteStreams;
import com.metamx.common.IAE;
import com.metamx.common.ISE;
import com.metamx.common.logger.Logger;
import io.druid.data.input.InputRow;
import io.druid.data.input.MapBasedInputRow;
@ -35,16 +36,11 @@ import io.druid.segment.incremental.IncrementalIndex;
import io.druid.segment.serde.ComplexMetricSerde;
import io.druid.segment.serde.ComplexMetrics;
import org.apache.hadoop.io.ArrayWritable;
import org.apache.hadoop.io.BytesWritable;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.MapWritable;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.io.WritableUtils;
import java.io.DataInput;
import java.io.IOException;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
@ -54,8 +50,6 @@ public class InputRowSerde
{
private static final Logger log = new Logger(InputRowSerde.class);
private static final Text[] EMPTY_TEXT_ARRAY = new Text[0];
public static final byte[] toBytes(final InputRow row, AggregatorFactory[] aggs)
{
try {
@ -67,35 +61,12 @@ public class InputRowSerde
//writing all dimensions
List<String> dimList = row.getDimensions();
Text[] dims = EMPTY_TEXT_ARRAY;
if(dimList != null) {
dims = new Text[dimList.size()];
for (int i = 0; i < dims.length; i++) {
dims[i] = new Text(dimList.get(i));
}
}
StringArrayWritable sw = new StringArrayWritable(dims);
sw.write(out);
MapWritable mw = new MapWritable();
if(dimList != null) {
WritableUtils.writeVInt(out, dimList.size());
if (dimList != null) {
for (String dim : dimList) {
List<String> dimValue = row.getDimension(dim);
if (dimValue == null || dimValue.size() == 0) {
continue;
}
if (dimValue.size() == 1) {
mw.put(new Text(dim), new Text(dimValue.get(0)));
} else {
Text[] dimValueArr = new Text[dimValue.size()];
for (int i = 0; i < dimValueArr.length; i++) {
dimValueArr[i] = new Text(dimValue.get(i));
}
mw.put(new Text(dim), new StringArrayWritable(dimValueArr));
}
List<String> dimValues = row.getDimension(dim);
writeString(dim, out);
writeStringArray(dimValues, out);
}
}
@ -108,8 +79,10 @@ public class InputRowSerde
return row;
}
};
WritableUtils.writeVInt(out, aggs.length);
for (AggregatorFactory aggFactory : aggs) {
String k = aggFactory.getName();
writeString(k, out);
Aggregator agg = aggFactory.factorize(
IncrementalIndex.makeColumnSelectorFactory(
@ -123,24 +96,73 @@ public class InputRowSerde
String t = aggFactory.getTypeName();
if (t.equals("float")) {
mw.put(new Text(k), new FloatWritable(agg.getFloat()));
out.writeFloat(agg.getFloat());
} else if (t.equals("long")) {
mw.put(new Text(k), new LongWritable(agg.getLong()));
WritableUtils.writeVLong(out, agg.getLong());
} else {
//its a complex metric
Object val = agg.get();
ComplexMetricSerde serde = getComplexMetricSerde(t);
mw.put(new Text(k), new BytesWritable(serde.toBytes(val)));
writeBytes(serde.toBytes(val), out);
}
}
mw.write(out);
return out.toByteArray();
} catch(IOException ex) {
throw Throwables.propagate(ex);
}
}
private static void writeBytes(byte[] value, ByteArrayDataOutput out) throws IOException
{
WritableUtils.writeVInt(out, value.length);
out.write(value, 0, value.length);
}
private static void writeString(String value, ByteArrayDataOutput out) throws IOException
{
writeBytes(value.getBytes(Charsets.UTF_8), out);
}
private static void writeStringArray(List<String> values, ByteArrayDataOutput out) throws IOException
{
if (values == null || values.size() == 0) {
WritableUtils.writeVInt(out, 0);
return;
}
WritableUtils.writeVInt(out, values.size());
for (String value : values) {
writeString(value, out);
}
}
private static String readString(DataInput in) throws IOException
{
byte[] result = readBytes(in);
return new String(result, Charsets.UTF_8);
}
private static byte[] readBytes(DataInput in) throws IOException
{
int size = WritableUtils.readVInt(in);
byte[] result = new byte[size];
in.readFully(result, 0, size);
return result;
}
private static List<String> readStringArray(DataInput in) throws IOException
{
int count = WritableUtils.readVInt(in);
if (count == 0) {
return null;
}
List<String> values = Lists.newArrayListWithCapacity(count);
for (int i = 0; i < count; i++) {
values.add(readString(in));
}
return values;
}
public static final InputRow fromBytes(byte[] data, AggregatorFactory[] aggs)
{
try {
@ -149,52 +171,38 @@ public class InputRowSerde
//Read timestamp
long timestamp = in.readLong();
//Read dimensions
StringArrayWritable sw = new StringArrayWritable();
sw.readFields(in);
List<String> dimensions = Arrays.asList(sw.toStrings());
MapWritable mw = new MapWritable();
mw.readFields(in);
Map<String, Object> event = Maps.newHashMap();
for (String d : dimensions) {
Writable v = mw.get(new Text(d));
if (v == null) {
//Read dimensions
List<String> dimensions = Lists.newArrayList();
int dimNum = WritableUtils.readVInt(in);
for (int i = 0; i < dimNum; i++) {
String dimension = readString(in);
dimensions.add(dimension);
List<String> dimensionValues = readStringArray(in);
if (dimensionValues == null) {
continue;
}
if (v instanceof Text) {
event.put(d, ((Text) v).toString());
} else if (v instanceof StringArrayWritable) {
event.put(d, Arrays.asList(((StringArrayWritable) v).toStrings()));
if (dimensionValues.size() == 1) {
event.put(dimension, dimensionValues.get(0));
} else {
throw new ISE("unknown dim value type %s", v.getClass().getName());
event.put(dimension, dimensionValues);
}
}
//Read metrics
for (AggregatorFactory aggFactory : aggs) {
String k = aggFactory.getName();
Writable v = mw.get(new Text(k));
if (v == null) {
continue;
}
String t = aggFactory.getTypeName();
if (t.equals("float")) {
event.put(k, ((FloatWritable) v).get());
} else if (t.equals("long")) {
event.put(k, ((LongWritable) v).get());
int metricSize = WritableUtils.readVInt(in);
for (int i = 0; i < metricSize; i++) {
String metric = readString(in);
String type = getType(metric, aggs, i);
if (type.equals("float")) {
event.put(metric, in.readFloat());
} else if (type.equals("long")) {
event.put(metric, WritableUtils.readVLong(in));
} else {
//its a complex metric
ComplexMetricSerde serde = getComplexMetricSerde(t);
BytesWritable bw = (BytesWritable) v;
event.put(k, serde.fromBytes(bw.getBytes(), 0, bw.getLength()));
ComplexMetricSerde serde = getComplexMetricSerde(type);
byte[] value = readBytes(in);
event.put(metric, serde.fromBytes(value, 0, value.length));
}
}
@ -204,6 +212,20 @@ public class InputRowSerde
}
}
private static String getType(String metric, AggregatorFactory[] aggs, int i)
{
if (aggs[i].getName().equals(metric)) {
return aggs[i].getTypeName();
}
log.warn("Aggs disordered, fall backs to loop.");
for (AggregatorFactory agg : aggs) {
if (agg.getName().equals(metric)) {
return agg.getTypeName();
}
}
return null;
}
private static ComplexMetricSerde getComplexMetricSerde(String type)
{
ComplexMetricSerde serde = ComplexMetrics.getSerdeForType(type);