HDFS-12427: libhdfs++: Prevent requests from holding dangling pointer to RpcEngine. Contributed by James Clampffer.

This commit is contained in:
James 2017-10-30 13:54:05 -04:00 committed by James Clampffer
parent 48db24a430
commit fc5e44d9ed
12 changed files with 240 additions and 160 deletions

View File

@ -212,6 +212,8 @@ private:
* The IoService must be the first member variable to ensure that it gets
* destroyed last. This allows other members to dequeue things from the
* service in their own destructors.
* A side effect of this is that requests may outlive the RpcEngine they
* reference.
**/
std::shared_ptr<IoServiceImpl> io_service_;
const Options options_;

View File

@ -42,11 +42,11 @@ namespace hdfs {
void NameNodeOperations::Connect(const std::string &cluster_name,
const std::vector<ResolvedNamenodeInfo> &servers,
std::function<void(const Status &)> &&handler) {
engine_.Connect(cluster_name, servers, handler);
engine_->Connect(cluster_name, servers, handler);
}
bool NameNodeOperations::CancelPendingConnect() {
return engine_.CancelPendingConnect();
return engine_->CancelPendingConnect();
}
void NameNodeOperations::GetBlockLocations(const std::string & path, uint64_t offset, uint64_t length,
@ -678,7 +678,7 @@ void NameNodeOperations::DisallowSnapshot(const std::string & path, std::functio
}
void NameNodeOperations::SetFsEventCallback(fs_event_callback callback) {
engine_.SetFsEventCallback(callback);
engine_->SetFsEventCallback(callback);
}
void NameNodeOperations::HdfsFileStatusProtoToStatInfo(

View File

@ -46,8 +46,9 @@ public:
const std::string &client_name, const std::string &user_name,
const char *protocol_name, int protocol_version) :
io_service_(io_service),
engine_(io_service, options, client_name, user_name, protocol_name, protocol_version),
namenode_(& engine_), options_(options) {}
engine_(std::make_shared<RpcEngine>(io_service, options, client_name, user_name, protocol_name, protocol_version)),
namenode_(engine_), options_(options) {}
void Connect(const std::string &cluster_name,
const std::vector<ResolvedNamenodeInfo> &servers,
@ -119,7 +120,14 @@ private:
static void GetFsStatsResponseProtoToFsInfo(hdfs::FsInfo & fs_info, const std::shared_ptr<::hadoop::hdfs::GetFsStatsResponseProto> & fs);
::asio::io_service * io_service_;
RpcEngine engine_;
// This is the only permanent owner of the RpcEngine, however the RPC layer
// needs to reference count it prevent races during FileSystem destruction.
// In order to do this they hold weak_ptrs and promote them to shared_ptr
// when calling non-blocking RpcEngine methods e.g. get_client_id().
std::shared_ptr<RpcEngine> engine_;
// Automatically generated methods for RPC calls. See protoc_gen_hrpc.cc
ClientNamenodeProtocol namenode_;
const Options options_;
};

View File

@ -64,11 +64,11 @@ void StubGenerator::EmitService(const ServiceDescriptor *service,
out->Print("\n// GENERATED AUTOMATICALLY. DO NOT MODIFY.\n"
"class $service$ {\n"
"private:\n"
" ::hdfs::RpcEngine *const engine_;\n"
" std::shared_ptr<::hdfs::RpcEngine> engine_;\n"
"public:\n"
" typedef std::function<void(const ::hdfs::Status &)> Callback;\n"
" typedef ::google::protobuf::MessageLite Message;\n"
" inline $service$(::hdfs::RpcEngine *engine)\n"
" inline $service$(std::shared_ptr<::hdfs::RpcEngine> engine)\n"
" : engine_(engine) {}\n",
"service", service->name());
for (int i = 0; i < service->method_count(); ++i) {

View File

@ -38,6 +38,11 @@ using namespace ::std::placeholders;
static const int kNoRetry = -1;
// Protobuf helper functions.
// Note/todo: Using the zero-copy protobuf API here makes the simple procedures
// below tricky to read and debug while providing minimal benefit. Reducing
// allocations in BlockReader (HDFS-11266) and smarter use of std::stringstream
// will have a much larger impact according to cachegrind profiles on common
// workloads.
static void AddHeadersToPacket(std::string *res,
std::initializer_list<const pb::MessageLite *> headers,
const std::string *payload) {
@ -82,50 +87,33 @@ static void ConstructPayload(std::string *res, const pb::MessageLite *header) {
buf = header->SerializeWithCachedSizesToArray(buf);
}
static void ConstructPayload(std::string *res, const std::string *request) {
int len =
pbio::CodedOutputStream::VarintSize32(request->size()) + request->size();
res->reserve(len);
pbio::StringOutputStream ss(res);
pbio::CodedOutputStream os(&ss);
uint8_t *buf = os.GetDirectBufferForNBytesAndAdvance(len);
assert(buf);
buf = pbio::CodedOutputStream::WriteVarint32ToArray(request->size(), buf);
buf = os.WriteStringToArray(*request, buf);
}
static void SetRequestHeader(LockFreeRpcEngine *engine, int call_id,
static void SetRequestHeader(std::weak_ptr<LockFreeRpcEngine> weak_engine, int call_id,
const std::string &method_name, int retry_count,
RpcRequestHeaderProto *rpc_header,
RequestHeaderProto *req_header) {
RequestHeaderProto *req_header)
{
// Ensure the RpcEngine is live. If it's not then the FileSystem is being destructed.
std::shared_ptr<LockFreeRpcEngine> counted_engine = weak_engine.lock();
if(!counted_engine) {
LOG_ERROR(kRPC, << "SetRequestHeader attempted to access an invalid RpcEngine");
return;
}
rpc_header->set_rpckind(RPC_PROTOCOL_BUFFER);
rpc_header->set_rpcop(RpcRequestHeaderProto::RPC_FINAL_PACKET);
rpc_header->set_callid(call_id);
if (retry_count != kNoRetry)
if (retry_count != kNoRetry) {
rpc_header->set_retrycount(retry_count);
rpc_header->set_clientid(engine->client_id());
}
rpc_header->set_clientid(counted_engine->client_id());
req_header->set_methodname(method_name);
req_header->set_declaringclassprotocolname(engine->protocol_name());
req_header->set_clientprotocolversion(engine->protocol_version());
req_header->set_declaringclassprotocolname(counted_engine->protocol_name());
req_header->set_clientprotocolversion(counted_engine->protocol_version());
}
// Request implementation
Request::Request(LockFreeRpcEngine *engine, const std::string &method_name, int call_id,
const std::string &request, Handler &&handler)
: engine_(engine),
method_name_(method_name),
call_id_(call_id),
timer_(engine->io_service()),
handler_(std::move(handler)),
retry_count_(engine->retry_policy() ? 0 : kNoRetry),
failover_count_(0) {
ConstructPayload(&payload_, &request);
}
Request::Request(LockFreeRpcEngine *engine, const std::string &method_name, int call_id,
Request::Request(std::shared_ptr<LockFreeRpcEngine> engine, const std::string &method_name, int call_id,
const pb::MessageLite *request, Handler &&handler)
: engine_(engine),
method_name_(method_name),
@ -133,13 +121,14 @@ Request::Request(LockFreeRpcEngine *engine, const std::string &method_name, int
timer_(engine->io_service()),
handler_(std::move(handler)),
retry_count_(engine->retry_policy() ? 0 : kNoRetry),
failover_count_(0) {
failover_count_(0)
{
ConstructPayload(&payload_, request);
}
Request::Request(LockFreeRpcEngine *engine, Handler &&handler)
Request::Request(std::shared_ptr<LockFreeRpcEngine> engine, Handler &&handler)
: engine_(engine),
call_id_(-1),
call_id_(-1/*Handshake ID*/),
timer_(engine->io_service()),
handler_(std::move(handler)),
retry_count_(engine->retry_policy() ? 0 : kNoRetry),

View File

@ -23,6 +23,7 @@
#include "common/new_delete.h"
#include <string>
#include <memory>
#include <google/protobuf/message_lite.h>
#include <google/protobuf/io/coded_stream.h>
@ -48,14 +49,13 @@ class Request {
typedef std::function<void(::google::protobuf::io::CodedInputStream *is,
const Status &status)> Handler;
Request(LockFreeRpcEngine *engine, const std::string &method_name, int call_id,
const std::string &request, Handler &&callback);
Request(LockFreeRpcEngine *engine, const std::string &method_name, int call_id,
// Constructors will not make any blocking calls while holding the shared_ptr<RpcEngine>
Request(std::shared_ptr<LockFreeRpcEngine> engine, const std::string &method_name, int call_id,
const ::google::protobuf::MessageLite *request, Handler &&callback);
// Null request (with no actual message) used to track the state of an
// initial Connect call
Request(LockFreeRpcEngine *engine, Handler &&handler);
Request(std::shared_ptr<LockFreeRpcEngine> engine, Handler &&handler);
int call_id() const { return call_id_; }
std::string method_name() const { return method_name_; }
@ -71,7 +71,7 @@ class Request {
std::string GetDebugString() const;
private:
LockFreeRpcEngine *const engine_;
std::weak_ptr<LockFreeRpcEngine> engine_;
const std::string method_name_;
const int call_id_;

View File

@ -52,7 +52,7 @@ class SaslProtocol;
class RpcConnection : public std::enable_shared_from_this<RpcConnection> {
public:
MEMCHECKED_CLASS(RpcConnection)
RpcConnection(LockFreeRpcEngine *engine);
RpcConnection(std::shared_ptr<LockFreeRpcEngine> engine);
virtual ~RpcConnection();
// Note that a single server can have multiple endpoints - especially both
@ -82,8 +82,8 @@ class RpcConnection : public std::enable_shared_from_this<RpcConnection> {
void SetClusterName(std::string cluster_name);
void SetAuthInfo(const AuthInfo& auth_info);
LockFreeRpcEngine *engine() { return engine_; }
::asio::io_service &io_service();
std::weak_ptr<LockFreeRpcEngine> engine() { return engine_; }
::asio::io_service *GetIoService();
protected:
struct Response {
@ -139,7 +139,7 @@ class RpcConnection : public std::enable_shared_from_this<RpcConnection> {
void ClearAndDisconnect(const ::asio::error_code &ec);
std::shared_ptr<Request> RemoveFromRunningQueue(int call_id);
LockFreeRpcEngine *const engine_;
std::weak_ptr<LockFreeRpcEngine> engine_;
std::shared_ptr<Response> current_response_state_;
AuthInfo auth_info_;
@ -158,16 +158,17 @@ class RpcConnection : public std::enable_shared_from_this<RpcConnection> {
// State machine for performing a SASL handshake
std::shared_ptr<SaslProtocol> sasl_protocol_;
// The request being sent over the wire; will also be in requests_on_fly_
std::shared_ptr<Request> request_over_the_wire_;
// The request being sent over the wire; will also be in sent_requests_
std::shared_ptr<Request> outgoing_request_;
// Requests to be sent over the wire
std::deque<std::shared_ptr<Request>> pending_requests_;
// Requests to be sent over the wire during authentication; not retried if
// there is a connection error
std::deque<std::shared_ptr<Request>> auth_requests_;
// Requests that are waiting for responses
typedef std::unordered_map<int, std::shared_ptr<Request>> RequestOnFlyMap;
RequestOnFlyMap requests_on_fly_;
typedef std::unordered_map<int, std::shared_ptr<Request>> SentRequestMap;
SentRequestMap sent_requests_;
std::shared_ptr<LibhdfsEvents> event_handlers_;
std::string cluster_name_;

View File

@ -66,17 +66,29 @@ static void AddHeadersToPacket(
RpcConnection::~RpcConnection() {}
RpcConnection::RpcConnection(LockFreeRpcEngine *engine)
RpcConnection::RpcConnection(std::shared_ptr<LockFreeRpcEngine> engine)
: engine_(engine),
connected_(kNotYetConnected) {}
::asio::io_service &RpcConnection::io_service() {
return engine_->io_service();
::asio::io_service *RpcConnection::GetIoService() {
std::shared_ptr<LockFreeRpcEngine> pinnedEngine = engine_.lock();
if(!pinnedEngine) {
LOG_ERROR(kRPC, << "RpcConnection@" << this << " attempted to access invalid RpcEngine");
return nullptr;
}
return &pinnedEngine->io_service();
}
void RpcConnection::StartReading() {
auto shared_this = shared_from_this();
io_service().post([shared_this, this] () {
::asio::io_service *service = GetIoService();
if(!service) {
LOG_ERROR(kRPC, << "RpcConnection@" << this << " attempted to access invalid IoService");
return;
}
service->post([shared_this, this] () {
OnRecvCompleted(::asio::error_code(), 0);
});
}
@ -151,12 +163,19 @@ void RpcConnection::ContextComplete(const Status &s) {
void RpcConnection::AsyncFlushPendingRequests() {
std::shared_ptr<RpcConnection> shared_this = shared_from_this();
io_service().post([shared_this, this]() {
::asio::io_service *service = GetIoService();
if(!service) {
LOG_ERROR(kRPC, << "RpcConnection@" << this << " attempted to access invalid IoService");
return;
}
service->post([shared_this, this]() {
std::lock_guard<std::mutex> state_lock(connection_state_lock_);
LOG_TRACE(kRPC, << "RpcConnection::AsyncFlushPendingRequests called (connected=" << ToString(connected_) << ")");
if (!request_over_the_wire_) {
if (!outgoing_request_) {
FlushPendingRequests();
}
});
@ -209,7 +228,13 @@ Status RpcConnection::HandleRpcResponse(std::shared_ptr<Response> response) {
return status;
}
io_service().post([req, response, status]() {
::asio::io_service *service = GetIoService();
if(!service) {
LOG_ERROR(kRPC, << "RpcConnection@" << this << " attempted to access invalid IoService");
return Status::Error("RpcConnection attempted to access invalid IoService");
}
service->post([req, response, status]() {
req->OnResponseArrived(response->in.get(), status); // Never call back while holding a lock
});
@ -267,23 +292,29 @@ std::shared_ptr<std::string> RpcConnection::PrepareContextPacket() {
// after the SASL handshake (if any)
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
auto res = std::make_shared<std::string>();
std::shared_ptr<LockFreeRpcEngine> pinnedEngine = engine_.lock();
if(!pinnedEngine) {
LOG_ERROR(kRPC, << "RpcConnection@" << this << " attempted to access invalid RpcEngine");
return std::make_shared<std::string>();
}
RpcRequestHeaderProto h;
h.set_rpckind(RPC_PROTOCOL_BUFFER);
h.set_rpcop(RpcRequestHeaderProto::RPC_FINAL_PACKET);
h.set_callid(RpcEngine::kCallIdConnectionContext);
h.set_clientid(engine_->client_name());
std::shared_ptr<std::string> serializedPacketBuffer = std::make_shared<std::string>();
IpcConnectionContextProto handshake;
handshake.set_protocol(engine_->protocol_name());
RpcRequestHeaderProto headerProto;
headerProto.set_rpckind(RPC_PROTOCOL_BUFFER);
headerProto.set_rpcop(RpcRequestHeaderProto::RPC_FINAL_PACKET);
headerProto.set_callid(RpcEngine::kCallIdConnectionContext);
headerProto.set_clientid(pinnedEngine->client_name());
IpcConnectionContextProto handshakeContextProto;
handshakeContextProto.set_protocol(pinnedEngine->protocol_name());
const std::string & user_name = auth_info_.getUser();
if (!user_name.empty()) {
*handshake.mutable_userinfo()->mutable_effectiveuser() = user_name;
*handshakeContextProto.mutable_userinfo()->mutable_effectiveuser() = user_name;
}
AddHeadersToPacket(res.get(), {&h, &handshake}, nullptr);
AddHeadersToPacket(serializedPacketBuffer.get(), {&headerProto, &handshakeContextProto}, nullptr);
return res;
return serializedPacketBuffer;
}
void RpcConnection::AsyncRpc(
@ -310,11 +341,22 @@ void RpcConnection::AsyncRpc_locked(
handler(status);
};
int call_id = (method_name != SASL_METHOD_NAME ? engine_->NextCallId() : RpcEngine::kCallIdSasl);
auto r = std::make_shared<Request>(engine_, method_name, call_id, req,
std::move(wrapped_handler));
auto r_vector = std::vector<std::shared_ptr<Request> > (1, r);
SendRpcRequests(r_vector);
std::shared_ptr<Request> rpcRequest;
{ // Scope to minimize how long RpcEngine's lifetime may be extended
std::shared_ptr<LockFreeRpcEngine> pinnedEngine = engine_.lock();
if(!pinnedEngine) {
LOG_ERROR(kRPC, << "RpcConnection@" << this << " attempted to access invalid RpcEngine");
handler(Status::Error("Invalid RpcEngine access."));
return;
}
int call_id = (method_name != SASL_METHOD_NAME ? pinnedEngine->NextCallId() : RpcEngine::kCallIdSasl);
rpcRequest = std::make_shared<Request>(pinnedEngine, method_name, call_id,
req, std::move(wrapped_handler));
}
SendRpcRequests({rpcRequest});
}
void RpcConnection::AsyncRpc(const std::vector<std::shared_ptr<Request> > & requests) {
@ -330,13 +372,20 @@ void RpcConnection::SendRpcRequests(const std::vector<std::shared_ptr<Request> >
// Oops. The connection failed _just_ before the engine got a chance
// to send it. Register it as a failure
Status status = Status::ResourceUnavailable("RpcConnection closed before send.");
engine_->AsyncRpcCommsError(status, shared_from_this(), requests);
std::shared_ptr<LockFreeRpcEngine> pinnedEngine = engine_.lock();
if(!pinnedEngine) {
LOG_ERROR(kRPC, << "RpcConnection@" << this << " attempted to access invalid RpcEngine");
return;
}
pinnedEngine->AsyncRpcCommsError(status, shared_from_this(), requests);
} else {
for (auto r: requests) {
if (r->method_name() != SASL_METHOD_NAME)
pending_requests_.push_back(r);
for (auto request : requests) {
if (request->method_name() != SASL_METHOD_NAME)
pending_requests_.push_back(request);
else
auth_requests_.push_back(r);
auth_requests_.push_back(request);
}
if (connected_ == kConnected || connected_ == kHandshaking || connected_ == kAuthenticating) { // Dont flush if we're waiting or handshaking
FlushPendingRequests();
@ -395,26 +444,32 @@ void RpcConnection::CommsError(const Status &status) {
// Anything that has been queued to the connection (on the fly or pending)
// will get dinged for a retry
std::vector<std::shared_ptr<Request>> requestsToReturn;
std::transform(requests_on_fly_.begin(), requests_on_fly_.end(),
std::transform(sent_requests_.begin(), sent_requests_.end(),
std::back_inserter(requestsToReturn),
std::bind(&RequestOnFlyMap::value_type::second, _1));
requests_on_fly_.clear();
std::bind(&SentRequestMap::value_type::second, _1));
sent_requests_.clear();
requestsToReturn.insert(requestsToReturn.end(),
std::make_move_iterator(pending_requests_.begin()),
std::make_move_iterator(pending_requests_.end()));
pending_requests_.clear();
engine_->AsyncRpcCommsError(status, shared_from_this(), requestsToReturn);
std::shared_ptr<LockFreeRpcEngine> pinnedEngine = engine_.lock();
if(!pinnedEngine) {
LOG_ERROR(kRPC, << "RpcConnection@" << this << " attempted to access an invalid RpcEngine");
return;
}
pinnedEngine->AsyncRpcCommsError(status, shared_from_this(), requestsToReturn);
}
void RpcConnection::ClearAndDisconnect(const ::asio::error_code &ec) {
Disconnect();
std::vector<std::shared_ptr<Request>> requests;
std::transform(requests_on_fly_.begin(), requests_on_fly_.end(),
std::transform(sent_requests_.begin(), sent_requests_.end(),
std::back_inserter(requests),
std::bind(&RequestOnFlyMap::value_type::second, _1));
requests_on_fly_.clear();
std::bind(&SentRequestMap::value_type::second, _1));
sent_requests_.clear();
requests.insert(requests.end(),
std::make_move_iterator(pending_requests_.begin()),
std::make_move_iterator(pending_requests_.end()));
@ -426,13 +481,13 @@ void RpcConnection::ClearAndDisconnect(const ::asio::error_code &ec) {
std::shared_ptr<Request> RpcConnection::RemoveFromRunningQueue(int call_id) {
assert(lock_held(connection_state_lock_)); // Must be holding lock before calling
auto it = requests_on_fly_.find(call_id);
if (it == requests_on_fly_.end()) {
auto it = sent_requests_.find(call_id);
if (it == sent_requests_.end()) {
return std::shared_ptr<Request>();
}
auto req = it->second;
requests_on_fly_.erase(it);
sent_requests_.erase(it);
return req;
}

View File

@ -40,7 +40,7 @@ class RpcConnectionImpl : public RpcConnection {
public:
MEMCHECKED_CLASS(RpcConnectionImpl)
RpcConnectionImpl(RpcEngine *engine);
RpcConnectionImpl(std::shared_ptr<RpcEngine> engine);
virtual ~RpcConnectionImpl() override;
virtual void Connect(const std::vector<::asio::ip::tcp::endpoint> &server,
@ -73,7 +73,7 @@ public:
};
template <class Socket>
RpcConnectionImpl<Socket>::RpcConnectionImpl(RpcEngine *engine)
RpcConnectionImpl<Socket>::RpcConnectionImpl(std::shared_ptr<RpcEngine> engine)
: RpcConnection(engine),
options_(engine->options()),
socket_(engine->io_service()),
@ -88,8 +88,8 @@ RpcConnectionImpl<Socket>::~RpcConnectionImpl() {
if (pending_requests_.size() > 0)
LOG_WARN(kRPC, << "RpcConnectionImpl::~RpcConnectionImpl called with items in the pending queue");
if (requests_on_fly_.size() > 0)
LOG_WARN(kRPC, << "RpcConnectionImpl::~RpcConnectionImpl called with items in the requests_on_fly queue");
if (sent_requests_.size() > 0)
LOG_WARN(kRPC, << "RpcConnectionImpl::~RpcConnectionImpl called with items in the sent_requests queue");
}
template <class Socket>
@ -101,13 +101,23 @@ void RpcConnectionImpl<Socket>::Connect(
this->auth_info_ = auth_info;
auto connectionSuccessfulReq = std::make_shared<Request>(
engine_, [handler](::google::protobuf::io::CodedInputStream *is,
const Status &status) {
(void)is;
handler(status);
});
pending_requests_.push_back(connectionSuccessfulReq);
std::shared_ptr<Request> connectionRequest;
{ // Scope to minimize how long RpcEngine's lifetime may be extended
std::shared_ptr<LockFreeRpcEngine> pinned_engine = engine_.lock();
if(!pinned_engine) {
LOG_ERROR(kRPC, << "RpcConnectionImpl@" << this << " attempted to access invalid RpcEngine");
handler(Status::Error("Invalid RpcEngine access."));
return;
}
connectionRequest = std::make_shared<Request>(pinned_engine,
[handler](::google::protobuf::io::CodedInputStream *is,const Status &status) {
(void)is;
handler(status);
});
}
pending_requests_.push_back(connectionRequest);
this->ConnectAndFlush(server); // need "this" so compiler can infer type of CAF
}
@ -263,7 +273,7 @@ void RpcConnectionImpl<Socket>::OnSendCompleted(const ::asio::error_code &ec,
LOG_TRACE(kRPC, << "RpcConnectionImpl::OnSendCompleted called");
request_over_the_wire_.reset();
outgoing_request_.reset();
if (ec) {
LOG_WARN(kRPC, << "Network error during RPC write: " << ec.message());
CommsError(ToStatus(ec));
@ -283,7 +293,7 @@ void RpcConnectionImpl<Socket>::FlushPendingRequests() {
LOG_TRACE(kRPC, << "RpcConnectionImpl::FlushPendingRequests called");
// Don't send if we don't need to
if (request_over_the_wire_) {
if (outgoing_request_) {
return;
}
@ -324,9 +334,9 @@ void RpcConnectionImpl<Socket>::FlushPendingRequests() {
std::shared_ptr<std::string> payload = std::make_shared<std::string>();
req->GetPacket(payload.get());
if (!payload->empty()) {
assert(requests_on_fly_.find(req->call_id()) == requests_on_fly_.end());
requests_on_fly_[req->call_id()] = req;
request_over_the_wire_ = req;
assert(sent_requests_.find(req->call_id()) == sent_requests_.end());
sent_requests_[req->call_id()] = req;
outgoing_request_ = req;
req->timer().expires_from_now(
std::chrono::milliseconds(options_.rpc_timeout));
@ -343,7 +353,15 @@ void RpcConnectionImpl<Socket>::FlushPendingRequests() {
OnSendCompleted(ec, size);
});
} else { // Nothing to send for this request, inform the handler immediately
io_service().post(
::asio::io_service *service = GetIoService();
if(!service) {
LOG_ERROR(kRPC, << "RpcConnectionImpl@" << this << " attempted to access null IoService");
// No easy way to bail out of this context, but the only way to get here is when
// the FileSystem is being destroyed.
return;
}
service->post(
// Never hold locks when calling a callback
[req]() { req->OnResponseArrived(nullptr, Status::OK()); }
);
@ -433,7 +451,7 @@ void RpcConnectionImpl<Socket>::Disconnect() {
LOG_INFO(kRPC, << "RpcConnectionImpl::Disconnect called");
request_over_the_wire_.reset();
outgoing_request_.reset();
if (connected_ == kConnecting || connected_ == kHandshaking || connected_ == kAuthenticating || connected_ == kConnected) {
// Don't print out errors, we were expecting a disconnect here
SafeDisconnect(get_asio_socket_ptr(&socket_));

View File

@ -171,7 +171,7 @@ std::shared_ptr<RpcConnection> RpcEngine::NewConnection()
{
LOG_DEBUG(kRPC, << "RpcEngine::NewConnection called");
return std::make_shared<RpcConnectionImpl<::asio::ip::tcp::socket>>(this);
return std::make_shared<RpcConnectionImpl<::asio::ip::tcp::socket>>(shared_from_this());
}
std::shared_ptr<RpcConnection> RpcEngine::InitializeConnection()

View File

@ -75,16 +75,16 @@ public:
std::vector<std::shared_ptr<Request>> pendingRequests) = 0;
virtual const RetryPolicy * retry_policy() const = 0;
virtual const RetryPolicy *retry_policy() = 0;
virtual int NextCallId() = 0;
virtual const std::string &client_name() const = 0;
virtual const std::string &client_id() const = 0;
virtual const std::string &user_name() const = 0;
virtual const std::string &protocol_name() const = 0;
virtual int protocol_version() const = 0;
virtual const std::string &client_name() = 0;
virtual const std::string &client_id() = 0;
virtual const std::string &user_name() = 0;
virtual const std::string &protocol_name() = 0;
virtual int protocol_version() = 0;
virtual ::asio::io_service &io_service() = 0;
virtual const Options &options() const = 0;
virtual const Options &options() = 0;
};
@ -95,7 +95,7 @@ public:
* Threading model: thread-safe. All callbacks will be called back from
* an asio pool and will not hold any internal locks
*/
class RpcEngine : public LockFreeRpcEngine {
class RpcEngine : public LockFreeRpcEngine, public std::enable_shared_from_this<RpcEngine> {
public:
MEMCHECKED_CLASS(RpcEngine)
enum { kRpcVersion = 9 };
@ -133,20 +133,20 @@ class RpcEngine : public LockFreeRpcEngine {
std::vector<std::shared_ptr<Request>> pendingRequests);
const RetryPolicy * retry_policy() const override { return retry_policy_.get(); }
const RetryPolicy * retry_policy() override { return retry_policy_.get(); }
int NextCallId() override { return ++call_id_; }
void TEST_SetRpcConnection(std::shared_ptr<RpcConnection> conn);
void TEST_SetRetryPolicy(std::unique_ptr<const RetryPolicy> policy);
std::unique_ptr<const RetryPolicy> TEST_GenerateRetryPolicyUsingOptions();
const std::string &client_name() const override { return client_name_; }
const std::string &client_id() const override { return client_id_; }
const std::string &user_name() const override { return auth_info_.getUser(); }
const std::string &protocol_name() const override { return protocol_name_; }
int protocol_version() const override { return protocol_version_; }
const std::string &client_name() override { return client_name_; }
const std::string &client_id() override { return client_id_; }
const std::string &user_name() override { return auth_info_.getUser(); }
const std::string &protocol_name() override { return protocol_name_; }
int protocol_version() override { return protocol_version_; }
::asio::io_service &io_service() override { return *io_service_; }
const Options &options() const override { return options_; }
const Options &options() override { return options_; }
static std::string GetRandomClientName();
void SetFsEventCallback(fs_event_callback callback);

View File

@ -71,7 +71,7 @@ protected:
// Stuff in some dummy endpoints so we don't error out
last_endpoints_ = make_endpoint()[0].endpoints;
return std::make_shared<RpcConnectionImpl<SharedMockRPCConnection>>(this);
return std::make_shared<RpcConnectionImpl<SharedMockRPCConnection>>(shared_from_this());
}
};
@ -106,9 +106,9 @@ using namespace hdfs;
TEST(RpcEngineTest, TestRoundTrip) {
::asio::io_service io_service;
Options options;
RpcEngine engine(&io_service, options, "foo", "", "protocol", 1);
std::shared_ptr<RpcEngine> engine = std::make_shared<RpcEngine>(&io_service, options, "foo", "", "protocol", 1);
auto conn =
std::make_shared<RpcConnectionImpl<MockRPCConnection> >(&engine);
std::make_shared<RpcConnectionImpl<MockRPCConnection> >(engine);
conn->TEST_set_connected(true);
conn->StartReading();
@ -122,14 +122,14 @@ TEST(RpcEngineTest, TestRoundTrip) {
.WillOnce(Return(RpcResponse(h, server_resp.SerializeAsString())));
std::shared_ptr<RpcConnection> conn_ptr(conn);
engine.TEST_SetRpcConnection(conn_ptr);
engine->TEST_SetRpcConnection(conn_ptr);
bool complete = false;
EchoRequestProto req;
req.set_message("foo");
std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
engine.AsyncRpc("test", &req, resp, [resp, &complete,&io_service](const Status &stat) {
engine->AsyncRpc("test", &req, resp, [resp, &complete,&io_service](const Status &stat) {
ASSERT_TRUE(stat.ok());
ASSERT_EQ("foo", resp->message());
complete = true;
@ -142,9 +142,9 @@ TEST(RpcEngineTest, TestRoundTrip) {
TEST(RpcEngineTest, TestConnectionResetAndFail) {
::asio::io_service io_service;
Options options;
RpcEngine engine(&io_service, options, "foo", "", "protocol", 1);
std::shared_ptr<RpcEngine> engine = std::make_shared<RpcEngine>(&io_service, options, "foo", "", "protocol", 1);
auto conn =
std::make_shared<RpcConnectionImpl<MockRPCConnection> >(&engine);
std::make_shared<RpcConnectionImpl<MockRPCConnection> >(engine);
conn->TEST_set_connected(true);
conn->StartReading();
@ -158,13 +158,13 @@ TEST(RpcEngineTest, TestConnectionResetAndFail) {
h, "", make_error_code(::asio::error::connection_reset))));
std::shared_ptr<RpcConnection> conn_ptr(conn);
engine.TEST_SetRpcConnection(conn_ptr);
engine->TEST_SetRpcConnection(conn_ptr);
EchoRequestProto req;
req.set_message("foo");
std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
engine.AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
engine->AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_FALSE(stat.ok());
@ -179,11 +179,12 @@ TEST(RpcEngineTest, TestConnectionResetAndRecover) {
Options options;
options.max_rpc_retries = 1;
options.rpc_retry_delay_ms = 0;
SharedConnectionEngine engine(&io_service, options, "foo", "", "protocol", 1);
std::shared_ptr<SharedConnectionEngine> engine
= std::make_shared<SharedConnectionEngine>(&io_service, options, "foo", "", "protocol", 1);
// Normally determined during RpcEngine::Connect, but in this case options
// provides enough info to determine policy here.
engine.TEST_SetRetryPolicy(engine.TEST_GenerateRetryPolicyUsingOptions());
engine->TEST_SetRetryPolicy(engine->TEST_GenerateRetryPolicyUsingOptions());
EchoResponseProto server_resp;
@ -205,7 +206,7 @@ TEST(RpcEngineTest, TestConnectionResetAndRecover) {
req.set_message("foo");
std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
engine.AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
engine->AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_TRUE(stat.ok());
@ -219,11 +220,12 @@ TEST(RpcEngineTest, TestConnectionResetAndRecoverWithDelay) {
Options options;
options.max_rpc_retries = 1;
options.rpc_retry_delay_ms = 1;
SharedConnectionEngine engine(&io_service, options, "foo", "", "protocol", 1);
std::shared_ptr<SharedConnectionEngine> engine =
std::make_shared<SharedConnectionEngine>(&io_service, options, "foo", "", "protocol", 1);
// Normally determined during RpcEngine::Connect, but in this case options
// provides enough info to determine policy here.
engine.TEST_SetRetryPolicy(engine.TEST_GenerateRetryPolicyUsingOptions());
engine->TEST_SetRetryPolicy(engine->TEST_GenerateRetryPolicyUsingOptions());
EchoResponseProto server_resp;
server_resp.set_message("foo");
@ -244,7 +246,7 @@ TEST(RpcEngineTest, TestConnectionResetAndRecoverWithDelay) {
req.set_message("foo");
std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
engine.AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
engine->AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_TRUE(stat.ok());
@ -272,11 +274,12 @@ TEST(RpcEngineTest, TestConnectionFailure)
Options options;
options.max_rpc_retries = 0;
options.rpc_retry_delay_ms = 0;
SharedConnectionEngine engine(&io_service, options, "foo", "", "protocol", 1);
std::shared_ptr<SharedConnectionEngine> engine
= std::make_shared<SharedConnectionEngine>(&io_service, options, "foo", "", "protocol", 1);
EXPECT_CALL(*producer, Produce())
.WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")));
engine.Connect("", make_endpoint(), [&complete, &io_service](const Status &stat) {
engine->Connect("", make_endpoint(), [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_FALSE(stat.ok());
@ -298,13 +301,14 @@ TEST(RpcEngineTest, TestConnectionFailureRetryAndFailure)
Options options;
options.max_rpc_retries = 2;
options.rpc_retry_delay_ms = 0;
SharedConnectionEngine engine(&io_service, options, "foo", "", "protocol", 1);
std::shared_ptr<SharedConnectionEngine> engine =
std::make_shared<SharedConnectionEngine>(&io_service, options, "foo", "", "protocol", 1);
EXPECT_CALL(*producer, Produce())
.WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")))
.WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")))
.WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")));
engine.Connect("", make_endpoint(), [&complete, &io_service](const Status &stat) {
engine->Connect("", make_endpoint(), [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_FALSE(stat.ok());
@ -326,13 +330,14 @@ TEST(RpcEngineTest, TestConnectionFailureAndRecover)
Options options;
options.max_rpc_retries = 1;
options.rpc_retry_delay_ms = 0;
SharedConnectionEngine engine(&io_service, options, "foo", "", "protocol", 1);
std::shared_ptr<SharedConnectionEngine> engine =
std::make_shared<SharedConnectionEngine>(&io_service, options, "foo", "", "protocol", 1);
EXPECT_CALL(*producer, Produce())
.WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")))
.WillOnce(Return(std::make_pair(::asio::error_code(), "")))
.WillOnce(Return(std::make_pair(::asio::error::would_block, "")));
engine.Connect("", make_endpoint(), [&complete, &io_service](const Status &stat) {
engine->Connect("", make_endpoint(), [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_TRUE(stat.ok());
@ -347,16 +352,17 @@ TEST(RpcEngineTest, TestEventCallbacks)
Options options;
options.max_rpc_retries = 99;
options.rpc_retry_delay_ms = 0;
SharedConnectionEngine engine(&io_service, options, "foo", "", "protocol", 1);
std::shared_ptr<SharedConnectionEngine> engine =
std::make_shared<SharedConnectionEngine>(&io_service, options, "foo", "", "protocol", 1);
// Normally determined during RpcEngine::Connect, but in this case options
// provides enough info to determine policy here.
engine.TEST_SetRetryPolicy(engine.TEST_GenerateRetryPolicyUsingOptions());
engine->TEST_SetRetryPolicy(engine->TEST_GenerateRetryPolicyUsingOptions());
// Set up event callbacks
int calls = 0;
std::vector<std::string> callbacks;
engine.SetFsEventCallback([&calls, &callbacks] (const char * event,
engine->SetFsEventCallback([&calls, &callbacks] (const char * event,
const char * cluster,
int64_t value) {
(void)cluster; (void)value;
@ -393,7 +399,7 @@ TEST(RpcEngineTest, TestEventCallbacks)
std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
bool complete = false;
engine.AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
engine->AsyncRpc("test", &req, resp, [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_TRUE(stat.ok());
@ -431,13 +437,14 @@ TEST(RpcEngineTest, TestConnectionFailureAndAsyncRecover)
Options options;
options.max_rpc_retries = 1;
options.rpc_retry_delay_ms = 1;
SharedConnectionEngine engine(&io_service, options, "foo", "", "protocol", 1);
std::shared_ptr<SharedConnectionEngine> engine =
std::make_shared<SharedConnectionEngine>(&io_service, options, "foo", "", "protocol", 1);
EXPECT_CALL(*producer, Produce())
.WillOnce(Return(std::make_pair(make_error_code(::asio::error::connection_reset), "")))
.WillOnce(Return(std::make_pair(::asio::error_code(), "")))
.WillOnce(Return(std::make_pair(::asio::error::would_block, "")));
engine.Connect("", make_endpoint(), [&complete, &io_service](const Status &stat) {
engine->Connect("", make_endpoint(), [&complete, &io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_TRUE(stat.ok());
@ -455,9 +462,9 @@ TEST(RpcEngineTest, TestTimeout) {
::asio::io_service io_service;
Options options;
options.rpc_timeout = 1;
RpcEngine engine(&io_service, options, "foo", "", "protocol", 1);
std::shared_ptr<RpcEngine> engine = std::make_shared<RpcEngine>(&io_service, options, "foo", "", "protocol", 1);
auto conn =
std::make_shared<RpcConnectionImpl<MockRPCConnection> >(&engine);
std::make_shared<RpcConnectionImpl<MockRPCConnection> >(engine);
conn->TEST_set_connected(true);
conn->StartReading();
@ -465,14 +472,14 @@ TEST(RpcEngineTest, TestTimeout) {
.WillOnce(Return(std::make_pair(::asio::error::would_block, "")));
std::shared_ptr<RpcConnection> conn_ptr(conn);
engine.TEST_SetRpcConnection(conn_ptr);
engine->TEST_SetRpcConnection(conn_ptr);
bool complete = false;
EchoRequestProto req;
req.set_message("foo");
std::shared_ptr<EchoResponseProto> resp(new EchoResponseProto());
engine.AsyncRpc("test", &req, resp, [resp, &complete,&io_service](const Status &stat) {
engine->AsyncRpc("test", &req, resp, [resp, &complete,&io_service](const Status &stat) {
complete = true;
io_service.stop();
ASSERT_FALSE(stat.ok());