rework client using async api

master
Chunting Gu 4 years ago
parent 7f345b7a4e
commit f5210ba1a2

@ -1,24 +1,2 @@
# Automation test add_subdirectory(client_autotest)
add_subdirectory(client_timeout_autotest)
set(AT_SRCS
client_autotest.cc
client_timeout_autotest.cc
main.cc
)
# Common libraries to link.
set(AT_LIBS
webcc
jsoncpp
GTest::GTest)
if(UNIX)
# Add `-ldl` for Linux to avoid "undefined reference to `dlopen'".
set(AT_LIBS ${AT_LIBS} ${CMAKE_DL_LIBS})
endif()
set(AT_TARGET_NAME autotest)
add_executable(${AT_TARGET_NAME} ${AT_SRCS})
target_link_libraries(${AT_TARGET_NAME} ${AT_LIBS})
set_target_properties(${AT_TARGET_NAME} PROPERTIES FOLDER "Tests")

@ -0,0 +1,24 @@
set(SRCS
client_autotest.cc
main.cc
)
# Common libraries to link.
set(LIBS
webcc
jsoncpp
GTest::GTest
)
if(UNIX)
# Add `-ldl` for Linux to avoid "undefined reference to `dlopen'".
set(LIBS ${LIBS} ${CMAKE_DL_LIBS})
endif()
set(TARGET_NAME client_autotest)
add_executable(${TARGET_NAME} ${SRCS})
target_link_libraries(${TARGET_NAME} ${LIBS})
set_target_properties(${TARGET_NAME} PROPERTIES FOLDER "Tests")

@ -0,0 +1,22 @@
set(SRCS
client_timeout_autotest.cc
main.cc
)
# Common libraries to link.
set(LIBS
webcc
jsoncpp
GTest::GTest
)
if(UNIX)
# Add `-ldl` for Linux to avoid "undefined reference to `dlopen'".
set(LIBS ${LIBS} ${CMAKE_DL_LIBS})
endif()
set(TARGET_NAME client_timeout_autotest)
add_executable(${TARGET_NAME} ${SRCS})
target_link_libraries(${TARGET_NAME} ${LIBS})
set_target_properties(${TARGET_NAME} PROPERTIES FOLDER "Tests")

