HBASE-16638 Reduce the number of Connection's created in classes of hbase-spark module (Weiqing Yang)

This commit is contained in:
tedyu 2016-10-11 09:04:26 -07:00
parent 4d09116695
commit 9d304d3b2d
7 changed files with 470 additions and 18 deletions

View File

@ -155,9 +155,10 @@ case class HBaseRelation (
if (numReg > 3) {
val tName = TableName.valueOf(catalog.name)
val cfs = catalog.getColumnFamilies
val connection = ConnectionFactory.createConnection(hbaseConf)
val connection = HBaseConnectionCache.getConnection(hbaseConf)
// Initialize hBase table if necessary
val admin = connection.getAdmin()
val admin = connection.getAdmin
try {
if (!admin.isTableAvailable(tName)) {
val tableDesc = new HTableDescriptor(tName)

View File

@ -0,0 +1,243 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.spark
import java.io.IOException
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hbase.client.{Admin, Connection, ConnectionFactory, RegionLocator, Table}
import org.apache.hadoop.hbase.ipc.RpcControllerFactory
import org.apache.hadoop.hbase.security.{User, UserProvider}
import org.apache.hadoop.hbase.spark.datasources.HBaseSparkConf
import org.apache.hadoop.hbase.{HConstants, TableName}
import org.apache.spark.Logging
import scala.collection.mutable
private[spark] object HBaseConnectionCache extends Logging {
// A hashmap of Spark-HBase connections. Key is HBaseConnectionKey.
val connectionMap = new mutable.HashMap[HBaseConnectionKey, SmartConnection]()
// in milliseconds
private final val DEFAULT_TIME_OUT: Long = HBaseSparkConf.connectionCloseDelay
private var timeout = DEFAULT_TIME_OUT
private var closed: Boolean = false
var housekeepingThread = new Thread(new Runnable {
override def run() {
while (true) {
try {
Thread.sleep(timeout)
} catch {
case e: InterruptedException =>
// setTimeout() and close() may interrupt the sleep and it's safe
// to ignore the exception
}
if (closed)
return
performHousekeeping(false)
}
}
})
housekeepingThread.setDaemon(true)
housekeepingThread.start()
def close(): Unit = {
try {
connectionMap.synchronized {
if (closed)
return
closed = true
housekeepingThread.interrupt()
housekeepingThread = null
HBaseConnectionCache.performHousekeeping(true)
}
} catch {
case e: Exception => logWarning("Error in finalHouseKeeping", e)
}
}
def performHousekeeping(forceClean: Boolean) = {
val tsNow: Long = System.currentTimeMillis()
connectionMap.synchronized {
connectionMap.foreach {
x => {
if(x._2.refCount < 0) {
logError("Bug to be fixed: negative refCount")
}
if(forceClean || ((x._2.refCount <= 0) && (tsNow - x._2.timestamp > timeout))) {
try{
x._2.connection.close()
} catch {
case e: IOException => logWarning("Fail to close connection", e)
}
connectionMap.remove(x._1)
}
}
}
}
}
// For testing purpose only
def getConnection(key: HBaseConnectionKey, conn: => Connection): SmartConnection = {
connectionMap.synchronized {
if (closed)
return null
val sc = connectionMap.getOrElseUpdate(key, new SmartConnection(conn))
sc.refCount += 1
sc
}
}
def getConnection(conf: Configuration): SmartConnection =
getConnection(new HBaseConnectionKey(conf), ConnectionFactory.createConnection(conf))
// For testing purpose only
def setTimeout(to: Long): Unit = {
connectionMap.synchronized {
if (closed)
return
timeout = to
housekeepingThread.interrupt()
}
}
}
private[hbase] case class SmartConnection (
connection: Connection, var refCount: Int = 0, var timestamp: Long = 0) {
def getTable(tableName: TableName): Table = connection.getTable(tableName)
def getRegionLocator(tableName: TableName): RegionLocator = connection.getRegionLocator(tableName)
def isClosed: Boolean = connection.isClosed
def getAdmin: Admin = connection.getAdmin
def close() = {
HBaseConnectionCache.connectionMap.synchronized {
refCount -= 1
if(refCount <= 0)
timestamp = System.currentTimeMillis()
}
}
}
/**
* Denotes a unique key to an HBase Connection instance.
* Please refer to 'org.apache.hadoop.hbase.client.HConnectionKey'.
*
* In essence, this class captures the properties in Configuration
* that may be used in the process of establishing a connection.
*
*/
class HBaseConnectionKey(c: Configuration) extends Logging {
val conf: Configuration = c
val CONNECTION_PROPERTIES: Array[String] = Array[String](
HConstants.ZOOKEEPER_QUORUM,
HConstants.ZOOKEEPER_ZNODE_PARENT,
HConstants.ZOOKEEPER_CLIENT_PORT,
HConstants.ZOOKEEPER_RECOVERABLE_WAITTIME,
HConstants.HBASE_CLIENT_PAUSE,
HConstants.HBASE_CLIENT_RETRIES_NUMBER,
HConstants.HBASE_RPC_TIMEOUT_KEY,
HConstants.HBASE_META_SCANNER_CACHING,
HConstants.HBASE_CLIENT_INSTANCE_ID,
HConstants.RPC_CODEC_CONF_KEY,
HConstants.USE_META_REPLICAS,
RpcControllerFactory.CUSTOM_CONTROLLER_CONF_KEY)
var username: String = _
var m_properties = mutable.HashMap.empty[String, String]
if (conf != null) {
for (property <- CONNECTION_PROPERTIES) {
val value: String = conf.get(property)
if (value != null) {
m_properties.+=((property, value))
}
}
try {
val provider: UserProvider = UserProvider.instantiate(conf)
val currentUser: User = provider.getCurrent
if (currentUser != null) {
username = currentUser.getName
}
}
catch {
case e: IOException => {
logWarning("Error obtaining current user, skipping username in HBaseConnectionKey", e)
}
}
}
// make 'properties' immutable
val properties = m_properties.toMap
override def hashCode: Int = {
val prime: Int = 31
var result: Int = 1
if (username != null) {
result = username.hashCode
}
for (property <- CONNECTION_PROPERTIES) {
val value: Option[String] = properties.get(property)
if (value.isDefined) {
result = prime * result + value.hashCode
}
}
result
}
override def equals(obj: Any): Boolean = {
if (obj == null) return false
if (getClass ne obj.getClass) return false
val that: HBaseConnectionKey = obj.asInstanceOf[HBaseConnectionKey]
if (this.username != null && !(this.username == that.username)) {
return false
}
else if (this.username == null && that.username != null) {
return false
}
if (this.properties == null) {
if (that.properties != null) {
return false
}
}
else {
if (that.properties == null) {
return false
}
var flag: Boolean = true
for (property <- CONNECTION_PROPERTIES) {
val thisValue: Option[String] = this.properties.get(property)
val thatValue: Option[String] = that.properties.get(property)
flag = true
if (thisValue eq thatValue) {
flag = false //continue, so make flag be false
}
if (flag && (thisValue == null || !(thisValue == thatValue))) {
return false
}
}
}
true
}
override def toString: String = {
"HBaseConnectionKey{" + "properties=" + properties + ", username='" + username + '\'' + '}'
}
}

View File

@ -482,9 +482,9 @@ class HBaseContext(@transient sc: SparkContext,
applyCreds
// specify that this is a proxy user
val connection = ConnectionFactory.createConnection(config)
f(it, connection)
connection.close()
val smartConn = HBaseConnectionCache.getConnection(config)
f(it, smartConn.connection)
smartConn.close()
}
private def getConf(configBroadcast: Broadcast[SerializableWritable[Configuration]]):
@ -522,11 +522,10 @@ class HBaseContext(@transient sc: SparkContext,
val config = getConf(configBroadcast)
applyCreds
val connection = ConnectionFactory.createConnection(config)
val res = mp(it, connection)
connection.close()
val smartConn = HBaseConnectionCache.getConnection(config)
val res = mp(it, smartConn.connection)
smartConn.close()
res
}
/**
@ -619,7 +618,7 @@ class HBaseContext(@transient sc: SparkContext,
compactionExclude: Boolean = false,
maxSize:Long = HConstants.DEFAULT_MAX_FILE_SIZE):
Unit = {
val conn = ConnectionFactory.createConnection(config)
val conn = HBaseConnectionCache.getConnection(config)
val regionLocator = conn.getRegionLocator(tableName)
val startKeys = regionLocator.getStartKeys
val defaultCompressionStr = config.get("hfile.compression",
@ -742,7 +741,7 @@ class HBaseContext(@transient sc: SparkContext,
compactionExclude: Boolean = false,
maxSize:Long = HConstants.DEFAULT_MAX_FILE_SIZE):
Unit = {
val conn = ConnectionFactory.createConnection(config)
val conn = HBaseConnectionCache.getConnection(config)
val regionLocator = conn.getRegionLocator(tableName)
val startKeys = regionLocator.getStartKeys
val defaultCompressionStr = config.get("hfile.compression",

View File

@ -19,7 +19,8 @@ package org.apache.hadoop.hbase.spark.datasources
import org.apache.hadoop.hbase.TableName
import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.spark.HBaseRelation
import org.apache.hadoop.hbase.spark.{HBaseConnectionKey, SmartConnection,
HBaseConnectionCache, HBaseRelation}
import scala.language.implicitConversions
// Resource and ReferencedResources are defined for extensibility,
@ -84,11 +85,11 @@ trait ReferencedResource {
}
case class TableResource(relation: HBaseRelation) extends ReferencedResource {
var connection: Connection = _
var connection: SmartConnection = _
var table: Table = _
override def init(): Unit = {
connection = ConnectionFactory.createConnection(relation.hbaseConf)
connection = HBaseConnectionCache.getConnection(relation.hbaseConf)
table = connection.getTable(TableName.valueOf(relation.tableName))
}
@ -113,7 +114,7 @@ case class TableResource(relation: HBaseRelation) extends ReferencedResource {
}
case class RegionResource(relation: HBaseRelation) extends ReferencedResource {
var connection: Connection = _
var connection: SmartConnection = _
var rl: RegionLocator = _
val regions = releaseOnException {
val keys = rl.getStartEndKeys
@ -127,7 +128,7 @@ case class RegionResource(relation: HBaseRelation) extends ReferencedResource {
}
override def init(): Unit = {
connection = ConnectionFactory.createConnection(relation.hbaseConf)
connection = HBaseConnectionCache.getConnection(relation.hbaseConf)
rl = connection.getRegionLocator(TableName.valueOf(relation.tableName))
}

View File

@ -43,4 +43,7 @@ object HBaseSparkConf{
val MAX_VERSIONS = "hbase.spark.query.maxVersions"
val ENCODER = "hbase.spark.query.encoder"
val defaultEncoder = classOf[NaiveEncoder].getCanonicalName
// in milliseconds
val connectionCloseDelay = 10 * 60 * 1000
}

View File

@ -23,6 +23,7 @@ import org.apache.hadoop.hbase.client._
import org.apache.hadoop.hbase.spark._
import org.apache.hadoop.hbase.spark.hbase._
import org.apache.hadoop.hbase.spark.datasources.HBaseResources._
import org.apache.hadoop.hbase.util.ShutdownHookManager
import org.apache.spark.sql.datasources.hbase.Field
import org.apache.spark.{SparkEnv, TaskContext, Logging, Partition}
import org.apache.spark.rdd.RDD
@ -85,10 +86,14 @@ class HBaseTableScanRDD(relation: HBaseRelation,
}
}.toArray
regions.release()
ShutdownHookManager.affixShutdownHook( new Thread() {
override def run() {
HBaseConnectionCache.close()
}
}, 0)
ps.asInstanceOf[Array[Partition]]
}
override def getPreferredLocations(split: Partition): Seq[String] = {
split.asInstanceOf[HBaseScanPartition].regions.server.map {
identity
@ -148,7 +153,6 @@ class HBaseTableScanRDD(relation: HBaseRelation,
iterator
}
private def buildScan(range: Range,
filter: Option[SparkSQLPushDownFilter],
columns: Seq[Field]): Scan = {
@ -226,6 +230,11 @@ class HBaseTableScanRDD(relation: HBaseRelation,
.fold(Iterator.empty: Iterator[Result]){ case (x, y) =>
x ++ y
} ++ gIt
ShutdownHookManager.affixShutdownHook( new Thread() {
override def run() {
HBaseConnectionCache.close()
}
}, 0)
rIts
}

View File

@ -0,0 +1,196 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.spark
import java.util.concurrent.ExecutorService
import scala.util.Random
import org.apache.hadoop.hbase.client.{BufferedMutator, Table, RegionLocator,
Connection, BufferedMutatorParams, Admin}
import org.apache.hadoop.conf.Configuration
import org.apache.hadoop.hbase.TableName
import org.apache.spark.Logging
import org.scalatest.FunSuite
case class HBaseConnectionKeyMocker (confId: Int) extends HBaseConnectionKey (null) {
override def hashCode: Int = {
confId
}
override def equals(obj: Any): Boolean = {
if(!obj.isInstanceOf[HBaseConnectionKeyMocker])
false
else
confId == obj.asInstanceOf[HBaseConnectionKeyMocker].confId
}
}
class ConnectionMocker extends Connection {
var isClosed: Boolean = false
def getRegionLocator (tableName: TableName): RegionLocator = null
def getConfiguration: Configuration = null
def getTable (tableName: TableName): Table = null
def getTable(tableName: TableName, pool: ExecutorService): Table = null
def getBufferedMutator (params: BufferedMutatorParams): BufferedMutator = null
def getBufferedMutator (tableName: TableName): BufferedMutator = null
def getAdmin: Admin = null
def close(): Unit = {
if (isClosed)
throw new IllegalStateException()
isClosed = true
}
def isAborted: Boolean = true
def abort(why: String, e: Throwable) = {}
}
class HBaseConnectionCacheSuit extends FunSuite with Logging {
/*
* These tests must be performed sequentially as they operate with an
* unique running thread and resource.
*
* It looks there's no way to tell FunSuite to do so, so making those
* test cases normal functions which are called sequentially in a single
* test case.
*/
test("all test cases") {
testBasic()
testWithPressureWithoutClose()
testWithPressureWithClose()
}
def testBasic() {
HBaseConnectionCache.setTimeout(1 * 1000)
val connKeyMocker1 = new HBaseConnectionKeyMocker(1)
val connKeyMocker1a = new HBaseConnectionKeyMocker(1)
val connKeyMocker2 = new HBaseConnectionKeyMocker(2)
val c1 = HBaseConnectionCache
.getConnection(connKeyMocker1, new ConnectionMocker)
val c1a = HBaseConnectionCache
.getConnection(connKeyMocker1a, new ConnectionMocker)
HBaseConnectionCache.connectionMap.synchronized {
assert(HBaseConnectionCache.connectionMap.size === 1)
}
val c2 = HBaseConnectionCache
.getConnection(connKeyMocker2, new ConnectionMocker)
HBaseConnectionCache.connectionMap.synchronized {
assert(HBaseConnectionCache.connectionMap.size === 2)
}
c1.close()
HBaseConnectionCache.connectionMap.synchronized {
assert(HBaseConnectionCache.connectionMap.size === 2)
}
c1a.close()
HBaseConnectionCache.connectionMap.synchronized {
assert(HBaseConnectionCache.connectionMap.size === 2)
}
Thread.sleep(3 * 1000) // Leave housekeeping thread enough time
HBaseConnectionCache.connectionMap.synchronized {
assert(HBaseConnectionCache.connectionMap.size === 1)
assert(HBaseConnectionCache.connectionMap.iterator.next()._1
.asInstanceOf[HBaseConnectionKeyMocker].confId === 2)
}
c2.close()
}
def testWithPressureWithoutClose() {
class TestThread extends Runnable {
override def run() {
for (i <- 0 to 999) {
val c = HBaseConnectionCache.getConnection(
new HBaseConnectionKeyMocker(Random.nextInt(10)), new ConnectionMocker)
}
}
}
HBaseConnectionCache.setTimeout(500)
val threads: Array[Thread] = new Array[Thread](100)
for (i <- 0 to 99) {
threads.update(i, new Thread(new TestThread()))
threads(i).run()
}
try {
threads.foreach { x => x.join() }
} catch {
case e: InterruptedException => println(e.getMessage)
}
Thread.sleep(1000)
HBaseConnectionCache.connectionMap.synchronized {
assert(HBaseConnectionCache.connectionMap.size === 10)
var totalRc : Int = 0
HBaseConnectionCache.connectionMap.foreach {
x => totalRc += x._2.refCount
}
assert(totalRc === 100 * 1000)
HBaseConnectionCache.connectionMap.foreach {
x => {
x._2.refCount = 0
x._2.timestamp = System.currentTimeMillis() - 1000
}
}
}
Thread.sleep(1000)
assert(HBaseConnectionCache.connectionMap.size === 0)
}
def testWithPressureWithClose() {
class TestThread extends Runnable {
override def run() {
for (i <- 0 to 999) {
val c = HBaseConnectionCache.getConnection(
new HBaseConnectionKeyMocker(Random.nextInt(10)), new ConnectionMocker)
Thread.`yield`()
c.close()
}
}
}
HBaseConnectionCache.setTimeout(3 * 1000)
val threads: Array[Thread] = new Array[Thread](100)
for (i <- threads.indices) {
threads.update(i, new Thread(new TestThread()))
threads(i).run()
}
try {
threads.foreach { x => x.join() }
} catch {
case e: InterruptedException => println(e.getMessage)
}
HBaseConnectionCache.connectionMap.synchronized {
assert(HBaseConnectionCache.connectionMap.size === 10)
}
Thread.sleep(6 * 1000)
HBaseConnectionCache.connectionMap.synchronized {
assert(HBaseConnectionCache.connectionMap.size === 0)
}
}
}