HBASE-14795 Enhance the spark-hbase scan operations (Zhan Zhang)

This commit is contained in:
tedyu 2015-12-13 18:26:54 -08:00
parent f34d3e1d26
commit 676ce01c82
7 changed files with 546 additions and 77 deletions

View File

@ -20,11 +20,11 @@ package org.apache.hadoop.hbase.spark
import java.util import java.util
import java.util.concurrent.ConcurrentLinkedQueue import java.util.concurrent.ConcurrentLinkedQueue
import org.apache.hadoop.hbase.client.{ConnectionFactory, Get, Result, Scan} import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.spark.datasources.{HBaseTableScanRDD, HBaseRegion, SerializableConfiguration}
import org.apache.hadoop.hbase.types._ import org.apache.hadoop.hbase.types._
import org.apache.hadoop.hbase.util.{SimplePositionedMutableByteRange, import org.apache.hadoop.hbase.util.{Bytes, PositionedByteRange, SimplePositionedMutableByteRange}
PositionedByteRange, Bytes} import org.apache.hadoop.hbase.{HBaseConfiguration, TableName}
import org.apache.hadoop.hbase.{TableName, HBaseConfiguration}
import org.apache.spark.Logging import org.apache.spark.Logging
import org.apache.spark.rdd.RDD import org.apache.spark.rdd.RDD
import org.apache.spark.sql.types.DataType import org.apache.spark.sql.types.DataType
@ -159,7 +159,7 @@ class DefaultSource extends RelationProvider with Logging {
* connection information * connection information
* @param sqlContext SparkSQL context * @param sqlContext SparkSQL context
*/ */
class HBaseRelation (val tableName:String, case class HBaseRelation (val tableName:String,
val schemaMappingDefinition: val schemaMappingDefinition:
java.util.HashMap[String, SchemaQualifierDefinition], java.util.HashMap[String, SchemaQualifierDefinition],
val batchingNum:Int, val batchingNum:Int,
@ -179,6 +179,9 @@ class HBaseRelation (val tableName:String,
new HBaseContext(sqlContext.sparkContext, config) new HBaseContext(sqlContext.sparkContext, config)
} }
val wrappedConf = new SerializableConfiguration(hbaseContext.config)
def hbaseConf = wrappedConf.value
/** /**
* Generates a Spark SQL schema object so Spark SQL knows what is being * Generates a Spark SQL schema object so Spark SQL knows what is being
* provided by this BaseRelation * provided by this BaseRelation
@ -222,6 +225,7 @@ class HBaseRelation (val tableName:String,
*/ */
override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = { override def buildScan(requiredColumns: Array[String], filters: Array[Filter]): RDD[Row] = {
val pushDownTuple = buildPushDownPredicatesResource(filters) val pushDownTuple = buildPushDownPredicatesResource(filters)
val pushDownRowKeyFilter = pushDownTuple._1 val pushDownRowKeyFilter = pushDownTuple._1
var pushDownDynamicLogicExpression = pushDownTuple._2 var pushDownDynamicLogicExpression = pushDownTuple._2
@ -253,7 +257,6 @@ class HBaseRelation (val tableName:String,
//retain the information for unit testing checks //retain the information for unit testing checks
DefaultSourceStaticUtils.populateLatestExecutionRules(pushDownRowKeyFilter, DefaultSourceStaticUtils.populateLatestExecutionRules(pushDownRowKeyFilter,
pushDownDynamicLogicExpression) pushDownDynamicLogicExpression)
var resultRDD: RDD[Row] = null
val getList = new util.ArrayList[Get]() val getList = new util.ArrayList[Get]()
val rddList = new util.ArrayList[RDD[Row]]() val rddList = new util.ArrayList[RDD[Row]]()
@ -268,77 +271,24 @@ class HBaseRelation (val tableName:String,
getList.add(get) getList.add(get)
}) })
val rangeIt = pushDownRowKeyFilter.ranges.iterator val pushDownFilterJava = if (usePushDownColumnFilter && pushDownDynamicLogicExpression != null) {
Some(new SparkSQLPushDownFilter(pushDownDynamicLogicExpression,
while (rangeIt.hasNext) { valueArray, requiredQualifierDefinitionList))
val r = rangeIt.next()
val scan = new Scan()
scan.setBatch(batchingNum)
scan.setCaching(cachingNum)
requiredQualifierDefinitionList.foreach( d =>
if (d.columnFamilyBytes.length > 0)
scan.addColumn(d.columnFamilyBytes, d.qualifierBytes))
if (usePushDownColumnFilter && pushDownDynamicLogicExpression != null) {
val pushDownFilterJava =
new SparkSQLPushDownFilter(pushDownDynamicLogicExpression,
valueArray, requiredQualifierDefinitionList)
scan.setFilter(pushDownFilterJava)
}
//Check if there is a lower bound
if (r.lowerBound != null && r.lowerBound.length > 0) {
if (r.isLowerBoundEqualTo) {
//HBase startRow is inclusive: Therefore it acts like isLowerBoundEqualTo
// by default
scan.setStartRow(r.lowerBound)
} else { } else {
//Since we don't equalTo we want the next value we need None
// to add another byte to the start key. That new byte will be
// the min byte value.
val newArray = new Array[Byte](r.lowerBound.length + 1)
System.arraycopy(r.lowerBound, 0, newArray, 0, r.lowerBound.length)
//new Min Byte
newArray(r.lowerBound.length) = Byte.MinValue
scan.setStartRow(newArray)
} }
} val hRdd = new HBaseTableScanRDD(this, pushDownFilterJava, requiredQualifierDefinitionList.seq)
pushDownRowKeyFilter.ranges.foreach(hRdd.addRange(_))
//Check if there is a upperBound var resultRDD: RDD[Row] = {
if (r.upperBound != null && r.upperBound.length > 0) { val tmp = hRdd.map{ r =>
if (r.isUpperBoundEqualTo) {
//HBase stopRow is exclusive: therefore it DOESN'T ast like isUpperBoundEqualTo
// by default. So we need to add a new max byte to the stopRow key
val newArray = new Array[Byte](r.upperBound.length + 1)
System.arraycopy(r.upperBound, 0, newArray, 0, r.upperBound.length)
//New Max Bytes
newArray(r.upperBound.length) = Byte.MaxValue
scan.setStopRow(newArray)
} else {
//Here equalTo is false for Upper bound which is exclusive and
// HBase stopRow acts like that by default so no need to mutate the
// rowKey
scan.setStopRow(r.upperBound)
}
}
val rdd = hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan).map(r => {
Row.fromSeq(requiredColumns.map(c => Row.fromSeq(requiredColumns.map(c =>
DefaultSourceStaticUtils.getValue(c, serializableDefinitionMap, r._2))) DefaultSourceStaticUtils.getValue(c, serializableDefinitionMap, r)))
}) }
rddList.add(rdd) if (tmp.partitions.size > 0) {
tmp
} else {
null
} }
//If there is more then one RDD then we have to union them together
for (i <- 0 until rddList.size()) {
if (resultRDD == null) resultRDD = rddList.get(i)
else resultRDD = resultRDD.union(rddList.get(i))
} }
//If there are gets then we can get them from the driver and union that rdd in //If there are gets then we can get them from the driver and union that rdd in