@ -9,7 +9,7 @@
namespace { namespace {
const char* kData = "Hello, World!"; const char* const kData = "Hello, World!";
const std::uint16_t kPort = 8080; const std::uint16_t kPort = 8080;
@ -83,8 +83,8 @@ TEST_F(ClientTimeoutTest, NoTimeout) {
TEST_F(ClientTimeoutTest, Timeout) { TEST_F(ClientTimeoutTest, Timeout) {
webcc::ClientSession session; webcc::ClientSession session;
// Change timeout to 1s. // Change read timeout to 1s.
session.set_timeout(1); session.set_read_timeout(1);
webcc::ResponsePtr r; webcc::ResponsePtr r;
bool timeout = false; bool timeout = false;

@ -0,0 +1,11 @@
#include "gtest/gtest.h"
#include "webcc/logger.h"
int main(int argc, char* argv[]) {
// Set webcc::LOG_CONSOLE to enable logging.
WEBCC_LOG_INIT("", 0);
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

@ -8,7 +8,8 @@ int main() {
WEBCC_LOG_INIT("", webcc::LOG_CONSOLE); WEBCC_LOG_INIT("", webcc::LOG_CONSOLE);
webcc::ClientSession session; webcc::ClientSession session;
session.set_connect_timeout(5);
session.set_read_timeout(5);
session.Accept("application/json"); session.Accept("application/json");
webcc::ResponsePtr r; webcc::ResponsePtr r;

@ -30,7 +30,7 @@ int main(int argc, const char* argv[]) {
threads.emplace_back([&url]() { threads.emplace_back([&url]() {
// NOTE: Each thread has its own client session. // NOTE: Each thread has its own client session.
webcc::ClientSession session; webcc::ClientSession session;
session.set_timeout(180); session.set_read_timeout(180);
try { try {
LOG_USER("Start"); LOG_USER("Start");

@ -27,7 +27,10 @@ int main(int argc, char* argv[]) {
webcc::ClientSession session; webcc::ClientSession session;
try { try {
auto r = session.Send(webcc::RequestBuilder{}.Get(url)(), true); auto r = session.Send(webcc::RequestBuilder{}.Get(url)(), true,
[](std::size_t length, std::size_t total_length) {
std::cout << "Progress " << length << " / " << total_length << std::endl;
});
if (auto file_body = r->file_body()) { if (auto file_body = r->file_body()) {
file_body->Move(path); file_body->Move(path);

@ -3,23 +3,40 @@
#include "webcc/logger.h" #include "webcc/logger.h"
using boost::asio::ip::tcp; using boost::asio::ip::tcp;
using namespace std::placeholders;
namespace webcc { namespace webcc {
Client::Client() #if WEBCC_ENABLE_SSL
: timer_(io_context_),
ssl_verify_(true), Client::Client(boost::asio::io_context& io_context,
buffer_size_(kBufferSize), boost::asio::ssl::context& ssl_context)
timeout_(kMaxReadSeconds), : io_context_(io_context),
closed_(false), ssl_context_(ssl_context),
timer_canceled_(false) { resolver_(io_context),
deadline_timer_(io_context) {
} }
Error Client::Request(RequestPtr request, bool connect, bool stream) { #else
closed_ = false;
timer_canceled_ = false; Client::Client(boost::asio::io_context& io_context)
: io_context_(io_context),
resolver_(io_context),
deadline_timer_(io_context) {
}
#endif // WEBCC_ENABLE_SSL
Error Client::Request(RequestPtr request, bool stream) {
LOG_VERB("Request begin");
request_finished_ = false;
error_ = Error{}; error_ = Error{};
request_ = request;
length_read_ = 0;
response_.reset(new Response{}); response_.reset(new Response{});
response_parser_.Init(response_.get(), stream); response_parser_.Init(response_.get(), stream);
@ -35,246 +52,312 @@ Error Client::Request(RequestPtr request, bool connect, bool stream) {
// have Content-Length; // have Content-Length;
// - If request.Accept-Encoding is "identity", the response will have // - If request.Accept-Encoding is "identity", the response will have
// Content-Length. // Content-Length.
if (request->method() == methods::kHead) { if (request_->method() == methods::kHead) {
response_parser_.set_ignore_body(true); response_parser_.set_ignore_body(true);
} else { } else {
// Reset in case the connection is persistent. // Reset in case the connection is persistent.
response_parser_.set_ignore_body(false); response_parser_.set_ignore_body(false);
} }
io_context_.restart(); if (!connected_) {
AsyncConnect();
if (connect) { } else {
// No existing socket connection was specified, create a new one. AsyncWrite();
Connect(request);
if (error_) {
return error_;
}
} }
WriteRequest(request); // Wait for the request to be finished.
std::unique_lock<std::mutex> response_lock{ request_mutex_ };
if (error_) { request_cv_.wait(response_lock, [=] { return request_finished_; });
return error_;
}
ReadResponse(); LOG_VERB("Request end");
return error_; return error_;
} }
void Client::Close() { void Client::Close() {
if (closed_) { if (!connected_) {
//resolver_.cancel(); // TODO
if (socket_) {
// Cancel any async operations on the socket.
LOG_VERB("Close socket");
socket_->Close();
// Make sure the current request, if any, could be finished.
FinishRequest();
}
return; return;
} }
closed_ = true; connected_ = false;
LOG_INFO("Close socket..."); if (socket_) {
LOG_INFO("Shutdown & close socket");
socket_->Shutdown();
socket_->Close();
// Make sure the current request, if any, could be finished.
FinishRequest();
}
socket_->Close(); LOG_INFO("Socket closed");
} }
void Client::Connect(RequestPtr request) { void Client::AsyncConnect() {
if (request->url().scheme() == "https") { if (request_->url().scheme() == "https") {
#if WEBCC_ENABLE_SSL #if WEBCC_ENABLE_SSL
socket_.reset(new SslSocket{ io_context_, ssl_verify_ }); socket_.reset(new SslSocket{ io_context_, ssl_context_, ssl_verify_ });
DoConnect(request, "443"); AsyncResolve("443");
#else #else
LOG_ERRO("SSL/HTTPS support is not enabled."); LOG_ERRO("SSL/HTTPS support is not enabled.");
error_.Set(Error::kSyntaxError, "SSL/HTTPS is not supported"); error_.Set(Error::kSyntaxError, "SSL/HTTPS is not supported");
FinishRequest();
#endif // WEBCC_ENABLE_SSL #endif // WEBCC_ENABLE_SSL
} else { } else {
socket_.reset(new Socket{ io_context_ }); socket_.reset(new Socket{ io_context_ });
DoConnect(request, "80"); AsyncResolve("80");
} }
} }
void Client::DoConnect(RequestPtr request, const std::string& default_port) { void Client::AsyncResolve(const std::string& default_port) {
tcp::resolver resolver(io_context_); std::string port = request_->port();
std::string port = request->port();
if (port.empty()) { if (port.empty()) {
port = default_port; port = default_port;
} }
LOG_VERB("Resolve host (%s)...", request->host().c_str()); LOG_VERB("Resolve host (%s)", request_->host().c_str());
boost::system::error_code ec;
// The protocol depends on the `host`, both V4 and V6 are supported. // The protocol depends on the `host`, both V4 and V6 are supported.
auto endpoints = resolver.resolve(request->host(), port, ec); resolver_.async_resolve(request_->host(), port,
std::bind(&Client::OnResolve, this, _1, _2));
}
void Client::OnResolve(boost::system::error_code ec,
tcp::resolver::results_type endpoints) {
if (ec) { if (ec) {
LOG_ERRO("Host resolve error (%s): %s, %s.", ec.message().c_str(), LOG_ERRO("Host resolve error (%s)", ec.message().c_str());
request->host().c_str(), port.c_str());
error_.Set(Error::kResolveError, "Host resolve error"); error_.Set(Error::kResolveError, "Host resolve error");
FinishRequest();
return; return;
} }
LOG_VERB("Connect to server..."); LOG_VERB("Connect socket");
// Use sync API directly since we don't need timeout control. AsyncWaitDeadlineTimer(connect_timeout_);
if (!socket_->Connect(request->host(), endpoints)) { socket_->AsyncConnect(request_->host(), endpoints,
error_.Set(Error::kConnectError, "Endpoint connect error"); std::bind(&Client::OnConnect, this, _1, _2));
Close(); }
void Client::OnConnect(boost::system::error_code ec, tcp::endpoint) {
LOG_VERB("On connect");
StopDeadlineTimer();
if (ec) {
if (ec == boost::asio::error::operation_aborted) {
// Socket has been closed by OnDeadlineTimer() or Close().
LOG_WARN("Connect operation aborted");
} else {
LOG_INFO("Connect error");
// No need to close socket since no async operation is on it.
// socket_->Close();
}
error_.Set(Error::kConnectError, "Socket connect error");
FinishRequest();
return; return;
} }
LOG_VERB("Socket connected."); LOG_INFO("Socket connected");
connected_ = true;
AsyncWrite();
} }
void Client::WriteRequest(RequestPtr request) { void Client::AsyncWrite() {
LOG_VERB("HTTP request:\n%s", request->Dump().c_str()); LOG_VERB("Request:\n%s", request_->Dump().c_str());
// NOTE: socket_->AsyncWrite(request_->GetPayload(),
// It doesn't make much sense to set a timeout for socket write. std::bind(&Client::OnWrite, this, _1, _2));
// I find that it's almost impossible to simulate a situation in the server }
// side to test this timeout.
// Use sync API directly since we don't need timeout control. void Client::OnWrite(boost::system::error_code ec, std::size_t length) {
if (ec) {
HandleWriteError(ec);
return;
}
boost::system::error_code ec; request_->body()->InitPayload();
if (socket_->Write(request->GetPayload(), &ec)) { AsyncWriteBody();
// Write request body. }
auto body = request->body();
body->InitPayload(); void Client::AsyncWriteBody() {
for (auto p = body->NextPayload(true); !p.empty(); auto p = request_->body()->NextPayload(true);
p = body->NextPayload(true)) {
if (!socket_->Write(p, &ec)) { if (!p.empty()) {
break; socket_->AsyncWrite(p, std::bind(&Client::OnWriteBody, this, _1, _2));
} } else {
} LOG_INFO("Request send");
// Start the read deadline timer.
AsyncWaitDeadlineTimer(read_timeout_);
// Start to read response.
AsyncRead();
} }
}
void Client::OnWriteBody(boost::system::error_code ec, std::size_t legnth) {
if (ec) { if (ec) {
LOG_ERRO("Socket write error (%s).", ec.message().c_str()); HandleWriteError(ec);
Close(); return;
error_.Set(Error::kSocketWriteError, "Socket write error");
} }
LOG_INFO("Request sent."); // Continue to write the next payload of body.
AsyncWriteBody();
} }
void Client::ReadResponse() { void Client::HandleWriteError(boost::system::error_code ec) {
LOG_VERB("Read response (timeout: %ds)...", timeout_); if (ec == boost::asio::error::operation_aborted) {
// Socket has been closed by OnDeadlineTimer() or Close().
DoReadResponse(); LOG_WARN("Write operation aborted");
} else {
if (!error_) { LOG_ERRO("Socket write error (%s)", ec.message().c_str());
LOG_VERB("HTTP response:\n%s", response_->Dump().c_str()); Close();
} }
error_.Set(Error::kSocketWriteError, "Socket write error");
FinishRequest();
} }
void Client::DoReadResponse() { void Client::AsyncRead() {
boost::system::error_code ec = boost::asio::error::would_block; socket_->AsyncReadSome(std::bind(&Client::OnRead, this, _1, _2), &buffer_);
std::size_t length = 0; }
// The read handler. void Client::OnRead(boost::system::error_code ec, std::size_t length) {
auto handler = [&ec, &length](boost::system::error_code inner_ec, StopDeadlineTimer();
std::size_t inner_length) {
ec = inner_ec;
length = inner_length;
};
while (true) { if (ec) {
ec = boost::asio::error::would_block; if (ec == boost::asio::error::operation_aborted) {
length = 0; // Socket has been closed by OnDeadlineTimer() or Close().
LOG_WARN("Read operation aborted");
} else {
LOG_ERRO("Socket read error (%s)", ec.message().c_str());
Close();
}
socket_->AsyncReadSome(std::move(handler), &buffer_); error_.Set(Error::kSocketReadError, "Socket read error");
FinishRequest();
return;
}
// Start the timer. length_read_ += length;
DoWaitTimer();
// Block until the asynchronous operation has completed. LOG_INFO("Read length: %u", length);
do {
io_context_.run_one();
} while (ec == boost::asio::error::would_block);
// Stop the timer. // Parse the piece of data just read.
CancelTimer(); if (!response_parser_.Parse(buffer_.data(), length)) {
LOG_ERRO("Failed to parse the response");
Close();
error_.Set(Error::kParseError, "Response parse error");
FinishRequest();
return;
}
// The error normally is caused by timeout. See OnTimer(). // Inform progress callback if it's specified.
if (ec || length == 0) { if (progress_callback_) {
Close(); if (response_parser_.header_ended()) {
error_.Set(Error::kSocketReadError, "Socket read error"); // NOTE: Need to get rid of the header length.
LOG_ERRO("Socket read error (%s).", ec.message().c_str()); progress_callback_(length_read_ - response_parser_.header_length(),
break; response_parser_.content_length());
} }
}
LOG_INFO("Read data, length: %u.", length); if (response_parser_.finished()) {
LOG_VERB("Response:\n%s", response_->Dump().c_str());
// Parse the piece of data just read. if (response_->IsConnectionKeepAlive()) {
if (!response_parser_.Parse(buffer_.data(), length)) { LOG_INFO("Keep the socket connection alive");
} else {
Close(); Close();
error_.Set(Error::kParseError, "HTTP parse error");
LOG_ERRO("Failed to parse the HTTP response.");
break;
} }
if (response_parser_.finished()) { // Stop trying to read once all content has been received, because some
// Stop trying to read once all content has been received, because // servers will block extra call to read_some().
// some servers will block extra call to read_some().
if (response_->IsConnectionKeepAlive()) { LOG_INFO("Finished to read the response");
// Close the timer but keep the socket connection. FinishRequest();
LOG_INFO("Keep the socket connection alive."); return;
} else {
Close();
}
// Stop reading.
LOG_INFO("Finished to read the HTTP response.");
break;
}
} }
// Continue to read the response.
AsyncRead();
} }
void Client::DoWaitTimer() { void Client::AsyncWaitDeadlineTimer(int seconds) {
LOG_VERB("Wait timer asynchronously."); if (seconds <= 0) {
timer_.expires_after(std::chrono::seconds(timeout_)); deadline_timer_stopped_ = true;
timer_.async_wait(std::bind(&Client::OnTimer, this, std::placeholders::_1)); return;
}
LOG_VERB("Async wait deadline timer");
deadline_timer_stopped_ = false;
deadline_timer_.expires_after(std::chrono::seconds(seconds));
deadline_timer_.async_wait(std::bind(&Client::OnDeadlineTimer, this, _1));
} }
void Client::OnTimer(boost::system::error_code ec) { void Client::OnDeadlineTimer(boost::system::error_code ec) {
LOG_VERB("On timer."); LOG_VERB("On deadline timer");
deadline_timer_stopped_ = true;
// timer_.cancel() was called. // deadline_timer_.cancel() was called.
if (ec == boost::asio::error::operation_aborted) { if (ec == boost::asio::error::operation_aborted) {
LOG_VERB("Timer canceled."); LOG_VERB("Deadline timer canceled");
return; return;
} }
if (closed_) { LOG_WARN("Timeout");
LOG_VERB("Socket has been closed.");
return;
}
if (timer_.expiry() <= boost::asio::steady_timer::clock_type::now()) { // Cancel the async operations on the socket.
// The deadline has passed. The socket is closed so that any outstanding // OnXxx() will be called with `error::operation_aborted`.
// asynchronous operations are canceled. if (connected_) {
LOG_WARN("HTTP client timed out.");
error_.set_timeout(true);
Close(); Close();
return; } else {
socket_->Close();
} }
// Put the actor back to sleep. error_.set_timeout(true);
DoWaitTimer();
} }
void Client::CancelTimer() { void Client::StopDeadlineTimer() {
if (timer_canceled_) { if (deadline_timer_stopped_) {
return; return;
} }
LOG_INFO("Cancel timer..."); LOG_INFO("Cancel deadline timer");
timer_.cancel();
try {
// Cancel the async wait operation on this timer.
deadline_timer_.cancel();
} catch (const boost::system::system_error&) {
}
deadline_timer_stopped_ = true;
}
timer_canceled_ = true; void Client::FinishRequest() {
{
std::lock_guard<std::mutex> lock{ request_mutex_ };
if (!request_finished_) {
request_finished_ = true;
} else {
return;
}
}
request_cv_.notify_one();
} }
} // namespace webcc } // namespace webcc

@ -1,8 +1,9 @@
#ifndef WEBCC_CLIENT_H_ #ifndef WEBCC_CLIENT_H_
#define WEBCC_CLIENT_H_ #define WEBCC_CLIENT_H_
#include <cassert> #include <condition_variable>
#include <memory> #include <memory>
#include <mutex>
#include <string> #include <string>
#include <vector> #include <vector>
@ -19,18 +20,22 @@
namespace webcc { namespace webcc {
// Synchronous HTTP & HTTPS client. // Synchronous HTTP & HTTPS client.
// In synchronous mode, a request won't return until the response is received // A request won't return until the response is received or timeout occurs.
// or timeout occurs.
// Please don't use the same client object in multiple threads.
class Client { class Client {
public: public:
Client(); // TODO
#if WEBCC_ENABLE_SSL
~Client() = default; Client(boost::asio::io_context& io_context,
boost::asio::ssl::context& ssl_context);
#else
explicit Client(boost::asio::io_context& io_context);
#endif
Client(const Client&) = delete; Client(const Client&) = delete;
Client& operator=(const Client&) = delete; Client& operator=(const Client&) = delete;
~Client() = default;
void set_ssl_verify(bool ssl_verify) { void set_ssl_verify(bool ssl_verify) {
ssl_verify_ = ssl_verify; ssl_verify_ = ssl_verify;
} }
@ -41,19 +46,35 @@ public:
} }
} }
// Set the timeout (in seconds) for reading response. void set_connect_timeout(int timeout) {
void set_timeout(int timeout) { if (timeout > 0) {
connect_timeout_ = timeout;
}
}
void set_read_timeout(int timeout) {
if (timeout > 0) { if (timeout > 0) {
timeout_ = timeout; read_timeout_ = timeout;
} }
} }
// Connect to server, send request, wait until response is received. // Set progress callback to be informed about the read progress.
Error Request(RequestPtr request, bool connect = true, bool stream = false); // TODO: What about write?
// TODO: std::move?
void set_progress_callback(ProgressCallback callback) {
progress_callback_ = callback;
}
// Connect, send request, wait until response is received.
Error Request(RequestPtr request, bool stream = false);
// Close the socket. // Close the socket.
void Close(); void Close();
bool connected() const {
return connected_;
}
ResponsePtr response() const { ResponsePtr response() const {
return response_; return response_;
} }
@ -66,57 +87,82 @@ public:
response_parser_.Init(nullptr, false); response_parser_.Init(nullptr, false);
} }
bool closed() const {
return closed_;
}
private: private:
void Connect(RequestPtr request); void AsyncConnect();
void AsyncResolve(const std::string& default_port);
void DoConnect(RequestPtr request, const std::string& default_port); void OnResolve(boost::system::error_code ec,
boost::asio::ip::tcp::resolver::results_type endpoints);
void WriteRequest(RequestPtr request); void OnConnect(boost::system::error_code ec, boost::asio::ip::tcp::endpoint);
void ReadResponse(); void AsyncWrite();
void OnWrite(boost::system::error_code ec, std::size_t length);
void DoReadResponse(); void AsyncWriteBody();
void OnWriteBody(boost::system::error_code ec, std::size_t length);
void DoWaitTimer(); void HandleWriteError(boost::system::error_code ec);
void OnTimer(boost::system::error_code ec);
// Cancel any async-operations waiting on the timer. void AsyncRead();
void CancelTimer(); void OnRead(boost::system::error_code ec, std::size_t length);
void AsyncWaitDeadlineTimer(int seconds);
void OnDeadlineTimer(boost::system::error_code ec);
void StopDeadlineTimer();
void FinishRequest();
private: private:
boost::asio::io_context io_context_; boost::asio::io_context& io_context_;
#if WEBCC_ENABLE_SSL
boost::asio::ssl::context& ssl_context_;
#endif
// Socket connection.
std::unique_ptr<SocketBase> socket_; std::unique_ptr<SocketBase> socket_;
boost::asio::ip::tcp::resolver resolver_;
bool request_finished_ = true;
std::condition_variable request_cv_;
std::mutex request_mutex_;
RequestPtr request_;
ResponsePtr response_; ResponsePtr response_;
ResponseParser response_parser_; ResponseParser response_parser_;
// Timer for the timeout control. // The length already read.
boost::asio::steady_timer timer_; std::size_t length_read_ = 0;
// The buffer for reading response. // The buffer for reading response.
std::vector<char> buffer_; std::vector<char> buffer_;
// Verify the certificate of the peer or not (for HTTPS). // Verify the certificate of the peer or not (for HTTPS).
bool ssl_verify_; bool ssl_verify_ = true;
// The size of the buffer for reading response. // The size of the buffer for reading response.
// 0 means default value will be used. // 0 means default value will be used.
std::size_t buffer_size_; std::size_t buffer_size_ = kBufferSize;
// Timeout (seconds) for connecting to server.
// Default as 0 to disable our own control (i.e., deadline_timer_).
int connect_timeout_ = 0;
// Timeout (seconds) for reading response.
int read_timeout_ = kMaxReadSeconds;
// Timeout (seconds) for receiving response. // Deadline timer for connecting to server.
int timeout_; boost::asio::steady_timer deadline_timer_;
bool deadline_timer_stopped_ = true;
// Connection closed. // Socket connected or not.
bool closed_; bool connected_ = false;
// Deadline timer canceled. // Progress callback (optional).
bool timer_canceled_; ProgressCallback progress_callback_;
Error error_; Error error_;
}; };

@ -6,7 +6,7 @@ namespace webcc {
ClientPool::~ClientPool() { ClientPool::~ClientPool() {
if (!clients_.empty()) { if (!clients_.empty()) {
LOG_INFO("Close socket for all (%u) connections in the pool.", LOG_INFO("Close socket for all (%u) connections in the pool",
clients_.size()); clients_.size());
for (auto& pair : clients_) { for (auto& pair : clients_) {
@ -21,21 +21,21 @@ ClientPtr ClientPool::Get(const Key& key) const {
if (it != clients_.end()) { if (it != clients_.end()) {
return it->second; return it->second;
} else { } else {
return ClientPtr{}; return {};
} }
} }
void ClientPool::Add(const Key& key, ClientPtr client) { void ClientPool::Add(const Key& key, ClientPtr client) {
clients_[key] = client; clients_[key] = client;
LOG_INFO("Added connection to pool (%s, %s, %s).", LOG_INFO("Connection added to pool (%s, %s, %s)",
key.scheme.c_str(), key.host.c_str(), key.port.c_str()); key.scheme.c_str(), key.host.c_str(), key.port.c_str());
} }
void ClientPool::Remove(const Key& key) { void ClientPool::Remove(const Key& key) {
clients_.erase(key); clients_.erase(key);
LOG_INFO("Removed connection from pool (%s, %s, %s).", LOG_INFO("Connection removed from pool (%s, %s, %s)",
key.scheme.c_str(), key.host.c_str(), key.port.c_str()); key.scheme.c_str(), key.host.c_str(), key.port.c_str());
} }

@ -1,16 +1,123 @@
#include "webcc/client_session.h" #include "webcc/client_session.h"
#include <cassert>
#include "webcc/base64.h" #include "webcc/base64.h"
#include "webcc/logger.h" #include "webcc/logger.h"
#include "webcc/url.h" #include "webcc/url.h"
#include "webcc/utility.h" #include "webcc/utility.h"
namespace ssl = boost::asio::ssl;
namespace webcc { namespace webcc {
ClientSession::ClientSession(int timeout, bool ssl_verify, #if WEBCC_ENABLE_SSL
std::size_t buffer_size) #if (defined(_WIN32) || defined(_WIN64))
: timeout_(timeout), ssl_verify_(ssl_verify), buffer_size_(buffer_size) {
// Let OpenSSL on Windows use the system certificate store
// 1. Load your certificate (in PCCERT_CONTEXT structure) from Windows Cert
// store using Crypto APIs.
// 2. Get encrypted content of it in binary format as it is.
// [PCCERT_CONTEXT->pbCertEncoded].
// 3. Parse this binary buffer into X509 certificate Object using OpenSSL's
// d2i_X509() method.
// 4. Get handle to OpenSSL's trust store using SSL_CTX_get_cert_store()
// method.
// 5. Load above parsed X509 certificate into this trust store using
// X509_STORE_add_cert() method.
// 6. You are done!
// NOTES: Enum Windows store with "ROOT" (not "CA").
// See: https://stackoverflow.com/a/11763389/6825348
static bool UseSystemCertificateStore(SSL_CTX* ssl_ctx) {
// NOTE: Cannot use nullptr to replace NULL.
HCERTSTORE cert_store = ::CertOpenSystemStoreW(NULL, L"ROOT");
if (cert_store == nullptr) {
LOG_ERRO("Cannot open Windows system certificate store.");
return false;
}
X509_STORE* x509_store = SSL_CTX_get_cert_store(ssl_ctx);
PCCERT_CONTEXT cert_context = nullptr;
while (cert_context = CertEnumCertificatesInStore(cert_store, cert_context)) {
auto in = (const unsigned char**)&cert_context->pbCertEncoded;
X509* x509 = d2i_X509(nullptr, in, cert_context->cbCertEncoded);
if (x509 != nullptr) {
if (X509_STORE_add_cert(x509_store, x509) == 0) {
LOG_ERRO("Cannot add Windows root certificate.");
}
X509_free(x509);
}
}
CertFreeCertificateContext(cert_context);
CertCloseStore(cert_store, 0);
return true;
}
#endif // defined(_WIN32) || defined(_WIN64)
#endif // WEBCC_ENABLE_SSL
ClientSession::ClientSession(bool ssl_verify, std::size_t buffer_size)
: work_guard_(boost::asio::make_work_guard(io_context_)),
#if WEBCC_ENABLE_SSL
ssl_context_(ssl::context::sslv23),
#endif
ssl_verify_(ssl_verify), buffer_size_(buffer_size) {
#if WEBCC_ENABLE_SSL
#if (defined(_WIN32) || defined(_WIN64))
// if (ssl_verify_) {
UseSystemCertificateStore(ssl_context_.native_handle());
// }
#else
// Use the default paths for finding CA certificates.
ssl_context_.set_default_verify_paths();
#endif // defined(_WIN32) || defined(_WIN64)
#endif // WEBCC_ENABLE_SSL
InitHeaders(); InitHeaders();
Start();
}
ClientSession::~ClientSession() {
Stop();
}
void ClientSession::Start() {
if (started_) {
return;
}
started_ = true;
io_context_.restart();
// Run the io context off in its own thread so that it operates completely
// asynchronously with respect to the rest of the program.
io_thread_.reset(new std::thread{ [this]() { io_context_.run(); }});
LOG_INFO("Loop is now running");
}
void ClientSession::Stop() {
if (!started_) {
return;
}
Cancel();
io_context_.stop();
io_thread_->join();
LOG_INFO("Loop stopped");
started_ = false;
} }
void ClientSession::Accept(const std::string& content_types) { void ClientSession::Accept(const std::string& content_types) {
@ -68,9 +175,16 @@ void ClientSession::AuthToken(const std::string& token) {
return Auth("Token", token); return Auth("Token", token);
} }
ResponsePtr ClientSession::Send(RequestPtr request, bool stream) { ResponsePtr ClientSession::Send(RequestPtr request, bool stream,
ProgressCallback callback) {
assert(request); assert(request);
std::lock_guard<std::mutex> lock{ mutex_ };
if (!started_) {
throw Error{ Error::kStateError, "Loop is not running" };
}
for (auto& h : headers_.data()) { for (auto& h : headers_.data()) {
if (!request->HasHeader(h.first)) { if (!request->HasHeader(h.first)) {
request->SetHeader(h.first, h.second); request->SetHeader(h.first, h.second);
@ -84,7 +198,13 @@ ResponsePtr ClientSession::Send(RequestPtr request, bool stream) {
request->Prepare(); request->Prepare();
return DoSend(request, stream); return DoSend(request, stream, callback);
}
void ClientSession::Cancel() {
if (client_) {
client_->Close();
}
} }
void ClientSession::InitHeaders() { void ClientSession::InitHeaders() {
@ -99,34 +219,40 @@ void ClientSession::InitHeaders() {
headers_.Set(headers::kConnection, "Keep-Alive"); headers_.Set(headers::kConnection, "Keep-Alive");
} }
ResponsePtr ClientSession::DoSend(RequestPtr request, bool stream) { ResponsePtr ClientSession::DoSend(RequestPtr request, bool stream,
ProgressCallback callback) {
const ClientPool::Key key{ request->url() }; const ClientPool::Key key{ request->url() };
// Reuse a pooled connection. // Reuse a pooled connection.
bool reuse = false; bool reuse = false;
ClientPtr client = pool_.Get(key); ClientPtr client = pool_.Get(key);
if (!client) { if (!client) {
client.reset(new Client{}); #if WEBCC_ENABLE_SSL
client.reset(new Client{ io_context_, ssl_context_ });
#else
client.reset(new Client{ io_context_ });
#endif // WEBCC_ENABLE_SSL
reuse = false; reuse = false;
} else { } else {
LOG_VERB("Reuse an existing connection."); LOG_VERB("Reuse an existing connection");
reuse = true; reuse = true;
} }
client->set_ssl_verify(ssl_verify_); client->set_ssl_verify(ssl_verify_);
client->set_buffer_size(buffer_size_); client->set_buffer_size(buffer_size_);
client->set_timeout(timeout_); client->set_connect_timeout(connect_timeout_);
client->set_read_timeout(read_timeout_);
Error error = client->Request(request, !reuse, stream); client->set_progress_callback(callback);
if (error) { // Save current client for cancel.
if (reuse && error.code() == Error::kSocketWriteError) { client_ = client;
LOG_WARN("Cannot send request with the reused connection. "
"The server must have closed it, reconnect and try again."); Error error = client->Request(request, stream);
error = client->Request(request, true, stream);
} client_.reset();
}
if (error) { if (error) {
// Remove the failed connection from pool. // Remove the failed connection from pool.
@ -139,11 +265,11 @@ ResponsePtr ClientSession::DoSend(RequestPtr request, bool stream) {
// Update connection pool. // Update connection pool.
if (reuse) { if (reuse) {
if (client->closed()) { if (!client->connected()) {
pool_.Remove(key); pool_.Remove(key);
} }
} else { } else {
if (!client->closed()) { if (client->connected()) {
pool_.Add(key, client); pool_.Add(key, client);
} }
} }

@ -1,9 +1,14 @@
#ifndef WEBCC_CLIENT_SESSION_H_ #ifndef WEBCC_CLIENT_SESSION_H_
#define WEBCC_CLIENT_SESSION_H_ #define WEBCC_CLIENT_SESSION_H_
#include <memory>
#include <mutex>
#include <string> #include <string>
#include <thread>
#include <vector> #include <vector>
#include "boost/asio/io_context.hpp"
#include "webcc/client_pool.h" #include "webcc/client_pool.h"
#include "webcc/request_builder.h" #include "webcc/request_builder.h"
#include "webcc/response.h" #include "webcc/response.h"
@ -11,18 +16,32 @@
namespace webcc { namespace webcc {
// HTTP requests session providing connection-pooling, configuration and more. // HTTP requests session providing connection-pooling, configuration and more.
// A session shouldn't be shared by multiple threads. Please create a new // NOTE: If a session is shared by multiple threads, the requests sent through
// session for each thread instead. // it will be serialized by using a mutex.
class ClientSession { class ClientSession {
public: public:
explicit ClientSession(int timeout = 0, bool ssl_verify = true, explicit ClientSession(bool ssl_verify = true, std::size_t buffer_size = 0);
std::size_t buffer_size = 0);
~ClientSession();
// Start Asio loop in a thread.
// You don't have to call Start() manually because it's called by the
// constructor.
void Start();
// Stop Asio loop.
// You can call Start() to resume the loop.
void Stop();
~ClientSession() = default; void set_connect_timeout(int timeout) {
if (timeout > 0) {
connect_timeout_ = timeout;
}
}
void set_timeout(int timeout) { void set_read_timeout(int timeout) {
if (timeout > 0) { if (timeout > 0) {
timeout_ = timeout; read_timeout_ = timeout;
} }
} }
@ -73,14 +92,35 @@ public:
// the response body will be FileBody, and you can easily move the temp file // the response body will be FileBody, and you can easily move the temp file
// to another path with FileBody::Move(). So, |stream| is really useful for // to another path with FileBody::Move(). So, |stream| is really useful for
// downloading files (JPEG, etc.) or saving memory for huge data responses. // downloading files (JPEG, etc.) or saving memory for huge data responses.
ResponsePtr Send(RequestPtr request, bool stream = false); ResponsePtr Send(RequestPtr request, bool stream = false,
ProgressCallback callback = {});
// Cancel any in-progress connecting, writing or reading.
void Cancel();
private: private:
void InitHeaders(); void InitHeaders();
ResponsePtr DoSend(RequestPtr request, bool stream); ResponsePtr DoSend(RequestPtr request, bool stream,
ProgressCallback callback);
private: private:
boost::asio::io_context io_context_;
// The thread to run Asio loop.
std::unique_ptr<std::thread> io_thread_;
using ExecutorType = boost::asio::io_context::executor_type;
boost::asio::executor_work_guard<ExecutorType> work_guard_;
// TODO
#if WEBCC_ENABLE_SSL
boost::asio::ssl::context ssl_context_;
#endif
// Is Asio loop running?
bool started_ = false;
// The media (or MIME) type of `Content-Type` header. // The media (or MIME) type of `Content-Type` header.
// E.g., "application/json". // E.g., "application/json".
std::string media_type_; std::string media_type_;
@ -92,18 +132,27 @@ private:
// Additional headers for each request. // Additional headers for each request.
Headers headers_; Headers headers_;
// Timeout in seconds for receiving response. // Timeout (seconds) for connecting to server.
int timeout_; int connect_timeout_ = 0;
// Timeout (seconds) for reading response.
int read_timeout_ = 0;
// Verify the certificate of the peer or not. // Verify the certificate of the peer or not.
bool ssl_verify_; bool ssl_verify_ = true;
// The size of the buffer for reading response. // The size of the buffer for reading response.
// 0 means default value will be used. // 0 means default value will be used.
std::size_t buffer_size_; std::size_t buffer_size_;
// Pool for Keep-Alive client connections. // Keep-Alive client connections.
ClientPool pool_; ClientPool pool_;
// Current requested client.
ClientPtr client_;
// The mutex to guard the request.
std::mutex mutex_;
}; };
} // namespace webcc } // namespace webcc

@ -32,7 +32,7 @@ void Connection::Start() {
} }
void Connection::Close() { void Connection::Close() {
LOG_INFO("Shutdown socket..."); LOG_INFO("Shutdown socket");
// Initiate graceful connection closure. // Initiate graceful connection closure.
// Socket close VS. shutdown: // Socket close VS. shutdown:
@ -41,17 +41,17 @@ void Connection::Close() {
socket_.shutdown(tcp::socket::shutdown_both, ec); socket_.shutdown(tcp::socket::shutdown_both, ec);
if (ec) { if (ec) {
LOG_WARN("Socket shutdown error (%s).", ec.message().c_str()); LOG_WARN("Socket shutdown error (%s)", ec.message().c_str());
ec.clear(); ec.clear();
// Don't return, try to close the socket anywhere. // Don't return, try to close the socket anywhere.
} }
LOG_INFO("Close socket..."); LOG_INFO("Close socket");
socket_.close(ec); socket_.close(ec);
if (ec) { if (ec) {
LOG_ERRO("Socket close error (%s).", ec.message().c_str()); LOG_ERRO("Socket close error (%s)", ec.message().c_str());
} }
} }
@ -92,13 +92,13 @@ void Connection::DoRead() {
void Connection::OnRead(boost::system::error_code ec, std::size_t length) { void Connection::OnRead(boost::system::error_code ec, std::size_t length) {
if (ec) { if (ec) {
if (ec == boost::asio::error::eof) { if (ec == boost::asio::error::eof) {
LOG_INFO("Socket read EOF (%s).", ec.message().c_str()); LOG_INFO("Socket read EOF (%s)", ec.message().c_str());
} else if (ec == boost::asio::error::operation_aborted) { } else if (ec == boost::asio::error::operation_aborted) {
// The socket of this connection has been closed. // The socket of this connection has been closed.
// This happens, e.g., when the server was stopped by a signal (Ctrl-C). // This happens, e.g., when the server was stopped by a signal (Ctrl-C).
LOG_WARN("Socket operation aborted (%s).", ec.message().c_str()); LOG_WARN("Socket operation aborted (%s)", ec.message().c_str());
} else { } else {
LOG_ERRO("Socket read error (%s).", ec.message().c_str()); LOG_ERRO("Socket read error (%s)", ec.message().c_str());
} }
// Don't try to send any response back. // Don't try to send any response back.
@ -111,7 +111,7 @@ void Connection::OnRead(boost::system::error_code ec, std::size_t length) {
} }
if (!request_parser_.Parse(buffer_.data(), length)) { if (!request_parser_.Parse(buffer_.data(), length)) {
LOG_ERRO("Failed to parse HTTP request."); LOG_ERRO("Failed to parse request");
// Send Bad Request (400) to the client and no Keep-Alive. // Send Bad Request (400) to the client and no Keep-Alive.
SendResponse(Status::kBadRequest, true); SendResponse(Status::kBadRequest, true);
// Close the socket connection. // Close the socket connection.
@ -125,7 +125,7 @@ void Connection::OnRead(boost::system::error_code ec, std::size_t length) {
return; return;
} }
LOG_VERB("HTTP request:\n%s", request_->Dump().c_str()); LOG_VERB("Request:\n%s", request_->Dump().c_str());
// Enqueue this connection once the request has been read. // Enqueue this connection once the request has been read.
// Some worker thread will handle the request later. // Some worker thread will handle the request later.
@ -133,7 +133,7 @@ void Connection::OnRead(boost::system::error_code ec, std::size_t length) {
} }
void Connection::DoWrite() { void Connection::DoWrite() {
LOG_VERB("HTTP response:\n%s", response_->Dump().c_str()); LOG_VERB("Response:\n%s", response_->Dump().c_str());
// Firstly, write the headers. // Firstly, write the headers.
boost::asio::async_write(socket_, response_->GetPayload(), boost::asio::async_write(socket_, response_->GetPayload(),
@ -177,11 +177,11 @@ void Connection::OnWriteBody(boost::system::error_code ec, std::size_t length) {
} }
void Connection::OnWriteOK() { void Connection::OnWriteOK() {
LOG_INFO("Response has been sent back."); LOG_INFO("Response has been sent back");
if (request_->IsConnectionKeepAlive()) { if (request_->IsConnectionKeepAlive()) {
LOG_INFO("The client asked for a keep-alive connection."); LOG_INFO("The client asked for a keep-alive connection");
LOG_INFO("Continue to read the next request..."); LOG_INFO("Continue to read the next request");
Start(); Start();
} else { } else {
pool_->Close(shared_from_this()); pool_->Close(shared_from_this());
@ -189,7 +189,7 @@ void Connection::OnWriteOK() {
} }
void Connection::OnWriteError(boost::system::error_code ec) { void Connection::OnWriteError(boost::system::error_code ec) {
LOG_ERRO("Socket write error (%s).", ec.message().c_str()); LOG_ERRO("Socket write error (%s)", ec.message().c_str());
if (ec != boost::asio::error::operation_aborted) { if (ec != boost::asio::error::operation_aborted) {
pool_->Close(shared_from_this()); pool_->Close(shared_from_this());

@ -5,11 +5,11 @@
namespace webcc { namespace webcc {
void ConnectionPool::Start(ConnectionPtr c) { void ConnectionPool::Start(ConnectionPtr c) {
LOG_VERB("Starting connection..."); LOG_VERB("Start connection");
{ {
// Lock the container only. // Lock the container only.
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock{ mutex_ };
connections_.insert(c); connections_.insert(c);
} }
@ -17,11 +17,11 @@ void ConnectionPool::Start(ConnectionPtr c) {
} }
void ConnectionPool::Close(ConnectionPtr c) { void ConnectionPool::Close(ConnectionPtr c) {
LOG_VERB("Closing connection..."); LOG_VERB("Close connection");
{ {
// Lock the container only. // Lock the container only.
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock{ mutex_ };
connections_.erase(c); connections_.erase(c);
} }
@ -30,10 +30,10 @@ void ConnectionPool::Close(ConnectionPtr c) {
void ConnectionPool::Clear() { void ConnectionPool::Clear() {
// Lock all since we are going to stop anyway. // Lock all since we are going to stop anyway.
std::lock_guard<std::mutex> lock(mutex_); std::lock_guard<std::mutex> lock{ mutex_ };
if (!connections_.empty()) { if (!connections_.empty()) {
LOG_VERB("Closing all (%u) connections...", connections_.size()); LOG_VERB("Close all (%u) connections", connections_.size());
for (auto& c : connections_) { for (auto& c : connections_) {
c->Close(); c->Close();
} }

@ -3,6 +3,7 @@
#include <cassert> #include <cassert>
#include <exception> #include <exception>
#include <functional>
#include <iosfwd> #include <iosfwd>
#include <string> #include <string>
#include <vector> #include <vector>
@ -23,11 +24,14 @@ using UrlArgs = std::vector<std::string>;
using Payload = std::vector<boost::asio::const_buffer>; using Payload = std::vector<boost::asio::const_buffer>;
using ProgressCallback =
std::function<void(std::size_t length, std::size_t total_length)>;
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
const char* const kCRLF = "\r\n"; const char* const kCRLF = "\r\n";
const std::size_t kInvalidLength = std::string::npos; const std::size_t kInvalidLength = -1;
// Default timeout for reading response. // Default timeout for reading response.
const int kMaxReadSeconds = 30; const int kMaxReadSeconds = 30;
@ -150,6 +154,7 @@ public:
enum Code { enum Code {
kUnknownError = -1, kUnknownError = -1,
kOK = 0, kOK = 0,
kStateError,
kSyntaxError, kSyntaxError,
kResolveError, kResolveError,
kConnectError, kConnectError,
@ -166,7 +171,7 @@ public:
} }
// Note that `noexcept` is required by GCC. // Note that `noexcept` is required by GCC.
const char* what() const noexcept override{ const char* what() const noexcept override {
return message_.c_str(); return message_.c_str();
} }

@ -45,7 +45,12 @@ static FILE* FOpen(const bfs::path& path, bool overwrite) {
} }
struct Logger { struct Logger {
Logger() : file(nullptr), modes(0) { Logger() = default;
~Logger() {
if (file != nullptr) {
fclose(file);
}
} }
void Init(const bfs::path& path, int _modes) { void Init(const bfs::path& path, int _modes) {
@ -57,14 +62,8 @@ struct Logger {
} }
} }
~Logger() { FILE* file = nullptr;
if (file != nullptr) { int modes = 0;
fclose(file);
}
}
FILE* file;
int modes;
std::mutex mutex; std::mutex mutex;
}; };
@ -161,10 +160,10 @@ static bfs::path InitLogPath(const bfs::path& dir) {
return bfs::current_path() / WEBCC_LOG_FILE_NAME; return bfs::current_path() / WEBCC_LOG_FILE_NAME;
} }
if (!bfs::exists(dir) || !bfs::is_directory(dir)) { boost::system::error_code ec;
boost::system::error_code ec; if (!bfs::exists(dir, ec) || !bfs::is_directory(dir, ec)) {
if (!bfs::create_directories(dir, ec) || ec) { if (!bfs::create_directories(dir, ec) || ec) {
return bfs::path{}; return {};
} }
} }
@ -216,7 +215,7 @@ void Log(int level, const char* file, int line, const char* format, ...) {
va_list args; va_list args;
va_start(args, format); va_start(args, format);
fprintf(g_logger.file, "%s, %s, %7s, %20s, %4d, ", fprintf(g_logger.file, "%s, %s, %7s, %25s, %4d, ",
timestamp.c_str(), kLevelNames[level], thread_id.c_str(), timestamp.c_str(), kLevelNames[level], thread_id.c_str(),
file, line); file, line);
@ -239,12 +238,12 @@ void Log(int level, const char* file, int line, const char* format, ...) {
if (g_terminal_has_color) { if (g_terminal_has_color) {
if (level < WEBCC_WARN) { if (level < WEBCC_WARN) {
fprintf(stderr, "%s%s, %s, %7s, %20s, %4d, ", fprintf(stderr, "%s%s, %s, %7s, %25s, %4d, ",
TERM_RESET, TERM_RESET,
timestamp.c_str(), kLevelNames[level], thread_id.c_str(), timestamp.c_str(), kLevelNames[level], thread_id.c_str(),
file, line); file, line);
} else { } else {
fprintf(stderr, "%s%s%s, %s, %7s, %20s, %4d, ", fprintf(stderr, "%s%s%s, %s, %7s, %25s, %4d, ",
TERM_RESET, TERM_RESET,
level == WEBCC_WARN ? TERM_YELLOW : TERM_RED, level == WEBCC_WARN ? TERM_YELLOW : TERM_RED,
timestamp.c_str(), kLevelNames[level], thread_id.c_str(), timestamp.c_str(), kLevelNames[level], thread_id.c_str(),
@ -255,7 +254,7 @@ void Log(int level, const char* file, int line, const char* format, ...) {
fprintf(stderr, "%s\n", TERM_RESET); fprintf(stderr, "%s\n", TERM_RESET);
} else { } else {
fprintf(stderr, "%s, %s, %7s, %20s, %4d, ", fprintf(stderr, "%s, %s, %7s, %25s, %4d, ",
timestamp.c_str(), kLevelNames[level], thread_id.c_str(), timestamp.c_str(), kLevelNames[level], thread_id.c_str(),
file, line); file, line);

@ -8,7 +8,7 @@
namespace webcc { namespace webcc {
Message::Message() : body_(new Body{}), content_length_(kInvalidLength) { Message::Message() : body_(new Body{}) {
} }
void Message::SetBody(BodyPtr body, bool set_length) { void Message::SetBody(BodyPtr body, bool set_length) {

@ -118,7 +118,7 @@ protected:
std::string start_line_; std::string start_line_;
std::size_t content_length_; std::size_t content_length_ = kInvalidLength;
}; };
} // namespace webcc } // namespace webcc

@ -49,14 +49,14 @@ bool StringBodyHandler::Finish() {
auto body = std::make_shared<StringBody>(std::move(content_), IsCompressed()); auto body = std::make_shared<StringBody>(std::move(content_), IsCompressed());
#if WEBCC_ENABLE_GZIP #if WEBCC_ENABLE_GZIP
LOG_INFO("Decompress the HTTP content..."); LOG_INFO("Decompress the HTTP content");
if (!body->Decompress()) { if (!body->Decompress()) {
LOG_ERRO("Cannot decompress the HTTP content!"); LOG_ERRO("Cannot decompress the HTTP content");
return false; return false;
} }
#else #else
if (body->compressed()) { if (body->compressed()) {
LOG_WARN("Compressed HTTP content remains untouched."); LOG_WARN("Compressed HTTP content remains untouched");
} }
#endif // WEBCC_ENABLE_GZIP #endif // WEBCC_ENABLE_GZIP
@ -79,7 +79,7 @@ bool FileBodyHandler::OpenFile() {
temp_path_.string().c_str()); temp_path_.string().c_str());
} catch (const bfs::filesystem_error&) { } catch (const bfs::filesystem_error&) {
LOG_ERRO("Failed to generate temp path for streaming."); LOG_ERRO("Failed to generate temp path for streaming");
return false; return false;
} }
@ -132,6 +132,8 @@ bool Parser::Parse(const char* data, std::size_t length) {
return ParseContent(data, length); return ParseContent(data, length);
} }
header_length_ += length;
// Append the new data to the pending data. // Append the new data to the pending data.
pending_data_.append(data, length); pending_data_.append(data, length);
@ -140,11 +142,13 @@ bool Parser::Parse(const char* data, std::size_t length) {
} }
if (!header_ended_) { if (!header_ended_) {
LOG_INFO("HTTP headers will continue in next read."); LOG_INFO("HTTP headers will continue in next read");
return true; return true;
} }
LOG_INFO("HTTP headers just ended."); LOG_INFO("HTTP headers just ended");
header_length_ -= pending_data_.size();
if (!OnHeadersEnd()) { if (!OnHeadersEnd()) {
// Only request parser can reach here when no view matches the request. // Only request parser can reach here when no view matches the request.
@ -170,6 +174,7 @@ void Parser::Reset() {
stream_ = false; stream_ = false;
pending_data_.clear(); pending_data_.clear();
header_length_ = 0;
content_length_ = kInvalidLength; content_length_ = kInvalidLength;
content_type_.Reset(); content_type_.Reset();
@ -336,7 +341,7 @@ bool Parser::ParseChunkedContent(const char* data, std::size_t length) {
return false; return false;
} }
LOG_VERB("Chunk size: %u.", chunk_size_); LOG_VERB("Chunk size: %u", chunk_size_);
} }
if (chunk_size_ == 0) { if (chunk_size_ == 0) {
@ -378,14 +383,14 @@ bool Parser::ParseChunkedContent(const char* data, std::size_t length) {
} }
bool Parser::ParseChunkSize() { bool Parser::ParseChunkSize() {
LOG_VERB("Parse chunk size."); LOG_VERB("Parse chunk size");
std::string line; std::string line;
if (!GetNextLine(0, &line, true)) { if (!GetNextLine(0, &line, true)) {
return true; return true;
} }
LOG_VERB("Chunk size line: [%s].", line.c_str()); LOG_VERB("Chunk size line: [%s]", line.c_str());
std::string hex_str; // e.g., "cf0" (3312) std::string hex_str; // e.g., "cf0" (3312)
@ -397,7 +402,7 @@ bool Parser::ParseChunkSize() {
} }
if (!to_size_t(hex_str, 16, &chunk_size_)) { if (!to_size_t(hex_str, 16, &chunk_size_)) {
LOG_ERRO("Invalid chunk-size: %s.", hex_str.c_str()); LOG_ERRO("Invalid chunk-size: %s", hex_str.c_str());
return false; return false;
} }

@ -94,17 +94,35 @@ private:
class Parser { class Parser {
public: public:
Parser(); Parser();
virtual ~Parser() = default;
Parser(const Parser&) = delete; Parser(const Parser&) = delete;
Parser& operator=(const Parser&) = delete; Parser& operator=(const Parser&) = delete;
virtual ~Parser() = default;
void Init(Message* message); void Init(Message* message);
bool finished() const { bool finished() const {
return finished_; return finished_;
} }
// If the headers part has been parsed or not.
bool header_ended() const {
return header_ended_;
}
// Get the length of the headers part.
// Available after the headers have been parsed (see header_ended()).
std::size_t header_length() const {
return header_length_;
}
// The content length parsed from `Content-Length` header.
// kInvalidLength if the content is chunked.
std::size_t content_length() const {
return content_length_;
}
bool Parse(const char* data, std::size_t length); bool Parse(const char* data, std::size_t length);
protected: protected:
@ -144,25 +162,28 @@ protected:
bool Finish(); bool Finish();
protected: protected:
Message* message_; Message* message_ = nullptr;
std::unique_ptr<BodyHandler> body_handler_; std::unique_ptr<BodyHandler> body_handler_;
// Data streaming or not. // Data streaming or not.
bool stream_; bool stream_ = false;
// Data waiting to be parsed. // Data waiting to be parsed.
std::string pending_data_; std::string pending_data_;
// The length of the headers part.
std::size_t header_length_ = 0;
// Temporary data and helper flags for parsing. // Temporary data and helper flags for parsing.
std::size_t content_length_; std::size_t content_length_ = kInvalidLength;
ContentType content_type_; ContentType content_type_;
bool start_line_parsed_; bool start_line_parsed_ = false;
bool content_length_parsed_; bool content_length_parsed_ = false;
bool header_ended_; bool header_ended_ = false;
bool chunked_; bool chunked_ = false;
std::size_t chunk_size_; std::size_t chunk_size_ = kInvalidLength;
bool finished_; bool finished_ = false;
}; };
} // namespace webcc } // namespace webcc

@ -23,9 +23,6 @@ Server::Server(boost::asio::ip::tcp protocol, std::uint16_t port,
: protocol_(protocol), : protocol_(protocol),
port_(port), port_(port),
doc_root_(doc_root), doc_root_(doc_root),
buffer_size_(kBufferSize),
file_chunk_size_(1024),
running_(false),
acceptor_(io_context_), acceptor_(io_context_),
signals_(io_context_) { signals_(io_context_) {
AddSignals(); AddSignals();
@ -35,12 +32,12 @@ void Server::Run(std::size_t workers, std::size_t loops) {
assert(workers > 0); assert(workers > 0);
{ {
std::lock_guard<std::mutex> lock(state_mutex_); std::lock_guard<std::mutex> lock{ state_mutex_ };
assert(worker_threads_.empty()); assert(worker_threads_.empty());
if (IsRunning()) { if (IsRunning()) {
LOG_WARN("Server is already running."); LOG_WARN("Server is already running");
return; return;
} }
@ -48,11 +45,11 @@ void Server::Run(std::size_t workers, std::size_t loops) {
io_context_.restart(); io_context_.restart();
if (!Listen(port_)) { if (!Listen(port_)) {
LOG_ERRO("Server is NOT going to run."); LOG_ERRO("Server is NOT going to run");
return; return;
} }
LOG_INFO("Server is going to run..."); LOG_INFO("Server is going to run");
AsyncWaitSignals(); AsyncWaitSignals();
@ -70,7 +67,7 @@ void Server::Run(std::size_t workers, std::size_t loops) {
// asynchronous operation outstanding: the asynchronous accept call waiting // asynchronous operation outstanding: the asynchronous accept call waiting
// for new incoming connections. // for new incoming connections.
LOG_INFO("Loop is running in %u thread(s).", loops); LOG_INFO("Loop is running in %u thread(s)", loops);
if (loops == 1) { if (loops == 1) {
// Run the loop in current thread. // Run the loop in current thread.
@ -88,7 +85,7 @@ void Server::Run(std::size_t workers, std::size_t loops) {
} }
void Server::Stop() { void Server::Stop() {
std::lock_guard<std::mutex> lock(state_mutex_); std::lock_guard<std::mutex> lock{ state_mutex_ };
DoStop(); DoStop();
} }
@ -112,7 +109,7 @@ void Server::AsyncWaitSignals() {
// The server is stopped by canceling all outstanding asynchronous // The server is stopped by canceling all outstanding asynchronous
// operations. Once all operations have finished the io_context::run() // operations. Once all operations have finished the io_context::run()
// call will exit. // call will exit.
LOG_INFO("On signal %d, stopping the server...", signo); LOG_INFO("On signal %d, stop the server", signo);
DoStop(); DoStop();
}); });
@ -121,12 +118,12 @@ void Server::AsyncWaitSignals() {
bool Server::Listen(std::uint16_t port) { bool Server::Listen(std::uint16_t port) {
boost::system::error_code ec; boost::system::error_code ec;
tcp::endpoint endpoint(protocol_, port); tcp::endpoint endpoint{ protocol_, port };
// Open the acceptor. // Open the acceptor.
acceptor_.open(endpoint.protocol(), ec); acceptor_.open(endpoint.protocol(), ec);
if (ec) { if (ec) {
LOG_ERRO("Acceptor open error (%s).", ec.message().c_str()); LOG_ERRO("Acceptor open error (%s)", ec.message().c_str());
return false; return false;
} }
@ -141,7 +138,7 @@ bool Server::Listen(std::uint16_t port) {
// Bind to the server address. // Bind to the server address.
acceptor_.bind(endpoint, ec); acceptor_.bind(endpoint, ec);
if (ec) { if (ec) {
LOG_ERRO("Acceptor bind error (%s).", ec.message().c_str()); LOG_ERRO("Acceptor bind error (%s)", ec.message().c_str());
return false; return false;
} }
@ -150,7 +147,7 @@ bool Server::Listen(std::uint16_t port) {
// has not started to accept the connection yet. // has not started to accept the connection yet.
acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec); acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec);
if (ec) { if (ec) {
LOG_ERRO("Acceptor listen error (%s).", ec.message().c_str()); LOG_ERRO("Acceptor listen error (%s)", ec.message().c_str());
return false; return false;
} }
@ -167,9 +164,10 @@ void Server::AsyncAccept() {
} }
if (!ec) { if (!ec) {
LOG_INFO("Accepted a connection."); LOG_INFO("Accepted a connection");
using namespace std::placeholders; using namespace std::placeholders;
auto view_matcher = std::bind(&Server::MatchViewOrStatic, this, _1, auto view_matcher = std::bind(&Server::MatchViewOrStatic, this, _1,
_2, _3); _2, _3);
@ -205,16 +203,16 @@ void Server::DoStop() {
} }
void Server::WorkerRoutine() { void Server::WorkerRoutine() {
LOG_INFO("Worker is running."); LOG_INFO("Worker is running");
for (;;) { for (;;) {
auto connection = queue_.PopOrWait(); auto connection = queue_.PopOrWait();
if (!connection) { if (!connection) {
LOG_INFO("Worker is going to stop."); LOG_INFO("Worker is going to stop");
// For stopping next worker. // For stopping next worker.
queue_.Push(ConnectionPtr()); queue_.Push({});
// Stop this worker. // Stop this worker.
break; break;
@ -225,13 +223,13 @@ void Server::WorkerRoutine() {
} }
void Server::StopWorkers() { void Server::StopWorkers() {
LOG_INFO("Stopping workers..."); LOG_INFO("Stop workers");
// Clear/drop pending connections. // Clear/drop pending connections.
// The connections will be closed later (see DoStop). // The connections will be closed later (see DoStop).
// Alternatively, we can wait for the pending connections to be handled. // Alternatively, we can wait for the pending connections to be handled.
if (queue_.Size() != 0) { if (queue_.Size() != 0) {
LOG_INFO("Clear pending connections..."); LOG_INFO("Clear pending connections");
queue_.Clear(); queue_.Clear();
} }
@ -252,7 +250,7 @@ void Server::StopWorkers() {
// last worker thread. // last worker thread.
queue_.Clear(); queue_.Clear();
LOG_INFO("All workers have been stopped."); LOG_INFO("Workers stopped");
} }
void Server::Handle(ConnectionPtr connection) { void Server::Handle(ConnectionPtr connection) {
@ -307,7 +305,9 @@ bool Server::MatchViewOrStatic(const std::string& method,
// Try to match a static file. // Try to match a static file.
if (method == methods::kGet && !doc_root_.empty()) { if (method == methods::kGet && !doc_root_.empty()) {
bfs::path path = doc_root_ / url; bfs::path path = doc_root_ / url;
if (!bfs::is_directory(path) && bfs::exists(path)) {
boost::system::error_code ec;
if (!bfs::is_directory(path, ec) && bfs::exists(path, ec)) {
return true; return true;
} }
} }
@ -319,7 +319,7 @@ ResponsePtr Server::ServeStatic(RequestPtr request) {
assert(request->method() == methods::kGet); assert(request->method() == methods::kGet);
if (doc_root_.empty()) { if (doc_root_.empty()) {
LOG_INFO("The doc root was not specified."); LOG_INFO("The doc root was not specified");
return {}; return {};
} }
@ -340,7 +340,7 @@ ResponsePtr Server::ServeStatic(RequestPtr request) {
return response; return response;
} catch (const Error& error) { } catch (const Error& error) {
LOG_ERRO("File error: %s.", error.message().c_str()); LOG_ERRO("File error: %s", error.message().c_str());
return {}; return {};
} }
} }

@ -24,11 +24,11 @@ public:
Server(boost::asio::ip::tcp protocol, std::uint16_t port, Server(boost::asio::ip::tcp protocol, std::uint16_t port,
const boost::filesystem::path& doc_root = {}); const boost::filesystem::path& doc_root = {});
~Server() = default;
Server(const Server&) = delete; Server(const Server&) = delete;
Server& operator=(const Server&) = delete; Server& operator=(const Server&) = delete;
~Server() = default;
void set_buffer_size(std::size_t buffer_size) { void set_buffer_size(std::size_t buffer_size) {
if (buffer_size > 0) { if (buffer_size > 0) {
buffer_size_ = buffer_size; buffer_size_ = buffer_size;
@ -105,20 +105,20 @@ private:
boost::asio::ip::tcp protocol_; boost::asio::ip::tcp protocol_;
// Port number. // Port number.
std::uint16_t port_; std::uint16_t port_ = 0;
// The directory with the static files to be served. // The directory with the static files to be served.
boost::filesystem::path doc_root_; boost::filesystem::path doc_root_;
// The size of the buffer for reading request. // The size of the buffer for reading request.
std::size_t buffer_size_; std::size_t buffer_size_ = kBufferSize;
// The size of the chunk loaded into memory each time when serving a // The size of the chunk loaded into memory each time when serving a
// static file. // static file.
std::size_t file_chunk_size_; std::size_t file_chunk_size_ = 1024;
// Is the server running? // Is the server running?
bool running_; bool running_ = false;
// The mutex for guarding the state of the server. // The mutex for guarding the state of the server.
std::mutex state_mutex_; std::mutex state_mutex_;

@ -3,9 +3,9 @@
#if WEBCC_ENABLE_SSL #if WEBCC_ENABLE_SSL
#if (defined(_WIN32) || defined(_WIN64)) #if (defined(_WIN32) || defined(_WIN64))
#include <windows.h>
#include <wincrypt.h>
#include <cryptuiapi.h> #include <cryptuiapi.h>
#include <wincrypt.h>
#include <windows.h>
#include "openssl/x509.h" #include "openssl/x509.h"
@ -18,6 +18,9 @@
#include "webcc/logger.h" #include "webcc/logger.h"
using boost::asio::ip::tcp;
using namespace std::placeholders;
namespace webcc { namespace webcc {
// ----------------------------------------------------------------------------- // -----------------------------------------------------------------------------
@ -25,48 +28,37 @@ namespace webcc {
Socket::Socket(boost::asio::io_context& io_context) : socket_(io_context) { Socket::Socket(boost::asio::io_context& io_context) : socket_(io_context) {
} }
bool Socket::Connect(const std::string& /*host*/, const Endpoints& endpoints) { void Socket::AsyncConnect(const std::string& host, const Endpoints& endpoints,
boost::system::error_code ec; ConnectHandler&& handler) {
boost::asio::connect(socket_, endpoints, ec); boost::asio::async_connect(socket_, endpoints, std::move(handler));
if (ec) {
LOG_ERRO("Socket connect error (%s).", ec.message().c_str());
return false;
}
return true;
}
bool Socket::Write(const Payload& payload, boost::system::error_code* ec) {
boost::asio::write(socket_, payload, *ec);
return !(*ec);
} }
bool Socket::ReadSome(std::vector<char>* buffer, std::size_t* size, void Socket::AsyncWrite(const Payload& payload, WriteHandler&& handler) {
boost::system::error_code* ec) { boost::asio::async_write(socket_, payload, std::move(handler));
*size = socket_.read_some(boost::asio::buffer(*buffer), *ec);
return (*size != 0 && !(*ec));
} }
void Socket::AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) { void Socket::AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) {
socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler)); socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler));
} }
bool Socket::Close() { bool Socket::Shutdown() {
boost::system::error_code ec; boost::system::error_code ec;
socket_.shutdown(tcp::socket::shutdown_both, ec);
socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec);
if (ec) { if (ec) {
LOG_WARN("Socket shutdown error (%s).", ec.message().c_str()); LOG_WARN("Socket shutdown error (%s)", ec.message().c_str());
ec.clear(); return false;
// Don't return, try to close the socket anywhere.
} }
return true;
}
bool Socket::Close() {
boost::system::error_code ec;
socket_.close(ec); socket_.close(ec);
if (ec) { if (ec) {
LOG_WARN("Socket close error (%s).", ec.message().c_str()); LOG_WARN("Socket close error (%s)", ec.message().c_str());
return false; return false;
} }
@ -77,91 +69,40 @@ bool Socket::Close() {
#if WEBCC_ENABLE_SSL #if WEBCC_ENABLE_SSL
#if (defined(_WIN32) || defined(_WIN64))
// Let OpenSSL on Windows use the system certificate store
// 1. Load your certificate (in PCCERT_CONTEXT structure) from Windows Cert
// store using Crypto APIs.
// 2. Get encrypted content of it in binary format as it is.
// [PCCERT_CONTEXT->pbCertEncoded].
// 3. Parse this binary buffer into X509 certificate Object using OpenSSL's
// d2i_X509() method.
// 4. Get handle to OpenSSL's trust store using SSL_CTX_get_cert_store()
// method.
// 5. Load above parsed X509 certificate into this trust store using
// X509_STORE_add_cert() method.
// 6. You are done!
// NOTES: Enum Windows store with "ROOT" (not "CA").
// See: https://stackoverflow.com/a/11763389/6825348
static bool UseSystemCertificateStore(SSL_CTX* ssl_ctx) {
// NOTE: Cannot use nullptr to replace NULL.
HCERTSTORE cert_store = ::CertOpenSystemStoreW(NULL, L"ROOT");
if (cert_store == nullptr) {
LOG_ERRO("Cannot open Windows system certificate store.");
return false;
}
X509_STORE* x509_store = SSL_CTX_get_cert_store(ssl_ctx);
PCCERT_CONTEXT cert_context = nullptr;
while (cert_context = CertEnumCertificatesInStore(cert_store, cert_context)) {
auto in = (const unsigned char**)&cert_context->pbCertEncoded;
X509* x509 = d2i_X509(nullptr, in, cert_context->cbCertEncoded);
if (x509 != nullptr) {
if (X509_STORE_add_cert(x509_store, x509) == 0) {
LOG_ERRO("Cannot add Windows root certificate.");
}
X509_free(x509);
}
}
CertFreeCertificateContext(cert_context);
CertCloseStore(cert_store, 0);
return true;
}
#endif // defined(_WIN32) || defined(_WIN64)
namespace ssl = boost::asio::ssl; namespace ssl = boost::asio::ssl;
SslSocket::SslSocket(boost::asio::io_context& io_context, bool ssl_verify) SslSocket::SslSocket(boost::asio::io_context& io_context,
: ssl_context_(ssl::context::sslv23), ssl::context& ssl_context, bool ssl_verify)
ssl_socket_(io_context, ssl_context_), : ssl_socket_(io_context, ssl_context),
ssl_verify_(ssl_verify) { ssl_verify_(ssl_verify) {
#if (defined(_WIN32) || defined(_WIN64))
if (ssl_verify_) {
UseSystemCertificateStore(ssl_context_.native_handle());
}
#else
// Use the default paths for finding CA certificates.
ssl_context_.set_default_verify_paths();
#endif // defined(_WIN32) || defined(_WIN64)
} }
bool SslSocket::Connect(const std::string& host, const Endpoints& endpoints) { void SslSocket::AsyncConnect(const std::string& host,
boost::system::error_code ec; const Endpoints& endpoints,
boost::asio::connect(ssl_socket_.lowest_layer(), endpoints, ec); ConnectHandler&& handler) {
connect_handler_ = std::move(handler);
if (ec) {
LOG_ERRO("Socket connect error (%s).", ec.message().c_str()); if (ssl_verify_) {
return false; ssl_socket_.set_verify_mode(ssl::verify_peer);
} else {
ssl_socket_.set_verify_mode(ssl::verify_none);
} }
return Handshake(host); // ssl::host_name_verification has been added since Boost 1.73 to replace
} // ssl::rfc2818_verification.
#if BOOST_VERSION < 107300
ssl_socket_.set_verify_callback(ssl::rfc2818_verification(host));
#else
ssl_socket_.set_verify_callback(ssl::host_name_verification(host));
#endif // BOOST_VERSION < 107300
bool SslSocket::Write(const Payload& payload, boost::system::error_code* ec) { boost::asio::async_connect(ssl_socket_.lowest_layer(), endpoints,
boost::asio::write(ssl_socket_, payload, *ec); std::bind(&SslSocket::OnConnect, this, _1, _2));
return !(*ec);
} }
bool SslSocket::ReadSome(std::vector<char>* buffer, std::size_t* size, void SslSocket::AsyncWrite(const Payload& payload, WriteHandler&& handler) {
boost::system::error_code* ec) { boost::asio::async_write(ssl_socket_, payload, std::move(handler));
*size = ssl_socket_.read_some(boost::asio::buffer(*buffer), *ec);
return (*size != 0 && !(*ec));
} }
void SslSocket::AsyncReadSome(ReadHandler&& handler, void SslSocket::AsyncReadSome(ReadHandler&& handler,
@ -169,39 +110,50 @@ void SslSocket::AsyncReadSome(ReadHandler&& handler,
ssl_socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler)); ssl_socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler));
} }
bool SslSocket::Close() { bool SslSocket::Shutdown() {
boost::system::error_code ec; boost::system::error_code ec;
ssl_socket_.lowest_layer().close(ec); ssl_socket_.lowest_layer().shutdown(tcp::socket::shutdown_both, ec);
return !ec;
}
bool SslSocket::Handshake(const std::string& host) { if (ec) {
if (ssl_verify_) { LOG_WARN("Socket shutdown error (%s)", ec.message().c_str());
ssl_socket_.set_verify_mode(ssl::verify_peer); return false;
} else {
ssl_socket_.set_verify_mode(ssl::verify_none);
} }
// ssl::host_name_verification has been added since Boost 1.73 to replace return true;
// ssl::rfc2818_verification. }
#if BOOST_VERSION < 107300
ssl_socket_.set_verify_callback(ssl::rfc2818_verification(host));
#else
ssl_socket_.set_verify_callback(ssl::host_name_verification(host));
#endif // BOOST_VERSION < 107300
// Use sync API directly since we don't need timeout control. bool SslSocket::Close() {
boost::system::error_code ec; boost::system::error_code ec;
ssl_socket_.handshake(ssl::stream_base::client, ec); ssl_socket_.lowest_layer().close(ec);
if (ec) { if (ec) {
LOG_ERRO("Handshake error (%s).", ec.message().c_str()); LOG_WARN("Socket close error (%s)", ec.message().c_str());
return false; return false;
} }
return true; return true;
} }
void SslSocket::OnConnect(boost::system::error_code ec,
tcp::endpoint endpoint) {
if (ec) {
connect_handler_(ec, std::move(endpoint));
return;
}
// Backup endpoint
endpoint_ = std::move(endpoint);
ssl_socket_.async_handshake(ssl::stream_base::client,
[this](boost::system::error_code ec) {
if (ec) {
LOG_ERRO("Handshake error (%s)", ec.message().c_str());
}
connect_handler_(ec, std::move(endpoint_));
});
}
#endif // WEBCC_ENABLE_SSL #endif // WEBCC_ENABLE_SSL
} // namespace webcc } // namespace webcc

@ -18,24 +18,29 @@ namespace webcc {
class SocketBase { class SocketBase {
public: public:
virtual ~SocketBase() = default;
using Endpoints = boost::asio::ip::tcp::resolver::results_type; using Endpoints = boost::asio::ip::tcp::resolver::results_type;
using ConnectHandler = std::function<void(boost::system::error_code,
boost::asio::ip::tcp::endpoint)>;
using WriteHandler =
std::function<void(boost::system::error_code, std::size_t)>;
using ReadHandler = using ReadHandler =
std::function<void(boost::system::error_code, std::size_t)>; std::function<void(boost::system::error_code, std::size_t)>;
// TODO: Remove |host| virtual ~SocketBase() = default;
virtual bool Connect(const std::string& host, const Endpoints& endpoints) = 0;
virtual bool Write(const Payload& payload, boost::system::error_code* ec) = 0; virtual void AsyncConnect(const std::string& host, const Endpoints& endpoints,
ConnectHandler&& handler) = 0;
virtual bool ReadSome(std::vector<char>* buffer, std::size_t* size, virtual void AsyncWrite(const Payload& payload, WriteHandler&& handler) = 0;
boost::system::error_code* ec) = 0;
virtual void AsyncReadSome(ReadHandler&& handler, virtual void AsyncReadSome(ReadHandler&& handler,
std::vector<char>* buffer) = 0; std::vector<char>* buffer) = 0;
virtual bool Shutdown() = 0;
virtual bool Close() = 0; virtual bool Close() = 0;
}; };
@ -45,15 +50,15 @@ class Socket : public SocketBase {
public: public:
explicit Socket(boost::asio::io_context& io_context); explicit Socket(boost::asio::io_context& io_context);
bool Connect(const std::string& host, const Endpoints& endpoints) override; void AsyncConnect(const std::string& host, const Endpoints& endpoints,
ConnectHandler&& handler) override;
bool Write(const Payload& payload, boost::system::error_code* ec) override;
bool ReadSome(std::vector<char>* buffer, std::size_t* size, void AsyncWrite(const Payload& payload, WriteHandler&& handler) override;
boost::system::error_code* ec) override;
void AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) override; void AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) override;
bool Shutdown() override;
bool Close() override; bool Close() override;
private: private:
@ -66,29 +71,32 @@ private:
class SslSocket : public SocketBase { class SslSocket : public SocketBase {
public: public:
explicit SslSocket(boost::asio::io_context& io_context, SslSocket(boost::asio::io_context& io_context,
bool ssl_verify = true); boost::asio::ssl::context& ssl_context,
bool ssl_verify = true);
bool Connect(const std::string& host, const Endpoints& endpoints) override; void AsyncConnect(const std::string& host, const Endpoints& endpoints,
ConnectHandler&& handler) override;
bool Write(const Payload& payload, boost::system::error_code* ec) override; void AsyncWrite(const Payload& payload, WriteHandler&& handler) override;
bool ReadSome(std::vector<char>* buffer, std::size_t* size,
boost::system::error_code* ec) override;
void AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) override; void AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) override;
bool Shutdown() override;
bool Close() override; bool Close() override;
private: private:
bool Handshake(const std::string& host); void OnConnect(boost::system::error_code ec,
boost::asio::ip::tcp::endpoint endpoint);
boost::asio::ssl::context ssl_context_; ConnectHandler connect_handler_;
boost::asio::ip::tcp::endpoint endpoint_;
boost::asio::ssl::stream<boost::asio::ip::tcp::socket> ssl_socket_; boost::asio::ssl::stream<boost::asio::ip::tcp::socket> ssl_socket_;
// Verify the certificate of the peer (remote server) or not. // Verify the certificate of the peer (remote server) or not.
bool ssl_verify_; bool ssl_verify_ = true;
}; };
#endif // WEBCC_ENABLE_SSL #endif // WEBCC_ENABLE_SSL

Loading…
Cancel
Save