HBASE-15336 Support Dataframe writer to the spark connector (Zhan Zhang)
This commit is contained in:
parent
d14b6c3810
commit
f6945c4631
|
@ -21,17 +21,20 @@ import java.util
|
|||
import java.util.concurrent.ConcurrentLinkedQueue
|
||||
|
||||
import org.apache.hadoop.hbase.client._
|
||||
import org.apache.hadoop.hbase.io.ImmutableBytesWritable
|
||||
import org.apache.hadoop.hbase.mapred.TableOutputFormat
|
||||
import org.apache.hadoop.hbase.spark.datasources.Utils
|
||||
import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf
|
||||
import org.apache.hadoop.hbase.spark.datasources.HBaseTableScanRDD
|
||||
import org.apache.hadoop.hbase.spark.datasources.SerializableConfiguration
|
||||
import org.apache.hadoop.hbase.types._
|
||||
import org.apache.hadoop.hbase.util.{Bytes, PositionedByteRange, SimplePositionedMutableByteRange}
|
||||
import org.apache.hadoop.hbase.{HBaseConfiguration, TableName}
|
||||
import org.apache.hadoop.hbase.{HColumnDescriptor, HTableDescriptor, HBaseConfiguration, TableName}
|
||||
import org.apache.hadoop.mapred.JobConf
|
||||
import org.apache.spark.Logging
|
||||
import org.apache.spark.rdd.RDD
|
||||
import org.apache.spark.sql.datasources.hbase.{Field, HBaseTableCatalog}
|
||||
import org.apache.spark.sql.types.{DataType => SparkDataType}
|
||||
import org.apache.spark.sql.{Row, SQLContext}
|
||||
import org.apache.spark.sql.{DataFrame, SaveMode, Row, SQLContext}
|
||||
import org.apache.spark.sql.sources._
|
||||
import org.apache.spark.sql.types._
|
||||
|
||||
|
@ -48,10 +51,11 @@ import scala.collection.mutable
|
|||
* - Type conversions of basic SQL types. All conversions will be
|
||||
* Through the HBase Bytes object commands.
|
||||
*/
|
||||
class DefaultSource extends RelationProvider with Logging {
|
||||
class DefaultSource extends RelationProvider with CreatableRelationProvider with Logging {
|
||||
/**
|
||||
* Is given input from SparkSQL to construct a BaseRelation
|
||||
* @param sqlContext SparkSQL context
|
||||
*
|
||||
* @param sqlContext SparkSQL context
|
||||
* @param parameters Parameters given to us from SparkSQL
|
||||
* @return A BaseRelation Object
|
||||
*/
|
||||
|
@ -60,18 +64,31 @@ class DefaultSource extends RelationProvider with Logging {
|
|||
BaseRelation = {
|
||||
new HBaseRelation(parameters, None)(sqlContext)
|
||||
}
|
||||
|
||||
|
||||
override def createRelation(
|
||||
sqlContext: SQLContext,
|
||||
mode: SaveMode,
|
||||
parameters: Map[String, String],
|
||||
data: DataFrame): BaseRelation = {
|
||||
val relation = HBaseRelation(parameters, Some(data.schema))(sqlContext)
|
||||
relation.createTable()
|
||||
relation.insert(data, false)
|
||||
relation
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
* Implementation of Spark BaseRelation that will build up our scan logic
|
||||
* , do the scan pruning, filter push down, and value conversions
|
||||
* @param sqlContext SparkSQL context
|
||||
*
|
||||
* @param sqlContext SparkSQL context
|
||||
*/
|
||||
case class HBaseRelation (
|
||||
@transient parameters: Map[String, String],
|
||||
userSpecifiedSchema: Option[StructType]
|
||||
)(@transient val sqlContext: SQLContext)
|
||||
extends BaseRelation with PrunedFilteredScan with Logging {
|
||||
extends BaseRelation with PrunedFilteredScan with InsertableRelation with Logging {
|
||||
val catalog = HBaseTableCatalog(parameters)
|
||||
def tableName = catalog.name
|
||||
val configResources = parameters.getOrElse(HBaseSparkConf.HBASE_CONFIG_RESOURCES_LOCATIONS, "")
|
||||
|
@ -116,6 +133,90 @@ case class HBaseRelation (
|
|||
*/
|
||||
override val schema: StructType = userSpecifiedSchema.getOrElse(catalog.toDataType)
|
||||
|
||||
|
||||
|
||||
def createTable() {
|
||||
val numReg = parameters.get(HBaseTableCatalog.newTable).map(x => x.toInt).getOrElse(0)
|
||||
val startKey = Bytes.toBytes(
|
||||
parameters.get(HBaseTableCatalog.regionStart)
|
||||
.getOrElse(HBaseTableCatalog.defaultRegionStart))
|
||||
val endKey = Bytes.toBytes(
|
||||
parameters.get(HBaseTableCatalog.regionEnd)
|
||||
.getOrElse(HBaseTableCatalog.defaultRegionEnd))
|
||||
if (numReg > 3) {
|
||||
val tName = TableName.valueOf(catalog.name)
|
||||
val cfs = catalog.getColumnFamilies
|
||||
val connection = ConnectionFactory.createConnection(hbaseConf)
|
||||
// Initialize hBase table if necessary
|
||||
val admin = connection.getAdmin()
|
||||
try {
|
||||
if (!admin.isTableAvailable(tName)) {
|
||||
val tableDesc = new HTableDescriptor(tName)
|
||||
cfs.foreach { x =>
|
||||
val cf = new HColumnDescriptor(x.getBytes())
|
||||
logDebug(s"add family $x to ${catalog.name}")
|
||||
tableDesc.addFamily(cf)
|
||||
}
|
||||
val splitKeys = Bytes.split(startKey, endKey, numReg);
|
||||
admin.createTable(tableDesc, splitKeys)
|
||||
|
||||
}
|
||||
}finally {
|
||||
admin.close()
|
||||
connection.close()
|
||||
}
|
||||
} else {
|
||||
logInfo(
|
||||
s"""${HBaseTableCatalog.newTable}
|
||||
|is not defined or no larger than 3, skip the create table""".stripMargin)
|
||||
}
|
||||
}
|
||||
|
||||
/**
|
||||
*
|
||||
* @param data
|
||||
* @param overwrite
|
||||
*/
|
||||
override def insert(data: DataFrame, overwrite: Boolean): Unit = {
|
||||
val jobConfig: JobConf = new JobConf(hbaseConf, this.getClass)
|
||||
jobConfig.setOutputFormat(classOf[TableOutputFormat])
|
||||
jobConfig.set(TableOutputFormat.OUTPUT_TABLE, catalog.name)
|
||||
var count = 0
|
||||
val rkFields = catalog.getRowKey
|
||||
val rkIdxedFields = rkFields.map{ case x =>
|
||||
(schema.fieldIndex(x.colName), x)
|
||||
}
|
||||
val colsIdxedFields = schema
|
||||
.fieldNames
|
||||
.partition( x => rkFields.map(_.colName).contains(x))
|
||||
._2.map(x => (schema.fieldIndex(x), catalog.getField(x)))
|
||||
val rdd = data.rdd
|
||||
def convertToPut(row: Row) = {
|
||||
// construct bytes for row key
|
||||
val rowBytes = rkIdxedFields.map { case (x, y) =>
|
||||
Utils.toBytes(row(x), y)
|
||||
}
|
||||
val rLen = rowBytes.foldLeft(0) { case (x, y) =>
|
||||
x + y.length
|
||||
}
|
||||
val rBytes = new Array[Byte](rLen)
|
||||
var offset = 0
|
||||
rowBytes.foreach { x =>
|
||||
System.arraycopy(x, 0, rBytes, offset, x.length)
|
||||
offset += x.length
|
||||
}
|
||||
val put = new Put(rBytes)
|
||||
|
||||
colsIdxedFields.foreach { case (x, y) =>
|
||||
val b = Utils.toBytes(row(x), y)
|
||||
put.addColumn(Bytes.toBytes(y.cf), Bytes.toBytes(y.col), b)
|
||||
}
|
||||
count += 1
|
||||
(new ImmutableBytesWritable, put)
|
||||
}
|
||||
rdd.map(convertToPut(_)).saveAsHadoopDataset(jobConfig)
|
||||
}
|
||||
|
||||
/**
|
||||
* Here we are building the functionality to populate the resulting RDD[Row]
|
||||
* Here is where we will do the following:
|
||||
|
@ -356,7 +457,8 @@ class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean,
|
|||
|
||||
/**
|
||||
* Function to merge another scan object through a AND operation
|
||||
* @param other Other scan object
|
||||
*
|
||||
* @param other Other scan object
|
||||
*/
|
||||
def mergeIntersect(other:ScanRange): Unit = {
|
||||
val upperBoundCompare = compareRange(upperBound, other.upperBound)
|
||||
|
@ -376,7 +478,8 @@ class ScanRange(var upperBound:Array[Byte], var isUpperBoundEqualTo:Boolean,
|
|||
|
||||
/**
|
||||
* Function to merge another scan object through a OR operation
|
||||
* @param other Other scan object
|
||||
*
|
||||
* @param other Other scan object
|
||||
*/
|
||||
def mergeUnion(other:ScanRange): Unit = {
|
||||
|
||||
|
|
|
@ -0,0 +1,44 @@
|
|||
|
||||
/*
|
||||
* Licensed to the Apache Software Foundation (ASF) under one or more
|
||||
* contributor license agreements. See the NOTICE file distributed with
|
||||
* this work for additional information regarding copyright ownership.
|
||||
* The ASF licenses this file to You under the Apache License, Version 2.0
|
||||
* (the "License"); you may not use this file except in compliance with
|
||||
* the License. You may obtain a copy of the License at
|
||||
*
|
||||
* http://www.apache.org/licenses/LICENSE-2.0
|
||||
*
|
||||
* Unless required by applicable law or agreed to in writing, software
|
||||
* distributed under the License is distributed on an "AS IS" BASIS,
|
||||
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
||||
* See the License for the specific language governing permissions and
|
||||
* limitations under the License.
|
||||
*/
|
||||
|
||||
package org.apache.hadoop.hbase.spark.datasources
|
||||
|
||||
import org.apache.hadoop.hbase.util.Bytes
|
||||
import org.apache.hadoop.hbase.util.Bytes
|
||||
import org.apache.spark.sql.datasources.hbase.Field
|
||||
import org.apache.spark.unsafe.types.UTF8String
|
||||
|
||||
object Utils {
|
||||
// convert input to data type
|
||||
def toBytes(input: Any, field: Field): Array[Byte] = {
|
||||
input match {
|
||||
case data: Boolean => Bytes.toBytes(data)
|
||||
case data: Byte => Array(data)
|
||||
case data: Array[Byte] => data
|
||||
case data: Double => Bytes.toBytes(data)
|
||||
case data: Float => Bytes.toBytes(data)
|
||||
case data: Int => Bytes.toBytes(data)
|
||||
case data: Long => Bytes.toBytes(data)
|
||||
case data: Short => Bytes.toBytes(data)
|
||||
case data: UTF8String => data.getBytes
|
||||
case data: String => Bytes.toBytes(data)
|
||||
// TODO: add more data type support
|
||||
case _ => throw new Exception(s"unsupported data type ${field.dt}")
|
||||
}
|
||||
}
|
||||
}
|
|
@ -121,7 +121,7 @@ case class HBaseTableCatalog(
|
|||
name: String,
|
||||
row: RowKey,
|
||||
sMap: SchemaMap,
|
||||
numReg: Int) extends Logging {
|
||||
@transient params: Map[String, String]) extends Logging {
|
||||
def toDataType = StructType(sMap.toFields)
|
||||
def getField(name: String) = sMap.getField(name)
|
||||
def getRowKey: Seq[Field] = row.fields
|
||||
|
@ -130,6 +130,8 @@ case class HBaseTableCatalog(
|
|||
sMap.fields.map(_.cf).filter(_ != HBaseTableCatalog.rowKey)
|
||||
}
|
||||
|
||||
def get(key: String) = params.get(key)
|
||||
|
||||
// Setup the start and length for each dimension of row key at runtime.
|
||||
def dynSetupRowKey(rowKey: HBaseType) {
|
||||
logDebug(s"length: ${rowKey.length}")
|
||||
|
@ -179,8 +181,13 @@ case class HBaseTableCatalog(
|
|||
}
|
||||
|
||||
object HBaseTableCatalog {
|
||||
// If defined and larger than 3, a new table will be created with the nubmer of region specified.
|
||||
val newTable = "newtable"
|
||||
// The json string specifying hbase catalog information
|
||||
val regionStart = "regionStart"
|
||||
val defaultRegionStart = "aaaaaaa"
|
||||
val regionEnd = "regionEnd"
|
||||
val defaultRegionEnd = "zzzzzzz"
|
||||
val tableCatalog = "catalog"
|
||||
// The row key with format key1:key2 specifying table row key
|
||||
val rowKey = "rowkey"
|
||||
|
@ -232,9 +239,8 @@ object HBaseTableCatalog {
|
|||
sAvro, sd, len)
|
||||
schemaMap.+=((name, f))
|
||||
}
|
||||
val numReg = parameters.get(newTable).map(x => x.toInt).getOrElse(0)
|
||||
val rKey = RowKey(map.get(rowKey).get.asInstanceOf[String])
|
||||
HBaseTableCatalog(nSpace, tName, rKey, SchemaMap(schemaMap), numReg)
|
||||
HBaseTableCatalog(nSpace, tName, rKey, SchemaMap(schemaMap), parameters)
|
||||
}
|
||||
|
||||
val TABLE_KEY: String = "hbase.table"
|
||||
|
|
|
@ -26,6 +26,26 @@ import org.apache.spark.sql.{DataFrame, SQLContext}
|
|||
import org.apache.spark.{SparkConf, SparkContext, Logging}
|
||||
import org.scalatest.{BeforeAndAfterAll, BeforeAndAfterEach, FunSuite}
|
||||
|
||||
case class HBaseRecord(
|
||||
col0: String,
|
||||
col1: String,
|
||||
col2: Double,
|
||||
col3: Float,
|
||||
col4: Int,
|
||||
col5: Long)
|
||||
|
||||
object HBaseRecord {
|
||||
def apply(i: Int, t: String): HBaseRecord = {
|
||||
val s = s"""row${"%03d".format(i)}"""
|
||||
HBaseRecord(s,
|
||||
s,
|
||||
i.toDouble,
|
||||
i.toFloat,
|
||||
i,
|
||||
i.toLong)
|
||||
}
|
||||
}
|
||||
|
||||
class DefaultSourceSuite extends FunSuite with
|
||||
BeforeAndAfterEach with BeforeAndAfterAll with Logging {
|
||||
@transient var sc: SparkContext = null
|
||||
|
@ -63,6 +83,7 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
|
|||
sparkConf.set(HBaseSparkConf.BLOCK_CACHE_ENABLE, "true")
|
||||
sparkConf.set(HBaseSparkConf.BATCH_NUM, "100")
|
||||
sparkConf.set(HBaseSparkConf.CACHE_SIZE, "100")
|
||||
|
||||
sc = new SparkContext("local", "test", sparkConf)
|
||||
|
||||
val connection = ConnectionFactory.createConnection(TEST_UTIL.getConfiguration)
|
||||
|
@ -759,4 +780,60 @@ BeforeAndAfterEach with BeforeAndAfterAll with Logging {
|
|||
|
||||
assert(executionRules.dynamicLogicExpression == null)
|
||||
}
|
||||
|
||||
def writeCatalog = s"""{
|
||||
|"table":{"namespace":"default", "name":"table1"},
|
||||
|"rowkey":"key",
|
||||
|"columns":{
|
||||
|"col0":{"cf":"rowkey", "col":"key", "type":"string"},
|
||||
|"col1":{"cf":"cf1", "col":"col1", "type":"string"},
|
||||
|"col2":{"cf":"cf2", "col":"col2", "type":"double"},
|
||||
|"col3":{"cf":"cf3", "col":"col3", "type":"float"},
|
||||
|"col4":{"cf":"cf4", "col":"col4", "type":"int"},
|
||||
|"col5":{"cf":"cf5", "col":"col5", "type":"bigint"}}
|
||||
|}
|
||||
|}""".stripMargin
|
||||
|
||||
def withCatalog(cat: String): DataFrame = {
|
||||
sqlContext
|
||||
.read
|
||||
.options(Map(HBaseTableCatalog.tableCatalog->cat))
|
||||
.format("org.apache.hadoop.hbase.spark")
|
||||
.load()
|
||||
}
|
||||
|
||||
test("populate table") {
|
||||
val sql = sqlContext
|
||||
import sql.implicits._
|
||||
val data = (0 to 255).map { i =>
|
||||
HBaseRecord(i, "extra")
|
||||
}
|
||||
sc.parallelize(data).toDF.write.options(
|
||||
Map(HBaseTableCatalog.tableCatalog -> writeCatalog, HBaseTableCatalog.newTable -> "5"))
|
||||
.format("org.apache.hadoop.hbase.spark")
|
||||
.save()
|
||||
}
|
||||
|
||||
test("empty column") {
|
||||
val df = withCatalog(writeCatalog)
|
||||
df.registerTempTable("table0")
|
||||
val c = sqlContext.sql("select count(1) from table0").rdd.collect()(0)(0).asInstanceOf[Long]
|
||||
assert(c == 256)
|
||||
}
|
||||
|
||||
test("full query") {
|
||||
val df = withCatalog(writeCatalog)
|
||||
df.show
|
||||
assert(df.count() == 256)
|
||||
}
|
||||
|
||||
test("filtered query0") {
|
||||
val sql = sqlContext
|
||||
import sql.implicits._
|
||||
val df = withCatalog(writeCatalog)
|
||||
val s = df.filter($"col0" <= "row005")
|
||||
.select("col0", "col1")
|
||||
s.show
|
||||
assert(s.count() == 6)
|
||||
}
|
||||
}
|
||||
|
|
Loading…
Reference in New Issue