HDFS-9228. libhdfs++ should respect NN retry configuration settings. Contributed by Bob Hansen

This commit is contained in:
James 2015-12-07 14:37:52 -05:00 committed by James Clampffer
parent eb70a64362
commit 58f2c7183e
16 changed files with 1175 additions and 266 deletions

View File

@ -33,7 +33,7 @@ struct Options {
/**
* Maximum number of retries for RPC operations
**/
const static int NO_RPC_RETRY = -1;
const static int kNoRetry = -1;
int max_rpc_retries;
/**

View File

@ -15,4 +15,4 @@
# specific language governing permissions and limitations
# under the License.
add_library(common base64.cc status.cc sasl_digest_md5.cc hdfs_public_api.cc options.cc configuration.cc util.cc)
add_library(common base64.cc status.cc sasl_digest_md5.cc hdfs_public_api.cc options.cc configuration.cc util.cc retry_policy.cc)

View File

@ -20,6 +20,6 @@
namespace hdfs {
Options::Options() : rpc_timeout(30000), max_rpc_retries(0),
Options::Options() : rpc_timeout(30000), max_rpc_retries(kNoRetry),
rpc_retry_delay_ms(10000), host_exclusion_duration(600000) {}
}

View File

@ -0,0 +1,47 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/retry_policy.h"
namespace hdfs {
RetryAction FixedDelayRetryPolicy::ShouldRetry(
const Status &s, uint64_t retries, uint64_t failovers,
bool isIdempotentOrAtMostOnce) const {
(void)s;
(void)isIdempotentOrAtMostOnce;
if (retries + failovers >= max_retries_) {
return RetryAction::fail(
"Failovers (" + std::to_string(retries + failovers) +
") exceeded maximum retries (" + std::to_string(max_retries_) + ")");
} else {
return RetryAction::retry(delay_);
}
}
RetryAction NoRetryPolicy::ShouldRetry(
const Status &s, uint64_t retries, uint64_t failovers,
bool isIdempotentOrAtMostOnce) const {
(void)s;
(void)retries;
(void)failovers;
(void)isIdempotentOrAtMostOnce;
return RetryAction::fail("No retry");
}
}

View File

@ -0,0 +1,91 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#ifndef LIB_COMMON_RETRY_POLICY_H_
#define LIB_COMMON_RETRY_POLICY_H_
#include "common/util.h"
#include <string>
#include <stdint.h>
namespace hdfs {
class RetryAction {
public:
enum RetryDecision { FAIL, RETRY, FAILOVER_AND_RETRY };
RetryDecision action;
uint64_t delayMillis;
std::string reason;
RetryAction(RetryDecision in_action, uint64_t in_delayMillis,
const std::string &in_reason)
: action(in_action), delayMillis(in_delayMillis), reason(in_reason) {}
static RetryAction fail(const std::string &reason) {
return RetryAction(FAIL, 0, reason);
}
static RetryAction retry(uint64_t delay) {
return RetryAction(RETRY, delay, "");
}
static RetryAction failover() {
return RetryAction(FAILOVER_AND_RETRY, 0, "");
}
};
class RetryPolicy {
public:
/*
* If there was an error in communications, responds with the configured
* action to take.
*/
virtual RetryAction ShouldRetry(const Status &s, uint64_t retries,
uint64_t failovers,
bool isIdempotentOrAtMostOnce) const = 0;
virtual ~RetryPolicy() {}
};
/*
* Returns a fixed delay up to a certain number of retries
*/
class FixedDelayRetryPolicy : public RetryPolicy {
public:
FixedDelayRetryPolicy(uint64_t delay, uint64_t max_retries)
: delay_(delay), max_retries_(max_retries) {}
RetryAction ShouldRetry(const Status &s, uint64_t retries,
uint64_t failovers,
bool isIdempotentOrAtMostOnce) const override;
private:
uint64_t delay_;
uint64_t max_retries_;
};
/*
* Never retries
*/
class NoRetryPolicy : public RetryPolicy {
public:
RetryAction ShouldRetry(const Status &s, uint64_t retries,
uint64_t failovers,
bool isIdempotentOrAtMostOnce) const override;
};
}
#endif

View File

@ -60,6 +60,20 @@ std::string Base64Encode(const std::string &src);
* Returns a new high-entropy client name
*/
std::string GetRandomClientName();
/* Returns true if _someone_ is holding the lock (not necessarily this thread,
* but a std::mutex doesn't track which thread is holding the lock)
*/
template<class T>
bool lock_held(T & mutex) {
bool result = !mutex.try_lock();
if (!result)
mutex.unlock();
return result;
}
}
#endif

View File

@ -51,9 +51,6 @@ void NameNodeOperations::Connect(const std::string &server,
engine_.Connect(m->state().front(), next);
}));
m->Run([this, handler](const Status &status, const State &) {
if (status.ok()) {
engine_.Start();
}
handler(status);
});
}

View File

