Refine the connection close

master
Chunting Gu 6 years ago
parent 2a20030f7e
commit 46ea52cb90

@ -238,14 +238,17 @@ int main(int argc, char* argv[]) {
// ... // ...
try { try {
webcc::Server server(8080, 2); webcc::Server server(8080);
server.Route("/books", std::make_shared<BookListView>(), { "GET", "POST" }); server.Route("/books",
std::make_shared<BookListView>(),
{ "GET", "POST" });
server.Route(webcc::R("/books/(\\d+)"), std::make_shared<BookDetailView>(), server.Route(webcc::R("/books/(\\d+)"),
std::make_shared<BookDetailView>(),
{ "GET", "PUT", "DELETE" }); { "GET", "PUT", "DELETE" });
server.Run(); server.Start();
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << e.what() << std::endl; std::cerr << e.what() << std::endl;

@ -1,3 +1,4 @@
#include <cassert>
#include <iostream> #include <iostream>
#include "webcc/client_session.h" #include "webcc/client_session.h"
@ -11,8 +12,6 @@ int main() {
webcc::ResponsePtr r; webcc::ResponsePtr r;
try { try {
r = session.Head("http://httpbin.org/get");
r = session.Request(webcc::RequestBuilder{}. r = session.Request(webcc::RequestBuilder{}.
Get("http://httpbin.org/get"). Get("http://httpbin.org/get").
Query("key1", "value1"). Query("key1", "value1").
@ -21,20 +20,32 @@ int main() {
Header("Accept", "application/json") Header("Accept", "application/json")
()); ());
assert(r->status() == webcc::Status::kOK);
assert(!r->data().empty());
r = session.Get("http://httpbin.org/get", r = session.Get("http://httpbin.org/get",
{ "key1", "value1", "key2", "value2" }, { "key1", "value1", "key2", "value2" },
{ "Accept", "application/json" }); { "Accept", "application/json" });
assert(r->status() == webcc::Status::kOK);
assert(!r->data().empty());
r = session.Request(webcc::RequestBuilder{}. r = session.Request(webcc::RequestBuilder{}.
Post("http://httpbin.org/post"). Post("http://httpbin.org/post").
Body("{'name'='Adam', 'age'=20}"). Body("{'name'='Adam', 'age'=20}").
Json().Utf8() Json().Utf8()
()); ());
assert(r->status() == webcc::Status::kOK);
assert(!r->data().empty());
#if WEBCC_ENABLE_SSL #if WEBCC_ENABLE_SSL
r = session.Get("https://httpbin.org/get"); r = session.Get("https://httpbin.org/get");
assert(r->status() == webcc::Status::kOK);
assert(!r->data().empty());
#endif // WEBCC_ENABLE_SSL #endif // WEBCC_ENABLE_SSL
} catch (const webcc::Error& error) { } catch (const webcc::Error& error) {

@ -48,14 +48,12 @@ int main(int argc, char* argv[]) {
std::uint16_t port = static_cast<std::uint16_t>(std::atoi(argv[1])); std::uint16_t port = static_cast<std::uint16_t>(std::atoi(argv[1]));
std::size_t workers = 2;
try { try {
webcc::Server server(port, workers); webcc::Server server(port);
server.Route("/upload", std::make_shared<FileUploadView>(), { "POST" }); server.Route("/upload", std::make_shared<FileUploadView>(), { "POST" });
server.Run(); server.Start();
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << e.what() << std::endl; std::cerr << e.what() << std::endl;

@ -1,3 +1,4 @@
#include "webcc/logger.h"
#include "webcc/response_builder.h" #include "webcc/response_builder.h"
#include "webcc/server.h" #include "webcc/server.h"
@ -13,12 +14,14 @@ public:
}; };
int main() { int main() {
WEBCC_LOG_INIT("", webcc::LOG_CONSOLE);
try { try {
webcc::Server server(8080); webcc::Server server(8080);
server.Route("/", std::make_shared<HelloView>()); server.Route("/", std::make_shared<HelloView>());
server.Run(); server.Start();
} catch (const std::exception&) { } catch (const std::exception&) {
return 1; return 1;

@ -273,5 +273,8 @@ int main(int argc, char* argv[]) {
PrintBookList(books); PrintBookList(books);
} }
std::cout << "Press any key to exit: ";
std::getchar();
return 0; return 0;
} }

@ -227,7 +227,7 @@ int main(int argc, char* argv[]) {
std::size_t workers = 2; std::size_t workers = 2;
try { try {
webcc::Server server(port, workers); webcc::Server server(port);
server.Route("/books", server.Route("/books",
std::make_shared<BookListView>(sleep_seconds), std::make_shared<BookListView>(sleep_seconds),
@ -237,7 +237,7 @@ int main(int argc, char* argv[]) {
std::make_shared<BookDetailView>(sleep_seconds), std::make_shared<BookDetailView>(sleep_seconds),
{ "GET", "PUT", "DELETE" }); { "GET", "PUT", "DELETE" });
server.Run(); server.Start(workers);
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << e.what() << std::endl; std::cerr << e.what() << std::endl;

@ -26,9 +26,9 @@ int main(int argc, char* argv[]) {
std::string doc_root = argv[2]; std::string doc_root = argv[2];
try { try {
webcc::Server server(port, 1, doc_root); webcc::Server server(port, doc_root);
server.Run(); server.Start();
} catch (const std::exception& e) { } catch (const std::exception& e) {
std::cerr << e.what() << std::endl; std::cerr << e.what() << std::endl;

@ -15,6 +15,7 @@ namespace webcc {
class Body { class Body {
public: public:
Body() = default;
virtual ~Body() = default; virtual ~Body() = default;
// Get the size in bytes of the body. // Get the size in bytes of the body.

@ -21,12 +21,15 @@ Error Client::Request(RequestPtr request, bool connect) {
// Response to HEAD could also have Content-Length. // Response to HEAD could also have Content-Length.
// Set this flag to skip the reading and parsing of the body. // Set this flag to skip the reading and parsing of the body.
// The test against HttpBin.org shows that: // The test against HttpBin.org shows that:
// - If request.Accept-Encoding is "gzip, deflate", the response doesn't // - If request.Accept-Encoding is "gzip, deflate", the response won't
// have Content-Length; // have Content-Length;
// - If request.Accept-Encoding is "identity", the response do have // - If request.Accept-Encoding is "identity", the response will have
// Content-Length. // Content-Length.
if (request->method() == methods::kHead) { if (request->method() == methods::kHead) {
response_parser_.set_ignroe_body(true); response_parser_.set_ignroe_body(true);
} else {
// Reset in case the connection is persistent.
response_parser_.set_ignroe_body(false);
} }
if (connect) { if (connect) {
@ -38,7 +41,7 @@ Error Client::Request(RequestPtr request, bool connect) {
} }
} }
WriteReqeust(request); WriteRequest(request);
if (error_) { if (error_) {
return error_; return error_;
@ -58,12 +61,7 @@ void Client::Close() {
LOG_INFO("Close socket..."); LOG_INFO("Close socket...");
boost::system::error_code ec; socket_->Close();
socket_->Close(&ec);
if (ec) {
LOG_ERRO("Socket close error (%s).", ec.message().c_str());
}
} }
void Client::Restart() { void Client::Restart() {
@ -112,23 +110,23 @@ void Client::DoConnect(RequestPtr request, const std::string& default_port) {
LOG_ERRO("Host resolve error (%s): %s, %s.", ec.message().c_str(), LOG_ERRO("Host resolve error (%s): %s, %s.", ec.message().c_str(),
request->host().c_str(), port.c_str()); request->host().c_str(), port.c_str());
error_.Set(Error::kResolveError, "Host resolve error"); error_.Set(Error::kResolveError, "Host resolve error");
return;
} }
LOG_VERB("Connect to server..."); LOG_VERB("Connect to server...");
// Use sync API directly since we don't need timeout control. // Use sync API directly since we don't need timeout control.
if (!socket_->Connect(request->host(), endpoints, &ec)) { if (!socket_->Connect(request->host(), endpoints)) {
LOG_ERRO("Socket connect error (%s).", ec.message().c_str());
Close();
// TODO: Handshake error
error_.Set(Error::kConnectError, "Endpoint connect error"); error_.Set(Error::kConnectError, "Endpoint connect error");
Close();
return;
} }
LOG_VERB("Socket connected."); LOG_VERB("Socket connected.");
} }
void Client::WriteReqeust(RequestPtr request) { void Client::WriteRequest(RequestPtr request) {
LOG_VERB("HTTP request:\n%s", request->Dump().c_str()); LOG_VERB("HTTP request:\n%s", request->Dump().c_str());
// NOTE: // NOTE:

@ -65,7 +65,7 @@ private:
void DoConnect(RequestPtr request, const std::string& default_port); void DoConnect(RequestPtr request, const std::string& default_port);
void WriteReqeust(RequestPtr request); void WriteRequest(RequestPtr request);
void ReadResponse(); void ReadResponse();

@ -11,22 +11,34 @@ namespace webcc {
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
void Headers::Set(const std::string& key, const std::string& value) { bool Headers::Set(const std::string& key, const std::string& value) {
if (value.empty()) {
return false;
}
auto it = Find(key); auto it = Find(key);
if (it != headers_.end()) { if (it != headers_.end()) {
it->second = value; it->second = value;
} else { } else {
headers_.push_back({ key, value }); headers_.push_back({ key, value });
} }
return true;
} }
void Headers::Set(std::string&& key, std::string&& value) { bool Headers::Set(std::string&& key, std::string&& value) {
if (value.empty()) {
return false;
}
auto it = Find(key); auto it = Find(key);
if (it != headers_.end()) { if (it != headers_.end()) {
it->second = std::move(value); it->second = std::move(value);
} else { } else {
headers_.push_back({ std::move(key), std::move(value) }); headers_.push_back({ std::move(key), std::move(value) });
} }
return true;
} }
bool Headers::Has(const std::string& key) const { bool Headers::Has(const std::string& key) const {

@ -28,9 +28,9 @@ public:
return headers_; return headers_;
} }
void Set(const std::string& key, const std::string& value); bool Set(const std::string& key, const std::string& value);
void Set(std::string&& key, std::string&& value); bool Set(std::string&& key, std::string&& value);
bool Has(const std::string& key) const; bool Has(const std::string& key) const;

@ -27,21 +27,35 @@ void Connection::Start() {
} }
void Connection::Close() { void Connection::Close() {
LOG_INFO("Close socket..."); LOG_INFO("Shutdown socket...");
// Initiate graceful connection closure.
// Socket close VS. shutdown:
// https://stackoverflow.com/questions/4160347/close-vs-shutdown-socket
boost::system::error_code ec; boost::system::error_code ec;
socket_.shutdown(tcp::socket::shutdown_both, ec);
if (ec) {
LOG_WARN("Socket shutdown error (%s).", ec.message().c_str());
ec.clear();
// Don't return, try to close the socket anywhere.
}
LOG_INFO("Close socket...");
socket_.close(ec); socket_.close(ec);
if (ec) { if (ec) {
LOG_ERRO("Socket close error (%s).", ec.message().c_str()); LOG_ERRO("Socket close error (%s).", ec.message().c_str());
} }
} }
void Connection::SendResponse(ResponsePtr response) { void Connection::SendResponse(ResponsePtr response, bool no_keep_alive) {
assert(response); assert(response);
response_ = response; response_ = response;
if (request_->IsConnectionKeepAlive()) { if (!no_keep_alive && request_->IsConnectionKeepAlive()) {
response_->SetHeader(headers::kConnection, "Keep-Alive"); response_->SetHeader(headers::kConnection, "Keep-Alive");
} else { } else {
response_->SetHeader(headers::kConnection, "Close"); response_->SetHeader(headers::kConnection, "Close");
@ -52,7 +66,7 @@ void Connection::SendResponse(ResponsePtr response) {
DoWrite(); DoWrite();
} }
void Connection::SendResponse(Status status) { void Connection::SendResponse(Status status, bool no_keep_alive) {
auto response = std::make_shared<Response>(status); auto response = std::make_shared<Response>(status);
// According to the testing based on HTTPie (and Chrome), the `Content-Length` // According to the testing based on HTTPie (and Chrome), the `Content-Length`
@ -60,7 +74,7 @@ void Connection::SendResponse(Status status) {
// is empty. // is empty.
response->SetBody(std::make_shared<Body>(), true); response->SetBody(std::make_shared<Body>(), true);
SendResponse(response); SendResponse(response, no_keep_alive);
} }
void Connection::DoRead() { void Connection::DoRead() {
@ -72,29 +86,31 @@ void Connection::DoRead() {
void Connection::OnRead(boost::system::error_code ec, std::size_t length) { void Connection::OnRead(boost::system::error_code ec, std::size_t length) {
if (ec) { if (ec) {
// TODO
if (ec == boost::asio::error::eof) { if (ec == boost::asio::error::eof) {
LOG_WARN("Socket read EOF."); LOG_WARN("Socket read EOF (%s).", ec.message().c_str());
//} else if (ec == boost::asio::error::operation_aborted) { } else if (ec == boost::asio::error::operation_aborted) {
// LOG_WARN("Socket read aborted."); // The socket of this connection has been closed.
//} else if (ec == boost::asio::error::connection_aborted) { // This happens, e.g., when the server was stopped by a signal (Ctrl-C).
// LOG_WARN("Socket connection aborted."); LOG_WARN("Socket operation aborted (%s).", ec.message().c_str());
} else { } else {
LOG_ERRO("Socket read error (%s).", ec.message().c_str()); LOG_ERRO("Socket read error (%s).", ec.message().c_str());
} }
// Don't try to send any response back.
if (ec != boost::asio::error::operation_aborted) { if (ec != boost::asio::error::operation_aborted) {
pool_->Close(shared_from_this()); pool_->Close(shared_from_this());
} } // else: The socket of this connection has already been closed.
return; return;
} }
if (!request_parser_.Parse(buffer_.data(), length)) { if (!request_parser_.Parse(buffer_.data(), length)) {
// Bad request.
// TODO: Always close the connection?
LOG_ERRO("Failed to parse HTTP request."); LOG_ERRO("Failed to parse HTTP request.");
SendResponse(Status::kBadRequest); // Send Bad Request (400) to the client and no Keep-Alive.
SendResponse(Status::kBadRequest, true);
// Close the socket connection.
pool_->Close(shared_from_this());
return; return;
} }
@ -163,7 +179,6 @@ void Connection::OnWriteOK() {
LOG_INFO("Continue to read the next request..."); LOG_INFO("Continue to read the next request...");
Start(); Start();
} else { } else {
Shutdown();
pool_->Close(shared_from_this()); pool_->Close(shared_from_this());
} }
} }
@ -176,18 +191,4 @@ void Connection::OnWriteError(boost::system::error_code ec) {
} }
} }
// Socket close VS. shutdown:
// https://stackoverflow.com/questions/4160347/close-vs-shutdown-socket
void Connection::Shutdown() {
LOG_INFO("Shutdown socket...");
// Initiate graceful connection closure.
boost::system::error_code ec;
socket_.shutdown(tcp::socket::shutdown_both, ec);
if (ec) {
LOG_ERRO("Socket shutdown error (%s).", ec.message().c_str());
}
}
} // namespace webcc } // namespace webcc

@ -38,10 +38,14 @@ public:
void Close(); void Close();
// Send a response to the client. // Send a response to the client.
void SendResponse(ResponsePtr response); // `Connection` header will be set to "Close" if |no_keep_alive| is true no
// matter whether the client asked for Keep-Alive or not.
void SendResponse(ResponsePtr response, bool no_keep_alive = false);
// Send a response with the given status and an empty body to the client. // Send a response with the given status and an empty body to the client.
void SendResponse(Status status); // `Connection` header will be set to "Close" if |no_keep_alive| is true no
// matter whether the client asked for Keep-Alive or not.
void SendResponse(Status status, bool no_keep_alive = false);
private: private:
void DoRead(); void DoRead();
@ -54,9 +58,6 @@ private:
void OnWriteOK(); void OnWriteOK();
void OnWriteError(boost::system::error_code ec); void OnWriteError(boost::system::error_code ec);
// Shutdown the socket.
void Shutdown();
// The socket for the connection. // The socket for the connection.
boost::asio::ip::tcp::socket socket_; boost::asio::ip::tcp::socket socket_;

@ -17,11 +17,13 @@ void ConnectionPool::Close(ConnectionPtr c) {
} }
void ConnectionPool::CloseAll() { void ConnectionPool::CloseAll() {
LOG_VERB("Closing all (%u) connections...", connections_.size()); if (!connections_.empty()) {
for (auto& c : connections_) { LOG_VERB("Closing all (%u) connections...", connections_.size());
c->Close(); for (auto& c : connections_) {
c->Close();
}
connections_.clear();
} }
connections_.clear();
} }
} // namespace webcc } // namespace webcc

@ -20,7 +20,7 @@ public:
// Close a connection. // Close a connection.
void Close(ConnectionPtr c); void Close(ConnectionPtr c);
// Close all connections. // Close all pending connections.
void CloseAll(); void CloseAll();
private: private:

@ -186,7 +186,6 @@ class Error {
kSyntaxError, kSyntaxError,
kResolveError, kResolveError,
kConnectError, kConnectError,
kHandshakeError,
kSocketReadError, kSocketReadError,
kSocketWriteError, kSocketWriteError,
kParseError, kParseError,

@ -85,10 +85,12 @@ void Message::SetContentType(const std::string& media_type,
const std::string& charset) { const std::string& charset) {
using headers::kContentType; using headers::kContentType;
if (charset.empty()) { if (!media_type.empty()) {
SetHeader(kContentType, media_type); if (charset.empty()) {
} else { SetHeader(kContentType, media_type);
SetHeader(kContentType, media_type + "; charset=" + charset); } else {
SetHeader(kContentType, media_type + "; charset=" + charset);
}
} }
} }

@ -121,7 +121,7 @@ public:
RequestBuilder& Auth(const std::string& type, const std::string& credentials); RequestBuilder& Auth(const std::string& type, const std::string& credentials);
RequestBuilder& AuthBasic(const std::string& login, RequestBuilder& AuthBasic(const std::string& login,
const std::string& password); const std::string& password);
RequestBuilder& AuthToken(const std::string& token); RequestBuilder& AuthToken(const std::string& token);

@ -32,10 +32,14 @@ ResponsePtr ResponseBuilder::operator()() {
} }
} }
#endif // WEBCC_ENABLE_GZIP #endif // WEBCC_ENABLE_GZIP
} else {
response->SetBody(body_, true); // Ensure the existing of `Content-Length` header if the body is empty.
// `Content-Length: 0` is required by most HTTP clients (e.g., Chrome).
body_ = std::make_shared<webcc::Body>();
} }
response->SetBody(body_, true);
return response; return response;
} }

@ -20,9 +20,8 @@ using tcp = boost::asio::ip::tcp;
namespace webcc { namespace webcc {
Server::Server(std::uint16_t port, std::size_t workers, const Path& doc_root) Server::Server(std::uint16_t port, const Path& doc_root)
: acceptor_(io_context_), signals_(io_context_), workers_(workers), : acceptor_(io_context_), signals_(io_context_), doc_root_(doc_root) {
doc_root_(doc_root) {
RegisterSignals(); RegisterSignals();
boost::system::error_code ec; boost::system::error_code ec;
@ -89,7 +88,10 @@ bool Server::Route(const UrlRegex& regex_url, ViewPtr view,
return true; return true;
} }
void Server::Run() { void Server::Start(std::size_t workers) {
assert(workers > 0);
assert(worker_threads_.empty());
if (!acceptor_.is_open()) { if (!acceptor_.is_open()) {
LOG_ERRO("Server is NOT going to run."); LOG_ERRO("Server is NOT going to run.");
return; return;
@ -101,13 +103,12 @@ void Server::Run() {
DoAccept(); DoAccept();
// Start worker threads. // Create worker threads.
assert(workers_ > 0 && worker_threads_.empty()); for (std::size_t i = 0; i < workers; ++i) {
for (std::size_t i = 0; i < workers_; ++i) {
worker_threads_.emplace_back(std::bind(&Server::WorkerRoutine, this)); worker_threads_.emplace_back(std::bind(&Server::WorkerRoutine, this));
} }
// Run the loop.
// The io_context::run() call will block until all asynchronous operations // The io_context::run() call will block until all asynchronous operations
// have finished. While the server is running, there is always at least one // have finished. While the server is running, there is always at least one
// asynchronous operation outstanding: the asynchronous accept call waiting // asynchronous operation outstanding: the asynchronous accept call waiting
@ -116,23 +117,14 @@ void Server::Run() {
} }
void Server::Stop() { void Server::Stop() {
LOG_INFO("Stopping workers..."); // Stop listener.
acceptor_.close();
// Clear pending connections.
// The connections will be closed later (see Server::DoAwaitStop).
LOG_INFO("Clear pending connections...");
queue_.Clear();
// Enqueue a null connection to trigger the first worker to stop. // Stop worker threads.
queue_.Push(ConnectionPtr()); StopWorkers();
for (auto& thread : worker_threads_) {
if (thread.joinable()) {
thread.join();
}
}
LOG_INFO("All workers have been stopped."); // Close all pending connections.
pool_.CloseAll();
} }
void Server::Enqueue(ConnectionPtr connection) { void Server::Enqueue(ConnectionPtr connection) {
@ -178,13 +170,7 @@ void Server::DoAwaitStop() {
// call will exit. // call will exit.
LOG_INFO("On signal %d, stopping the server...", signo); LOG_INFO("On signal %d, stopping the server...", signo);
acceptor_.close();
// Stop worker threads.
Stop(); Stop();
// Close all connections.
pool_.CloseAll();
}); });
} }
@ -208,6 +194,26 @@ void Server::WorkerRoutine() {
} }
} }
void Server::StopWorkers() {
LOG_INFO("Stopping workers...");
// Clear pending connections.
// The connections will be closed later (see Server::DoAwaitStop).
LOG_INFO("Clear pending connections...");
queue_.Clear();
// Enqueue a null connection to trigger the first worker to stop.
queue_.Push(ConnectionPtr());
for (auto& t : worker_threads_) {
if (t.joinable()) {
t.join();
}
}
LOG_INFO("All workers have been stopped.");
}
void Server::Handle(ConnectionPtr connection) { void Server::Handle(ConnectionPtr connection) {
auto request = connection->request(); auto request = connection->request();

@ -19,8 +19,7 @@ namespace webcc {
class Server { class Server {
public: public:
explicit Server(std::uint16_t port, std::size_t workers = 1, explicit Server(std::uint16_t port, const Path& doc_root = {});
const Path& doc_root = {});
virtual ~Server() = default; virtual ~Server() = default;
@ -38,10 +37,10 @@ public:
bool Route(const UrlRegex& regex_url, ViewPtr view, bool Route(const UrlRegex& regex_url, ViewPtr view,
const Strings& methods = { "GET" }); const Strings& methods = { "GET" });
// Run the loop. // Start the server with a given number of worker threads.
void Run(); void Start(std::size_t workers = 1);
// Clear pending connections from the queue and stop worker threads. // Stop the server.
void Stop(); void Stop();
// Put the connection into the queue. // Put the connection into the queue.
@ -57,8 +56,12 @@ private:
// Wait for a request to stop the server. // Wait for a request to stop the server.
void DoAwaitStop(); void DoAwaitStop();
// Worker thread routine.
void WorkerRoutine(); void WorkerRoutine();
// Clear pending connections from the queue and stop worker threads.
void StopWorkers();
// Handle a connection (or more precisely, the request inside it). // Handle a connection (or more precisely, the request inside it).
// Get the request from the connection, process it, prepare the response, // Get the request from the connection, process it, prepare the response,
// then send the response back to the client. // then send the response back to the client.
@ -93,9 +96,6 @@ private:
// The signals for processing termination notifications. // The signals for processing termination notifications.
boost::asio::signal_set signals_; boost::asio::signal_set signals_;
// The number of worker threads.
std::size_t workers_;
// Worker threads. // Worker threads.
std::vector<std::thread> worker_threads_; std::vector<std::thread> worker_threads_;
@ -105,6 +105,7 @@ private:
// The queue with connection waiting for the workers to process. // The queue with connection waiting for the workers to process.
Queue<ConnectionPtr> queue_; Queue<ConnectionPtr> queue_;
// Route table.
std::vector<RouteInfo> routes_; std::vector<RouteInfo> routes_;
}; };

@ -26,13 +26,18 @@ namespace webcc {
Socket::Socket(boost::asio::io_context& io_context) : socket_(io_context) { Socket::Socket(boost::asio::io_context& io_context) : socket_(io_context) {
} }
bool Socket::Connect(const std::string& host, const Endpoints& endpoints, bool Socket::Connect(const std::string& host, const Endpoints& endpoints) {
boost::system::error_code* ec) {
boost::ignore_unused(host); boost::ignore_unused(host);
boost::asio::connect(socket_, endpoints, *ec); boost::system::error_code ec;
boost::asio::connect(socket_, endpoints, ec);
return !(*ec); if (ec) {
LOG_ERRO("Socket connect error (%s).", ec.message().c_str());
return false;
}
return true;
} }
bool Socket::Write(const Payload& payload, boost::system::error_code* ec) { bool Socket::Write(const Payload& payload, boost::system::error_code* ec) {
@ -44,8 +49,25 @@ void Socket::AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) {
socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler)); socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler));
} }
void Socket::Close(boost::system::error_code* ec) { bool Socket::Close() {
socket_.close(*ec); boost::system::error_code ec;
socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec);
if (ec) {
LOG_WARN("Socket shutdown error (%s).", ec.message().c_str());
ec.clear();
// Don't return, try to close the socket anywhere.
}
socket_.close(ec);
if (ec) {
LOG_WARN("Socket close error (%s).", ec.message().c_str());
return false;
}
return true;
} }
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -104,15 +126,16 @@ SslSocket::SslSocket(boost::asio::io_context& io_context, bool ssl_verify)
#endif // defined(_WIN32) || defined(_WIN64) #endif // defined(_WIN32) || defined(_WIN64)
} }
bool SslSocket::Connect(const std::string& host, const Endpoints& endpoints, bool SslSocket::Connect(const std::string& host, const Endpoints& endpoints) {
boost::system::error_code* ec) { boost::system::error_code ec;
boost::asio::connect(ssl_socket_.lowest_layer(), endpoints, *ec); boost::asio::connect(ssl_socket_.lowest_layer(), endpoints, ec);
if (*ec) { if (ec) {
LOG_ERRO("Socket connect error (%s).", ec.message().c_str());
return false; return false;
} }
return Handshake(host, ec); return Handshake(host);
} }
bool SslSocket::Write(const Payload& payload, boost::system::error_code* ec) { bool SslSocket::Write(const Payload& payload, boost::system::error_code* ec) {
@ -125,12 +148,13 @@ void SslSocket::AsyncReadSome(ReadHandler&& handler,
ssl_socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler)); ssl_socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler));
} }
void SslSocket::Close(boost::system::error_code* ec) { bool SslSocket::Close() {
ssl_socket_.lowest_layer().close(*ec); boost::system::error_code ec;
ssl_socket_.lowest_layer().close(ec);
return !ec;
} }
bool SslSocket::Handshake(const std::string& host, bool SslSocket::Handshake(const std::string& host) {
boost::system::error_code* ec) {
if (ssl_verify_) { if (ssl_verify_) {
ssl_socket_.set_verify_mode(ssl::verify_peer); ssl_socket_.set_verify_mode(ssl::verify_peer);
} else { } else {
@ -140,10 +164,11 @@ bool SslSocket::Handshake(const std::string& host,
ssl_socket_.set_verify_callback(ssl::rfc2818_verification(host)); ssl_socket_.set_verify_callback(ssl::rfc2818_verification(host));
// Use sync API directly since we don't need timeout control. // Use sync API directly since we don't need timeout control.
ssl_socket_.handshake(ssl::stream_base::client, *ec); boost::system::error_code ec;
ssl_socket_.handshake(ssl::stream_base::client, ec);
if (*ec) { if (ec) {
LOG_ERRO("Handshake error (%s).", ec->message().c_str()); LOG_ERRO("Handshake error (%s).", ec.message().c_str());
return false; return false;
} }

@ -26,15 +26,14 @@ public:
std::function<void(boost::system::error_code, std::size_t)>; std::function<void(boost::system::error_code, std::size_t)>;
// TODO: Remove |host| // TODO: Remove |host|
virtual bool Connect(const std::string& host, const Endpoints& endpoints, virtual bool Connect(const std::string& host, const Endpoints& endpoints) = 0;
boost::system::error_code* ec) = 0;
virtual bool Write(const Payload& payload, boost::system::error_code* ec) = 0; virtual bool Write(const Payload& payload, boost::system::error_code* ec) = 0;
virtual void AsyncReadSome(ReadHandler&& handler, virtual void AsyncReadSome(ReadHandler&& handler,
std::vector<char>* buffer) = 0; std::vector<char>* buffer) = 0;
virtual void Close(boost::system::error_code* ec) = 0; virtual bool Close() = 0;
}; };
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -43,14 +42,13 @@ class Socket : public SocketBase {
public: public:
explicit Socket(boost::asio::io_context& io_context); explicit Socket(boost::asio::io_context& io_context);
bool Connect(const std::string& host, const Endpoints& endpoints, bool Connect(const std::string& host, const Endpoints& endpoints) override;
boost::system::error_code* ec) override;
bool Write(const Payload& payload, boost::system::error_code* ec) override; bool Write(const Payload& payload, boost::system::error_code* ec) override;
void AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) override; void AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) override;
void Close(boost::system::error_code* ec) override; bool Close() override;
private: private:
boost::asio::ip::tcp::socket socket_; boost::asio::ip::tcp::socket socket_;
@ -65,17 +63,16 @@ public:
explicit SslSocket(boost::asio::io_context& io_context, explicit SslSocket(boost::asio::io_context& io_context,
bool ssl_verify = true); bool ssl_verify = true);
bool Connect(const std::string& host, const Endpoints& endpoints, bool Connect(const std::string& host, const Endpoints& endpoints) override;
boost::system::error_code* ec) override;
bool Write(const Payload& payload, boost::system::error_code* ec) override; bool Write(const Payload& payload, boost::system::error_code* ec) override;
void AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) override; void AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) override;
void Close(boost::system::error_code* ec) override; bool Close() override;
private: private:
bool Handshake(const std::string& host, boost::system::error_code* ec); bool Handshake(const std::string& host);
boost::asio::ssl::context ssl_context_; boost::asio::ssl::context ssl_context_;

Loading…
Cancel
Save