HBASE-15336 Support Dataframe writer to the spark connector (Zhan Zhang)

This commit is contained in:
tedyu 2016-03-10 06:44:29 -08:00
parent d14b6c3810
commit f6945c4631
4 changed files with 242 additions and 12 deletions

View File

@ -21,17 +21,20 @@ import java.util
import java.util.concurrent.ConcurrentLinkedQueue
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.io.ImmutableBytesWritable
import org.apache.hadoop.hbase.mapred.TableOutputFormat
import org.apache.hadoop.hbase.spark.datasources.Utils
import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf
import org.apache.hadoop.hbase.spark.datasources.HBaseTableScanRDD
import org.apache.hadoop.hbase.spark.datasources.SerializableConfiguration
import org.apache.hadoop.hbase.types._
import org.apache.hadoop.hbase.util.{Bytes, PositionedByteRange, SimplePositionedMutableByteRange}
import org.apache.hadoop.hbase.{HBaseConfiguration, TableName}
import org.apache.hadoop.hbase.{HColumnDescriptor, HTableDescriptor, HBaseConfiguration, TableName}
import org.apache.hadoop.mapred.JobConf
import org.apache.spark.Logging
import org.apache.spark.rdd.RDD
import org.apache.spark.sql.datasources.hbase.{Field, HBaseTableCatalog}
import org.apache.spark.sql.types.{DataType => SparkDataType}
import org.apache.spark.sql.{Row, SQLContext}
import org.apache.spark.sql.{DataFrame, SaveMode, Row, SQLContext}
import org.apache.spark.sql.sources._
import org.apache.spark.sql.types._
@ -48,10 +51,11 @@ import scala.collection.mutable
* - Type conversions of basic SQL types. All conversions will be
* Through the HBase Bytes object commands.
*/
class DefaultSource extends RelationProvider with Logging {
class DefaultSource extends RelationProvider with CreatableRelationProvider with Logging {
/**
* Is given input from SparkSQL to construct a BaseRelation
* @param sqlContext SparkSQL context
*
* @param sqlContext SparkSQL context
* @param parameters Parameters given to us from SparkSQL
* @return A BaseRelation Object
*/
@ -60,18 +64,31 @@ class DefaultSource extends RelationProvider with Logging {
BaseRelation = {
new HBaseRelation(parameters, None)(sqlContext)
}
override def createRelation(
sqlContext: SQLContext,
mode: SaveMode,
parameters: Map[String, String],
data: DataFrame): BaseRelation = {
val relation = HBaseRelation(parameters, Some(data.schema))(sqlContext)
relation.createTable()
relation.insert(data, false)
relation
}
}
/**
* Implementation of Spark BaseRelation that will build up our scan logic
* , do the scan pruning, filter push down, and value conversions
* @param sqlContext SparkSQL context
*
* @param sqlContext SparkSQL context
*/
case class HBaseRelation (
@transient parameters: Map[String, String],
userSpecifiedSchema: Option[StructType]
)(@transient val sqlContext: SQLContext)
extends BaseRelation with PrunedFilteredScan with Logging {
extends BaseRelation with PrunedFilteredScan with InsertableRelation with Logging {
val catalog = HBaseTableCatalog(parameters)
def tableName = catalog.name
val configResources = parameters.getOrElse(HBaseSparkConf.HBASE_CONFIG_RESOURCES_LOCATIONS, "")
@ -116,6 +133,90 @@ case class HBaseRelation (
*/
override val schema: StructType = userSpecifiedSchema.getOrElse(catalog.toDataType)
def createTable() {
val numReg = parameters.get(HBaseTableCatalog.newTable).map(x => x.toInt).getOrElse(0)
val startKey = Bytes.toBytes(
parameters.get(HBaseTableCatalog.regionStart)
.getOrElse(HBaseTableCatalog.defaultRegionStart))
val endKey = Bytes.toBytes(
parameters.get(HBaseTableCatalog.regionEnd)
.getOrElse(HBaseTableCatalog.defaultRegionEnd))
if (numReg > 3) {
val tName = TableName.valueOf(catalog.name)
val cfs = catalog.getColumnFamilies
val connection = ConnectionFactory.createConnection(hbaseConf)
// Initialize hBase table if necessary
val admin = connection.getAdmin()
try {
if (!admin.isTableAvailable(tName)) {
val tableDesc = new HTableDescriptor(tName)
cfs.foreach { x =>
val cf = new HColumnDescriptor(x.getBytes())
logDebug(s"add family $x to ${catalog.name}")
tableDesc.addFamily(cf)
}
val splitKeys = Bytes.split(startKey, endKey, numReg);
admin.createTable(tableDesc, splitKeys)
}
}finally {
admin.close()
connection.close()
}
} else {
logInfo(
s"""${HBaseTableCatalog.newTable}
|is not defined or no larger than 3, skip the create table""".stripMargin)
}
}
/**
*
* @param data
* @param overwrite
*/
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
val jobConfig: JobConf = new JobConf(hbaseConf, this.getClass)
jobConfig.setOutputFormat(classOf[TableOutputFormat])
jobConfig.set(TableOutputFormat.OUTPUT_TABLE, catalog.name)
var count = 0
val rkFields = catalog.getRowKey
val rkIdxedFields = rkFields.map{ case x =>
(schema.fieldIndex(x.colName), x)
}
val colsIdxedFields = schema
.fieldNames
.partition( x => rkFields.map(_.colName).contains(x))
._2.map(x => (schema.fieldIndex(x), catalog.getField(x)))
val rdd = data.rdd
def convertToPut(row: Row) = {
// construct bytes for row key
val rowBytes = rkIdxedFields.map { case (x, y) =>
Utils.toBytes(row(x), y)
}
val rLen = rowBytes.foldLeft(0) { case (x, y) =>
x + y.length
}
val rBytes = new Array[Byte](rLen)
var offset = 0
rowBytes.foreach { x =>
System.arraycopy(x, 0, rBytes, offset, x.length)
offset += x.length
}
val put = new Put(rBytes)
colsIdxedFields.foreach { case (x, y) =>
val b = Utils.toBytes(row(x), y)
put.addColumn(Bytes.toBytes(y.cf), Bytes.toBytes(y.col), b)
}
count += 1
(new ImmutableBytesWritable, put)
}
rdd.map(convertToPut(_)).saveAsHadoopDataset(jobConfig)
}
/**
* Here we are building the functionality to populate the resulting RDD[Row]
* Here is where we will do the following:
@ -356,7 +457,8 @@ class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean,
/**
* Function to merge another scan object through a AND operation
* @param other Other scan object
*
* @param other Other scan object
*/
def mergeIntersect(other:ScanRange): Unit = {
val upperBoundCompare = compareRange(upperBound, other.upperBound)
@ -376,7 +478,8 @@ class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean,
/**
* Function to merge another scan object through a OR operation
* @param other Other scan object
*
* @param other Other scan object
*/
def mergeUnion(other:ScanRange): Unit = {

View File

@ -0,0 +1,44 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.spark.datasources
import org.apache.hadoop.hbase.util.Bytes
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.sql.datasources.hbase.Field
import org.apache.spark.unsafe.types.UTF8String
object Utils {
// convert input to data type
def toBytes(input: Any, field: Field): Array[Byte] = {
input match {
case data: Boolean => Bytes.toBytes(data)
case data: Byte => Array(data)
case data: Array[Byte] => data
case data: Double => Bytes.toBytes(data)
case data: Float => Bytes.toBytes(data)
case data: Int => Bytes.toBytes(data)
case data: Long => Bytes.toBytes(data)
case data: Short => Bytes.toBytes(data)
case data: UTF8String => data.getBytes
case data: String => Bytes.toBytes(data)
// TODO: add more data type support
case _ => throw new Exception(s"unsupported data type ${field.dt}")
}
}
}

View File

@ -121,7 +121,7 @@ case class HBaseTableCatalog(
name: String,
row: RowKey,
sMap: SchemaMap,
numReg: Int) extends Logging {
@transient params: Map[String, String]) extends Logging {
def toDataType = StructType(sMap.toFields)
def getField(name: String) = sMap.getField(name)
def getRowKey: Seq[Field] = row.fields
@ -130,6 +130,8 @@ case class HBaseTableCatalog(
sMap.fields.map(_.cf).filter(_ != HBaseTableCatalog.rowKey)
}
def get(key: String) = params.get(key)
// Setup the start and length for each dimension of row key at runtime.
def dynSetupRowKey(rowKey: HBaseType) {
logDebug(s"length: ${rowKey.length}")
@ -179,8 +181,13 @@ case class HBaseTableCatalog(
}
object HBaseTableCatalog {
// If defined and larger than 3, a new table will be created with the nubmer of region specified.
val newTable = "newtable"
// The json string specifying hbase catalog information
val regionStart = "regionStart"
val defaultRegionStart = "aaaaaaa"
val regionEnd = "regionEnd"
val defaultRegionEnd = "zzzzzzz"
val tableCatalog = "catalog"
// The row key with format key1:key2 specifying table row key
val rowKey = "rowkey"
@ -232,9 +239,8 @@ object HBaseTableCatalog {
sAvro, sd, len)
schemaMap.+=((name, f))
}
val numReg = parameters.get(newTable).map(x => x.toInt).getOrElse(0)
val rKey = RowKey(map.get(rowKey).get.asInstanceOf[String])
HBaseTableCatalog(nSpace, tName, rKey, SchemaMap(schemaMap), numReg)
HBaseTableCatalog(nSpace, tName, rKey, SchemaMap(schemaMap), parameters)
}
val TABLE_KEY: String = "hbase.table"

View File

@ -26,6 +26,26 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
import org.apache.spark.{SparkConf, SparkContext, Logging}
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
case class HBaseRecord(
col0: String,
col1: String,
col2: Double,
col3: Float,
col4: Int,
col5: Long)
object HBaseRecord {
def apply(i: Int, t: String): HBaseRecord = {
val s = s"""row${"%03d".format(i)}"""
HBaseRecord(s,
s,
i.toDouble,
i.toFloat,
i,
i.toLong)
}
}
class DefaultSourceSuite extends FunSuite with
BeforeAndAfterEach with BeforeAndAfterAll with Logging {
@transient var sc: SparkContext = null
@ -63,6 +83,7 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
sparkConf.set(HBaseSparkConf.BLOCK_CACHE_ENABLE, "true")
sparkConf.set(HBaseSparkConf.BATCH_NUM, "100")
sparkConf.set(HBaseSparkConf.CACHE_SIZE, "100")
sc = new SparkContext("local", "test", sparkConf)
val connection = ConnectionFactory.createConnection(TEST_UTIL.getConfiguration)
@ -759,4 +780,60 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
assert(executionRules.dynamicLogicExpression == null)
}
def writeCatalog = s"""{
|"table":{"namespace":"default", "name":"table1"},
|"rowkey":"key",
|"columns":{
|"col0":{"cf":"rowkey", "col":"key", "type":"string"},
|"col1":{"cf":"cf1", "col":"col1", "type":"string"},
|"col2":{"cf":"cf2", "col":"col2", "type":"double"},
|"col3":{"cf":"cf3", "col":"col3", "type":"float"},
|"col4":{"cf":"cf4", "col":"col4", "type":"int"},
|"col5":{"cf":"cf5", "col":"col5", "type":"bigint"}}
|}
|}""".stripMargin
def withCatalog(cat: String): DataFrame = {
sqlContext
.read
.options(Map(HBaseTableCatalog.tableCatalog->cat))
.format("org.apache.hadoop.hbase.spark")
.load()
}
test("populate table") {
val sql = sqlContext
import sql.implicits._
val data = (0 to 255).map { i =>
HBaseRecord(i, "extra")
}
sc.parallelize(data).toDF.write.options(
Map(HBaseTableCatalog.tableCatalog -> writeCatalog, HBaseTableCatalog.newTable -> "5"))
.format("org.apache.hadoop.hbase.spark")
.save()
}
test("empty column") {
val df = withCatalog(writeCatalog)
df.registerTempTable("table0")
val c = sqlContext.sql("select count(1) from table0").rdd.collect()(0)(0).asInstanceOf[Long]
assert(c == 256)
}
test("full query") {
val df = withCatalog(writeCatalog)
df.show
assert(df.count() == 256)
}
test("filtered query0") {
val sql = sqlContext
import sql.implicits._
val df = withCatalog(writeCatalog)
val s = df.filter($"col0" <= "row005")
.select("col0", "col1")
s.show
assert(s.count() == 6)
}
}