@ -26,9 +26,6 @@
#include <asio/read.hpp>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
namespace hdfs {
namespace pb = ::google::protobuf;
@ -37,17 +34,18 @@ namespace pbio = ::google::protobuf::io;
using namespace ::hadoop::common;
using namespace ::std::placeholders;
static void
ConstructPacket(std::string *res,
std::initializer_list<const pb::MessageLite *> headers,
const std::string *request) {
static const int kNoRetry = -1;
static void AddHeadersToPacket(
std::string *res, std::initializer_list<const pb::MessageLite *> headers,
const std::string *payload) {
int len = 0;
std::for_each(
headers.begin(), headers.end(),
[&len](const pb::MessageLite *v) { len += DelimitedPBMessageSize(v); });
if (request) {
len += pbio::CodedOutputStream::VarintSize32(request->size()) +
request->size();
if (payload) {
len += payload->size();
}
int net_len = htonl(len);
@ -58,6 +56,7 @@ ConstructPacket(std::string *res,
os.WriteRaw(reinterpret_cast<const char *>(&net_len), sizeof(net_len));
uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
assert(buf);
std::for_each(
headers.begin(), headers.end(), [&buf](const pb::MessageLite *v) {
@ -65,19 +64,43 @@ ConstructPacket(std::string *res,
buf = v->SerializeWithCachedSizesToArray(buf);
});
if (request) {
buf = pbio::CodedOutputStream::WriteVarint32ToArray(request->size(), buf);
buf = os.WriteStringToArray(*request, buf);
if (payload) {
buf = os.WriteStringToArray(*payload, buf);
}
}
static void SetRequestHeader(RpcEngine *engine, int call_id,
const std::string &method_name,
static void ConstructPayload(std::string *res, const pb::MessageLite *header) {
int len = DelimitedPBMessageSize(header);
res->reserve(len);
pbio::StringOutputStream ss(res);
pbio::CodedOutputStream os(&ss);
uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
assert(buf);
buf = pbio::CodedOutputStream::WriteVarint32ToArray(header->ByteSize(), buf);
buf = header->SerializeWithCachedSizesToArray(buf);
}
static void ConstructPayload(std::string *res, const std::string *request) {
int len =
pbio::CodedOutputStream::VarintSize32(request->size()) + request->size();
res->reserve(len);
pbio::StringOutputStream ss(res);
pbio::CodedOutputStream os(&ss);
uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
assert(buf);
buf = pbio::CodedOutputStream::WriteVarint32ToArray(request->size(), buf);
buf = os.WriteStringToArray(*request, buf);
}
static void SetRequestHeader(LockFreeRpcEngine *engine, int call_id,
const std::string &method_name, int retry_count,
RpcRequestHeaderProto *rpc_header,
RequestHeaderProto *req_header) {
rpc_header->set_rpckind(RPC_PROTOCOL_BUFFER);
rpc_header->set_rpcop(RpcRequestHeaderProto::RPC_FINAL_PACKET);
rpc_header->set_callid(call_id);
if (retry_count != kNoRetry)
rpc_header->set_retrycount(retry_count);
rpc_header->set_clientid(engine->client_name());
req_header->set_methodname(method_name);
@ -87,64 +110,84 @@ static void SetRequestHeader(RpcEngine *engine, int call_id,
RpcConnection::~RpcConnection() {}
RpcConnection::Request::Request(RpcConnection *parent,
const std::string &method_name,
const std::string &request, Handler &&handler)
: call_id_(parent->engine_->NextCallId()), timer_(parent->io_service()),
handler_(std::move(handler)) {
RpcRequestHeaderProto rpc_header;
RequestHeaderProto req_header;
SetRequestHeader(parent->engine_, call_id_, method_name, &rpc_header,
&req_header);
ConstructPacket(&payload_, {&rpc_header, &req_header}, &request);
Request::Request(LockFreeRpcEngine *engine, const std::string &method_name,
const std::string &request, Handler &&handler)
: engine_(engine),
method_name_(method_name),
call_id_(engine->NextCallId()),
timer_(engine->io_service()),
handler_(std::move(handler)),
retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
ConstructPayload(&payload_, &request);
}
RpcConnection::Request::Request(RpcConnection *parent,
const std::string &method_name,
const pb::MessageLite *request,
Handler &&handler)
: call_id_(parent->engine_->NextCallId()), timer_(parent->io_service()),
handler_(std::move(handler)) {
RpcRequestHeaderProto rpc_header;
RequestHeaderProto req_header;
SetRequestHeader(parent->engine_, call_id_, method_name, &rpc_header,
&req_header);
ConstructPacket(&payload_, {&rpc_header, &req_header, request}, nullptr);
Request::Request(LockFreeRpcEngine *engine, const std::string &method_name,
const pb::MessageLite *request, Handler &&handler)
: engine_(engine),
method_name_(method_name),
call_id_(engine->NextCallId()),
timer_(engine->io_service()),
handler_(std::move(handler)),
retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
ConstructPayload(&payload_, request);
}
void RpcConnection::Request::OnResponseArrived(pbio::CodedInputStream *is,
const Status &status) {
Request::Request(LockFreeRpcEngine *engine, Handler &&handler)
: engine_(engine),
call_id_(-1),
timer_(engine->io_service()),
handler_(std::move(handler)),
retry_count_(engine->retry_policy() ? 0 : kNoRetry) {
}
void Request::GetPacket(std::string *res) const {
if (payload_.empty())
return;
RpcRequestHeaderProto rpc_header;
RequestHeaderProto req_header;
SetRequestHeader(engine_, call_id_, method_name_, retry_count_, &rpc_header,
&req_header);
AddHeadersToPacket(res, {&rpc_header, &req_header}, &payload_);
}
void Request::OnResponseArrived(pbio::CodedInputStream *is,
const Status &status) {
handler_(is, status);
}
RpcConnection::RpcConnection(RpcEngine *engine)
: engine_(engine), resp_state_(kReadLength), resp_length_(0) {}
RpcConnection::RpcConnection(LockFreeRpcEngine *engine)
: engine_(engine),
connected_(false) {}
::asio::io_service &RpcConnection::io_service() {
return engine_->io_service();
}
void RpcConnection::Start() {
void RpcConnection::StartReading() {
io_service().post(std::bind(&RpcConnection::OnRecvCompleted, this,
::asio::error_code(), 0));
}
void RpcConnection::FlushPendingRequests() {
io_service().post([this]() {
void RpcConnection::AsyncFlushPendingRequests() {
std::shared_ptr<RpcConnection> shared_this = shared_from_this();
io_service().post([shared_this, this]() {
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
if (!request_over_the_wire_) {
OnSendCompleted(::asio::error_code(), 0);
FlushPendingRequests();
}
});
}
void RpcConnection::HandleRpcResponse(const std::vector<char> &data) {
/* assumed to be called from a context that has already acquired the
* engine_state_lock */
pbio::ArrayInputStream ar(&data[0], data.size());
pbio::CodedInputStream in(&ar);
in.PushLimit(data.size());
void RpcConnection::HandleRpcResponse(std::shared_ptr<Response> response) {
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
response->ar.reset(new pbio::ArrayInputStream(&response->data_[0], response->data_.size()));
response->in.reset(new pbio::CodedInputStream(response->ar.get()));
response->in->PushLimit(response->data_.size());
RpcResponseHeaderProto h;
ReadDelimitedPBMessage(&in, &h);
ReadDelimitedPBMessage(response->in.get(), &h);
auto req = RemoveFromRunningQueue(h.callid());
if (!req) {
@ -152,12 +195,15 @@ void RpcConnection::HandleRpcResponse(const std::vector<char> &data) {
return;
}
Status stat;
Status status;
if (h.has_exceptionclassname()) {
stat =
status =
Status::Exception(h.exceptionclassname().c_str(), h.errormsg().c_str());
}
req->OnResponseArrived(&in, stat);
io_service().post([req, response, status]() {
req->OnResponseArrived(response->in.get(), status); // Never call back while holding a lock
});
}
void RpcConnection::HandleRpcTimeout(std::shared_ptr<Request> req,
@ -166,7 +212,7 @@ void RpcConnection::HandleRpcTimeout(std::shared_ptr<Request> req,
return;
}
std::lock_guard<std::mutex> state_lock(engine_state_lock_);
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
auto r = RemoveFromRunningQueue(req->call_id());
if (!r) {
// The RPC might have been finished and removed from the queue
@ -179,6 +225,8 @@ void RpcConnection::HandleRpcTimeout(std::shared_ptr<Request> req,
}
std::shared_ptr<std::string> RpcConnection::PrepareHandshakePacket() {
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
static const char kHandshakeHeader[] = {'h', 'r', 'p', 'c',
RpcEngine::kRpcVersion, 0, 0};
auto res =
@ -192,25 +240,27 @@ std::shared_ptr<std::string> RpcConnection::PrepareHandshakePacket() {
IpcConnectionContextProto handshake;
handshake.set_protocol(engine_->protocol_name());
ConstructPacket(res.get(), {&h, &handshake}, nullptr);
AddHeadersToPacket(res.get(), {&h, &handshake}, nullptr);
return res;
}
void RpcConnection::AsyncRpc(
const std::string &method_name, const ::google::protobuf::MessageLite *req,
std::shared_ptr<::google::protobuf::MessageLite> resp,
const Callback &handler) {
std::lock_guard<std::mutex> state_lock(engine_state_lock_);
const RpcCallback &handler) {
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
auto wrapped_handler =
[resp, handler](pbio::CodedInputStream *is, const Status &status) {
if (status.ok()) {
ReadDelimitedPBMessage(is, resp.get());
if (is) { // Connect messages will not have an is
ReadDelimitedPBMessage(is, resp.get());
}
}
handler(status);
};
auto r = std::make_shared<Request>(this, method_name, req,
auto r = std::make_shared<Request>(engine_, method_name, req,
std::move(wrapped_handler));
pending_requests_.push_back(r);
FlushPendingRequests();
@ -219,29 +269,62 @@ void RpcConnection::AsyncRpc(
void RpcConnection::AsyncRawRpc(const std::string &method_name,
const std::string &req,
std::shared_ptr<std::string> resp,
Callback &&handler) {
std::lock_guard<std::mutex> state_lock(engine_state_lock_);
RpcCallback &&handler) {
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
auto wrapped_handler =
[this, resp, handler](pbio::CodedInputStream *is, const Status &status) {
if (status.ok()) {
uint32_t size = 0;
is->ReadVarint32(&size);
auto limit = is->PushLimit(size);
is->ReadString(resp.get(), limit);
is->PopLimit(limit);
}
handler(status);
};
std::shared_ptr<RpcConnection> shared_this = shared_from_this();
auto wrapped_handler = [shared_this, this, resp, handler](
pbio::CodedInputStream *is, const Status &status) {
if (status.ok()) {
uint32_t size = 0;
is->ReadVarint32(&size);
auto limit = is->PushLimit(size);
is->ReadString(resp.get(), limit);
is->PopLimit(limit);
}
handler(status);
};
auto r = std::make_shared<Request>(this, method_name, req,
auto r = std::make_shared<Request>(engine_, method_name, req,
std::move(wrapped_handler));
pending_requests_.push_back(r);
FlushPendingRequests();
}
void RpcConnection::PreEnqueueRequests(
std::vector<std::shared_ptr<Request>> requests) {
// Public method - acquire lock
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
assert(!connected_);
pending_requests_.insert(pending_requests_.end(), requests.begin(),
requests.end());
// Don't start sending yet; will flush when connected
}
void RpcConnection::CommsError(const Status &status) {
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
Disconnect();
// Anything that has been queued to the connection (on the fly or pending)
// will get dinged for a retry
std::vector<std::shared_ptr<Request>> requestsToReturn;
std::transform(requests_on_fly_.begin(), requests_on_fly_.end(),
std::back_inserter(requestsToReturn),
std::bind(&RequestOnFlyMap::value_type::second, _1));
requests_on_fly_.clear();
requestsToReturn.insert(requestsToReturn.end(),
std::make_move_iterator(pending_requests_.begin()),
std::make_move_iterator(pending_requests_.end()));
pending_requests_.clear();
engine_->AsyncRpcCommsError(status, requestsToReturn);
}
void RpcConnection::ClearAndDisconnect(const ::asio::error_code &ec) {
Shutdown();
Disconnect();
std::vector<std::shared_ptr<Request>> requests;
std::transform(requests_on_fly_.begin(), requests_on_fly_.end(),
std::back_inserter(requests),
@ -256,8 +339,8 @@ void RpcConnection::ClearAndDisconnect(const ::asio::error_code &ec) {
}
}
std::shared_ptr<RpcConnection::Request>
RpcConnection::RemoveFromRunningQueue(int call_id) {
std::shared_ptr<Request> RpcConnection::RemoveFromRunningQueue(int call_id) {
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
auto it = requests_on_fly_.find(call_id);
if (it == requests_on_fly_.end()) {
return std::shared_ptr<Request>();

View File

@ -29,46 +29,90 @@
namespace hdfs {
template <class NextLayer> class RpcConnectionImpl : public RpcConnection {
template <class NextLayer>
class RpcConnectionImpl : public RpcConnection {
public:
RpcConnectionImpl(RpcEngine *engine);
virtual void Connect(const ::asio::ip::tcp::endpoint &server,
Callback &&handler) override;
virtual void Handshake(Callback &&handler) override;
virtual void Shutdown() override;
RpcCallback &handler);
virtual void ConnectAndFlush(
const ::asio::ip::tcp::endpoint &server) override;
virtual void Handshake(RpcCallback &handler) override;
virtual void Disconnect() override;
virtual void OnSendCompleted(const ::asio::error_code &ec,
size_t transferred) override;
virtual void OnRecvCompleted(const ::asio::error_code &ec,
size_t transferred) override;
virtual void FlushPendingRequests() override;
NextLayer &next_layer() { return next_layer_; }
private:
void TEST_set_connected(bool new_value) { connected_ = new_value; }
private:
const Options options_;
NextLayer next_layer_;
};
template <class NextLayer>
RpcConnectionImpl<NextLayer>::RpcConnectionImpl(RpcEngine *engine)
: RpcConnection(engine), options_(engine->options()),
: RpcConnection(engine),
options_(engine->options()),
next_layer_(engine->io_service()) {}
template <class NextLayer>
void RpcConnectionImpl<NextLayer>::Connect(
const ::asio::ip::tcp::endpoint &server, Callback &&handler) {
next_layer_.async_connect(server,
[handler](const ::asio::error_code &ec) {
handler(ToStatus(ec));
const ::asio::ip::tcp::endpoint &server, RpcCallback &handler) {
auto connectionSuccessfulReq = std::make_shared<Request>(
engine_, [handler](::google::protobuf::io::CodedInputStream *is,
const Status &status) {
(void)is;
handler(status);
});
pending_requests_.push_back(connectionSuccessfulReq);
this->ConnectAndFlush(server); // need "this" so compiler can infer type of CAF
}
template <class NextLayer>
void RpcConnectionImpl<NextLayer>::Handshake(Callback &&handler) {
void RpcConnectionImpl<NextLayer>::ConnectAndFlush(
const ::asio::ip::tcp::endpoint &server) {
std::shared_ptr<RpcConnection> shared_this = shared_from_this();
next_layer_.async_connect(server,
[shared_this, this](const ::asio::error_code &ec) {
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
Status status = ToStatus(ec);
if (status.ok()) {
StartReading();
Handshake([shared_this, this](const Status &s) {
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
if (s.ok()) {
FlushPendingRequests();
} else {
CommsError(s);
};
});
} else {
CommsError(status);
}
});
}
template <class NextLayer>
void RpcConnectionImpl<NextLayer>::Handshake(RpcCallback &handler) {
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
auto shared_this = shared_from_this();
auto handshake_packet = PrepareHandshakePacket();
::asio::async_write(
next_layer_, asio::buffer(*handshake_packet),
[handshake_packet, handler](const ::asio::error_code &ec, size_t) {
handler(ToStatus(ec));
});
::asio::async_write(next_layer_, asio::buffer(*handshake_packet),
[handshake_packet, handler, shared_this, this](
const ::asio::error_code &ec, size_t) {
Status status = ToStatus(ec);
if (status.ok()) {
connected_ = true;
}
handler(status);
});
}
template <class NextLayer>
@ -76,82 +120,129 @@ void RpcConnectionImpl<NextLayer>::OnSendCompleted(const ::asio::error_code &ec,
size_t) {
using std::placeholders::_1;
using std::placeholders::_2;
std::lock_guard<std::mutex> state_lock(engine_state_lock_);
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
request_over_the_wire_.reset();
if (ec) {
// Current RPC has failed -- abandon the
// connection and do proper clean up
ClearAndDisconnect(ec);
LOG_WARN() << "Network error during RPC write: " << ec.message();
CommsError(ToStatus(ec));
return;
}
if (!pending_requests_.size()) {
FlushPendingRequests();
}
template <class NextLayer>
void RpcConnectionImpl<NextLayer>::FlushPendingRequests() {
using namespace ::std::placeholders;
// Lock should be held
assert(lock_held(connection_state_lock_));
if (pending_requests_.empty()) {
return;
}
if (!connected_) {
return;
}
// Don't send if we don't need to
if (request_over_the_wire_) {
return;
}
std::shared_ptr<Request> req = pending_requests_.front();
pending_requests_.erase(pending_requests_.begin());
requests_on_fly_[req->call_id()] = req;
request_over_the_wire_ = req;
req->timer().expires_from_now(
std::chrono::milliseconds(options_.rpc_timeout));
req->timer().async_wait(std::bind(
&RpcConnectionImpl<NextLayer>::HandleRpcTimeout, this, req, _1));
std::shared_ptr<RpcConnection> shared_this = shared_from_this();
std::shared_ptr<std::string> payload = std::make_shared<std::string>();
req->GetPacket(payload.get());
if (!payload->empty()) {
requests_on_fly_[req->call_id()] = req;
request_over_the_wire_ = req;
asio::async_write(
next_layer_, asio::buffer(req->payload()),
std::bind(&RpcConnectionImpl<NextLayer>::OnSendCompleted, this, _1, _2));
req->timer().expires_from_now(
std::chrono::milliseconds(options_.rpc_timeout));
req->timer().async_wait(std::bind(
&RpcConnection::HandleRpcTimeout, this, req, _1));
asio::async_write(next_layer_, asio::buffer(*payload),
[shared_this, this, payload](const ::asio::error_code &ec,
size_t size) {
OnSendCompleted(ec, size);
});
} else { // Nothing to send for this request, inform the handler immediately
io_service().post(
// Never hold locks when calling a callback
[req]() { req->OnResponseArrived(nullptr, Status::OK()); }
);
// Reschedule to flush the next one
AsyncFlushPendingRequests();
}
}
template <class NextLayer>
void RpcConnectionImpl<NextLayer>::OnRecvCompleted(const ::asio::error_code &ec,
size_t) {
using std::placeholders::_1;
using std::placeholders::_2;
std::lock_guard<std::mutex> state_lock(engine_state_lock_);
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
std::shared_ptr<RpcConnection> shared_this = shared_from_this();
switch (ec.value()) {
case 0:
// No errors
break;
case asio::error::operation_aborted:
// The event loop has been shut down. Ignore the error.
return;
default:
LOG_WARN() << "Network error during RPC: " << ec.message();
ClearAndDisconnect(ec);
return;
case 0:
// No errors
break;
case asio::error::operation_aborted:
// The event loop has been shut down. Ignore the error.
return;
default:
LOG_WARN() << "Network error during RPC read: " << ec.message();
CommsError(ToStatus(ec));
return;
}
if (resp_state_ == kReadLength) {
resp_state_ = kReadContent;
auto buf = ::asio::buffer(reinterpret_cast<char *>(&resp_length_),
sizeof(resp_length_));
asio::async_read(next_layer_, buf,
std::bind(&RpcConnectionImpl<NextLayer>::OnRecvCompleted,
this, _1, _2));
if (!response_) { /* start a new one */
response_ = std::make_shared<Response>();
}
} else if (resp_state_ == kReadContent) {
resp_state_ = kParseResponse;
resp_length_ = ntohl(resp_length_);
resp_data_.resize(resp_length_);
asio::async_read(next_layer_, ::asio::buffer(resp_data_),
std::bind(&RpcConnectionImpl<NextLayer>::OnRecvCompleted,
this, _1, _2));
} else if (resp_state_ == kParseResponse) {
resp_state_ = kReadLength;
HandleRpcResponse(resp_data_);
resp_data_.clear();
Start();
if (response_->state_ == Response::kReadLength) {
response_->state_ = Response::kReadContent;
auto buf = ::asio::buffer(reinterpret_cast<char *>(&response_->length_),
sizeof(response_->length_));
asio::async_read(
next_layer_, buf,
[shared_this, this](const ::asio::error_code &ec, size_t size) {
OnRecvCompleted(ec, size);
});
} else if (response_->state_ == Response::kReadContent) {
response_->state_ = Response::kParseResponse;
response_->length_ = ntohl(response_->length_);
response_->data_.resize(response_->length_);
asio::async_read(
next_layer_, ::asio::buffer(response_->data_),
[shared_this, this](const ::asio::error_code &ec, size_t size) {
OnRecvCompleted(ec, size);
});
} else if (response_->state_ == Response::kParseResponse) {
HandleRpcResponse(response_);
response_ = nullptr;
StartReading();
}
}
template <class NextLayer> void RpcConnectionImpl<NextLayer>::Shutdown() {
template <class NextLayer>
void RpcConnectionImpl<NextLayer>::Disconnect() {
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
request_over_the_wire_.reset();
next_layer_.cancel();
next_layer_.close();
connected_ = false;
}
}

View File

@ -18,52 +18,71 @@
#include "rpc_engine.h"
#include "rpc_connection.h"
#include "common/util.h"
#include "optional.hpp"
#include <future>
namespace hdfs {
template <class T>
using optional = std::experimental::optional<T>;
RpcEngine::RpcEngine(::asio::io_service *io_service, const Options &options,
const std::string &client_name, const char *protocol_name,
int protocol_version)
: io_service_(io_service), options_(options), client_name_(client_name),
protocol_name_(protocol_name), protocol_version_(protocol_version),
call_id_(0) {
}
: io_service_(io_service),
options_(options),
client_name_(client_name),
protocol_name_(protocol_name),
protocol_version_(protocol_version),
retry_policy_(std::move(MakeRetryPolicy(options))),
call_id_(0),
retry_timer(*io_service) {}
void RpcEngine::Connect(const ::asio::ip::tcp::endpoint &server,
const std::function<void(const Status &)> &handler) {
conn_.reset(new RpcConnectionImpl<::asio::ip::tcp::socket>(this));
conn_->Connect(server, [this, handler](const Status &stat) {
if (!stat.ok()) {
handler(stat);
} else {
conn_->Handshake([handler](const Status &s) { handler(s); });
}
RpcCallback &handler) {
std::lock_guard<std::mutex> state_lock(engine_state_lock_);
last_endpoint_ = server;
conn_ = NewConnection();
conn_->Connect(server, handler);
}
void RpcEngine::Shutdown() {
io_service_->post([this]() {
std::lock_guard<std::mutex> state_lock(engine_state_lock_);
conn_->Disconnect();
conn_.reset();
});
}
void RpcEngine::Start() { conn_->Start(); }
void RpcEngine::Shutdown() {
io_service_->post([this]() { conn_->Shutdown(); });
std::unique_ptr<const RetryPolicy> RpcEngine::MakeRetryPolicy(const Options &options) {
if (options.max_rpc_retries > 0) {
return std::unique_ptr<RetryPolicy>(new FixedDelayRetryPolicy(options.rpc_retry_delay_ms, options.max_rpc_retries));
} else {
return nullptr;
}
}
void RpcEngine::TEST_SetRpcConnection(std::unique_ptr<RpcConnection> *conn) {
conn_.reset(conn->release());
void RpcEngine::TEST_SetRpcConnection(std::shared_ptr<RpcConnection> conn) {
conn_ = conn;
}
void RpcEngine::AsyncRpc(
const std::string &method_name, const ::google::protobuf::MessageLite *req,
const std::shared_ptr<::google::protobuf::MessageLite> &resp,
const std::function<void(const Status &)> &handler) {
std::lock_guard<std::mutex> state_lock(engine_state_lock_);
if (!conn_) {
conn_ = NewConnection();
conn_->ConnectAndFlush(last_endpoint_);
}
conn_->AsyncRpc(method_name, req, resp, handler);
}
Status
RpcEngine::Rpc(const std::string &method_name,
const ::google::protobuf::MessageLite *req,
const std::shared_ptr<::google::protobuf::MessageLite> &resp) {
Status RpcEngine::Rpc(
const std::string &method_name, const ::google::protobuf::MessageLite *req,
const std::shared_ptr<::google::protobuf::MessageLite> &resp) {
auto stat = std::make_shared<std::promise<Status>>();
std::future<Status> future(stat->get_future());
AsyncRpc(method_name, req, resp,
@ -71,13 +90,95 @@ RpcEngine::Rpc(const std::string &method_name,
return future.get();
}
std::shared_ptr<RpcConnection> RpcEngine::NewConnection()
{
return std::make_shared<RpcConnectionImpl<::asio::ip::tcp::socket>>(this);
}
Status RpcEngine::RawRpc(const std::string &method_name, const std::string &req,
std::shared_ptr<std::string> resp) {
std::shared_ptr<RpcConnection> conn;
{
std::lock_guard<std::mutex> state_lock(engine_state_lock_);
if (!conn_) {
conn_ = NewConnection();
conn_->ConnectAndFlush(last_endpoint_);
}
conn = conn_;
}
auto stat = std::make_shared<std::promise<Status>>();
std::future<Status> future(stat->get_future());
conn_->AsyncRawRpc(method_name, req, resp,
conn->AsyncRawRpc(method_name, req, resp,
[stat](const Status &status) { stat->set_value(status); });
return future.get();
}
void RpcEngine::AsyncRpcCommsError(
const Status &status,
std::vector<std::shared_ptr<Request>> pendingRequests) {
io_service().post([this, status, pendingRequests]() {
RpcCommsError(status, pendingRequests);
});
}
void RpcEngine::RpcCommsError(
const Status &status,
std::vector<std::shared_ptr<Request>> pendingRequests) {
(void)status;
std::lock_guard<std::mutex> state_lock(engine_state_lock_);
auto head_action = optional<RetryAction>();
// Filter out anything with too many retries already
for (auto it = pendingRequests.begin(); it < pendingRequests.end();) {
auto req = *it;
RetryAction retry = RetryAction::fail(""); // Default to fail
if (retry_policy()) {
retry = retry_policy()->ShouldRetry(status, req->IncrementRetryCount(), 0, true);
}
if (retry.action == RetryAction::FAIL) {
// If we've exceeded the maximum retry, take the latest error and pass it
// on. There might be a good argument for caching the first error
// rather than the last one, that gets messy
io_service().post([req, status]() {
req->OnResponseArrived(nullptr, status); // Never call back while holding a lock
});
it = pendingRequests.erase(it);
} else {
if (!head_action) {
head_action = retry;
}
++it;
}
}
// Close the connection and retry and requests that might have been sent to
// the NN
if (!pendingRequests.empty() &&
head_action && head_action->action != RetryAction::FAIL) {
conn_ = NewConnection();
conn_->PreEnqueueRequests(pendingRequests);
if (head_action->delayMillis > 0) {
retry_timer.expires_from_now(
std::chrono::milliseconds(options_.rpc_retry_delay_ms));
retry_timer.async_wait([this](asio::error_code ec) {
if (!ec) conn_->ConnectAndFlush(last_endpoint_);
});
} else {
conn_->ConnectAndFlush(last_endpoint_);
}
} else {
// Connection will try again if someone calls AsyncRpc
conn_.reset();
}
}
}

View File

@ -21,7 +21,11 @@
#include "libhdfspp/options.h"
#include "libhdfspp/status.h"
#include "common/retry_policy.h"
#include <google/protobuf/message_lite.h>
#include <google/protobuf/io/coded_stream.h>
#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
#include <asio/ip/tcp.hpp>
#include <asio/deadline_timer.hpp>
@ -34,77 +38,146 @@
namespace hdfs {
class RpcEngine;
class RpcConnection {
public:
typedef std::function<void(const Status &)> Callback;
virtual ~RpcConnection();
RpcConnection(RpcEngine *engine);
virtual void Connect(const ::asio::ip::tcp::endpoint &server,
Callback &&handler) = 0;
virtual void Handshake(Callback &&handler) = 0;
virtual void Shutdown() = 0;
/*
* NOTE ABOUT LOCKING MODELS
*
* To prevent deadlocks, anything that might acquire multiple locks must
* acquire the lock on the RpcEngine first, then the RpcConnection. Callbacks
* will never be called while holding any locks, so the components are free
* to take locks when servicing a callback.
*
* An RpcRequest or RpcConnection should never call any methods on the RpcEngine
* except for those that are exposed through the LockFreeRpcEngine interface.
*/
void Start();
typedef const std::function<void(const Status &)> RpcCallback;
class LockFreeRpcEngine;
class RpcConnection;
/*
* Internal bookkeeping for an outstanding request from the consumer.
*
* Threading model: not thread-safe; should only be accessed from a single
* thread at a time
*/
class Request {
public:
typedef std::function<void(::google::protobuf::io::CodedInputStream *is,
const Status &status)> Handler;
Request(LockFreeRpcEngine *engine, const std::string &method_name,
const std::string &request, Handler &&callback);
Request(LockFreeRpcEngine *engine, const std::string &method_name,
const ::google::protobuf::MessageLite *request, Handler &&callback);
// Null request (with no actual message) used to track the state of an
// initial Connect call
Request(LockFreeRpcEngine *engine, Handler &&handler);
int call_id() const { return call_id_; }
::asio::deadline_timer &timer() { return timer_; }
int IncrementRetryCount() { return retry_count_++; }
void GetPacket(std::string *res) const;
void OnResponseArrived(::google::protobuf::io::CodedInputStream *is,
const Status &status);
private:
LockFreeRpcEngine *const engine_;
const std::string method_name_;
const int call_id_;
::asio::deadline_timer timer_;
std::string payload_;
const Handler handler_;
int retry_count_;
};
/*
* Encapsulates a persistent connection to the NameNode, and the sending of
* RPC requests and evaluating their responses.
*
* Can have multiple RPC requests in-flight simultaneously, but they are
* evaluated in-order on the server side in a blocking manner.
*
* Threading model: public interface is thread-safe
* All handlers passed in to method calls will be called from an asio thread,
* and will not be holding any internal RpcConnection locks.
*/
class RpcConnection : public std::enable_shared_from_this<RpcConnection> {
public:
RpcConnection(LockFreeRpcEngine *engine);
virtual ~RpcConnection();
virtual void Connect(const ::asio::ip::tcp::endpoint &server,
RpcCallback &handler) = 0;
virtual void ConnectAndFlush(const ::asio::ip::tcp::endpoint &server) = 0;
virtual void Handshake(RpcCallback &handler) = 0;
virtual void Disconnect() = 0;
void StartReading();
void AsyncRpc(const std::string &method_name,
const ::google::protobuf::MessageLite *req,
std::shared_ptr<::google::protobuf::MessageLite> resp,
const Callback &handler);
const RpcCallback &handler);
void AsyncRawRpc(const std::string &method_name, const std::string &request,
std::shared_ptr<std::string> resp, Callback &&handler);
std::shared_ptr<std::string> resp, RpcCallback &&handler);
protected:
class Request;
RpcEngine *const engine_;
// Enqueue requests before the connection is connected. Will be flushed
// on connect
void PreEnqueueRequests(std::vector<std::shared_ptr<Request>> requests);
LockFreeRpcEngine *engine() { return engine_; }
::asio::io_service &io_service();
protected:
struct Response {
enum ResponseState {
kReadLength,
kReadContent,
kParseResponse,
} state_;
unsigned length_;
std::vector<char> data_;
std::unique_ptr<::google::protobuf::io::ArrayInputStream> ar;
std::unique_ptr<::google::protobuf::io::CodedInputStream> in;
Response() : state_(kReadLength), length_(0) {}
};
LockFreeRpcEngine *const engine_;
virtual void OnSendCompleted(const ::asio::error_code &ec,
size_t transferred) = 0;
virtual void OnRecvCompleted(const ::asio::error_code &ec,
size_t transferred) = 0;
virtual void FlushPendingRequests()=0; // Synchronously write the next request
void AsyncFlushPendingRequests(); // Queue requests to be flushed at a later time
::asio::io_service &io_service();
std::shared_ptr<std::string> PrepareHandshakePacket();
static std::string
SerializeRpcRequest(const std::string &method_name,
const ::google::protobuf::MessageLite *req);
void HandleRpcResponse(const std::vector<char> &data);
static std::string SerializeRpcRequest(
const std::string &method_name,
const ::google::protobuf::MessageLite *req);
void HandleRpcResponse(std::shared_ptr<Response> response);
void HandleRpcTimeout(std::shared_ptr<Request> req,
const ::asio::error_code &ec);
void FlushPendingRequests();
void CommsError(const Status &status);
void ClearAndDisconnect(const ::asio::error_code &ec);
std::shared_ptr<Request> RemoveFromRunningQueue(int call_id);
enum ResponseState {
kReadLength,
kReadContent,
kParseResponse,
} resp_state_;
unsigned resp_length_;
std::vector<char> resp_data_;
std::shared_ptr<Response> response_;
class Request {
public:
typedef std::function<void(::google::protobuf::io::CodedInputStream *is,
const Status &status)> Handler;
Request(RpcConnection *parent, const std::string &method_name,
const std::string &request, Handler &&callback);
Request(RpcConnection *parent, const std::string &method_name,
const ::google::protobuf::MessageLite *request, Handler &&callback);
int call_id() const { return call_id_; }
::asio::deadline_timer &timer() { return timer_; }
const std::string &payload() const { return payload_; }
void OnResponseArrived(::google::protobuf::io::CodedInputStream *is,
const Status &status);
private:
const int call_id_;
::asio::deadline_timer timer_;
std::string payload_;
Handler handler_;
};
// The request being sent over the wire
// Connection can have deferred connection, especially when we're pausing
// during retry
bool connected_;
// The request being sent over the wire; will also be in requests_on_fly_
std::shared_ptr<Request> request_over_the_wire_;
// Requests to be sent over the wire
std::vector<std::shared_ptr<Request>> pending_requests_;
@ -112,11 +185,40 @@ protected:
typedef std::unordered_map<int, std::shared_ptr<Request>> RequestOnFlyMap;
RequestOnFlyMap requests_on_fly_;
// Lock for mutable parts of this class that need to be thread safe
std::mutex engine_state_lock_;
std::mutex connection_state_lock_;
};
class RpcEngine {
/*
* These methods of the RpcEngine will never acquire locks, and are safe for
* RpcConnections to call while holding a ConnectionLock.
*/
class LockFreeRpcEngine {
public:
/* Enqueues a CommsError without acquiring a lock*/
virtual void AsyncRpcCommsError(const Status &status,
std::vector<std::shared_ptr<Request>> pendingRequests) = 0;
virtual const RetryPolicy * retry_policy() const = 0;
virtual int NextCallId() = 0;
virtual const std::string &client_name() const = 0;
virtual const std::string &protocol_name() const = 0;
virtual int protocol_version() const = 0;
virtual ::asio::io_service &io_service() = 0;
virtual const Options &options() const = 0;
};
/*
* An engine for reliable communication with a NameNode. Handles connection,
* retry, and (someday) failover of the requested messages.
*
* Threading model: thread-safe. All callbacks will be called back from
* an asio pool and will not hold any internal locks
*/
class RpcEngine : public LockFreeRpcEngine {
public:
enum { kRpcVersion = 9 };
enum {
kCallIdAuthorizationFailed = -1,
@ -129,6 +231,8 @@ public:
const std::string &client_name, const char *protocol_name,
int protocol_version);
void Connect(const ::asio::ip::tcp::endpoint &server, RpcCallback &handler);
void AsyncRpc(const std::string &method_name,
const ::google::protobuf::MessageLite *req,
const std::shared_ptr<::google::protobuf::MessageLite> &resp,
@ -143,29 +247,46 @@ public:
**/
Status RawRpc(const std::string &method_name, const std::string &req,
std::shared_ptr<std::string> resp);
void Connect(const ::asio::ip::tcp::endpoint &server,
const std::function<void(const Status &)> &handler);
void Start();
void Shutdown();
void TEST_SetRpcConnection(std::unique_ptr<RpcConnection> *conn);
int NextCallId() { return ++call_id_; }
/* Enqueues a CommsError without acquiring a lock*/
void AsyncRpcCommsError(const Status &status,
std::vector<std::shared_ptr<Request>> pendingRequests) override;
void RpcCommsError(const Status &status,
std::vector<std::shared_ptr<Request>> pendingRequests);
const std::string &client_name() const { return client_name_; }
const std::string &protocol_name() const { return protocol_name_; }
int protocol_version() const { return protocol_version_; }
::asio::io_service &io_service() { return *io_service_; }
const Options &options() { return options_; }
const RetryPolicy * retry_policy() const override { return retry_policy_.get(); }
int NextCallId() override { return ++call_id_; }
void TEST_SetRpcConnection(std::shared_ptr<RpcConnection> conn);
const std::string &client_name() const override { return client_name_; }
const std::string &protocol_name() const override { return protocol_name_; }
int protocol_version() const override { return protocol_version_; }
::asio::io_service &io_service() override { return *io_service_; }
const Options &options() const override { return options_; }
static std::string GetRandomClientName();
protected:
std::shared_ptr<RpcConnection> conn_;
virtual std::shared_ptr<RpcConnection> NewConnection();
virtual std::unique_ptr<const RetryPolicy> MakeRetryPolicy(const Options &options);
private:
::asio::io_service *io_service_;
Options options_;
::asio::io_service * const io_service_;
const Options options_;
const std::string client_name_;
const std::string protocol_name_;
const int protocol_version_;
const std::unique_ptr<const RetryPolicy> retry_policy_; //null --> no retry
std::atomic_int call_id_;
std::unique_ptr<RpcConnection> conn_;
::asio::deadline_timer retry_timer;
// Remember the last endpoint in case we need to reconnect to retry
::asio::ip::tcp::endpoint last_endpoint_;
std::mutex engine_state_lock_;
};
}

View File

@ -51,6 +51,10 @@ add_executable(sasl_digest_md5_test sasl_digest_md5_test.cc)
target_link_libraries(sasl_digest_md5_test common ${OPENSSL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})
add_test(sasl_digest_md5 sasl_digest_md5_test)
add_executable(retry_policy_test retry_policy_test.cc)
target_link_libraries(retry_policy_test common gmock_main ${CMAKE_THREAD_LIBS_INIT})
add_test(retry_policy retry_policy_test)
include_directories(${CMAKE_CURRENT_BINARY_DIR})
add_executable(rpc_engine_test rpc_engine_test.cc ${PROTO_TEST_SRCS} ${PROTO_TEST_HDRS} $<TARGET_OBJECTS:test_common>)
target_link_libraries(rpc_engine_test rpc proto common ${PROTOBUF_LIBRARIES} ${OPENSSL_LIBRARIES} gmock_main ${CMAKE_THREAD_LIBS_INIT})

View File

@ -26,4 +26,15 @@ MockConnectionBase::MockConnectionBase(::asio::io_service *io_service)
MockConnectionBase::~MockConnectionBase() {}
ProducerResult SharedMockConnection::Produce() {
if (auto shared_prducer = shared_connection_data_.lock()) {
return shared_prducer->Produce();
} else {
assert(false && "No producer registered");
return std::make_pair(asio::error_code(), "");
}
}
std::weak_ptr<SharedConnectionData> SharedMockConnection::shared_connection_data_;
}

View File

@ -29,7 +29,21 @@
namespace hdfs {
class MockConnectionBase : public AsyncStream{
typedef std::pair<asio::error_code, std::string> ProducerResult;
class AsioProducer {
public:
/*
* Return either:
* (::asio::error_code(), <some data>) for a good result
* (<an ::asio::error instance>, <anything>) to pass an error to the caller
* (::asio::error::would_block, <anything>) to block the next call forever
*/
virtual ProducerResult Produce() = 0;
};
class MockConnectionBase : public AsioProducer, public AsyncStream {
public:
MockConnectionBase(::asio::io_service *io_service);
virtual ~MockConnectionBase();
@ -40,6 +54,9 @@ public:
std::size_t bytes_transferred) > handler) override {
if (produced_.size() == 0) {
ProducerResult r = Produce();
if (r.first == asio::error::would_block) {
return; // No more reads to do
}
if (r.first) {
io_service_->post(std::bind(handler, r.first, 0));
return;
@ -62,6 +79,13 @@ public:
io_service_->post(std::bind(handler, asio::error_code(), asio::buffer_size(buf)));
}
template <class Endpoint, class Callback>
void async_connect(const Endpoint &, Callback &&handler) {
io_service_->post([handler]() { handler(::asio::error_code()); });
}
virtual void cancel() {}
virtual void close() {}
protected:
virtual ProducerResult Produce() = 0;
::asio::io_service *io_service_;
@ -69,6 +93,48 @@ protected:
private:
asio::streambuf produced_;
};
class SharedConnectionData : public AsioProducer {
public:
bool checkProducerForConnect = false;
MOCK_METHOD0(Produce, ProducerResult());
};
class SharedMockConnection : public MockConnectionBase {
public:
using MockConnectionBase::MockConnectionBase;
template <class Endpoint, class Callback>
void async_connect(const Endpoint &, Callback &&handler) {
auto data = shared_connection_data_.lock();
assert(data);
if (!data->checkProducerForConnect) {
io_service_->post([handler]() { handler(::asio::error_code()); });
} else {
ProducerResult result = Produce();
if (result.first == asio::error::would_block) {
return; // Connect will hang
} else {
io_service_->post([handler, result]() { handler( result.first); });
}
}
}
static void SetSharedConnectionData(std::shared_ptr<SharedConnectionData> new_producer) {
shared_connection_data_ = new_producer; // get a weak reference to it
}
protected:
ProducerResult Produce() override;
static std::weak_ptr<SharedConnectionData> shared_connection_data_;
};
}
#endif

View File

@ -0,0 +1,63 @@
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "common/retry_policy.h"
#include <gmock/gmock.h>
using namespace hdfs;
TEST(RetryPolicyTest, TestNoRetry) {
NoRetryPolicy policy;
EXPECT_EQ(RetryAction::FAIL, policy.ShouldRetry(Status::Unimplemented(), 0, 0, true).action);
}
TEST(RetryPolicyTest, TestFixedDelay) {
static const uint64_t DELAY = 100;
FixedDelayRetryPolicy policy(DELAY, 10);
// No error
RetryAction result = policy.ShouldRetry(Status::Unimplemented(), 0, 0, true);
EXPECT_EQ(RetryAction::RETRY, result.action);
EXPECT_EQ(DELAY, result.delayMillis);
// Few errors
result = policy.ShouldRetry(Status::Unimplemented(), 2, 2, true);
EXPECT_EQ(RetryAction::RETRY, result.action);
EXPECT_EQ(DELAY, result.delayMillis);
result = policy.ShouldRetry(Status::Unimplemented(), 9, 0, true);
EXPECT_EQ(RetryAction::RETRY, result.action);
EXPECT_EQ(DELAY, result.delayMillis);
// Too many errors
result = policy.ShouldRetry(Status::Unimplemented(), 10, 0, true);
EXPECT_EQ(RetryAction::FAIL, result.action);
EXPECT_TRUE(result.reason.size() > 0); // some error message
result = policy.ShouldRetry(Status::Unimplemented(), 0, 10, true);
EXPECT_EQ(RetryAction::FAIL, result.action);
EXPECT_TRUE(result.reason.size() > 0); // some error message
}
int main(int argc, char *argv[]) {
// The following line must be executed to initialize Google Mock
// (and Google Test) before running the tests.
::testing::InitGoogleMock(&argc, argv);
return RUN_ALL_TESTS();
}

View File

@ -44,21 +44,33 @@ namespace pbio = ::google::protobuf::io;
namespace hdfs {
class MockRPCConnection : public MockConnectionBase {
public:
public:
MockRPCConnection(::asio::io_service &io_service)
: MockConnectionBase(&io_service) {}
MOCK_METHOD0(Produce, ProducerResult());
template <class Endpoint, class Callback>
void async_connect(const Endpoint &, Callback &&handler) {
handler(::asio::error_code());
}
void cancel() {}
void close() {}
};
static inline std::pair<error_code, string>
RpcResponse(const RpcResponseHeaderProto &h, const std::string &data,
const ::asio::error_code &ec = error_code()) {
class SharedMockRPCConnection : public SharedMockConnection {
public:
SharedMockRPCConnection(::asio::io_service &io_service)
: SharedMockConnection(&io_service) {}
};
class SharedConnectionEngine : public RpcEngine {
using RpcEngine::RpcEngine;
protected:
std::shared_ptr<RpcConnection> NewConnection() override {
return std::make_shared<RpcConnectionImpl<SharedMockRPCConnection>>(this);
}
};
}
static inline std::pair<error_code, string> RpcResponse(
const RpcResponseHeaderProto &h, const std::string &data,
const ::asio::error_code &ec = error_code()) {
uint32_t payload_length =
pbio::CodedOutputStream::VarintSize32(h.ByteSize()) +
pbio::CodedOutputStream::VarintSize32(data.size()) + h.ByteSize() +
@ -77,7 +89,7 @@ RpcResponse(const RpcResponseHeaderProto &h, const std::string &data,
return std::make_pair(ec, std::move(res));
}
}
using namespace hdfs;
@ -87,6 +99,9 @@ TEST(RpcEngineTest, TestRoundTrip) {
RpcEngine engine(&io_service, options, "foo", "protocol", 1);
RpcConnectionImpl<MockRPCConnection> *conn =
new RpcConnectionImpl<MockRPCConnection>(&engine);
conn->TEST_set_connected(true);
conn->StartReading();
EchoResponseProto server_resp;
server_resp.set_message("foo");
@ -96,27 +111,34 @@ TEST(RpcEngineTest, TestRoundTrip) {
EXPECT_CALL(conn->next_layer(), Produce())
.WillOnce(Return(RpcResponse(h, server_resp.SerializeAsString())));
std::unique_ptr<RpcConnection> conn_ptr(conn);
engine.TEST_SetRpcConnection(&conn_ptr);
std::shared_ptr<RpcConnection> conn_ptr(conn);
engine.TEST_SetRpcConnection(conn_ptr);
bool complete = false;
EchoRequestProto req;
req.set_message("foo");
std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
engine.AsyncRpc("test", &req, resp, [resp, &io_service](const Status &stat) {
engine.AsyncRpc("test", &req, resp, [resp, &complete,&io_service](const Status &stat) {
ASSERT_TRUE(stat.ok());
ASSERT_EQ("foo", resp->message());
complete = true;
io_service.stop();
});
conn->Start();
io_service.run();
ASSERT_TRUE(complete);
}
TEST(RpcEngineTest, TestConnectionReset) {
TEST(RpcEngineTest, TestConnectionResetAndFail) {
::asio::io_service io_service;
Options options;
RpcEngine engine(&io_service, options, "foo", "protocol", 1);
RpcConnectionImpl<MockRPCConnection> *conn =
new RpcConnectionImpl<MockRPCConnection>(&engine);
conn->TEST_set_connected(true);
conn->StartReading();
bool complete = false;
RpcResponseHeaderProto h;
h.set_callid(1);
@ -125,23 +147,213 @@ TEST(RpcEngineTest, TestConnectionReset) {
.WillOnce(Return(RpcResponse(
h, "", make_error_code(::asio::error::connection_reset))));
std::unique_ptr<RpcConnection> conn_ptr(conn);
engine.TEST_SetRpcConnection(&conn_ptr);
std::shared_ptr<RpcConnection> conn_ptr(conn);
engine.TEST_SetRpcConnection(conn_ptr);
EchoRequestProto req;
req.set_message("foo");
std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
engine.AsyncRpc("test", &req, resp, [&io_service](const Status &stat) {
ASSERT_FALSE(stat.ok());
});
engine.AsyncRpc("test", &req, resp, [&io_service](const Status &stat) {
engine.AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_FALSE(stat.ok());
});
conn->Start();
io_service.run();
ASSERT_TRUE(complete);
}
TEST(RpcEngineTest, TestConnectionResetAndRecover) {
::asio::io_service io_service;
Options options;
options.max_rpc_retries = 1;
options.rpc_retry_delay_ms = 0;
SharedConnectionEngine engine(&io_service, options, "foo", "protocol", 1);
EchoResponseProto server_resp;
server_resp.set_message("foo");
bool complete = false;
auto producer = std::make_shared<SharedConnectionData>();
RpcResponseHeaderProto h;
h.set_callid(1);
h.set_status(RpcResponseHeaderProto::SUCCESS);
EXPECT_CALL(*producer, Produce())
.WillOnce(Return(RpcResponse(
h, "", make_error_code(::asio::error::connection_reset))))
.WillOnce(Return(RpcResponse(h, server_resp.SerializeAsString())));
SharedMockConnection::SetSharedConnectionData(producer);
EchoRequestProto req;
req.set_message("foo");
std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
engine.AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_TRUE(stat.ok());
});
io_service.run();
ASSERT_TRUE(complete);
}
TEST(RpcEngineTest, TestConnectionResetAndRecoverWithDelay) {
::asio::io_service io_service;
Options options;
options.max_rpc_retries = 1;
options.rpc_retry_delay_ms = 1;
SharedConnectionEngine engine(&io_service, options, "foo", "protocol", 1);
EchoResponseProto server_resp;
server_resp.set_message("foo");
bool complete = false;
auto producer = std::make_shared<SharedConnectionData>();
RpcResponseHeaderProto h;
h.set_callid(1);
h.set_status(RpcResponseHeaderProto::SUCCESS);
EXPECT_CALL(*producer, Produce())
.WillOnce(Return(RpcResponse(
h, "", make_error_code(::asio::error::connection_reset))))
.WillOnce(Return(RpcResponse(h, server_resp.SerializeAsString())));
SharedMockConnection::SetSharedConnectionData(producer);
EchoRequestProto req;
req.set_message("foo");
std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
engine.AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_TRUE(stat.ok());
});
::asio::deadline_timer timer(io_service);
timer.expires_from_now(std::chrono::hours(100));
timer.async_wait([](const asio::error_code & err){(void)err; ASSERT_FALSE("Timed out"); });
io_service.run();
ASSERT_TRUE(complete);
}
TEST(RpcEngineTest, TestConnectionFailure)
{
auto producer = std::make_shared<SharedConnectionData>();
producer->checkProducerForConnect = true;
SharedMockConnection::SetSharedConnectionData(producer);
// Error and no retry
::asio::io_service io_service;
bool complete = false;
Options options;
options.max_rpc_retries = 0;
options.rpc_retry_delay_ms = 0;
SharedConnectionEngine engine(&io_service, options, "foo", "protocol", 1);
EXPECT_CALL(*producer, Produce())
.WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")));
engine.Connect(asio::ip::basic_endpoint<asio::ip::tcp>(), [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_FALSE(stat.ok());
});
io_service.run();
ASSERT_TRUE(complete);
}
TEST(RpcEngineTest, TestConnectionFailureRetryAndFailure)
{
auto producer = std::make_shared<SharedConnectionData>();
producer->checkProducerForConnect = true;
SharedMockConnection::SetSharedConnectionData(producer);
::asio::io_service io_service;
bool complete = false;
Options options;
options.max_rpc_retries = 2;
options.rpc_retry_delay_ms = 0;
SharedConnectionEngine engine(&io_service, options, "foo", "protocol", 1);
EXPECT_CALL(*producer, Produce())
.WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")))
.WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")))
.WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")));
engine.Connect(asio::ip::basic_endpoint<asio::ip::tcp>(), [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_FALSE(stat.ok());
});
io_service.run();
ASSERT_TRUE(complete);
}
TEST(RpcEngineTest, TestConnectionFailureAndRecover)
{
auto producer = std::make_shared<SharedConnectionData>();
producer->checkProducerForConnect = true;
SharedMockConnection::SetSharedConnectionData(producer);
::asio::io_service io_service;
bool complete = false;
Options options;
options.max_rpc_retries = 1;
options.rpc_retry_delay_ms = 0;
SharedConnectionEngine engine(&io_service, options, "foo", "protocol", 1);
EXPECT_CALL(*producer, Produce())
.WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")))
.WillOnce(Return(std::make_pair(::asio::error_code(), "")))
.WillOnce(Return(std::make_pair(::asio::error::would_block, "")));
engine.Connect(asio::ip::basic_endpoint<asio::ip::tcp>(), [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_TRUE(stat.ok());
});
io_service.run();
ASSERT_TRUE(complete);
}
TEST(RpcEngineTest, TestConnectionFailureAndAsyncRecover)
{
// Error and async recover
auto producer = std::make_shared<SharedConnectionData>();
producer->checkProducerForConnect = true;
SharedMockConnection::SetSharedConnectionData(producer);
::asio::io_service io_service;
bool complete = false;
Options options;
options.max_rpc_retries = 1;
options.rpc_retry_delay_ms = 1;
SharedConnectionEngine engine(&io_service, options, "foo", "protocol", 1);
EXPECT_CALL(*producer, Produce())
.WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")))
.WillOnce(Return(std::make_pair(::asio::error_code(), "")))
.WillOnce(Return(std::make_pair(::asio::error::would_block, "")));
engine.Connect(asio::ip::basic_endpoint<asio::ip::tcp>(), [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_TRUE(stat.ok());
});
::asio::deadline_timer timer(io_service);
timer.expires_from_now(std::chrono::hours(100));
timer.async_wait([](const asio::error_code & err){(void)err; ASSERT_FALSE("Timed out"); });
io_service.run();
ASSERT_TRUE(complete);
}
TEST(RpcEngineTest, TestTimeout) {
@ -151,24 +363,32 @@ TEST(RpcEngineTest, TestTimeout) {
RpcEngine engine(&io_service, options, "foo", "protocol", 1);
RpcConnectionImpl<MockRPCConnection> *conn =
new RpcConnectionImpl<MockRPCConnection>(&engine);
conn->TEST_set_connected(true);
conn->StartReading();
EXPECT_CALL(conn->next_layer(), Produce()).Times(0);
EXPECT_CALL(conn->next_layer(), Produce())
.WillOnce(Return(std::make_pair(::asio::error::would_block, "")));
std::unique_ptr<RpcConnection> conn_ptr(conn);
engine.TEST_SetRpcConnection(&conn_ptr);
std::shared_ptr<RpcConnection> conn_ptr(conn);
engine.TEST_SetRpcConnection(conn_ptr);
bool complete = false;
EchoRequestProto req;
req.set_message("foo");
std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
engine.AsyncRpc("test", &req, resp, [resp, &io_service](const Status &stat) {
engine.AsyncRpc("test", &req, resp, [resp, &complete,&io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_FALSE(stat.ok());
});
::asio::deadline_timer timer(io_service);
timer.expires_from_now(std::chrono::milliseconds(options.rpc_timeout * 2));
timer.async_wait(std::bind(&RpcConnection::Start, conn));
timer.expires_from_now(std::chrono::hours(100));
timer.async_wait([](const asio::error_code & err){(void)err; ASSERT_FALSE("Timed out"); });
io_service.run();
ASSERT_TRUE(complete);
}
int main(int argc, char *argv[]) {