Fix decimal type handling in ORC extension. (#4535)

This commit is contained in:
Gian Merlino 2017-07-12 12:16:48 -07:00 committed by GitHub
parent b2865b7c7b
commit 3399d1a488
2 changed files with 114 additions and 44 deletions

View File

@ -20,6 +20,7 @@ package io.druid.data.input.orc;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonProperty;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
@ -31,6 +32,7 @@ import io.druid.data.input.impl.ParseSpec;
import io.druid.data.input.impl.TimestampSpec;
import io.druid.java.util.common.StringUtils;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.common.type.HiveDecimal;
import org.apache.hadoop.hive.ql.io.orc.OrcSerde;
import org.apache.hadoop.hive.ql.io.orc.OrcStruct;
import org.apache.hadoop.hive.serde2.SerDeException;
@ -47,6 +49,7 @@ import org.joda.time.DateTime;
import javax.annotation.Nullable;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Properties;
public class OrcHadoopInputRowParser implements InputRowParser<OrcStruct>
@ -76,18 +79,24 @@ public class OrcHadoopInputRowParser implements InputRowParser<OrcStruct>
{
Map<String, Object> map = Maps.newHashMap();
List<? extends StructField> fields = oip.getAllStructFieldRefs();
for (StructField field: fields) {
for (StructField field : fields) {
ObjectInspector objectInspector = field.getFieldObjectInspector();
switch(objectInspector.getCategory()) {
switch (objectInspector.getCategory()) {
case PRIMITIVE:
PrimitiveObjectInspector primitiveObjectInspector = (PrimitiveObjectInspector)objectInspector;
map.put(field.getFieldName(),
primitiveObjectInspector.getPrimitiveJavaObject(oip.getStructFieldData(input, field)));
PrimitiveObjectInspector primitiveObjectInspector = (PrimitiveObjectInspector) objectInspector;
map.put(
field.getFieldName(),
coercePrimitiveObject(
primitiveObjectInspector.getPrimitiveJavaObject(oip.getStructFieldData(input, field))
)
);
break;
case LIST: // array case - only 1-depth array supported yet
ListObjectInspector listObjectInspector = (ListObjectInspector)objectInspector;
map.put(field.getFieldName(),
getListObject(listObjectInspector, oip.getStructFieldData(input, field)));
ListObjectInspector listObjectInspector = (ListObjectInspector) objectInspector;
map.put(
field.getFieldName(),
getListObject(listObjectInspector, oip.getStructFieldData(input, field))
);
break;
default:
break;
@ -106,13 +115,16 @@ public class OrcHadoopInputRowParser implements InputRowParser<OrcStruct>
typeString = typeStringFromParseSpec(parseSpec);
}
TypeInfo typeInfo = TypeInfoUtils.getTypeInfoFromTypeString(typeString);
Preconditions.checkArgument(typeInfo instanceof StructTypeInfo,
StringUtils.format("typeString should be struct type but not [%s]", typeString));
Properties table = getTablePropertiesFromStructTypeInfo((StructTypeInfo)typeInfo);
Preconditions.checkArgument(
typeInfo instanceof StructTypeInfo,
StringUtils.format("typeString should be struct type but not [%s]", typeString)
);
Properties table = getTablePropertiesFromStructTypeInfo((StructTypeInfo) typeInfo);
serde.initialize(new Configuration(), table);
try {
oip = (StructObjectInspector) serde.getObjectInspector();
} catch (SerDeException e) {
}
catch (SerDeException e) {
throw new RuntimeException(e);
}
}
@ -122,14 +134,16 @@ public class OrcHadoopInputRowParser implements InputRowParser<OrcStruct>
List objectList = listObjectInspector.getList(listObject);
List list = null;
ObjectInspector child = listObjectInspector.getListElementObjectInspector();
switch(child.getCategory()) {
switch (child.getCategory()) {
case PRIMITIVE:
final PrimitiveObjectInspector primitiveObjectInspector = (PrimitiveObjectInspector)child;
list = Lists.transform(objectList, new Function() {
final PrimitiveObjectInspector primitiveObjectInspector = (PrimitiveObjectInspector) child;
list = Lists.transform(objectList, new Function()
{
@Nullable
@Override
public Object apply(@Nullable Object input) {
return primitiveObjectInspector.getPrimitiveJavaObject(input);
public Object apply(@Nullable Object input)
{
return coercePrimitiveObject(primitiveObjectInspector.getPrimitiveJavaObject(input));
}
});
break;
@ -159,12 +173,37 @@ public class OrcHadoopInputRowParser implements InputRowParser<OrcStruct>
return new OrcHadoopInputRowParser(parseSpec, typeString);
}
public InputRowParser withTypeString(String typeString)
@Override
public boolean equals(final Object o)
{
return new OrcHadoopInputRowParser(parseSpec, typeString);
if (this == o) {
return true;
}
if (o == null || getClass() != o.getClass()) {
return false;
}
final OrcHadoopInputRowParser that = (OrcHadoopInputRowParser) o;
return Objects.equals(parseSpec, that.parseSpec) &&
Objects.equals(typeString, that.typeString);
}
public static String typeStringFromParseSpec(ParseSpec parseSpec)
@Override
public int hashCode()
{
return Objects.hash(parseSpec, typeString);
}
@Override
public String toString()
{
return "OrcHadoopInputRowParser{" +
"parseSpec=" + parseSpec +
", typeString='" + typeString + '\'' +
'}';
}
@VisibleForTesting
static String typeStringFromParseSpec(ParseSpec parseSpec)
{
StringBuilder builder = new StringBuilder("struct<");
builder.append(parseSpec.getTimestampSpec().getTimestampColumn()).append(":string");
@ -178,7 +217,16 @@ public class OrcHadoopInputRowParser implements InputRowParser<OrcStruct>
return builder.toString();
}
public static Properties getTablePropertiesFromStructTypeInfo(StructTypeInfo structTypeInfo)
private static Object coercePrimitiveObject(final Object object)
{
if (object instanceof HiveDecimal) {
return ((HiveDecimal) object).doubleValue();
} else {
return object;
}
}
private static Properties getTablePropertiesFromStructTypeInfo(StructTypeInfo structTypeInfo)
{
Properties table = new Properties();
table.setProperty("columns", String.join(",", structTypeInfo.getAllStructFieldNames()));
@ -186,10 +234,12 @@ public class OrcHadoopInputRowParser implements InputRowParser<OrcStruct>
",",
Lists.transform(
structTypeInfo.getAllStructFieldTypeInfos(),
new Function<TypeInfo, String>() {
new Function<TypeInfo, String>()
{
@Nullable
@Override
public String apply(@Nullable TypeInfo typeInfo) {
public String apply(@Nullable TypeInfo typeInfo)
{
return typeInfo.getTypeName();
}
}
@ -198,24 +248,4 @@ public class OrcHadoopInputRowParser implements InputRowParser<OrcStruct>
return table;
}
@Override
public boolean equals(Object o)
{
if (!(o instanceof OrcHadoopInputRowParser)) {
return false;
}
OrcHadoopInputRowParser other = (OrcHadoopInputRowParser)o;
if (!parseSpec.equals(other.parseSpec)) {
return false;
}
if (!typeString.equals(other.typeString)) {
return false;
}
return true;
}
}

View File

@ -25,6 +25,7 @@ import com.google.inject.Binder;
import com.google.inject.Injector;
import com.google.inject.Module;
import com.google.inject.name.Names;
import io.druid.data.input.InputRow;
import io.druid.data.input.impl.DimensionSchema;
import io.druid.data.input.impl.DimensionsSpec;
import io.druid.data.input.impl.InputRowParser;
@ -35,6 +36,14 @@ import io.druid.data.input.impl.TimestampSpec;
import io.druid.guice.GuiceInjectors;
import io.druid.initialization.Initialization;
import io.druid.jackson.DefaultObjectMapper;
import org.apache.hadoop.hive.ql.io.orc.OrcStruct;
import org.apache.hadoop.hive.serde2.io.HiveDecimalWritable;
import org.apache.hadoop.hive.serde2.objectinspector.SettableStructObjectInspector;
import org.apache.hadoop.hive.serde2.typeinfo.TypeInfoUtils;
import org.apache.hadoop.io.FloatWritable;
import org.apache.hadoop.io.LongWritable;
import org.apache.hadoop.io.Text;
import org.joda.time.DateTime;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Test;
@ -132,5 +141,36 @@ public class OrcHadoopInputRowParserTest
Assert.assertEquals(expected, typeString);
}
@Test
public void testParse()
{
final String typeString = "struct<timestamp:string,col1:string,col2:array<string>,col3:float,col4:bigint,col5:decimal>";
final OrcHadoopInputRowParser parser = new OrcHadoopInputRowParser(
new TimeAndDimsParseSpec(
new TimestampSpec("timestamp", "auto", null),
new DimensionsSpec(null, null, null)
),
typeString
);
final SettableStructObjectInspector oi = (SettableStructObjectInspector) OrcStruct.createObjectInspector(
TypeInfoUtils.getTypeInfoFromTypeString(typeString)
);
final OrcStruct struct = (OrcStruct) oi.create();
struct.setNumFields(6);
oi.setStructFieldData(struct, oi.getStructFieldRef("timestamp"), new Text("2000-01-01"));
oi.setStructFieldData(struct, oi.getStructFieldRef("col1"), new Text("foo"));
oi.setStructFieldData(struct, oi.getStructFieldRef("col2"), ImmutableList.of(new Text("foo"), new Text("bar")));
oi.setStructFieldData(struct, oi.getStructFieldRef("col3"), new FloatWritable(1));
oi.setStructFieldData(struct, oi.getStructFieldRef("col4"), new LongWritable(2));
oi.setStructFieldData(struct, oi.getStructFieldRef("col5"), new HiveDecimalWritable(3));
final InputRow row = parser.parse(struct);
Assert.assertEquals("timestamp", new DateTime("2000-01-01"), row.getTimestamp());
Assert.assertEquals("col1", "foo", row.getRaw("col1"));
Assert.assertEquals("col2", ImmutableList.of("foo", "bar"), row.getRaw("col2"));
Assert.assertEquals("col3", 1.0f, row.getRaw("col3"));
Assert.assertEquals("col4", 2L, row.getRaw("col4"));
Assert.assertEquals("col5", 3.0d, row.getRaw("col5"));
}
}