HBASE-14801 Enhance the Spark-HBase connector catalog with json format (Zhan Zhang)

This commit is contained in:
Jonathan M Hsieh 2016-03-09 10:41:56 -08:00
parent ad9b91a904
commit 97cce850fe
9 changed files with 751 additions and 334 deletions

View File

@ -27,8 +27,10 @@ import org.apache.hadoop.hbase.filter.FilterBase;
import org.apache.hadoop.hbase.spark.protobuf.generated.FilterProtos;
import org.apache.hadoop.hbase.util.ByteStringer;
import org.apache.hadoop.hbase.util.Bytes;
import org.apache.spark.sql.datasources.hbase.Field;
import scala.collection.mutable.MutableList;
import java.io.IOException;
import java.util.HashMap;
import java.util.List;
@ -66,7 +68,7 @@ public class SparkSQLPushDownFilter extends FilterBase{
public SparkSQLPushDownFilter(DynamicLogicExpression dynamicLogicExpression,
byte[][] valueFromQueryArray,
MutableList<SchemaQualifierDefinition> columnDefinitions) {
MutableList<Field> fields) {
this.dynamicLogicExpression = dynamicLogicExpression;
this.valueFromQueryArray = valueFromQueryArray;
@ -74,12 +76,12 @@ public class SparkSQLPushDownFilter extends FilterBase{
this.currentCellToColumnIndexMap =
new HashMap<>();
for (int i = 0; i < columnDefinitions.size(); i++) {
SchemaQualifierDefinition definition = columnDefinitions.get(i).get();
for (int i = 0; i < fields.size(); i++) {
Field field = fields.apply(i);
byte[] cfBytes = field.cfBytes();
ByteArrayComparable familyByteComparable =
new ByteArrayComparable(definition.columnFamilyBytes(),
0, definition.columnFamilyBytes().length);
new ByteArrayComparable(cfBytes, 0, cfBytes.length);
HashMap<ByteArrayComparable, String> qualifierIndexMap =
currentCellToColumnIndexMap.get(familyByteComparable);
@ -88,11 +90,11 @@ public class SparkSQLPushDownFilter extends FilterBase{
qualifierIndexMap = new HashMap<>();
currentCellToColumnIndexMap.put(familyByteComparable, qualifierIndexMap);
}
byte[] qBytes = field.colBytes();
ByteArrayComparable qualifierByteComparable =
new ByteArrayComparable(definition.qualifierBytes(), 0,
definition.qualifierBytes().length);
new ByteArrayComparable(qBytes, 0, qBytes.length);
qualifierIndexMap.put(qualifierByteComparable, definition.columnName());
qualifierIndexMap.put(qualifierByteComparable, field.colName());
}
}

View File

@ -29,7 +29,8 @@ import org.apache.hadoop.hbase.util.{Bytes, PositionedByteRange, SimplePositione
import org.apache.hadoop.hbase.{HBaseConfiguration, TableName}
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.DataType
import org.apache.spark.sql.datasources.hbase.{Field, HBaseTableCatalog}
import org.apache.spark.sql.types.{DataType => SparkDataType}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@ -48,13 +49,6 @@ import scala.collection.mutable
* Through the HBase Bytes object commands.
*/
class DefaultSource extends RelationProvider with Logging {
val TABLE_KEY:String = "hbase.table"
val SCHEMA_COLUMNS_MAPPING_KEY:String = "hbase.columns.mapping"
val HBASE_CONFIG_RESOURCES_LOCATIONS:String = "hbase.config.resources"
val USE_HBASE_CONTEXT:String = "hbase.use.hbase.context"
val PUSH_DOWN_COLUMN_FILTER:String = "hbase.push.down.column.filter"
/**
* Is given input from SparkSQL to construct a BaseRelation
* @param sqlContext SparkSQL context
@ -64,87 +58,26 @@ class DefaultSource extends RelationProvider with Logging {
override def createRelation(sqlContext: SQLContext,
parameters: Map[String, String]):
BaseRelation = {
val tableName = parameters.get(TABLE_KEY)
if (tableName.isEmpty)
new IllegalArgumentException("Invalid value for " + TABLE_KEY +" '" + tableName + "'")
val schemaMappingString = parameters.getOrElse(SCHEMA_COLUMNS_MAPPING_KEY, "")
val hbaseConfigResources = parameters.getOrElse(HBASE_CONFIG_RESOURCES_LOCATIONS, "")
val useHBaseReources = parameters.getOrElse(USE_HBASE_CONTEXT, "true")
val usePushDownColumnFilter = parameters.getOrElse(PUSH_DOWN_COLUMN_FILTER, "true")
new HBaseRelation(tableName.get,
generateSchemaMappingMap(schemaMappingString),
hbaseConfigResources,
useHBaseReources.equalsIgnoreCase("true"),
usePushDownColumnFilter.equalsIgnoreCase("true"),
parameters)(sqlContext)
}
/**
* Reads the SCHEMA_COLUMNS_MAPPING_KEY and converts it to a map of
* SchemaQualifierDefinitions with the original sql column name as the key
* @param schemaMappingString The schema mapping string from the SparkSQL map
* @return A map of definitions keyed by the SparkSQL column name
*/
def generateSchemaMappingMap(schemaMappingString:String):
java.util.HashMap[String, SchemaQualifierDefinition] = {
try {
val columnDefinitions = schemaMappingString.split(',')
val resultingMap = new java.util.HashMap[String, SchemaQualifierDefinition]()
columnDefinitions.map(cd => {
val parts = cd.trim.split(' ')
//Make sure we get three parts
//<ColumnName> <ColumnType> <ColumnFamily:Qualifier>
if (parts.length == 3) {
val hbaseDefinitionParts = if (parts(2).charAt(0) == ':') {
Array[String]("", "key")
} else {
parts(2).split(':')
}
resultingMap.put(parts(0), new SchemaQualifierDefinition(parts(0),
parts(1), hbaseDefinitionParts(0), hbaseDefinitionParts(1)))
} else {
throw new IllegalArgumentException("Invalid value for schema mapping '" + cd +
"' should be '<columnName> <columnType> <columnFamily>:<qualifier>' " +
"for columns and '<columnName> <columnType> :<qualifier>' for rowKeys")
}
})
resultingMap
} catch {
case e:Exception => throw
new IllegalArgumentException("Invalid value for " + SCHEMA_COLUMNS_MAPPING_KEY +
" '" + schemaMappingString + "'", e )
}
new HBaseRelation(parameters, None)(sqlContext)
}
}
/**
* Implementation of Spark BaseRelation that will build up our scan logic
* , do the scan pruning, filter push down, and value conversions
*
* @param tableName HBase table that we plan to read from
* @param schemaMappingDefinition SchemaMapping information to map HBase
* Qualifiers to SparkSQL columns
* @param configResources Optional comma separated list of config resources
* to get based on their URI
* @param useHBaseContext If true this will look to see if
* HBaseContext.latest is populated to use that
* connection information
* @param sqlContext SparkSQL context
*/
case class HBaseRelation (val tableName:String,
val schemaMappingDefinition:
java.util.HashMap[String, SchemaQualifierDefinition],
val configResources:String,
val useHBaseContext:Boolean,
val usePushDownColumnFilter:Boolean,
@transient parameters: Map[String, String] ) (
@transient val sqlContext:SQLContext)
case class HBaseRelation (
@transient parameters: Map[String, String],
userSpecifiedSchema: Option[StructType]
)(@transient val sqlContext: SQLContext)
extends BaseRelation with PrunedFilteredScan with Logging {
val catalog = HBaseTableCatalog(parameters)
def tableName = catalog.name
val configResources = parameters.getOrElse(HBaseSparkConf.HBASE_CONFIG_RESOURCES_LOCATIONS, "")
val useHBaseContext = parameters.get(HBaseSparkConf.USE_HBASE_CONTEXT).map(_.toBoolean).getOrElse(true)
val usePushDownColumnFilter = parameters.get(HBaseSparkConf.PUSH_DOWN_COLUMN_FILTER)
.map(_.toBoolean).getOrElse(true)
// The user supplied per table parameter will overwrite global ones in SparkConf
val blockCacheEnable = parameters.get(HBaseSparkConf.BLOCK_CACHE_ENABLE).map(_.toBoolean)
@ -176,33 +109,12 @@ case class HBaseRelation (val tableName:String,
def hbaseConf = wrappedConf.value
/**
* Generates a Spark SQL schema object so Spark SQL knows what is being
* Generates a Spark SQL schema objeparametersct so Spark SQL knows what is being
* provided by this BaseRelation
*
* @return schema generated from the SCHEMA_COLUMNS_MAPPING_KEY value
*/
override def schema: StructType = {
val metadataBuilder = new MetadataBuilder()
val structFieldArray = new Array[StructField](schemaMappingDefinition.size())
val schemaMappingDefinitionIt = schemaMappingDefinition.values().iterator()
var indexCounter = 0
while (schemaMappingDefinitionIt.hasNext) {
val c = schemaMappingDefinitionIt.next()
val metadata = metadataBuilder.putString("name", c.columnName).build()
val structField =
new StructField(c.columnName, c.columnSparkSqlType, nullable = true, metadata)
structFieldArray(indexCounter) = structField
indexCounter += 1
}
val result = new StructType(structFieldArray)
result
}
override val schema: StructType = userSpecifiedSchema.getOrElse(catalog.toDataType)
/**
* Here we are building the functionality to populate the resulting RDD[Row]
@ -218,7 +130,6 @@ case class HBaseRelation (val tableName:String,
*/
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val pushDownTuple = buildPushDownPredicatesResource(filters)
val pushDownRowKeyFilter = pushDownTuple._1
var pushDownDynamicLogicExpression = pushDownTuple._2
@ -236,17 +147,13 @@ case class HBaseRelation (val tableName:String,
logDebug("valueArray: " + valueArray.length)
val requiredQualifierDefinitionList =
new mutable.MutableList[SchemaQualifierDefinition]
new mutable.MutableList[Field]
requiredColumns.foreach( c => {
val definition = schemaMappingDefinition.get(c)
requiredQualifierDefinitionList += definition
val field = catalog.getField(c)
requiredQualifierDefinitionList += field
})
//Create a local variable so that scala doesn't have to
// serialize the whole HBaseRelation Object
val serializableDefinitionMap = schemaMappingDefinition
//retain the information for unit testing checks
DefaultSourceStaticUtils.populateLatestExecutionRules(pushDownRowKeyFilter,
pushDownDynamicLogicExpression)
@ -258,8 +165,8 @@ case class HBaseRelation (val tableName:String,
pushDownRowKeyFilter.points.foreach(p => {
val get = new Get(p)
requiredQualifierDefinitionList.foreach( d => {
if (d.columnFamilyBytes.length > 0)
get.addColumn(d.columnFamilyBytes, d.qualifierBytes)
if (d.isRowKey)
get.addColumn(d.cfBytes, d.colBytes)
})
getList.add(get)
})
@ -276,7 +183,7 @@ case class HBaseRelation (val tableName:String,
var resultRDD: RDD[Row] = {
val tmp = hRdd.map{ r =>
Row.fromSeq(requiredColumns.map(c =>
DefaultSourceStaticUtils.getValue(c, serializableDefinitionMap, r)))
DefaultSourceStaticUtils.getValue(catalog.getField(c), r)))
}
if (tmp.partitions.size > 0) {
tmp
@ -291,11 +198,10 @@ case class HBaseRelation (val tableName:String,
scan.setBatch(batchNum)
scan.setCaching(cacheSize)
requiredQualifierDefinitionList.foreach( d =>
scan.addColumn(d.columnFamilyBytes, d.qualifierBytes))
scan.addColumn(d.cfBytes, d.colBytes))
val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => {
Row.fromSeq(requiredColumns.map(c => DefaultSourceStaticUtils.getValue(c,
serializableDefinitionMap, r._2)))
Row.fromSeq(requiredColumns.map(c => DefaultSourceStaticUtils.getValue(catalog.getField(c), r._2)))
})
resultRDD=rdd
}
@ -337,74 +243,73 @@ case class HBaseRelation (val tableName:String,
filter match {
case EqualTo(attr, value) =>
val columnDefinition = schemaMappingDefinition.get(attr)
if (columnDefinition != null) {
if (columnDefinition.columnFamily.isEmpty) {
val field = catalog.getField(attr)
if (field != null) {
if (field.isRowKey) {
parentRowKeyFilter.mergeIntersect(new RowKeyFilter(
DefaultSourceStaticUtils.getByteValue(attr,
schemaMappingDefinition, value.toString), null))
DefaultSourceStaticUtils.getByteValue(field,
value.toString), null))
}
val byteValue =
DefaultSourceStaticUtils.getByteValue(attr,
schemaMappingDefinition, value.toString)
DefaultSourceStaticUtils.getByteValue(field, value.toString)
valueArray += byteValue
}
new EqualLogicExpression(attr, valueArray.length - 1, false)
case LessThan(attr, value) =>
val columnDefinition = schemaMappingDefinition.get(attr)
if (columnDefinition != null) {
if (columnDefinition.columnFamily.isEmpty) {
val field = catalog.getField(attr)
if (field != null) {
if (field.isRowKey) {
parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null,
new ScanRange(DefaultSourceStaticUtils.getByteValue(attr,
schemaMappingDefinition, value.toString), false,
new ScanRange(DefaultSourceStaticUtils.getByteValue(field,
value.toString), false,
new Array[Byte](0), true)))
}
val byteValue =
DefaultSourceStaticUtils.getByteValue(attr,
schemaMappingDefinition, value.toString)
DefaultSourceStaticUtils.getByteValue(catalog.getField(attr),
value.toString)
valueArray += byteValue
}
new LessThanLogicExpression(attr, valueArray.length - 1)
case GreaterThan(attr, value) =>
val columnDefinition = schemaMappingDefinition.get(attr)
if (columnDefinition != null) {
if (columnDefinition.columnFamily.isEmpty) {
val field = catalog.getField(attr)
if (field != null) {
if (field.isRowKey) {
parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null,
new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(attr,
schemaMappingDefinition, value.toString), false)))
new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(field,
value.toString), false)))
}
val byteValue =
DefaultSourceStaticUtils.getByteValue(attr,
schemaMappingDefinition, value.toString)
DefaultSourceStaticUtils.getByteValue(field,
value.toString)
valueArray += byteValue
}
new GreaterThanLogicExpression(attr, valueArray.length - 1)
case LessThanOrEqual(attr, value) =>
val columnDefinition = schemaMappingDefinition.get(attr)
if (columnDefinition != null) {
if (columnDefinition.columnFamily.isEmpty) {
val field = catalog.getField(attr)
if (field != null) {
if (field.isRowKey) {
parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null,
new ScanRange(DefaultSourceStaticUtils.getByteValue(attr,
schemaMappingDefinition, value.toString), true,
new ScanRange(DefaultSourceStaticUtils.getByteValue(field,
value.toString), true,
new Array[Byte](0), true)))
}
val byteValue =
DefaultSourceStaticUtils.getByteValue(attr,
schemaMappingDefinition, value.toString)
DefaultSourceStaticUtils.getByteValue(catalog.getField(attr),
value.toString)
valueArray += byteValue
}
new LessThanOrEqualLogicExpression(attr, valueArray.length - 1)
case GreaterThanOrEqual(attr, value) =>
val columnDefinition = schemaMappingDefinition.get(attr)
if (columnDefinition != null) {
if (columnDefinition.columnFamily.isEmpty) {
val field = catalog.getField(attr)
if (field != null) {
if (field.isRowKey) {
parentRowKeyFilter.mergeIntersect(new RowKeyFilter(null,
new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(attr,
schemaMappingDefinition, value.toString), true)))
new ScanRange(null, true, DefaultSourceStaticUtils.getByteValue(field,
value.toString), true)))
}
val byteValue =
DefaultSourceStaticUtils.getByteValue(attr,
schemaMappingDefinition, value.toString)
DefaultSourceStaticUtils.getByteValue(catalog.getField(attr),
value.toString)
valueArray += byteValue
}
@ -435,32 +340,6 @@ case class HBaseRelation (val tableName:String,
}
}
/**
* Construct to contains column data that spend SparkSQL and HBase
*
* @param columnName SparkSQL column name
* @param colType SparkSQL column type
* @param columnFamily HBase column family
* @param qualifier HBase qualifier name
*/
case class SchemaQualifierDefinition(columnName:String,
colType:String,
columnFamily:String,
qualifier:String) extends Serializable {
val columnFamilyBytes = Bytes.toBytes(columnFamily)
val qualifierBytes = Bytes.toBytes(qualifier)
val columnSparkSqlType:DataType = if (colType.equals("BOOLEAN")) BooleanType
else if (colType.equals("TINYINT")) IntegerType
else if (colType.equals("INT")) IntegerType
else if (colType.equals("BIGINT")) LongType
else if (colType.equals("FLOAT")) FloatType
else if (colType.equals("DOUBLE")) DoubleType
else if (colType.equals("STRING")) StringType
else if (colType.equals("TIMESTAMP")) TimestampType
else if (colType.equals("DECIMAL")) StringType
else throw new IllegalArgumentException("Unsupported column type :" + colType)
}
/**
* Construct to contain a single scan ranges information. Also
* provide functions to merge with other scan ranges through AND
@ -788,35 +667,6 @@ class ColumnFilterCollection {
})
}
/**
* This will collect all the filter information in a way that is optimized
* for the HBase filter commend. Allowing the filter to be accessed
* with columnFamily and qualifier information
*
* @param schemaDefinitionMap Schema Map that will help us map the right filters
* to the correct columns
* @return HashMap oc column filters
*/
def generateFamilyQualifiterFilterMap(schemaDefinitionMap:
java.util.HashMap[String,
SchemaQualifierDefinition]):
util.HashMap[ColumnFamilyQualifierMapKeyWrapper, ColumnFilter] = {
val familyQualifierFilterMap =
new util.HashMap[ColumnFamilyQualifierMapKeyWrapper, ColumnFilter]()
columnFilterMap.foreach( e => {
val definition = schemaDefinitionMap.get(e._1)
//Don't add rowKeyFilter
if (definition.columnFamilyBytes.size > 0) {
familyQualifierFilterMap.put(
new ColumnFamilyQualifierMapKeyWrapper(
definition.columnFamilyBytes, 0, definition.columnFamilyBytes.length,
definition.qualifierBytes, 0, definition.qualifierBytes.length), e._2)
}
})
familyQualifierFilterMap
}
override def toString:String = {
val strBuilder = new StringBuilder
columnFilterMap.foreach( e => strBuilder.append(e))
@ -836,7 +686,7 @@ object DefaultSourceStaticUtils {
val rawDouble = new RawDouble
val rawString = RawString.ASCENDING
val byteRange = new ThreadLocal[PositionedByteRange]{
val byteRange = new ThreadLocal[PositionedByteRange] {
override def initialValue(): PositionedByteRange = {
val range = new SimplePositionedMutableByteRange()
range.setOffset(0)
@ -844,11 +694,11 @@ object DefaultSourceStaticUtils {
}
}
def getFreshByteRange(bytes:Array[Byte]): PositionedByteRange = {
def getFreshByteRange(bytes: Array[Byte]): PositionedByteRange = {
getFreshByteRange(bytes, 0, bytes.length)
}
def getFreshByteRange(bytes:Array[Byte], offset:Int = 0, length:Int):
def getFreshByteRange(bytes: Array[Byte], offset: Int = 0, length: Int):
PositionedByteRange = {
byteRange.get().set(bytes).setLength(length).setOffset(offset)
}
@ -867,7 +717,7 @@ object DefaultSourceStaticUtils {
* @param dynamicLogicExpression The dynamicLogicExpression used in the last query
*/
def populateLatestExecutionRules(rowKeyFilter: RowKeyFilter,
dynamicLogicExpression: DynamicLogicExpression):Unit = {
dynamicLogicExpression: DynamicLogicExpression): Unit = {
lastFiveExecutionRules.add(new ExecutionRuleForUnitTesting(
rowKeyFilter, dynamicLogicExpression))
while (lastFiveExecutionRules.size() > 5) {
@ -879,25 +729,16 @@ object DefaultSourceStaticUtils {
* This method will convert the result content from HBase into the
* SQL value type that is requested by the Spark SQL schema definition
*
* @param columnName The name of the SparkSQL Column
* @param schemaMappingDefinition The schema definition map
* @param field The structure of the SparkSQL Column
* @param r The result object from HBase
* @return The converted object type
*/
def getValue(columnName: String,
schemaMappingDefinition:
java.util.HashMap[String, SchemaQualifierDefinition],
r: Result): Any = {
val columnDef = schemaMappingDefinition.get(columnName)
if (columnDef == null) throw new IllegalArgumentException("Unknown column:" + columnName)
if (columnDef.columnFamilyBytes.isEmpty) {
def getValue(field: Field,
r: Result): Any = {
if (field.isRowKey) {
val row = r.getRow
columnDef.columnSparkSqlType match {
field.dt match {
case IntegerType => rawInteger.decode(getFreshByteRange(row))
case LongType => rawLong.decode(getFreshByteRange(row))
case FloatType => rawFloat.decode(getFreshByteRange(row))
@ -908,9 +749,9 @@ object DefaultSourceStaticUtils {
}
} else {
val cellByteValue =
r.getColumnLatestCell(columnDef.columnFamilyBytes, columnDef.qualifierBytes)
r.getColumnLatestCell(field.cfBytes, field.colBytes)
if (cellByteValue == null) null
else columnDef.columnSparkSqlType match {
else field.dt match {
case IntegerType => rawInteger.decode(getFreshByteRange(cellByteValue.getValueArray,
cellByteValue.getValueOffset, cellByteValue.getValueLength))
case LongType => rawLong.decode(getFreshByteRange(cellByteValue.getValueArray,
@ -933,52 +774,41 @@ object DefaultSourceStaticUtils {
* This will convert the value from SparkSQL to be stored into HBase using the
* right byte Type
*
* @param columnName SparkSQL column name
* @param schemaMappingDefinition Schema definition map
* @param value String value from SparkSQL
* @return Returns the byte array to go into HBase
*/
def getByteValue(columnName: String,
schemaMappingDefinition:
java.util.HashMap[String, SchemaQualifierDefinition],
value: String): Array[Byte] = {
def getByteValue(field: Field,
value: String): Array[Byte] = {
field.dt match {
case IntegerType =>
val result = new Array[Byte](Bytes.SIZEOF_INT)
val localDataRange = getFreshByteRange(result)
rawInteger.encode(localDataRange, value.toInt)
localDataRange.getBytes
case LongType =>
val result = new Array[Byte](Bytes.SIZEOF_LONG)
val localDataRange = getFreshByteRange(result)
rawLong.encode(localDataRange, value.toLong)
localDataRange.getBytes
case FloatType =>
val result = new Array[Byte](Bytes.SIZEOF_FLOAT)
val localDataRange = getFreshByteRange(result)
rawFloat.encode(localDataRange, value.toFloat)
localDataRange.getBytes
case DoubleType =>
val result = new Array[Byte](Bytes.SIZEOF_DOUBLE)
val localDataRange = getFreshByteRange(result)
rawDouble.encode(localDataRange, value.toDouble)
localDataRange.getBytes
case StringType =>
Bytes.toBytes(value)
case TimestampType =>
val result = new Array[Byte](Bytes.SIZEOF_LONG)
val localDataRange = getFreshByteRange(result)
rawLong.encode(localDataRange, value.toLong)
localDataRange.getBytes
val columnDef = schemaMappingDefinition.get(columnName)
if (columnDef == null) {
throw new IllegalArgumentException("Unknown column:" + columnName)
} else {
columnDef.columnSparkSqlType match {
case IntegerType =>
val result = new Array[Byte](Bytes.SIZEOF_INT)
val localDataRange = getFreshByteRange(result)
rawInteger.encode(localDataRange, value.toInt)
localDataRange.getBytes
case LongType =>
val result = new Array[Byte](Bytes.SIZEOF_LONG)
val localDataRange = getFreshByteRange(result)
rawLong.encode(localDataRange, value.toLong)
localDataRange.getBytes
case FloatType =>
val result = new Array[Byte](Bytes.SIZEOF_FLOAT)
val localDataRange = getFreshByteRange(result)
rawFloat.encode(localDataRange, value.toFloat)
localDataRange.getBytes
case DoubleType =>
val result = new Array[Byte](Bytes.SIZEOF_DOUBLE)
val localDataRange = getFreshByteRange(result)
rawDouble.encode(localDataRange, value.toDouble)
localDataRange.getBytes
case StringType =>
Bytes.toBytes(value)
case TimestampType =>
val result = new Array[Byte](Bytes.SIZEOF_LONG)
val localDataRange = getFreshByteRange(result)
rawLong.encode(localDataRange, value.toLong)
localDataRange.getBytes
case _ => Bytes.toBytes(value)
}
case _ => Bytes.toBytes(value)
}
}
}

View File

@ -31,4 +31,9 @@ object HBaseSparkConf{
val defaultBatchNum = 1000
val BULKGET_SIZE = "spark.hbase.bulkGetSize"
val defaultBulkGetSize = 1000
val HBASE_CONFIG_RESOURCES_LOCATIONS = "hbase.config.resources"
val USE_HBASE_CONTEXT = "hbase.use.hbase.context"
val PUSH_DOWN_COLUMN_FILTER = "hbase.pushdown.column.filter"
val defaultPushDownColumnFilter = true
}

View File

@ -23,6 +23,7 @@ import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.spark._
import org.apache.hadoop.hbase.spark.hbase._
import org.apache.hadoop.hbase.spark.datasources.HBaseResources._
import org.apache.spark.sql.datasources.hbase.Field
import org.apache.spark.{SparkEnv, TaskContext, Logging, Partition}
import org.apache.spark.rdd.RDD
@ -31,7 +32,7 @@ import scala.collection.mutable
class HBaseTableScanRDD(relation: HBaseRelation,
val hbaseContext: HBaseContext,
@transient val filter: Option[SparkSQLPushDownFilter] = None,
val columns: Seq[SchemaQualifierDefinition] = Seq.empty
val columns: Seq[Field] = Seq.empty
)extends RDD[Result](relation.sqlContext.sparkContext, Nil) with Logging {
private def sparkConf = SparkEnv.get.conf
@transient var ranges = Seq.empty[Range]
@ -98,15 +99,15 @@ class HBaseTableScanRDD(relation: HBaseRelation,
tbr: TableResource,
g: Seq[Array[Byte]],
filter: Option[SparkSQLPushDownFilter],
columns: Seq[SchemaQualifierDefinition],
columns: Seq[Field],
hbaseContext: HBaseContext): Iterator[Result] = {
g.grouped(relation.bulkGetSize).flatMap{ x =>
val gets = new ArrayList[Get]()
x.foreach{ y =>
val g = new Get(y)
columns.foreach { d =>
if (d.columnFamilyBytes.length > 0) {
g.addColumn(d.columnFamilyBytes, d.qualifierBytes)
if (!d.isRowKey) {
g.addColumn(d.cfBytes, d.colBytes)
}
}
filter.foreach(g.setFilter(_))
@ -149,7 +150,7 @@ class HBaseTableScanRDD(relation: HBaseRelation,
private def buildScan(range: Range,
filter: Option[SparkSQLPushDownFilter],
columns: Seq[SchemaQualifierDefinition]): Scan = {
columns: Seq[Field]): Scan = {
val scan = (range.lower, range.upper) match {
case (Some(Bound(a, b)), Some(Bound(c, d))) => new Scan(a, c)
case (None, Some(Bound(c, d))) => new Scan(Array[Byte](), c)
@ -158,8 +159,8 @@ class HBaseTableScanRDD(relation: HBaseRelation,
}
columns.foreach { d =>
if (d.columnFamilyBytes.length > 0) {
scan.addColumn(d.columnFamilyBytes, d.qualifierBytes)
if (!d.isRowKey) {
scan.addColumn(d.cfBytes, d.colBytes)
}
}
scan.setCacheBlocks(relation.blockCacheEnable)

View File

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.spark.datasources
import java.io.ByteArrayInputStream
import org.apache.avro.Schema
import org.apache.avro.Schema.Type._
import org.apache.avro.generic.GenericDatumReader
import org.apache.avro.generic.GenericDatumWriter
import org.apache.avro.generic.GenericRecord
import org.apache.avro.generic.{GenericDatumReader, GenericDatumWriter, GenericRecord}
import org.apache.avro.io._
import org.apache.commons.io.output.ByteArrayOutputStream
import org.apache.hadoop.hbase.util.Bytes
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.sql.types._
trait SerDes {
def serialize(value: Any): Array[Byte]
def deserialize(bytes: Array[Byte], start: Int, end: Int): Any
}
class DoubleSerDes extends SerDes {
override def serialize(value: Any): Array[Byte] = Bytes.toBytes(value.asInstanceOf[Double])
override def deserialize(bytes: Array[Byte], start: Int, end: Int): Any = {
Bytes.toDouble(bytes, start)
}
}

View File

@ -0,0 +1,30 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.datasources.hbase
import org.apache.spark.sql.catalyst.SqlLexical
import org.apache.spark.sql.catalyst.util.DataTypeParser
import org.apache.spark.sql.types.DataType
object DataTypeParserWrapper {
lazy val dataTypeParser = new DataTypeParser {
override val lexical = new SqlLexical
}
def parse(dataTypeString: String): DataType = dataTypeParser.toDataType(dataTypeString)
}

View File

@ -0,0 +1,339 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.spark.sql.datasources.hbase
import org.apache.hadoop.hbase.spark.datasources._
import org.apache.hadoop.hbase.spark.hbase._
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.Logging
import org.apache.spark.sql.catalyst.util.DataTypeParser
import org.apache.spark.sql.types._
import org.json4s.jackson.JsonMethods._
import scala.collection.mutable
// Due the access issue defined in spark, we have to locate the file in this package.
// The definition of each column cell, which may be composite type
// TODO: add avro support
case class Field(
colName: String,
cf: String,
col: String,
sType: Option[String] = None,
avroSchema: Option[String] = None,
serdes: Option[SerDes]= None,
len: Int = -1) extends Logging {
override def toString = s"$colName $cf $col"
val isRowKey = cf == HBaseTableCatalog.rowKey
var start: Int = _
def cfBytes: Array[Byte] = {
if (isRowKey) {
Bytes.toBytes("")
} else {
Bytes.toBytes(cf)
}
}
def colBytes: Array[Byte] = {
if (isRowKey) {
Bytes.toBytes("key")
} else {
Bytes.toBytes(col)
}
}
val dt = {
sType.map(DataTypeParser.parse(_)).get
}
var length: Int = {
if (len == -1) {
dt match {
case BinaryType | StringType => -1
case BooleanType => Bytes.SIZEOF_BOOLEAN
case ByteType => 1
case DoubleType => Bytes.SIZEOF_DOUBLE
case FloatType => Bytes.SIZEOF_FLOAT
case IntegerType => Bytes.SIZEOF_INT
case LongType => Bytes.SIZEOF_LONG
case ShortType => Bytes.SIZEOF_SHORT
case _ => -1
}
} else {
len
}
}
override def equals(other: Any): Boolean = other match {
case that: Field =>
colName == that.colName && cf == that.cf && col == that.col
case _ => false
}
}
// The row key definition, with each key refer to the col defined in Field, e.g.,
// key1:key2:key3
case class RowKey(k: String) {
val keys = k.split(":")
var fields: Seq[Field] = _
var varLength = false
def length = {
if (varLength) {
-1
} else {
fields.foldLeft(0){case (x, y) =>
x + y.length
}
}
}
}
// The map between the column presented to Spark and the HBase field
case class SchemaMap(map: mutable.HashMap[String, Field]) {
def toFields = map.map { case (name, field) =>
StructField(name, field.dt)
}.toSeq
def fields = map.values
def getField(name: String) = map(name)
}
// The definition of HBase and Relation relation schema
case class HBaseTableCatalog(
namespace: String,
name: String,
row: RowKey,
sMap: SchemaMap,
numReg: Int) extends Logging {
def toDataType = StructType(sMap.toFields)
def getField(name: String) = sMap.getField(name)
def getRowKey: Seq[Field] = row.fields
def getPrimaryKey= row.keys(0)
def getColumnFamilies = {
sMap.fields.map(_.cf).filter(_ != HBaseTableCatalog.rowKey)
}
// Setup the start and length for each dimension of row key at runtime.
def dynSetupRowKey(rowKey: HBaseType) {
logDebug(s"length: ${rowKey.length}")
if(row.varLength) {
var start = 0
row.fields.foreach { f =>
logDebug(s"start: $start")
f.start = start
f.length = {
// If the length is not defined
if (f.length == -1) {
f.dt match {
case StringType =>
var pos = rowKey.indexOf(HBaseTableCatalog.delimiter, start)
if (pos == -1 || pos > rowKey.length) {
// this is at the last dimension
pos = rowKey.length
}
pos - start
// We don't know the length, assume it extend to the end of the rowkey.
case _ => rowKey.length - start
}
} else {
f.length
}
}
start += f.length
}
}
}
def initRowKey = {
val fields = sMap.fields.filter(_.cf == HBaseTableCatalog.rowKey)
row.fields = row.keys.flatMap(n => fields.find(_.col == n))
// The length is determined at run time if it is string or binary and the length is undefined.
if (row.fields.filter(_.length == -1).isEmpty) {
var start = 0
row.fields.foreach { f =>
f.start = start
start += f.length
}
} else {
row.varLength = true
}
}
initRowKey
}
object HBaseTableCatalog {
val newTable = "newtable"
// The json string specifying hbase catalog information
val tableCatalog = "catalog"
// The row key with format key1:key2 specifying table row key
val rowKey = "rowkey"
// The key for hbase table whose value specify namespace and table name
val table = "table"
// The namespace of hbase table
val nameSpace = "namespace"
// The name of hbase table
val tableName = "name"
// The name of columns in hbase catalog
val columns = "columns"
val cf = "cf"
val col = "col"
val `type` = "type"
// the name of avro schema json string
val avro = "avro"
val delimiter: Byte = 0
val serdes = "serdes"
val length = "length"
/**
* User provide table schema definition
* {"tablename":"name", "rowkey":"key1:key2",
* "columns":{"col1":{"cf":"cf1", "col":"col1", "type":"type1"},
* "col2":{"cf":"cf2", "col":"col2", "type":"type2"}}}
* Note that any col in the rowKey, there has to be one corresponding col defined in columns
*/
def apply(params: Map[String, String]): HBaseTableCatalog = {
val parameters = convert(params)
// println(jString)
val jString = parameters(tableCatalog)
val map = parse(jString).values.asInstanceOf[Map[String, _]]
val tableMeta = map.get(table).get.asInstanceOf[Map[String, _]]
val nSpace = tableMeta.get(nameSpace).getOrElse("default").asInstanceOf[String]
val tName = tableMeta.get(tableName).get.asInstanceOf[String]
val cIter = map.get(columns).get.asInstanceOf[Map[String, Map[String, String]]].toIterator
val schemaMap = mutable.HashMap.empty[String, Field]
cIter.foreach { case (name, column) =>
val sd = {
column.get(serdes).asInstanceOf[Option[String]].map(n =>
Class.forName(n).newInstance().asInstanceOf[SerDes]
)
}
val len = column.get(length).map(_.toInt).getOrElse(-1)
val sAvro = column.get(avro).map(parameters(_))
val f = Field(name, column.getOrElse(cf, rowKey),
column.get(col).get,
column.get(`type`),
sAvro, sd, len)
schemaMap.+=((name, f))
}
val numReg = parameters.get(newTable).map(x => x.toInt).getOrElse(0)
val rKey = RowKey(map.get(rowKey).get.asInstanceOf[String])
HBaseTableCatalog(nSpace, tName, rKey, SchemaMap(schemaMap), numReg)
}
val TABLE_KEY: String = "hbase.table"
val SCHEMA_COLUMNS_MAPPING_KEY: String = "hbase.columns.mapping"
/* for backward compatibility. Convert the old definition to new json based definition formated as below
val catalog = s"""{
|"table":{"namespace":"default", "name":"htable"},
|"rowkey":"key1:key2",
|"columns":{
|"col1":{"cf":"rowkey", "col":"key1", "type":"string"},
|"col2":{"cf":"rowkey", "col":"key2", "type":"double"},
|"col3":{"cf":"cf1", "col":"col2", "type":"binary"},
|"col4":{"cf":"cf1", "col":"col3", "type":"timestamp"},
|"col5":{"cf":"cf1", "col":"col4", "type":"double", "serdes":"${classOf[DoubleSerDes].getName}"},
|"col6":{"cf":"cf1", "col":"col5", "type":"$map"},
|"col7":{"cf":"cf1", "col":"col6", "type":"$array"},
|"col8":{"cf":"cf1", "col":"col7", "type":"$arrayMap"}
|}
|}""".stripMargin
*/
@deprecated("Please use new json format to define HBaseCatalog")
def convert(parameters: Map[String, String]): Map[String, String] = {
val tableName = parameters.get(TABLE_KEY).getOrElse(null)
// if the hbase.table is not defined, we assume it is json format already.
if (tableName == null) return parameters
val schemaMappingString = parameters.getOrElse(SCHEMA_COLUMNS_MAPPING_KEY, "")
import scala.collection.JavaConverters._
val schemaMap = generateSchemaMappingMap(schemaMappingString).asScala.map(_._2.asInstanceOf[SchemaQualifierDefinition])
val rowkey = schemaMap.filter {
_.columnFamily == "rowkey"
}.map(_.columnName)
val cols = schemaMap.map { x =>
s""""${x.columnName}":{"cf":"${x.columnFamily}", "col":"${x.qualifier}", "type":"${x.colType}"}""".stripMargin
}
val jsonCatalog =
s"""{
|"table":{"namespace":"default", "name":"${tableName}"},
|"rowkey":"${rowkey.mkString(":")}",
|"columns":{
|${cols.mkString(",")}
|}
|}
""".stripMargin
parameters ++ Map(HBaseTableCatalog.tableCatalog->jsonCatalog)
}
/**
* Reads the SCHEMA_COLUMNS_MAPPING_KEY and converts it to a map of
* SchemaQualifierDefinitions with the original sql column name as the key
*
* @param schemaMappingString The schema mapping string from the SparkSQL map
* @return A map of definitions keyed by the SparkSQL column name
*/
def generateSchemaMappingMap(schemaMappingString:String):
java.util.HashMap[String, SchemaQualifierDefinition] = {
println(schemaMappingString)
try {
val columnDefinitions = schemaMappingString.split(',')
val resultingMap = new java.util.HashMap[String, SchemaQualifierDefinition]()
columnDefinitions.map(cd => {
val parts = cd.trim.split(' ')
//Make sure we get three parts
//<ColumnName> <ColumnType> <ColumnFamily:Qualifier>
if (parts.length == 3) {
val hbaseDefinitionParts = if (parts(2).charAt(0) == ':') {
Array[String]("rowkey", parts(0))
} else {
parts(2).split(':')
}
resultingMap.put(parts(0), new SchemaQualifierDefinition(parts(0),
parts(1), hbaseDefinitionParts(0), hbaseDefinitionParts(1)))
} else {
throw new IllegalArgumentException("Invalid value for schema mapping '" + cd +
"' should be '<columnName> <columnType> <columnFamily>:<qualifier>' " +
"for columns and '<columnName> <columnType> :<qualifier>' for rowKeys")
}
})
resultingMap
} catch {
case e:Exception => throw
new IllegalArgumentException("Invalid value for " + SCHEMA_COLUMNS_MAPPING_KEY +
" '" +
schemaMappingString + "'", e )
}
}
}
/**
* Construct to contains column data that spend SparkSQL and HBase
*
* @param columnName SparkSQL column name
* @param colType SparkSQL column type
* @param columnFamily HBase column family
* @param qualifier HBase qualifier name
*/
case class SchemaQualifierDefinition(columnName:String,
colType:String,
columnFamily:String,
qualifier:String)

View File

@ -21,6 +21,7 @@ import org.apache.hadoop.hbase.client.{Put, ConnectionFactory}
import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf
import org.apache.hadoop.hbase.util.Bytes
import org.apache.hadoop.hbase.{TableName, HBaseTestingUtility}
import org.apache.spark.sql.datasources.hbase.HBaseTableCatalog
import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkConf, SparkContext, Logging}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
@ -137,20 +138,37 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
connection.close()
}
def hbaseTable1Catalog = s"""{
|"table":{"namespace":"default", "name":"t1"},
|"rowkey":"key",
|"columns":{
|"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
|"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
|"B_FIELD":{"cf":"c", "col":"b", "type":"string"}
|}
|}""".stripMargin
new HBaseContext(sc, TEST_UTIL.getConfiguration)
sqlContext = new SQLContext(sc)
df = sqlContext.load("org.apache.hadoop.hbase.spark",
Map("hbase.columns.mapping" ->
"KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b,",
"hbase.table" -> "t1"))
Map(HBaseTableCatalog.tableCatalog->hbaseTable1Catalog))
df.registerTempTable("hbaseTable1")
def hbaseTable2Catalog = s"""{
|"table":{"namespace":"default", "name":"t2"},
|"rowkey":"key",
|"columns":{
|"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"int"},
|"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
|"B_FIELD":{"cf":"c", "col":"b", "type":"string"}
|}
|}""".stripMargin
df = sqlContext.load("org.apache.hadoop.hbase.spark",
Map("hbase.columns.mapping" ->
"KEY_FIELD INT :key, A_FIELD STRING c:a, B_FIELD STRING c:b,",
"hbase.table" -> "t2"))
Map(HBaseTableCatalog.tableCatalog->hbaseTable2Catalog))
df.registerTempTable("hbaseTable2")
}
@ -512,13 +530,20 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
assert(scanRange1.isUpperBoundEqualTo)
}
test("Test table that doesn't exist") {
val catalog = s"""{
|"table":{"namespace":"default", "name":"t1NotThere"},
|"rowkey":"key",
|"columns":{
|"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
|"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
|"B_FIELD":{"cf":"c", "col":"c", "type":"string"}
|}
|}""".stripMargin
intercept[Exception] {
df = sqlContext.load("org.apache.hadoop.hbase.spark",
Map("hbase.columns.mapping" ->
"KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b,",
"hbase.table" -> "t1NotThere"))
Map(HBaseTableCatalog.tableCatalog->catalog))
df.registerTempTable("hbaseNonExistingTmp")
@ -530,11 +555,20 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
}
test("Test table with column that doesn't exist") {
val catalog = s"""{
|"table":{"namespace":"default", "name":"t1"},
|"rowkey":"key",
|"columns":{
|"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
|"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
|"B_FIELD":{"cf":"c", "col":"b", "type":"string"},
|"C_FIELD":{"cf":"c", "col":"c", "type":"string"}
|}
|}""".stripMargin
df = sqlContext.load("org.apache.hadoop.hbase.spark",
Map("hbase.columns.mapping" ->
"KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b, C_FIELD STRING c:c,",
"hbase.table" -> "t1"))
Map(HBaseTableCatalog.tableCatalog->catalog))
df.registerTempTable("hbaseFactColumnTmp")
@ -549,10 +583,18 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
}
test("Test table with INT column") {
val catalog = s"""{
|"table":{"namespace":"default", "name":"t1"},
|"rowkey":"key",
|"columns":{
|"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
|"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
|"B_FIELD":{"cf":"c", "col":"b", "type":"string"},
|"I_FIELD":{"cf":"c", "col":"i", "type":"int"}
|}
|}""".stripMargin
df = sqlContext.load("org.apache.hadoop.hbase.spark",
Map("hbase.columns.mapping" ->
"KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b, I_FIELD INT c:i,",
"hbase.table" -> "t1"))
Map(HBaseTableCatalog.tableCatalog->catalog))
df.registerTempTable("hbaseIntTmp")
@ -571,10 +613,18 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
}
test("Test table with INT column defined at wrong type") {
val catalog = s"""{
|"table":{"namespace":"default", "name":"t1"},
|"rowkey":"key",
|"columns":{
|"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
|"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
|"B_FIELD":{"cf":"c", "col":"b", "type":"string"},
|"I_FIELD":{"cf":"c", "col":"i", "type":"string"}
|}
|}""".stripMargin
df = sqlContext.load("org.apache.hadoop.hbase.spark",
Map("hbase.columns.mapping" ->
"KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b, I_FIELD STRING c:i,",
"hbase.table" -> "t1"))
Map(HBaseTableCatalog.tableCatalog->catalog))
df.registerTempTable("hbaseIntWrongTypeTmp")
@ -594,32 +644,19 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
assert(localResult(0).getString(2).charAt(3).toByte == 1)
}
test("Test improperly formatted column mapping") {
intercept[IllegalArgumentException] {
df = sqlContext.load("org.apache.hadoop.hbase.spark",
Map("hbase.columns.mapping" ->
"KEY_FIELD,STRING,:key, A_FIELD,STRING,c:a, B_FIELD,STRING,c:b, I_FIELD,STRING,c:i,",
"hbase.table" -> "t1"))
df.registerTempTable("hbaseBadTmp")
val result = sqlContext.sql("SELECT KEY_FIELD, " +
"B_FIELD, I_FIELD FROM hbaseBadTmp")
val executionRules = DefaultSourceStaticUtils.lastFiveExecutionRules.poll()
assert(executionRules.dynamicLogicExpression == null)
result.take(5)
}
}
test("Test bad column type") {
intercept[IllegalArgumentException] {
val catalog = s"""{
|"table":{"namespace":"default", "name":"t1"},
|"rowkey":"key",
|"columns":{
|"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"FOOBAR"},
|"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
|"I_FIELD":{"cf":"c", "col":"i", "type":"string"}
|}
|}""".stripMargin
intercept[Exception] {
df = sqlContext.load("org.apache.hadoop.hbase.spark",
Map("hbase.columns.mapping" ->
"KEY_FIELD FOOBAR :key, A_FIELD STRING c:a, B_FIELD STRING c:b, I_FIELD STRING c:i,",
"hbase.table" -> "t1"))
Map(HBaseTableCatalog.tableCatalog->catalog))
df.registerTempTable("hbaseIntWrongTypeTmp")
@ -665,10 +702,18 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
}
test("Test table with sparse column") {
val catalog = s"""{
|"table":{"namespace":"default", "name":"t1"},
|"rowkey":"key",
|"columns":{
|"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
|"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
|"B_FIELD":{"cf":"c", "col":"b", "type":"string"},
|"Z_FIELD":{"cf":"c", "col":"z", "type":"string"}
|}
|}""".stripMargin
df = sqlContext.load("org.apache.hadoop.hbase.spark",
Map("hbase.columns.mapping" ->
"KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b, Z_FIELD STRING c:z,",
"hbase.table" -> "t1"))
Map(HBaseTableCatalog.tableCatalog->catalog))
df.registerTempTable("hbaseZTmp")
@ -688,11 +733,19 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
}
test("Test with column logic disabled") {
val catalog = s"""{
|"table":{"namespace":"default", "name":"t1"},
|"rowkey":"key",
|"columns":{
|"KEY_FIELD":{"cf":"rowkey", "col":"key", "type":"string"},
|"A_FIELD":{"cf":"c", "col":"a", "type":"string"},
|"B_FIELD":{"cf":"c", "col":"b", "type":"string"},
|"Z_FIELD":{"cf":"c", "col":"z", "type":"string"}
|}
|}""".stripMargin
df = sqlContext.load("org.apache.hadoop.hbase.spark",
Map("hbase.columns.mapping" ->
"KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD STRING c:b, Z_FIELD STRING c:z,",
"hbase.table" -> "t1",
"hbase.push.down.column.filter" -> "false"))
Map(HBaseTableCatalog.tableCatalog->catalog,
HBaseSparkConf.PUSH_DOWN_COLUMN_FILTER -> "false"))
df.registerTempTable("hbaseNoPushDownTmp")

View File

@ -0,0 +1,111 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.spark
import org.apache.hadoop.hbase.spark.datasources.{DoubleSerDes, SerDes}
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.Logging
import org.apache.spark.sql.datasources.hbase.{DataTypeParserWrapper, HBaseTableCatalog}
import org.apache.spark.sql.types._
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
class HBaseCatalogSuite extends FunSuite with BeforeAndAfterEach with BeforeAndAfterAll with Logging {
val map = s"""MAP<int, struct<varchar:string>>"""
val array = s"""array<struct<tinYint:tinyint>>"""
val arrayMap = s"""MAp<int, ARRAY<double>>"""
val catalog = s"""{
|"table":{"namespace":"default", "name":"htable"},
|"rowkey":"key1:key2",
|"columns":{
|"col1":{"cf":"rowkey", "col":"key1", "type":"string"},
|"col2":{"cf":"rowkey", "col":"key2", "type":"double"},
|"col3":{"cf":"cf1", "col":"col2", "type":"binary"},
|"col4":{"cf":"cf1", "col":"col3", "type":"timestamp"},
|"col5":{"cf":"cf1", "col":"col4", "type":"double", "serdes":"${classOf[DoubleSerDes].getName}"},
|"col6":{"cf":"cf1", "col":"col5", "type":"$map"},
|"col7":{"cf":"cf1", "col":"col6", "type":"$array"},
|"col8":{"cf":"cf1", "col":"col7", "type":"$arrayMap"}
|}
|}""".stripMargin
val parameters = Map(HBaseTableCatalog.tableCatalog->catalog)
val t = HBaseTableCatalog(parameters)
def checkDataType(dataTypeString: String, expectedDataType: DataType): Unit = {
test(s"parse ${dataTypeString.replace("\n", "")}") {
assert(DataTypeParserWrapper.parse(dataTypeString) === expectedDataType)
}
}
test("basic") {
assert(t.getField("col1").isRowKey == true)
assert(t.getPrimaryKey == "key1")
assert(t.getField("col3").dt == BinaryType)
assert(t.getField("col4").dt == TimestampType)
assert(t.getField("col5").dt == DoubleType)
assert(t.getField("col5").serdes != None)
assert(t.getField("col4").serdes == None)
assert(t.getField("col1").isRowKey)
assert(t.getField("col2").isRowKey)
assert(!t.getField("col3").isRowKey)
assert(t.getField("col2").length == Bytes.SIZEOF_DOUBLE)
assert(t.getField("col1").length == -1)
assert(t.getField("col8").length == -1)
}
checkDataType(
map,
t.getField("col6").dt
)
checkDataType(
array,
t.getField("col7").dt
)
checkDataType(
arrayMap,
t.getField("col8").dt
)
test("convert") {
val m = Map("hbase.columns.mapping" ->
"KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD DOUBLE c:b, C_FIELD BINARY c:c,",
"hbase.table" -> "t1")
val map = HBaseTableCatalog.convert(m)
val json = map.get(HBaseTableCatalog.tableCatalog).get
val parameters = Map(HBaseTableCatalog.tableCatalog->json)
val t = HBaseTableCatalog(parameters)
assert(t.getField("KEY_FIELD").isRowKey)
assert(DataTypeParserWrapper.parse("STRING") === t.getField("A_FIELD").dt)
assert(!t.getField("A_FIELD").isRowKey)
assert(DataTypeParserWrapper.parse("DOUBLE") === t.getField("B_FIELD").dt)
assert(DataTypeParserWrapper.parse("BINARY") === t.getField("C_FIELD").dt)
}
test("compatiblity") {
val m = Map("hbase.columns.mapping" ->
"KEY_FIELD STRING :key, A_FIELD STRING c:a, B_FIELD DOUBLE c:b, C_FIELD BINARY c:c,",
"hbase.table" -> "t1")
val t = HBaseTableCatalog(m)
assert(t.getField("KEY_FIELD").isRowKey)
assert(DataTypeParserWrapper.parse("STRING") === t.getField("A_FIELD").dt)
assert(!t.getField("A_FIELD").isRowKey)
assert(DataTypeParserWrapper.parse("DOUBLE") === t.getField("B_FIELD").dt)
assert(DataTypeParserWrapper.parse("BINARY") === t.getField("C_FIELD").dt)
}
}