HBASE-14796 Enhance the Gets in the connector (Zhan Zhang)

This commit is contained in:
tedyu 2015-12-28 15:48:10 -08:00
parent 2fba25b66a
commit 6868c63660
5 changed files with 127 additions and 38 deletions

View File

@ -159,6 +159,10 @@ case class HBaseRelation (val tableName:String,
.getOrElse(sqlContext.sparkContext.getConf.getInt( .getOrElse(sqlContext.sparkContext.getConf.getInt(
HBaseSparkConf.BATCH_NUM, HBaseSparkConf.defaultBatchNum)) HBaseSparkConf.BATCH_NUM, HBaseSparkConf.defaultBatchNum))
val bulkGetSize = parameters.get(HBaseSparkConf.BULKGET_SIZE).map(_.toInt)
.getOrElse(sqlContext.sparkContext.getConf.getInt(
HBaseSparkConf.BULKGET_SIZE, HBaseSparkConf.defaultBulkGetSize))
//create or get latest HBaseContext //create or get latest HBaseContext
@transient val hbaseContext:HBaseContext = if (useHBaseContext) { @transient val hbaseContext:HBaseContext = if (useHBaseContext) {
LatestHBaseContextCache.latest LatestHBaseContextCache.latest
@ -267,6 +271,7 @@ case class HBaseRelation (val tableName:String,
None None
} }
val hRdd = new HBaseTableScanRDD(this, pushDownFilterJava, requiredQualifierDefinitionList.seq) val hRdd = new HBaseTableScanRDD(this, pushDownFilterJava, requiredQualifierDefinitionList.seq)
pushDownRowKeyFilter.points.foreach(hRdd.addPoint(_))
pushDownRowKeyFilter.ranges.foreach(hRdd.addRange(_)) pushDownRowKeyFilter.ranges.foreach(hRdd.addRange(_))
var resultRDD: RDD[Row] = { var resultRDD: RDD[Row] = {
val tmp = hRdd.map{ r => val tmp = hRdd.map{ r =>
@ -280,34 +285,6 @@ case class HBaseRelation (val tableName:String,
} }
} }
//If there are gets then we can get them from the driver and union that rdd in
// with the rest of the values.
if (getList.size() > 0) {
val connection =
ConnectionFactory.createConnection(hbaseContext.tmpHdfsConfiguration)
try {
val table = connection.getTable(TableName.valueOf(tableName))
try {
val results = table.get(getList)
val rowList = mutable.MutableList[Row]()
for (i <- 0 until results.length) {
val rowArray = requiredColumns.map(c =>
DefaultSourceStaticUtils.getValue(c, schemaMappingDefinition, results(i)))
rowList += Row.fromSeq(rowArray)
}
val getRDD = sqlContext.sparkContext.parallelize(rowList)
if (resultRDD == null) resultRDD = getRDD
else {
resultRDD = resultRDD.union(getRDD)
}
} finally {
table.close()
}
} finally {
connection.close()
}
}
if (resultRDD == null) { if (resultRDD == null) {
val scan = new Scan() val scan = new Scan()
scan.setCacheBlocks(blockCacheEnable) scan.setCacheBlocks(blockCacheEnable)

View File

@ -87,3 +87,27 @@ object Ranges {
} }
} }
object Points {
def and(r: Range, ps: Seq[Array[Byte]]): Seq[Array[Byte]] = {
ps.flatMap { p =>
if (ord.compare(r.lower.get.b, p) <= 0) {
// if region lower bound is less or equal to the point
if (r.upper.isDefined) {
// if region upper bound is defined
if (ord.compare(r.upper.get.b, p) > 0) {
// if the upper bound is greater than the point (because upper bound is exclusive)
Some(p)
} else {
None
}
} else {
// if the region upper bound is not defined (infinity)
Some(p)
}
} else {
None
}
}
}
}

View File

@ -38,6 +38,12 @@ case class ScanResource(tbr: TableResource, rs: ResultScanner) extends Resource
} }
} }
case class GetResource(tbr: TableResource, rs: Array[Result]) extends Resource {
def release() {
tbr.release()
}
}
trait ReferencedResource { trait ReferencedResource {
var count: Int = 0 var count: Int = 0
def init(): Unit def init(): Unit
@ -100,6 +106,10 @@ case class TableResource(relation: HBaseRelation) extends ReferencedResource {
def getScanner(scan: Scan): ScanResource = releaseOnException { def getScanner(scan: Scan): ScanResource = releaseOnException {
ScanResource(this, table.getScanner(scan)) ScanResource(this, table.getScanner(scan))
} }
def get(list: java.util.List[org.apache.hadoop.hbase.client.Get]) = releaseOnException {
GetResource(this, table.get(list))
}
} }
case class RegionResource(relation: HBaseRelation) extends ReferencedResource { case class RegionResource(relation: HBaseRelation) extends ReferencedResource {
@ -138,6 +148,10 @@ object HBaseResources{
sr.rs sr.rs
} }
implicit def GetResToResult(gr: GetResource): Array[Result] = {
gr.rs
}
implicit def TableResToTable(tr: TableResource): Table = { implicit def TableResToTable(tr: TableResource): Table = {
tr.table tr.table
} }

View File

@ -29,4 +29,6 @@ object HBaseSparkConf{
val defaultCachingSize = 1000 val defaultCachingSize = 1000
val BATCH_NUM = "spark.hbase.batchNum" val BATCH_NUM = "spark.hbase.batchNum"
val defaultBatchNum = 1000 val defaultBatchNum = 1000
val BULKGET_SIZE = "spark.hbase.bulkGetSize"
val defaultBulkGetSize = 1000
} }

View File

@ -17,6 +17,8 @@
package org.apache.hadoop.hbase.spark.datasources package org.apache.hadoop.hbase.spark.datasources
import java.util.ArrayList
import org.apache.hadoop.hbase.client._ import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.spark.{ScanRange, SchemaQualifierDefinition, HBaseRelation, SparkSQLPushDownFilter} import org.apache.hadoop.hbase.spark.{ScanRange, SchemaQualifierDefinition, HBaseRelation, SparkSQLPushDownFilter}
import org.apache.hadoop.hbase.spark.hbase._ import org.apache.hadoop.hbase.spark.hbase._
@ -32,7 +34,12 @@ class HBaseTableScanRDD(relation: HBaseRelation,
val columns: Seq[SchemaQualifierDefinition] = Seq.empty val columns: Seq[SchemaQualifierDefinition] = Seq.empty
)extends RDD[Result](relation.sqlContext.sparkContext, Nil) with Logging { )extends RDD[Result](relation.sqlContext.sparkContext, Nil) with Logging {
private def sparkConf = SparkEnv.get.conf private def sparkConf = SparkEnv.get.conf
var ranges = Seq.empty[Range] @transient var ranges = Seq.empty[Range]
@transient var points = Seq.empty[Array[Byte]]
def addPoint(p: Array[Byte]) {
points :+= p
}
def addRange(r: ScanRange) = { def addRange(r: ScanRange) = {
val lower = if (r.lowerBound != null && r.lowerBound.length > 0) { val lower = if (r.lowerBound != null && r.lowerBound.length > 0) {
Some(Bound(r.lowerBound, r.isLowerBoundEqualTo)) Some(Bound(r.lowerBound, r.isLowerBoundEqualTo))
@ -65,12 +72,13 @@ class HBaseTableScanRDD(relation: HBaseRelation,
logDebug(s"There are ${regions.size} regions") logDebug(s"There are ${regions.size} regions")
val ps = regions.flatMap { x => val ps = regions.flatMap { x =>
val rs = Ranges.and(Range(x), ranges) val rs = Ranges.and(Range(x), ranges)
if (rs.size > 0) { val ps = Points.and(Range(x), points)
if (rs.size > 0 || ps.size > 0) {
if(log.isDebugEnabled) { if(log.isDebugEnabled) {
rs.foreach(x => logDebug(x.toString)) rs.foreach(x => logDebug(x.toString))
} }
idx += 1 idx += 1
Some(HBaseScanPartition(idx - 1, x, rs, SerializedFilter.toSerializedTypedFilter(filter))) Some(HBaseScanPartition(idx - 1, x, rs, ps, SerializedFilter.toSerializedTypedFilter(filter)))
} else { } else {
None None
} }
@ -86,6 +94,57 @@ class HBaseTableScanRDD(relation: HBaseRelation,
}.toSeq }.toSeq
} }
private def buildGets(
tbr: TableResource,
g: Seq[Array[Byte]],
filter: Option[SparkSQLPushDownFilter],
columns: Seq[SchemaQualifierDefinition]): Iterator[Result] = {
g.grouped(relation.bulkGetSize).flatMap{ x =>
val gets = new ArrayList[Get]()
x.foreach{ y =>
val g = new Get(y)
columns.foreach { d =>
if (d.columnFamilyBytes.length > 0) {
g.addColumn(d.columnFamilyBytes, d.qualifierBytes)
}
}
filter.foreach(g.setFilter(_))
gets.add(g)
}
val tmp = tbr.get(gets)
rddResources.addResource(tmp)
toResultIterator(tmp)
}
}
private def toResultIterator(result: GetResource): Iterator[Result] = {
val iterator = new Iterator[Result] {
var idx = 0
var cur: Option[Result] = None
override def hasNext: Boolean = {
while(idx < result.length && cur.isEmpty) {
val r = result(idx)
idx += 1
if (!r.isEmpty) {
cur = Some(r)
}
}
if (cur.isEmpty) {
rddResources.release(result)
}
cur.isDefined
}
override def next(): Result = {
hasNext
val ret = cur.get
cur = None
ret
}
}
iterator
}
private def buildScan(range: Range, private def buildScan(range: Range,
filter: Option[SparkSQLPushDownFilter], filter: Option[SparkSQLPushDownFilter],
columns: Seq[SchemaQualifierDefinition]): Scan = { columns: Seq[SchemaQualifierDefinition]): Scan = {
@ -130,6 +189,7 @@ class HBaseTableScanRDD(relation: HBaseRelation,
} }
iterator iterator
} }
lazy val rddResources = RDDResources(new mutable.HashSet[Resource]()) lazy val rddResources = RDDResources(new mutable.HashSet[Resource]())
private def close() { private def close() {
@ -138,18 +198,29 @@ class HBaseTableScanRDD(relation: HBaseRelation,
override def compute(split: Partition, context: TaskContext): Iterator[Result] = { override def compute(split: Partition, context: TaskContext): Iterator[Result] = {
val partition = split.asInstanceOf[HBaseScanPartition] val partition = split.asInstanceOf[HBaseScanPartition]
val filter = SerializedFilter.fromSerializedFilter(partition.sf)
val scans = partition.scanRanges val scans = partition.scanRanges
.map(buildScan(_, SerializedFilter.fromSerializedFilter(partition.sf), columns)) .map(buildScan(_, filter, columns))
val tableResource = TableResource(relation) val tableResource = TableResource(relation)
context.addTaskCompletionListener(context => close()) context.addTaskCompletionListener(context => close())
val sIts = scans.par val points = partition.points
.map(tableResource.getScanner(_)) val gIt: Iterator[Result] = {
.map(toResultIterator(_)) if (points.isEmpty) {
Iterator.empty: Iterator[Result]
} else {
buildGets(tableResource, points, filter, columns)
}
}
val rIts = scans.par
.map { scan =>
val scanner = tableResource.getScanner(scan)
rddResources.addResource(scanner)
scanner
}.map(toResultIterator(_))
.fold(Iterator.empty: Iterator[Result]){ case (x, y) => .fold(Iterator.empty: Iterator[Result]){ case (x, y) =>
x ++ y x ++ y
} } ++ gIt
sIts rIts
} }
} }
@ -176,6 +247,7 @@ private[hbase] case class HBaseScanPartition(
override val index: Int, override val index: Int,
val regions: HBaseRegion, val regions: HBaseRegion,
val scanRanges: Seq[Range], val scanRanges: Seq[Range],
val points: Seq[Array[Byte]],
val sf: SerializedFilter) extends Partition val sf: SerializedFilter) extends Partition
case class RDDResources(set: mutable.HashSet[Resource]) { case class RDDResources(set: mutable.HashSet[Resource]) {