View File

@ -57,7 +57,7 @@ import scala.collection.mutable
* to the working and managing the life cycle of HConnections. * to the working and managing the life cycle of HConnections.
*/ */
class HBaseContext(@transient sc: SparkContext, class HBaseContext(@transient sc: SparkContext,
@transient config: Configuration, @transient val config: Configuration,
val tmpHdfsConfgFile: String = null) val tmpHdfsConfgFile: String = null)
extends Serializable with Logging { extends Serializable with Logging {

View File

@ -0,0 +1,91 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.spark.datasources
import org.apache.hadoop.hbase.spark.SparkSQLPushDownFilter
import org.apache.spark.Partition
import org.apache.hadoop.hbase.spark.hbase._
/**
* The Bound represent the boudary for the scan
*
* @param b The byte array of the bound
* @param inc inclusive or not.
*/
case class Bound(b: Array[Byte], inc: Boolean)
// The non-overlapping ranges we need to scan, if lower is equal to upper, it is a get request
case class Range(lower: Option[Bound], upper: Option[Bound])
object Range {
def apply(region: HBaseRegion): Range = {
Range(region.start.map(Bound(_, true)), if (region.end.get.size == 0) {
None
} else {
region.end.map((Bound(_, false)))
})
}
}
object Ranges {
// We assume that
// 1. r.lower.inc is true, and r.upper.inc is false
// 2. for each range in rs, its upper.inc is false
def and(r: Range, rs: Seq[Range]): Seq[Range] = {
rs.flatMap{ s =>
val lower = s.lower.map { x =>
// the scan has lower bound
r.lower.map { y =>
// the region has lower bound
if (ord.compare(x.b, y.b) < 0) {
// scan lower bound is smaller than region server lower bound
Some(y)
} else {
// scan low bound is greater or equal to region server lower bound
Some(x)
}
}.getOrElse(Some(x))
}.getOrElse(r.lower)
val upper = s.upper.map { x =>
// the scan has upper bound
r.upper.map { y =>
// the region has upper bound
if (ord.compare(x.b, y.b) >= 0) {
// scan upper bound is larger than server upper bound
// but region server scan stop is exclusive. It is OK here.
Some(y)
} else {
// scan upper bound is less or equal to region server upper bound
Some(x)
}
}.getOrElse(Some(x))
}.getOrElse(r.upper)
val c = lower.map { case x =>
upper.map { case y =>
ord.compare(x.b, y.b)
}.getOrElse(-1)
}.getOrElse(-1)
if (c < 0) {
Some(Range(lower, upper))
} else {
None
}
}.seq
}
}

View File

@ -0,0 +1,147 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.spark.datasources
import org.apache.hadoop.hbase.TableName
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.spark.HBaseRelation
// Resource and ReferencedResources are defined for extensibility,
// e.g., consolidate scan and bulkGet in the future work.
// User has to invoke release explicitly to release the resource,
// and potentially parent resources
trait Resource {
def release(): Unit
}
case class ScanResource(tbr: TableResource, rs: ResultScanner) extends Resource {
def release() {
rs.close()
tbr.release()
}
}
trait ReferencedResource {
var count: Int = 0
def init(): Unit
def destroy(): Unit
def acquire() = synchronized {
try {
count += 1
if (count == 1) {
init()
}
} catch {
case e: Throwable =>
release()
throw e
}
}
def release() = synchronized {
count -= 1
if (count == 0) {
destroy()
}
}
def releaseOnException[T](func: => T): T = {
acquire()
val ret = {
try {
func
} catch {
case e: Throwable =>
release()
throw e
}
}
ret
}
}
case class TableResource(relation: HBaseRelation) extends ReferencedResource {
var connection: Connection = _
var table: Table = _
override def init(): Unit = {
connection = ConnectionFactory.createConnection(relation.hbaseConf)
table = connection.getTable(TableName.valueOf(relation.tableName))
}
override def destroy(): Unit = {
if (table != null) {
table.close()
table = null
}
if (connection != null) {
connection.close()
connection = null
}
}
def getScanner(scan: Scan): ScanResource = releaseOnException {
ScanResource(this, table.getScanner(scan))
}
}
case class RegionResource(relation: HBaseRelation) extends ReferencedResource {
var connection: Connection = _
var rl: RegionLocator = _
val regions = releaseOnException {
val keys = rl.getStartEndKeys
keys.getFirst.zip(keys.getSecond)
.zipWithIndex
.map(x =>
HBaseRegion(x._2,
Some(x._1._1),
Some(x._1._2),
Some(rl.getRegionLocation(x._1._1).getHostname)))
}
override def init(): Unit = {
connection = ConnectionFactory.createConnection(relation.hbaseConf)
rl = connection.getRegionLocator(TableName.valueOf(relation.tableName))
}
override def destroy(): Unit = {
if (rl != null) {
rl.close()
rl = null
}
if (connection != null) {
connection.close()
connection = null
}
}
}
object HBaseResources{
implicit def ScanResToScan(sr: ScanResource): ResultScanner = {
sr.rs
}
implicit def TableResToTable(tr: TableResource): Table = {
tr.table
}
implicit def RegionResToRegions(rr: RegionResource): Seq[HBaseRegion] = {
rr.regions
}
}

View File

@ -0,0 +1,199 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.spark.datasources
import java.util.concurrent.atomic.AtomicInteger
import org.apache.hadoop.hbase.TableName
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.filter.Filter
import org.apache.hadoop.hbase.spark.{ScanRange, SchemaQualifierDefinition, HBaseRelation, SparkSQLPushDownFilter}
import org.apache.hadoop.hbase.spark.hbase._
import org.apache.hadoop.hbase.spark.datasources.HBaseResources._
import org.apache.hadoop.hbase.util.Bytes
import org.apache.spark.sql.catalyst.expressions.Row
import org.apache.spark.{TaskContext, Logging, Partition}
import org.apache.spark.rdd.RDD
import scala.collection.mutable
class HBaseTableScanRDD(relation: HBaseRelation,
@transient val filter: Option[SparkSQLPushDownFilter] = None,
val columns: Seq[SchemaQualifierDefinition] = Seq.empty
)extends RDD[Result](relation.sqlContext.sparkContext, Nil) with Logging {
var ranges = Seq.empty[Range]
def addRange(r: ScanRange) = {
val lower = if (r.lowerBound != null && r.lowerBound.length > 0) {
Some(Bound(r.lowerBound, r.isLowerBoundEqualTo))
} else {
None
}
val upper = if (r.upperBound != null && r.upperBound.length > 0) {
if (!r.isUpperBoundEqualTo) {
Some(Bound(r.upperBound, false))
} else {
// HBase stopRow is exclusive: therefore it DOESN'T act like isUpperBoundEqualTo
// by default. So we need to add a new max byte to the stopRow key
val newArray = new Array[Byte](r.upperBound.length + 1)
System.arraycopy(r.upperBound, 0, newArray, 0, r.upperBound.length)
//New Max Bytes
newArray(r.upperBound.length) = ByteMin
Some(Bound(newArray, false))
}
} else {
None
}
ranges :+= Range(lower, upper)
}
override def getPartitions: Array[Partition] = {
val regions = RegionResource(relation)
var idx = 0
logDebug(s"There are ${regions.size} regions")
val ps = regions.flatMap { x =>
val rs = Ranges.and(Range(x), ranges)
if (rs.size > 0) {
if(log.isDebugEnabled) {
rs.foreach(x => logDebug(x.toString))
}
idx += 1
Some(HBaseScanPartition(idx - 1, x, rs, SerializedFilter.toSerializedTypedFilter(filter)))
} else {
None
}
}.toArray
regions.release()
ps.asInstanceOf[Array[Partition]]
}
override def getPreferredLocations(split: Partition): Seq[String] = {
split.asInstanceOf[HBaseScanPartition].regions.server.map {
identity
}.toSeq
}
private def buildScan(range: Range,
filter: Option[SparkSQLPushDownFilter],
columns: Seq[SchemaQualifierDefinition]): Scan = {
val scan = (range.lower, range.upper) match {
case (Some(Bound(a, b)), Some(Bound(c, d))) => new Scan(a, c)
case (None, Some(Bound(c, d))) => new Scan(Array[Byte](), c)
case (Some(Bound(a, b)), None) => new Scan(a)
case (None, None) => new Scan()
}
columns.foreach { d =>
if (d.columnFamilyBytes.length > 0) {
scan.addColumn(d.columnFamilyBytes, d.qualifierBytes)
}
}
scan.setBatch(relation.batchingNum)
scan.setCaching(relation.cachingNum)
filter.foreach(scan.setFilter(_))
scan
}
private def toResultIterator(scanner: ScanResource): Iterator[Result] = {
val iterator = new Iterator[Result] {
var cur: Option[Result] = None
override def hasNext: Boolean = {
if (cur.isEmpty) {
val r = scanner.next()
if (r == null) {
rddResources.release(scanner)
} else {
cur = Some(r)
}
}
cur.isDefined
}
override def next(): Result = {
hasNext
val ret = cur.get
cur = None
ret
}
}
iterator
}
lazy val rddResources = RDDResources(new mutable.HashSet[Resource]())
private def close() {
rddResources.release()
}
override def compute(split: Partition, context: TaskContext): Iterator[Result] = {
val partition = split.asInstanceOf[HBaseScanPartition]
val scans = partition.scanRanges
.map(buildScan(_, SerializedFilter.fromSerializedFilter(partition.sf), columns))
val tableResource = TableResource(relation)
context.addTaskCompletionListener(context => close())
val sIts = scans.par
.map(tableResource.getScanner(_))
.map(toResultIterator(_))
.fold(Iterator.empty: Iterator[Result]){ case (x, y) =>
x ++ y
}
sIts
}
}
case class SerializedFilter(b: Option[Array[Byte]])
object SerializedFilter {
def toSerializedTypedFilter(f: Option[SparkSQLPushDownFilter]): SerializedFilter = {
SerializedFilter(f.map(_.toByteArray))
}
def fromSerializedFilter(sf: SerializedFilter): Option[SparkSQLPushDownFilter] = {
sf.b.map(SparkSQLPushDownFilter.parseFrom(_))
}
}
private[hbase] case class HBaseRegion(
override val index: Int,
val start: Option[HBaseType] = None,
val end: Option[HBaseType] = None,
val server: Option[String] = None) extends Partition
private[hbase] case class HBaseScanPartition(
override val index: Int,
val regions: HBaseRegion,
val scanRanges: Seq[Range],
val sf: SerializedFilter) extends Partition
case class RDDResources(set: mutable.HashSet[Resource]) {
def addResource(s: Resource) {
set += s
}
def release() {
set.foreach(release(_))
}
def release(rs: Resource) {
try {
rs.release()
} finally {
set.remove(rs)
}
}
}

View File

@ -0,0 +1,46 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.spark.datasources
import java.io.{IOException, ObjectInputStream, ObjectOutputStream}
import org.apache.hadoop.conf.Configuration
import org.apache.spark.util.Utils
import scala.util.control.NonFatal
class SerializableConfiguration(@transient var value: Configuration) extends Serializable {
private def writeObject(out: ObjectOutputStream): Unit = tryOrIOException {
out.defaultWriteObject()
value.write(out)
}
private def readObject(in: ObjectInputStream): Unit = tryOrIOException {
value = new Configuration(false)
value.readFields(in)
}
def tryOrIOException(block: => Unit) {
try {
block
} catch {
case e: IOException => throw e
case NonFatal(t) => throw new IOException(t)
}
}
}

View File

@ -0,0 +1,36 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.spark
import org.apache.hadoop.hbase.util.Bytes
import scala.math.Ordering
package object hbase {
type HBaseType = Array[Byte]
val ByteMax = -1.asInstanceOf[Byte]
val ByteMin = 0.asInstanceOf[Byte]
val ord: Ordering[HBaseType] = new Ordering[HBaseType] {
def compare(x: Array[Byte], y: Array[Byte]): Int = {
return Bytes.compareTo(x, y)
}
}
//Do not use BinaryType.ordering
implicit val order: Ordering[HBaseType] = ord
}