rework client using async api

master
Chunting Gu 4 years ago
parent 7f345b7a4e
commit f5210ba1a2

@ -1,24 +1,2 @@
# Automation test
set(AT_SRCS
client_autotest.cc
client_timeout_autotest.cc
main.cc
)
# Common libraries to link.
set(AT_LIBS
webcc
jsoncpp
GTest::GTest)
if(UNIX)
# Add `-ldl` for Linux to avoid "undefined reference to `dlopen'".
set(AT_LIBS ${AT_LIBS} ${CMAKE_DL_LIBS})
endif()
set(AT_TARGET_NAME autotest)
add_executable(${AT_TARGET_NAME} ${AT_SRCS})
target_link_libraries(${AT_TARGET_NAME} ${AT_LIBS})
set_target_properties(${AT_TARGET_NAME} PROPERTIES FOLDER "Tests")
add_subdirectory(client_autotest)
add_subdirectory(client_timeout_autotest)

@ -0,0 +1,24 @@
set(SRCS
client_autotest.cc
main.cc
)
# Common libraries to link.
set(LIBS
webcc
jsoncpp
GTest::GTest
)
if(UNIX)
# Add `-ldl` for Linux to avoid "undefined reference to `dlopen'".
set(LIBS ${LIBS} ${CMAKE_DL_LIBS})
endif()
set(TARGET_NAME client_autotest)
add_executable(${TARGET_NAME} ${SRCS})
target_link_libraries(${TARGET_NAME} ${LIBS})
set_target_properties(${TARGET_NAME} PROPERTIES FOLDER "Tests")

@ -0,0 +1,22 @@
set(SRCS
client_timeout_autotest.cc
main.cc
)
# Common libraries to link.
set(LIBS
webcc
jsoncpp
GTest::GTest
)
if(UNIX)
# Add `-ldl` for Linux to avoid "undefined reference to `dlopen'".
set(LIBS ${LIBS} ${CMAKE_DL_LIBS})
endif()
set(TARGET_NAME client_timeout_autotest)
add_executable(${TARGET_NAME} ${SRCS})
target_link_libraries(${TARGET_NAME} ${LIBS})
set_target_properties(${TARGET_NAME} PROPERTIES FOLDER "Tests")

@ -9,7 +9,7 @@
namespace {
const char* kData = "Hello, World!";
const char* const kData = "Hello, World!";
const std::uint16_t kPort = 8080;
@ -83,8 +83,8 @@ TEST_F(ClientTimeoutTest, NoTimeout) {
TEST_F(ClientTimeoutTest, Timeout) {
webcc::ClientSession session;
// Change timeout to 1s.
session.set_timeout(1);
// Change read timeout to 1s.
session.set_read_timeout(1);
webcc::ResponsePtr r;
bool timeout = false;

@ -0,0 +1,11 @@
#include "gtest/gtest.h"
#include "webcc/logger.h"
int main(int argc, char* argv[]) {
// Set webcc::LOG_CONSOLE to enable logging.
WEBCC_LOG_INIT("", 0);
::testing::InitGoogleTest(&argc, argv);
return RUN_ALL_TESTS();
}

@ -8,7 +8,8 @@ int main() {
WEBCC_LOG_INIT("", webcc::LOG_CONSOLE);
webcc::ClientSession session;
session.set_connect_timeout(5);
session.set_read_timeout(5);
session.Accept("application/json");
webcc::ResponsePtr r;

@ -30,7 +30,7 @@ int main(int argc, const char* argv[]) {
threads.emplace_back([&url]() {
// NOTE: Each thread has its own client session.
webcc::ClientSession session;
session.set_timeout(180);
session.set_read_timeout(180);
try {
LOG_USER("Start");

@ -27,7 +27,10 @@ int main(int argc, char* argv[]) {
webcc::ClientSession session;
try {
auto r = session.Send(webcc::RequestBuilder{}.Get(url)(), true);
auto r = session.Send(webcc::RequestBuilder{}.Get(url)(), true,
[](std::size_t length, std::size_t total_length) {
std::cout << "Progress " << length << " / " << total_length << std::endl;
});
if (auto file_body = r->file_body()) {
file_body->Move(path);

@ -3,23 +3,40 @@
#include "webcc/logger.h"
using boost::asio::ip::tcp;
using namespace std::placeholders;
namespace webcc {
Client::Client()
: timer_(io_context_),
ssl_verify_(true),
buffer_size_(kBufferSize),
timeout_(kMaxReadSeconds),
closed_(false),
timer_canceled_(false) {
#if WEBCC_ENABLE_SSL
Client::Client(boost::asio::io_context& io_context,
boost::asio::ssl::context& ssl_context)
: io_context_(io_context),
ssl_context_(ssl_context),
resolver_(io_context),
deadline_timer_(io_context) {
}
#else
Client::Client(boost::asio::io_context& io_context)
: io_context_(io_context),
resolver_(io_context),
deadline_timer_(io_context) {
}
Error Client::Request(RequestPtr request, bool connect, bool stream) {
closed_ = false;
timer_canceled_ = false;
#endif // WEBCC_ENABLE_SSL
Error Client::Request(RequestPtr request, bool stream) {
LOG_VERB("Request begin");
request_finished_ = false;
error_ = Error{};
request_ = request;
length_read_ = 0;
response_.reset(new Response{});
response_parser_.Init(response_.get(), stream);
@ -35,246 +52,312 @@ Error Client::Request(RequestPtr request, bool connect, bool stream) {
// have Content-Length;
// - If request.Accept-Encoding is "identity", the response will have
// Content-Length.
if (request->method() == methods::kHead) {
if (request_->method() == methods::kHead) {
response_parser_.set_ignore_body(true);
} else {
// Reset in case the connection is persistent.
response_parser_.set_ignore_body(false);
}
io_context_.restart();
if (connect) {
// No existing socket connection was specified, create a new one.
Connect(request);
if (error_) {
return error_;
}
if (!connected_) {
AsyncConnect();
} else {
AsyncWrite();
}
WriteRequest(request);
// Wait for the request to be finished.
std::unique_lock<std::mutex> response_lock{ request_mutex_ };
request_cv_.wait(response_lock, [=] { return request_finished_; });
if (error_) {
return error_;
}
ReadResponse();
LOG_VERB("Request end");
return error_;
}
void Client::Close() {
if (closed_) {
if (!connected_) {
//resolver_.cancel(); // TODO
if (socket_) {
// Cancel any async operations on the socket.
LOG_VERB("Close socket");
socket_->Close();
// Make sure the current request, if any, could be finished.
FinishRequest();
}
return;
}
closed_ = true;
LOG_INFO("Close socket...");
connected_ = false;
if (socket_) {
LOG_INFO("Shutdown & close socket");
socket_->Shutdown();
socket_->Close();
// Make sure the current request, if any, could be finished.
FinishRequest();
}
LOG_INFO("Socket closed");
}
void Client::Connect(RequestPtr request) {
if (request->url().scheme() == "https") {
void Client::AsyncConnect() {
if (request_->url().scheme() == "https") {
#if WEBCC_ENABLE_SSL
socket_.reset(new SslSocket{ io_context_, ssl_verify_ });
DoConnect(request, "443");
socket_.reset(new SslSocket{ io_context_, ssl_context_, ssl_verify_ });
AsyncResolve("443");
#else
LOG_ERRO("SSL/HTTPS support is not enabled.");
error_.Set(Error::kSyntaxError, "SSL/HTTPS is not supported");
FinishRequest();
#endif // WEBCC_ENABLE_SSL
} else {
socket_.reset(new Socket{ io_context_ });
DoConnect(request, "80");
AsyncResolve("80");
}
}
void Client::DoConnect(RequestPtr request, const std::string& default_port) {
tcp::resolver resolver(io_context_);
std::string port = request->port();
void Client::AsyncResolve(const std::string& default_port) {
std::string port = request_->port();
if (port.empty()) {
port = default_port;
}
LOG_VERB("Resolve host (%s)...", request->host().c_str());
boost::system::error_code ec;
LOG_VERB("Resolve host (%s)", request_->host().c_str());
// The protocol depends on the `host`, both V4 and V6 are supported.
auto endpoints = resolver.resolve(request->host(), port, ec);
resolver_.async_resolve(request_->host(), port,
std::bind(&Client::OnResolve, this, _1, _2));
}
void Client::OnResolve(boost::system::error_code ec,
tcp::resolver::results_type endpoints) {
if (ec) {
LOG_ERRO("Host resolve error (%s): %s, %s.", ec.message().c_str(),
request->host().c_str(), port.c_str());
LOG_ERRO("Host resolve error (%s)", ec.message().c_str());
error_.Set(Error::kResolveError, "Host resolve error");
FinishRequest();
return;
}
LOG_VERB("Connect to server...");
LOG_VERB("Connect socket");
// Use sync API directly since we don't need timeout control.
AsyncWaitDeadlineTimer(connect_timeout_);
if (!socket_->Connect(request->host(), endpoints)) {
error_.Set(Error::kConnectError, "Endpoint connect error");
Close();
return;
socket_->AsyncConnect(request_->host(), endpoints,
std::bind(&Client::OnConnect, this, _1, _2));
}
LOG_VERB("Socket connected.");
}
void Client::OnConnect(boost::system::error_code ec, tcp::endpoint) {
LOG_VERB("On connect");
void Client::WriteRequest(RequestPtr request) {
LOG_VERB("HTTP request:\n%s", request->Dump().c_str());
StopDeadlineTimer();
// NOTE:
// It doesn't make much sense to set a timeout for socket write.
// I find that it's almost impossible to simulate a situation in the server
// side to test this timeout.
if (ec) {
if (ec == boost::asio::error::operation_aborted) {
// Socket has been closed by OnDeadlineTimer() or Close().
LOG_WARN("Connect operation aborted");
} else {
LOG_INFO("Connect error");
// No need to close socket since no async operation is on it.
// socket_->Close();
}
// Use sync API directly since we don't need timeout control.
error_.Set(Error::kConnectError, "Socket connect error");
FinishRequest();
return;
}
boost::system::error_code ec;
LOG_INFO("Socket connected");
if (socket_->Write(request->GetPayload(), &ec)) {
// Write request body.
auto body = request->body();
body->InitPayload();
for (auto p = body->NextPayload(true); !p.empty();
p = body->NextPayload(true)) {
if (!socket_->Write(p, &ec)) {
break;
}
connected_ = true;
AsyncWrite();
}
void Client::AsyncWrite() {
LOG_VERB("Request:\n%s", request_->Dump().c_str());
socket_->AsyncWrite(request_->GetPayload(),
std::bind(&Client::OnWrite, this, _1, _2));
}
void Client::OnWrite(boost::system::error_code ec, std::size_t length) {
if (ec) {
LOG_ERRO("Socket write error (%s).", ec.message().c_str());
Close();
error_.Set(Error::kSocketWriteError, "Socket write error");
HandleWriteError(ec);
return;
}
LOG_INFO("Request sent.");
request_->body()->InitPayload();
AsyncWriteBody();
}
void Client::ReadResponse() {
LOG_VERB("Read response (timeout: %ds)...", timeout_);
void Client::AsyncWriteBody() {
auto p = request_->body()->NextPayload(true);
if (!p.empty()) {
socket_->AsyncWrite(p, std::bind(&Client::OnWriteBody, this, _1, _2));
} else {
LOG_INFO("Request send");
DoReadResponse();
// Start the read deadline timer.
AsyncWaitDeadlineTimer(read_timeout_);
if (!error_) {
LOG_VERB("HTTP response:\n%s", response_->Dump().c_str());
// Start to read response.
AsyncRead();
}
}
void Client::DoReadResponse() {
boost::system::error_code ec = boost::asio::error::would_block;
std::size_t length = 0;
// The read handler.
auto handler = [&ec, &length](boost::system::error_code inner_ec,
std::size_t inner_length) {
ec = inner_ec;
length = inner_length;
};
void Client::OnWriteBody(boost::system::error_code ec, std::size_t legnth) {
if (ec) {
HandleWriteError(ec);
return;
}
while (true) {
ec = boost::asio::error::would_block;
length = 0;
// Continue to write the next payload of body.
AsyncWriteBody();
}
socket_->AsyncReadSome(std::move(handler), &buffer_);
void Client::HandleWriteError(boost::system::error_code ec) {
if (ec == boost::asio::error::operation_aborted) {
// Socket has been closed by OnDeadlineTimer() or Close().
LOG_WARN("Write operation aborted");
} else {
LOG_ERRO("Socket write error (%s)", ec.message().c_str());
Close();
}
// Start the timer.
DoWaitTimer();
error_.Set(Error::kSocketWriteError, "Socket write error");
FinishRequest();
}
// Block until the asynchronous operation has completed.
do {
io_context_.run_one();
} while (ec == boost::asio::error::would_block);
void Client::AsyncRead() {
socket_->AsyncReadSome(std::bind(&Client::OnRead, this, _1, _2), &buffer_);
}
// Stop the timer.
CancelTimer();
void Client::OnRead(boost::system::error_code ec, std::size_t length) {
StopDeadlineTimer();
// The error normally is caused by timeout. See OnTimer().
if (ec || length == 0) {
if (ec) {
if (ec == boost::asio::error::operation_aborted) {
// Socket has been closed by OnDeadlineTimer() or Close().
LOG_WARN("Read operation aborted");
} else {
LOG_ERRO("Socket read error (%s)", ec.message().c_str());
Close();
}
error_.Set(Error::kSocketReadError, "Socket read error");
LOG_ERRO("Socket read error (%s).", ec.message().c_str());
break;
FinishRequest();
return;
}
LOG_INFO("Read data, length: %u.", length);
length_read_ += length;
LOG_INFO("Read length: %u", length);
// Parse the piece of data just read.
if (!response_parser_.Parse(buffer_.data(), length)) {
LOG_ERRO("Failed to parse the response");
Close();
error_.Set(Error::kParseError, "HTTP parse error");
LOG_ERRO("Failed to parse the HTTP response.");
break;
error_.Set(Error::kParseError, "Response parse error");
FinishRequest();
return;
}
// Inform progress callback if it's specified.
if (progress_callback_) {
if (response_parser_.header_ended()) {
// NOTE: Need to get rid of the header length.
progress_callback_(length_read_ - response_parser_.header_length(),
response_parser_.content_length());
}
}
if (response_parser_.finished()) {
// Stop trying to read once all content has been received, because
// some servers will block extra call to read_some().
LOG_VERB("Response:\n%s", response_->Dump().c_str());
if (response_->IsConnectionKeepAlive()) {
// Close the timer but keep the socket connection.
LOG_INFO("Keep the socket connection alive.");
LOG_INFO("Keep the socket connection alive");
} else {
Close();
}
// Stop reading.
LOG_INFO("Finished to read the HTTP response.");
break;
// Stop trying to read once all content has been received, because some
// servers will block extra call to read_some().
LOG_INFO("Finished to read the response");
FinishRequest();
return;
}
// Continue to read the response.
AsyncRead();
}
void Client::AsyncWaitDeadlineTimer(int seconds) {
if (seconds <= 0) {
deadline_timer_stopped_ = true;
return;
}
void Client::DoWaitTimer() {
LOG_VERB("Wait timer asynchronously.");
timer_.expires_after(std::chrono::seconds(timeout_));
timer_.async_wait(std::bind(&Client::OnTimer, this, std::placeholders::_1));
LOG_VERB("Async wait deadline timer");
deadline_timer_stopped_ = false;
deadline_timer_.expires_after(std::chrono::seconds(seconds));
deadline_timer_.async_wait(std::bind(&Client::OnDeadlineTimer, this, _1));
}
void Client::OnTimer(boost::system::error_code ec) {
LOG_VERB("On timer.");
void Client::OnDeadlineTimer(boost::system::error_code ec) {
LOG_VERB("On deadline timer");
deadline_timer_stopped_ = true;
// timer_.cancel() was called.
// deadline_timer_.cancel() was called.
if (ec == boost::asio::error::operation_aborted) {
LOG_VERB("Timer canceled.");
LOG_VERB("Deadline timer canceled");
return;
}
if (closed_) {
LOG_VERB("Socket has been closed.");
return;
}
LOG_WARN("Timeout");
if (timer_.expiry() <= boost::asio::steady_timer::clock_type::now()) {
// The deadline has passed. The socket is closed so that any outstanding
// asynchronous operations are canceled.
LOG_WARN("HTTP client timed out.");
error_.set_timeout(true);
// Cancel the async operations on the socket.
// OnXxx() will be called with `error::operation_aborted`.
if (connected_) {
Close();
return;
} else {
socket_->Close();
}
// Put the actor back to sleep.
DoWaitTimer();
error_.set_timeout(true);
}
void Client::CancelTimer() {
if (timer_canceled_) {
void Client::StopDeadlineTimer() {
if (deadline_timer_stopped_) {
return;
}
LOG_INFO("Cancel timer...");
timer_.cancel();
LOG_INFO("Cancel deadline timer");
timer_canceled_ = true;
try {
// Cancel the async wait operation on this timer.
deadline_timer_.cancel();
} catch (const boost::system::system_error&) {
}
deadline_timer_stopped_ = true;
}
void Client::FinishRequest() {
{
std::lock_guard<std::mutex> lock{ request_mutex_ };
if (!request_finished_) {
request_finished_ = true;
} else {
return;
}
}
request_cv_.notify_one();
}
} // namespace webcc

@ -1,8 +1,9 @@
#ifndef WEBCC_CLIENT_H_
#define WEBCC_CLIENT_H_
#include <cassert>
#include <condition_variable>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
@ -19,18 +20,22 @@
namespace webcc {
// Synchronous HTTP & HTTPS client.
// In synchronous mode, a request won't return until the response is received
// or timeout occurs.
// Please don't use the same client object in multiple threads.
// A request won't return until the response is received or timeout occurs.
class Client {
public:
Client();
~Client() = default;
// TODO
#if WEBCC_ENABLE_SSL
Client(boost::asio::io_context& io_context,
boost::asio::ssl::context& ssl_context);
#else
explicit Client(boost::asio::io_context& io_context);
#endif
Client(const Client&) = delete;
Client& operator=(const Client&) = delete;
~Client() = default;
void set_ssl_verify(bool ssl_verify) {
ssl_verify_ = ssl_verify;
}
@ -41,19 +46,35 @@ public:
}
}
// Set the timeout (in seconds) for reading response.
void set_timeout(int timeout) {
void set_connect_timeout(int timeout) {
if (timeout > 0) {
connect_timeout_ = timeout;
}
}
void set_read_timeout(int timeout) {
if (timeout > 0) {
timeout_ = timeout;
read_timeout_ = timeout;
}
}
// Set progress callback to be informed about the read progress.
// TODO: What about write?
// TODO: std::move?
void set_progress_callback(ProgressCallback callback) {
progress_callback_ = callback;
}
// Connect to server, send request, wait until response is received.
Error Request(RequestPtr request, bool connect = true, bool stream = false);
// Connect, send request, wait until response is received.
Error Request(RequestPtr request, bool stream = false);
// Close the socket.
void Close();
bool connected() const {
return connected_;
}
ResponsePtr response() const {
return response_;
}
@ -66,57 +87,82 @@ public:
response_parser_.Init(nullptr, false);
}
bool closed() const {
return closed_;
}
private:
void Connect(RequestPtr request);
void AsyncConnect();
void AsyncResolve(const std::string& default_port);
void DoConnect(RequestPtr request, const std::string& default_port);
void OnResolve(boost::system::error_code ec,
boost::asio::ip::tcp::resolver::results_type endpoints);
void WriteRequest(RequestPtr request);
void OnConnect(boost::system::error_code ec, boost::asio::ip::tcp::endpoint);
void ReadResponse();
void AsyncWrite();
void OnWrite(boost::system::error_code ec, std::size_t length);
void DoReadResponse();
void AsyncWriteBody();
void OnWriteBody(boost::system::error_code ec, std::size_t length);
void DoWaitTimer();
void OnTimer(boost::system::error_code ec);
void HandleWriteError(boost::system::error_code ec);
// Cancel any async-operations waiting on the timer.
void CancelTimer();
void AsyncRead();
void OnRead(boost::system::error_code ec, std::size_t length);
void AsyncWaitDeadlineTimer(int seconds);
void OnDeadlineTimer(boost::system::error_code ec);
void StopDeadlineTimer();
void FinishRequest();
private:
boost::asio::io_context io_context_;
boost::asio::io_context& io_context_;
#if WEBCC_ENABLE_SSL
boost::asio::ssl::context& ssl_context_;
#endif
// Socket connection.
std::unique_ptr<SocketBase> socket_;
boost::asio::ip::tcp::resolver resolver_;
bool request_finished_ = true;
std::condition_variable request_cv_;
std::mutex request_mutex_;
RequestPtr request_;
ResponsePtr response_;
ResponseParser response_parser_;
// Timer for the timeout control.
boost::asio::steady_timer timer_;
// The length already read.
std::size_t length_read_ = 0;
// The buffer for reading response.
std::vector<char> buffer_;
// Verify the certificate of the peer or not (for HTTPS).
bool ssl_verify_;
bool ssl_verify_ = true;
// The size of the buffer for reading response.
// 0 means default value will be used.
std::size_t buffer_size_;
std::size_t buffer_size_ = kBufferSize;
// Timeout (seconds) for connecting to server.
// Default as 0 to disable our own control (i.e., deadline_timer_).
int connect_timeout_ = 0;
// Timeout (seconds) for reading response.
int read_timeout_ = kMaxReadSeconds;
// Timeout (seconds) for receiving response.
int timeout_;
// Deadline timer for connecting to server.
boost::asio::steady_timer deadline_timer_;
bool deadline_timer_stopped_ = true;
// Connection closed.
bool closed_;
// Socket connected or not.
bool connected_ = false;
// Deadline timer canceled.
bool timer_canceled_;
// Progress callback (optional).
ProgressCallback progress_callback_;
Error error_;
};

@ -6,7 +6,7 @@ namespace webcc {
ClientPool::~ClientPool() {
if (!clients_.empty()) {
LOG_INFO("Close socket for all (%u) connections in the pool.",
LOG_INFO("Close socket for all (%u) connections in the pool",
clients_.size());
for (auto& pair : clients_) {
@ -21,21 +21,21 @@ ClientPtr ClientPool::Get(const Key& key) const {
if (it != clients_.end()) {
return it->second;
} else {
return ClientPtr{};
return {};
}
}
void ClientPool::Add(const Key& key, ClientPtr client) {
clients_[key] = client;
LOG_INFO("Added connection to pool (%s, %s, %s).",
LOG_INFO("Connection added to pool (%s, %s, %s)",
key.scheme.c_str(), key.host.c_str(), key.port.c_str());
}
void ClientPool::Remove(const Key& key) {
clients_.erase(key);
LOG_INFO("Removed connection from pool (%s, %s, %s).",
LOG_INFO("Connection removed from pool (%s, %s, %s)",
key.scheme.c_str(), key.host.c_str(), key.port.c_str());
}

@ -1,16 +1,123 @@
#include "webcc/client_session.h"
#include <cassert>
#include "webcc/base64.h"
#include "webcc/logger.h"
#include "webcc/url.h"
#include "webcc/utility.h"
namespace ssl = boost::asio::ssl;
namespace webcc {
ClientSession::ClientSession(int timeout, bool ssl_verify,
std::size_t buffer_size)
: timeout_(timeout), ssl_verify_(ssl_verify), buffer_size_(buffer_size) {
#if WEBCC_ENABLE_SSL
#if (defined(_WIN32) || defined(_WIN64))
// Let OpenSSL on Windows use the system certificate store
// 1. Load your certificate (in PCCERT_CONTEXT structure) from Windows Cert
// store using Crypto APIs.
// 2. Get encrypted content of it in binary format as it is.
// [PCCERT_CONTEXT->pbCertEncoded].
// 3. Parse this binary buffer into X509 certificate Object using OpenSSL's
// d2i_X509() method.
// 4. Get handle to OpenSSL's trust store using SSL_CTX_get_cert_store()
// method.
// 5. Load above parsed X509 certificate into this trust store using
// X509_STORE_add_cert() method.
// 6. You are done!
// NOTES: Enum Windows store with "ROOT" (not "CA").
// See: https://stackoverflow.com/a/11763389/6825348
static bool UseSystemCertificateStore(SSL_CTX* ssl_ctx) {
// NOTE: Cannot use nullptr to replace NULL.
HCERTSTORE cert_store = ::CertOpenSystemStoreW(NULL, L"ROOT");
if (cert_store == nullptr) {
LOG_ERRO("Cannot open Windows system certificate store.");
return false;
}
X509_STORE* x509_store = SSL_CTX_get_cert_store(ssl_ctx);
PCCERT_CONTEXT cert_context = nullptr;
while (cert_context = CertEnumCertificatesInStore(cert_store, cert_context)) {
auto in = (const unsigned char**)&cert_context->pbCertEncoded;
X509* x509 = d2i_X509(nullptr, in, cert_context->cbCertEncoded);
if (x509 != nullptr) {
if (X509_STORE_add_cert(x509_store, x509) == 0) {
LOG_ERRO("Cannot add Windows root certificate.");
}
X509_free(x509);
}
}
CertFreeCertificateContext(cert_context);
CertCloseStore(cert_store, 0);
return true;
}
#endif // defined(_WIN32) || defined(_WIN64)
#endif // WEBCC_ENABLE_SSL
ClientSession::ClientSession(bool ssl_verify, std::size_t buffer_size)
: work_guard_(boost::asio::make_work_guard(io_context_)),
#if WEBCC_ENABLE_SSL
ssl_context_(ssl::context::sslv23),
#endif
ssl_verify_(ssl_verify), buffer_size_(buffer_size) {
#if WEBCC_ENABLE_SSL
#if (defined(_WIN32) || defined(_WIN64))
// if (ssl_verify_) {
UseSystemCertificateStore(ssl_context_.native_handle());
// }
#else
// Use the default paths for finding CA certificates.
ssl_context_.set_default_verify_paths();
#endif // defined(_WIN32) || defined(_WIN64)
#endif // WEBCC_ENABLE_SSL
InitHeaders();
Start();
}
ClientSession::~ClientSession() {
Stop();
}
void ClientSession::Start() {
if (started_) {
return;
}
started_ = true;
io_context_.restart();
// Run the io context off in its own thread so that it operates completely
// asynchronously with respect to the rest of the program.
io_thread_.reset(new std::thread{ [this]() { io_context_.run(); }});
LOG_INFO("Loop is now running");
}
void ClientSession::Stop() {
if (!started_) {
return;
}
Cancel();
io_context_.stop();
io_thread_->join();
LOG_INFO("Loop stopped");
started_ = false;
}
void ClientSession::Accept(const std::string& content_types) {
@ -68,9 +175,16 @@ void ClientSession::AuthToken(const std::string& token) {
return Auth("Token", token);
}
ResponsePtr ClientSession::Send(RequestPtr request, bool stream) {
ResponsePtr ClientSession::Send(RequestPtr request, bool stream,
ProgressCallback callback) {
assert(request);
std::lock_guard<std::mutex> lock{ mutex_ };
if (!started_) {
throw Error{ Error::kStateError, "Loop is not running" };
}
for (auto& h : headers_.data()) {
if (!request->HasHeader(h.first)) {
request->SetHeader(h.first, h.second);
@ -84,7 +198,13 @@ ResponsePtr ClientSession::Send(RequestPtr request, bool stream) {
request->Prepare();
return DoSend(request, stream);
return DoSend(request, stream, callback);
}
void ClientSession::Cancel() {
if (client_) {
client_->Close();
}
}
void ClientSession::InitHeaders() {
@ -99,34 +219,40 @@ void ClientSession::InitHeaders() {
headers_.Set(headers::kConnection, "Keep-Alive");
}
ResponsePtr ClientSession::DoSend(RequestPtr request, bool stream) {
ResponsePtr ClientSession::DoSend(RequestPtr request, bool stream,
ProgressCallback callback) {
const ClientPool::Key key{ request->url() };
// Reuse a pooled connection.
bool reuse = false;
ClientPtr client = pool_.Get(key);
if (!client) {
client.reset(new Client{});
#if WEBCC_ENABLE_SSL
client.reset(new Client{ io_context_, ssl_context_ });
#else
client.reset(new Client{ io_context_ });
#endif // WEBCC_ENABLE_SSL
reuse = false;
} else {
LOG_VERB("Reuse an existing connection.");
LOG_VERB("Reuse an existing connection");
reuse = true;
}
client->set_ssl_verify(ssl_verify_);
client->set_buffer_size(buffer_size_);
client->set_timeout(timeout_);
client->set_connect_timeout(connect_timeout_);
client->set_read_timeout(read_timeout_);
Error error = client->Request(request, !reuse, stream);
client->set_progress_callback(callback);
if (error) {
if (reuse && error.code() == Error::kSocketWriteError) {
LOG_WARN("Cannot send request with the reused connection. "
"The server must have closed it, reconnect and try again.");
error = client->Request(request, true, stream);
}
}
// Save current client for cancel.
client_ = client;
Error error = client->Request(request, stream);
client_.reset();
if (error) {
// Remove the failed connection from pool.
@ -139,11 +265,11 @@ ResponsePtr ClientSession::DoSend(RequestPtr request, bool stream) {
// Update connection pool.
if (reuse) {
if (client->closed()) {
if (!client->connected()) {
pool_.Remove(key);
}
} else {
if (!client->closed()) {
if (client->connected()) {
pool_.Add(key, client);
}
}

@ -1,9 +1,14 @@
#ifndef WEBCC_CLIENT_SESSION_H_
#define WEBCC_CLIENT_SESSION_H_
#include <memory>
#include <mutex>
#include <string>
#include <thread>
#include <vector>
#include "boost/asio/io_context.hpp"
#include "webcc/client_pool.h"
#include "webcc/request_builder.h"
#include "webcc/response.h"
@ -11,18 +16,32 @@
namespace webcc {
// HTTP requests session providing connection-pooling, configuration and more.
// A session shouldn't be shared by multiple threads. Please create a new
// session for each thread instead.
// NOTE: If a session is shared by multiple threads, the requests sent through
// it will be serialized by using a mutex.
class ClientSession {
public:
explicit ClientSession(int timeout = 0, bool ssl_verify = true,
std::size_t buffer_size = 0);
explicit ClientSession(bool ssl_verify = true, std::size_t buffer_size = 0);
~ClientSession();
// Start Asio loop in a thread.
// You don't have to call Start() manually because it's called by the
// constructor.
void Start();
// Stop Asio loop.
// You can call Start() to resume the loop.
void Stop();
~ClientSession() = default;
void set_connect_timeout(int timeout) {
if (timeout > 0) {
connect_timeout_ = timeout;
}
}
void set_timeout(int timeout) {
void set_read_timeout(int timeout) {
if (timeout > 0) {
timeout_ = timeout;
read_timeout_ = timeout;
}
}
@ -73,14 +92,35 @@ public:
// the response body will be FileBody, and you can easily move the temp file
// to another path with FileBody::Move(). So, |stream| is really useful for
// downloading files (JPEG, etc.) or saving memory for huge data responses.
ResponsePtr Send(RequestPtr request, bool stream = false);
ResponsePtr Send(RequestPtr request, bool stream = false,
ProgressCallback callback = {});
// Cancel any in-progress connecting, writing or reading.
void Cancel();
private:
void InitHeaders();
ResponsePtr DoSend(RequestPtr request, bool stream);
ResponsePtr DoSend(RequestPtr request, bool stream,
ProgressCallback callback);
private:
boost::asio::io_context io_context_;
// The thread to run Asio loop.
std::unique_ptr<std::thread> io_thread_;
using ExecutorType = boost::asio::io_context::executor_type;
boost::asio::executor_work_guard<ExecutorType> work_guard_;
// TODO
#if WEBCC_ENABLE_SSL
boost::asio::ssl::context ssl_context_;
#endif
// Is Asio loop running?
bool started_ = false;
// The media (or MIME) type of `Content-Type` header.
// E.g., "application/json".
std::string media_type_;
@ -92,18 +132,27 @@ private:
// Additional headers for each request.
Headers headers_;
// Timeout in seconds for receiving response.
int timeout_;
// Timeout (seconds) for connecting to server.
int connect_timeout_ = 0;
// Timeout (seconds) for reading response.
int read_timeout_ = 0;
// Verify the certificate of the peer or not.
bool ssl_verify_;
bool ssl_verify_ = true;
// The size of the buffer for reading response.
// 0 means default value will be used.
std::size_t buffer_size_;
// Pool for Keep-Alive client connections.
// Keep-Alive client connections.
ClientPool pool_;
// Current requested client.
ClientPtr client_;
// The mutex to guard the request.
std::mutex mutex_;
};
} // namespace webcc

@ -32,7 +32,7 @@ void Connection::Start() {
}
void Connection::Close() {
LOG_INFO("Shutdown socket...");
LOG_INFO("Shutdown socket");
// Initiate graceful connection closure.
// Socket close VS. shutdown:
@ -41,17 +41,17 @@ void Connection::Close() {
socket_.shutdown(tcp::socket::shutdown_both, ec);
if (ec) {
LOG_WARN("Socket shutdown error (%s).", ec.message().c_str());
LOG_WARN("Socket shutdown error (%s)", ec.message().c_str());
ec.clear();
// Don't return, try to close the socket anywhere.
}
LOG_INFO("Close socket...");
LOG_INFO("Close socket");
socket_.close(ec);
if (ec) {
LOG_ERRO("Socket close error (%s).", ec.message().c_str());
LOG_ERRO("Socket close error (%s)", ec.message().c_str());
}
}
@ -92,13 +92,13 @@ void Connection::DoRead() {
void Connection::OnRead(boost::system::error_code ec, std::size_t length) {
if (ec) {
if (ec == boost::asio::error::eof) {
LOG_INFO("Socket read EOF (%s).", ec.message().c_str());
LOG_INFO("Socket read EOF (%s)", ec.message().c_str());
} else if (ec == boost::asio::error::operation_aborted) {
// The socket of this connection has been closed.
// This happens, e.g., when the server was stopped by a signal (Ctrl-C).
LOG_WARN("Socket operation aborted (%s).", ec.message().c_str());
LOG_WARN("Socket operation aborted (%s)", ec.message().c_str());
} else {
LOG_ERRO("Socket read error (%s).", ec.message().c_str());
LOG_ERRO("Socket read error (%s)", ec.message().c_str());
}
// Don't try to send any response back.
@ -111,7 +111,7 @@ void Connection::OnRead(boost::system::error_code ec, std::size_t length) {
}
if (!request_parser_.Parse(buffer_.data(), length)) {
LOG_ERRO("Failed to parse HTTP request.");
LOG_ERRO("Failed to parse request");
// Send Bad Request (400) to the client and no Keep-Alive.
SendResponse(Status::kBadRequest, true);
// Close the socket connection.
@ -125,7 +125,7 @@ void Connection::OnRead(boost::system::error_code ec, std::size_t length) {
return;
}
LOG_VERB("HTTP request:\n%s", request_->Dump().c_str());
LOG_VERB("Request:\n%s", request_->Dump().c_str());
// Enqueue this connection once the request has been read.
// Some worker thread will handle the request later.
@ -133,7 +133,7 @@ void Connection::OnRead(boost::system::error_code ec, std::size_t length) {
}
void Connection::DoWrite() {
LOG_VERB("HTTP response:\n%s", response_->Dump().c_str());
LOG_VERB("Response:\n%s", response_->Dump().c_str());
// Firstly, write the headers.
boost::asio::async_write(socket_, response_->GetPayload(),
@ -177,11 +177,11 @@ void Connection::OnWriteBody(boost::system::error_code ec, std::size_t length) {
}
void Connection::OnWriteOK() {
LOG_INFO("Response has been sent back.");
LOG_INFO("Response has been sent back");
if (request_->IsConnectionKeepAlive()) {
LOG_INFO("The client asked for a keep-alive connection.");
LOG_INFO("Continue to read the next request...");
LOG_INFO("The client asked for a keep-alive connection");
LOG_INFO("Continue to read the next request");
Start();
} else {
pool_->Close(shared_from_this());
@ -189,7 +189,7 @@ void Connection::OnWriteOK() {
}
void Connection::OnWriteError(boost::system::error_code ec) {
LOG_ERRO("Socket write error (%s).", ec.message().c_str());
LOG_ERRO("Socket write error (%s)", ec.message().c_str());
if (ec != boost::asio::error::operation_aborted) {
pool_->Close(shared_from_this());

@ -5,11 +5,11 @@
namespace webcc {
void ConnectionPool::Start(ConnectionPtr c) {
LOG_VERB("Starting connection...");
LOG_VERB("Start connection");
{
// Lock the container only.
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock{ mutex_ };
connections_.insert(c);
}
@ -17,11 +17,11 @@ void ConnectionPool::Start(ConnectionPtr c) {
}
void ConnectionPool::Close(ConnectionPtr c) {
LOG_VERB("Closing connection...");
LOG_VERB("Close connection");
{
// Lock the container only.
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock{ mutex_ };
connections_.erase(c);
}
@ -30,10 +30,10 @@ void ConnectionPool::Close(ConnectionPtr c) {
void ConnectionPool::Clear() {
// Lock all since we are going to stop anyway.
std::lock_guard<std::mutex> lock(mutex_);
std::lock_guard<std::mutex> lock{ mutex_ };
if (!connections_.empty()) {
LOG_VERB("Closing all (%u) connections...", connections_.size());
LOG_VERB("Close all (%u) connections", connections_.size());
for (auto& c : connections_) {
c->Close();
}

@ -3,6 +3,7 @@
#include <cassert>
#include <exception>
#include <functional>
#include <iosfwd>
#include <string>
#include <vector>
@ -23,11 +24,14 @@ using UrlArgs = std::vector<std::string>;
using Payload = std::vector<boost::asio::const_buffer>;
using ProgressCallback =
std::function<void(std::size_t length, std::size_t total_length)>;
// -----------------------------------------------------------------------------
const char* const kCRLF = "\r\n";
const std::size_t kInvalidLength = std::string::npos;
const std::size_t kInvalidLength = -1;
// Default timeout for reading response.
const int kMaxReadSeconds = 30;
@ -150,6 +154,7 @@ public:
enum Code {
kUnknownError = -1,
kOK = 0,
kStateError,
kSyntaxError,
kResolveError,
kConnectError,

@ -45,7 +45,12 @@ static FILE* FOpen(const bfs::path& path, bool overwrite) {
}
struct Logger {
Logger() : file(nullptr), modes(0) {
Logger() = default;
~Logger() {
if (file != nullptr) {
fclose(file);
}
}
void Init(const bfs::path& path, int _modes) {
@ -57,14 +62,8 @@ struct Logger {
}
}
~Logger() {
if (file != nullptr) {
fclose(file);
}
}
FILE* file;
int modes;
FILE* file = nullptr;
int modes = 0;
std::mutex mutex;
};
@ -161,10 +160,10 @@ static bfs::path InitLogPath(const bfs::path& dir) {
return bfs::current_path() / WEBCC_LOG_FILE_NAME;
}
if (!bfs::exists(dir) || !bfs::is_directory(dir)) {
boost::system::error_code ec;
if (!bfs::exists(dir, ec) || !bfs::is_directory(dir, ec)) {
if (!bfs::create_directories(dir, ec) || ec) {
return bfs::path{};
return {};
}
}
@ -216,7 +215,7 @@ void Log(int level, const char* file, int line, const char* format, ...) {
va_list args;
va_start(args, format);
fprintf(g_logger.file, "%s, %s, %7s, %20s, %4d, ",
fprintf(g_logger.file, "%s, %s, %7s, %25s, %4d, ",
timestamp.c_str(), kLevelNames[level], thread_id.c_str(),
file, line);
@ -239,12 +238,12 @@ void Log(int level, const char* file, int line, const char* format, ...) {
if (g_terminal_has_color) {
if (level < WEBCC_WARN) {
fprintf(stderr, "%s%s, %s, %7s, %20s, %4d, ",
fprintf(stderr, "%s%s, %s, %7s, %25s, %4d, ",
TERM_RESET,
timestamp.c_str(), kLevelNames[level], thread_id.c_str(),
file, line);
} else {
fprintf(stderr, "%s%s%s, %s, %7s, %20s, %4d, ",
fprintf(stderr, "%s%s%s, %s, %7s, %25s, %4d, ",
TERM_RESET,
level == WEBCC_WARN ? TERM_YELLOW : TERM_RED,
timestamp.c_str(), kLevelNames[level], thread_id.c_str(),
@ -255,7 +254,7 @@ void Log(int level, const char* file, int line, const char* format, ...) {
fprintf(stderr, "%s\n", TERM_RESET);
} else {
fprintf(stderr, "%s, %s, %7s, %20s, %4d, ",
fprintf(stderr, "%s, %s, %7s, %25s, %4d, ",
timestamp.c_str(), kLevelNames[level], thread_id.c_str(),
file, line);

@ -8,7 +8,7 @@
namespace webcc {
Message::Message() : body_(new Body{}), content_length_(kInvalidLength) {
Message::Message() : body_(new Body{}) {
}
void Message::SetBody(BodyPtr body, bool set_length) {

@ -118,7 +118,7 @@ protected:
std::string start_line_;
std::size_t content_length_;
std::size_t content_length_ = kInvalidLength;
};
} // namespace webcc

@ -49,14 +49,14 @@ bool StringBodyHandler::Finish() {
auto body = std::make_shared<StringBody>(std::move(content_), IsCompressed());
#if WEBCC_ENABLE_GZIP
LOG_INFO("Decompress the HTTP content...");
LOG_INFO("Decompress the HTTP content");
if (!body->Decompress()) {
LOG_ERRO("Cannot decompress the HTTP content!");
LOG_ERRO("Cannot decompress the HTTP content");
return false;
}
#else
if (body->compressed()) {
LOG_WARN("Compressed HTTP content remains untouched.");
LOG_WARN("Compressed HTTP content remains untouched");
}
#endif // WEBCC_ENABLE_GZIP
@ -79,7 +79,7 @@ bool FileBodyHandler::OpenFile() {
temp_path_.string().c_str());
} catch (const bfs::filesystem_error&) {
LOG_ERRO("Failed to generate temp path for streaming.");
LOG_ERRO("Failed to generate temp path for streaming");
return false;
}
@ -132,6 +132,8 @@ bool Parser::Parse(const char* data, std::size_t length) {
return ParseContent(data, length);
}
header_length_ += length;
// Append the new data to the pending data.
pending_data_.append(data, length);
@ -140,11 +142,13 @@ bool Parser::Parse(const char* data, std::size_t length) {
}
if (!header_ended_) {
LOG_INFO("HTTP headers will continue in next read.");
LOG_INFO("HTTP headers will continue in next read");
return true;
}
LOG_INFO("HTTP headers just ended.");
LOG_INFO("HTTP headers just ended");
header_length_ -= pending_data_.size();
if (!OnHeadersEnd()) {
// Only request parser can reach here when no view matches the request.
@ -170,6 +174,7 @@ void Parser::Reset() {
stream_ = false;
pending_data_.clear();
header_length_ = 0;
content_length_ = kInvalidLength;
content_type_.Reset();
@ -336,7 +341,7 @@ bool Parser::ParseChunkedContent(const char* data, std::size_t length) {
return false;
}
LOG_VERB("Chunk size: %u.", chunk_size_);
LOG_VERB("Chunk size: %u", chunk_size_);
}
if (chunk_size_ == 0) {
@ -378,14 +383,14 @@ bool Parser::ParseChunkedContent(const char* data, std::size_t length) {
}
bool Parser::ParseChunkSize() {
LOG_VERB("Parse chunk size.");
LOG_VERB("Parse chunk size");
std::string line;
if (!GetNextLine(0, &line, true)) {
return true;
}
LOG_VERB("Chunk size line: [%s].", line.c_str());
LOG_VERB("Chunk size line: [%s]", line.c_str());
std::string hex_str; // e.g., "cf0" (3312)
@ -397,7 +402,7 @@ bool Parser::ParseChunkSize() {
}
if (!to_size_t(hex_str, 16, &chunk_size_)) {
LOG_ERRO("Invalid chunk-size: %s.", hex_str.c_str());
LOG_ERRO("Invalid chunk-size: %s", hex_str.c_str());
return false;
}

@ -94,17 +94,35 @@ private:
class Parser {
public:
Parser();
virtual ~Parser() = default;
Parser(const Parser&) = delete;
Parser& operator=(const Parser&) = delete;
virtual ~Parser() = default;
void Init(Message* message);
bool finished() const {
return finished_;
}
// If the headers part has been parsed or not.
bool header_ended() const {
return header_ended_;
}
// Get the length of the headers part.
// Available after the headers have been parsed (see header_ended()).
std::size_t header_length() const {
return header_length_;
}
// The content length parsed from `Content-Length` header.
// kInvalidLength if the content is chunked.
std::size_t content_length() const {
return content_length_;
}
bool Parse(const char* data, std::size_t length);
protected:
@ -144,25 +162,28 @@ protected:
bool Finish();
protected:
Message* message_;
Message* message_ = nullptr;
std::unique_ptr<BodyHandler> body_handler_;
// Data streaming or not.
bool stream_;
bool stream_ = false;
// Data waiting to be parsed.
std::string pending_data_;
// The length of the headers part.
std::size_t header_length_ = 0;
// Temporary data and helper flags for parsing.
std::size_t content_length_;
std::size_t content_length_ = kInvalidLength;
ContentType content_type_;
bool start_line_parsed_;
bool content_length_parsed_;
bool header_ended_;
bool chunked_;
std::size_t chunk_size_;
bool finished_;
bool start_line_parsed_ = false;
bool content_length_parsed_ = false;
bool header_ended_ = false;
bool chunked_ = false;
std::size_t chunk_size_ = kInvalidLength;
bool finished_ = false;
};
} // namespace webcc

@ -23,9 +23,6 @@ Server::Server(boost::asio::ip::tcp protocol, std::uint16_t port,
: protocol_(protocol),
port_(port),
doc_root_(doc_root),
buffer_size_(kBufferSize),
file_chunk_size_(1024),
running_(false),
acceptor_(io_context_),
signals_(io_context_) {
AddSignals();
@ -35,12 +32,12 @@ void Server::Run(std::size_t workers, std::size_t loops) {
assert(workers > 0);
{
std::lock_guard<std::mutex> lock(state_mutex_);
std::lock_guard<std::mutex> lock{ state_mutex_ };
assert(worker_threads_.empty());
if (IsRunning()) {
LOG_WARN("Server is already running.");
LOG_WARN("Server is already running");
return;
}
@ -48,11 +45,11 @@ void Server::Run(std::size_t workers, std::size_t loops) {
io_context_.restart();
if (!Listen(port_)) {
LOG_ERRO("Server is NOT going to run.");
LOG_ERRO("Server is NOT going to run");
return;
}
LOG_INFO("Server is going to run...");
LOG_INFO("Server is going to run");
AsyncWaitSignals();
@ -70,7 +67,7 @@ void Server::Run(std::size_t workers, std::size_t loops) {
// asynchronous operation outstanding: the asynchronous accept call waiting
// for new incoming connections.
LOG_INFO("Loop is running in %u thread(s).", loops);
LOG_INFO("Loop is running in %u thread(s)", loops);
if (loops == 1) {
// Run the loop in current thread.
@ -88,7 +85,7 @@ void Server::Run(std::size_t workers, std::size_t loops) {
}
void Server::Stop() {
std::lock_guard<std::mutex> lock(state_mutex_);
std::lock_guard<std::mutex> lock{ state_mutex_ };
DoStop();
}
@ -112,7 +109,7 @@ void Server::AsyncWaitSignals() {
// The server is stopped by canceling all outstanding asynchronous
// operations. Once all operations have finished the io_context::run()
// call will exit.
LOG_INFO("On signal %d, stopping the server...", signo);
LOG_INFO("On signal %d, stop the server", signo);
DoStop();
});
@ -121,12 +118,12 @@ void Server::AsyncWaitSignals() {
bool Server::Listen(std::uint16_t port) {
boost::system::error_code ec;
tcp::endpoint endpoint(protocol_, port);
tcp::endpoint endpoint{ protocol_, port };
// Open the acceptor.
acceptor_.open(endpoint.protocol(), ec);
if (ec) {
LOG_ERRO("Acceptor open error (%s).", ec.message().c_str());
LOG_ERRO("Acceptor open error (%s)", ec.message().c_str());
return false;
}
@ -141,7 +138,7 @@ bool Server::Listen(std::uint16_t port) {
// Bind to the server address.
acceptor_.bind(endpoint, ec);
if (ec) {
LOG_ERRO("Acceptor bind error (%s).", ec.message().c_str());
LOG_ERRO("Acceptor bind error (%s)", ec.message().c_str());
return false;
}
@ -150,7 +147,7 @@ bool Server::Listen(std::uint16_t port) {
// has not started to accept the connection yet.
acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec);
if (ec) {
LOG_ERRO("Acceptor listen error (%s).", ec.message().c_str());
LOG_ERRO("Acceptor listen error (%s)", ec.message().c_str());
return false;
}
@ -167,9 +164,10 @@ void Server::AsyncAccept() {
}
if (!ec) {
LOG_INFO("Accepted a connection.");
LOG_INFO("Accepted a connection");
using namespace std::placeholders;
auto view_matcher = std::bind(&Server::MatchViewOrStatic, this, _1,
_2, _3);
@ -205,16 +203,16 @@ void Server::DoStop() {
}
void Server::WorkerRoutine() {
LOG_INFO("Worker is running.");
LOG_INFO("Worker is running");
for (;;) {
auto connection = queue_.PopOrWait();
if (!connection) {
LOG_INFO("Worker is going to stop.");
LOG_INFO("Worker is going to stop");
// For stopping next worker.
queue_.Push(ConnectionPtr());
queue_.Push({});
// Stop this worker.
break;
@ -225,13 +223,13 @@ void Server::WorkerRoutine() {
}
void Server::StopWorkers() {
LOG_INFO("Stopping workers...");
LOG_INFO("Stop workers");
// Clear/drop pending connections.
// The connections will be closed later (see DoStop).
// Alternatively, we can wait for the pending connections to be handled.
if (queue_.Size() != 0) {
LOG_INFO("Clear pending connections...");
LOG_INFO("Clear pending connections");
queue_.Clear();
}
@ -252,7 +250,7 @@ void Server::StopWorkers() {
// last worker thread.
queue_.Clear();
LOG_INFO("All workers have been stopped.");
LOG_INFO("Workers stopped");
}
void Server::Handle(ConnectionPtr connection) {
@ -307,7 +305,9 @@ bool Server::MatchViewOrStatic(const std::string& method,
// Try to match a static file.
if (method == methods::kGet && !doc_root_.empty()) {
bfs::path path = doc_root_ / url;
if (!bfs::is_directory(path) && bfs::exists(path)) {
boost::system::error_code ec;
if (!bfs::is_directory(path, ec) && bfs::exists(path, ec)) {
return true;
}
}
@ -319,7 +319,7 @@ ResponsePtr Server::ServeStatic(RequestPtr request) {
assert(request->method() == methods::kGet);
if (doc_root_.empty()) {
LOG_INFO("The doc root was not specified.");
LOG_INFO("The doc root was not specified");
return {};
}
@ -340,7 +340,7 @@ ResponsePtr Server::ServeStatic(RequestPtr request) {
return response;
} catch (const Error& error) {
LOG_ERRO("File error: %s.", error.message().c_str());
LOG_ERRO("File error: %s", error.message().c_str());
return {};
}
}

@ -24,11 +24,11 @@ public:
Server(boost::asio::ip::tcp protocol, std::uint16_t port,
const boost::filesystem::path& doc_root = {});
~Server() = default;
Server(const Server&) = delete;
Server& operator=(const Server&) = delete;
~Server() = default;
void set_buffer_size(std::size_t buffer_size) {
if (buffer_size > 0) {
buffer_size_ = buffer_size;
@ -105,20 +105,20 @@ private:
boost::asio::ip::tcp protocol_;
// Port number.
std::uint16_t port_;
std::uint16_t port_ = 0;
// The directory with the static files to be served.
boost::filesystem::path doc_root_;
// The size of the buffer for reading request.
std::size_t buffer_size_;
std::size_t buffer_size_ = kBufferSize;
// The size of the chunk loaded into memory each time when serving a
// static file.
std::size_t file_chunk_size_;
std::size_t file_chunk_size_ = 1024;
// Is the server running?
bool running_;
bool running_ = false;
// The mutex for guarding the state of the server.
std::mutex state_mutex_;

@ -3,9 +3,9 @@
#if WEBCC_ENABLE_SSL
#if (defined(_WIN32) || defined(_WIN64))
#include <windows.h>
#include <wincrypt.h>
#include <cryptuiapi.h>
#include <wincrypt.h>
#include <windows.h>
#include "openssl/x509.h"
@ -18,6 +18,9 @@
#include "webcc/logger.h"
using boost::asio::ip::tcp;
using namespace std::placeholders;
namespace webcc {
// -----------------------------------------------------------------------------
@ -25,48 +28,37 @@ namespace webcc {
Socket::Socket(boost::asio::io_context& io_context) : socket_(io_context) {
}
bool Socket::Connect(const std::string& /*host*/, const Endpoints& endpoints) {
boost::system::error_code ec;
boost::asio::connect(socket_, endpoints, ec);
if (ec) {
LOG_ERRO("Socket connect error (%s).", ec.message().c_str());
return false;
void Socket::AsyncConnect(const std::string& host, const Endpoints& endpoints,
ConnectHandler&& handler) {
boost::asio::async_connect(socket_, endpoints, std::move(handler));
}
return true;
}
bool Socket::Write(const Payload& payload, boost::system::error_code* ec) {
boost::asio::write(socket_, payload, *ec);
return !(*ec);
}
bool Socket::ReadSome(std::vector<char>* buffer, std::size_t* size,
boost::system::error_code* ec) {
*size = socket_.read_some(boost::asio::buffer(*buffer), *ec);
return (*size != 0 && !(*ec));
void Socket::AsyncWrite(const Payload& payload, WriteHandler&& handler) {
boost::asio::async_write(socket_, payload, std::move(handler));
}
void Socket::AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) {
socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler));
}
bool Socket::Close() {
bool Socket::Shutdown() {
boost::system::error_code ec;
socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec);
socket_.shutdown(tcp::socket::shutdown_both, ec);
if (ec) {
LOG_WARN("Socket shutdown error (%s).", ec.message().c_str());
ec.clear();
// Don't return, try to close the socket anywhere.
LOG_WARN("Socket shutdown error (%s)", ec.message().c_str());
return false;
}
return true;
}
bool Socket::Close() {
boost::system::error_code ec;
socket_.close(ec);
if (ec) {
LOG_WARN("Socket close error (%s).", ec.message().c_str());
LOG_WARN("Socket close error (%s)", ec.message().c_str());
return false;
}
@ -77,129 +69,89 @@ bool Socket::Close() {
#if WEBCC_ENABLE_SSL
#if (defined(_WIN32) || defined(_WIN64))
namespace ssl = boost::asio::ssl;
// Let OpenSSL on Windows use the system certificate store
// 1. Load your certificate (in PCCERT_CONTEXT structure) from Windows Cert
// store using Crypto APIs.
// 2. Get encrypted content of it in binary format as it is.
// [PCCERT_CONTEXT->pbCertEncoded].
// 3. Parse this binary buffer into X509 certificate Object using OpenSSL's
// d2i_X509() method.
// 4. Get handle to OpenSSL's trust store using SSL_CTX_get_cert_store()
// method.
// 5. Load above parsed X509 certificate into this trust store using
// X509_STORE_add_cert() method.
// 6. You are done!
// NOTES: Enum Windows store with "ROOT" (not "CA").
// See: https://stackoverflow.com/a/11763389/6825348
static bool UseSystemCertificateStore(SSL_CTX* ssl_ctx) {
// NOTE: Cannot use nullptr to replace NULL.
HCERTSTORE cert_store = ::CertOpenSystemStoreW(NULL, L"ROOT");
if (cert_store == nullptr) {
LOG_ERRO("Cannot open Windows system certificate store.");
return false;
SslSocket::SslSocket(boost::asio::io_context& io_context,
ssl::context& ssl_context, bool ssl_verify)
: ssl_socket_(io_context, ssl_context),
ssl_verify_(ssl_verify) {
}
X509_STORE* x509_store = SSL_CTX_get_cert_store(ssl_ctx);
PCCERT_CONTEXT cert_context = nullptr;
void SslSocket::AsyncConnect(const std::string& host,
const Endpoints& endpoints,
ConnectHandler&& handler) {
connect_handler_ = std::move(handler);
while (cert_context = CertEnumCertificatesInStore(cert_store, cert_context)) {
auto in = (const unsigned char**)&cert_context->pbCertEncoded;
X509* x509 = d2i_X509(nullptr, in, cert_context->cbCertEncoded);
if (x509 != nullptr) {
if (X509_STORE_add_cert(x509_store, x509) == 0) {
LOG_ERRO("Cannot add Windows root certificate.");
if (ssl_verify_) {
ssl_socket_.set_verify_mode(ssl::verify_peer);
} else {
ssl_socket_.set_verify_mode(ssl::verify_none);
}
X509_free(x509);
}
}
// ssl::host_name_verification has been added since Boost 1.73 to replace
// ssl::rfc2818_verification.
#if BOOST_VERSION < 107300
ssl_socket_.set_verify_callback(ssl::rfc2818_verification(host));
#else
ssl_socket_.set_verify_callback(ssl::host_name_verification(host));
#endif // BOOST_VERSION < 107300
CertFreeCertificateContext(cert_context);
CertCloseStore(cert_store, 0);
return true;
boost::asio::async_connect(ssl_socket_.lowest_layer(), endpoints,
std::bind(&SslSocket::OnConnect, this, _1, _2));
}
#endif // defined(_WIN32) || defined(_WIN64)
namespace ssl = boost::asio::ssl;
SslSocket::SslSocket(boost::asio::io_context& io_context, bool ssl_verify)
: ssl_context_(ssl::context::sslv23),
ssl_socket_(io_context, ssl_context_),
ssl_verify_(ssl_verify) {
#if (defined(_WIN32) || defined(_WIN64))
if (ssl_verify_) {
UseSystemCertificateStore(ssl_context_.native_handle());
void SslSocket::AsyncWrite(const Payload& payload, WriteHandler&& handler) {
boost::asio::async_write(ssl_socket_, payload, std::move(handler));
}
#else
// Use the default paths for finding CA certificates.
ssl_context_.set_default_verify_paths();
#endif // defined(_WIN32) || defined(_WIN64)
void SslSocket::AsyncReadSome(ReadHandler&& handler,
std::vector<char>* buffer) {
ssl_socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler));
}
bool SslSocket::Connect(const std::string& host, const Endpoints& endpoints) {
bool SslSocket::Shutdown() {
boost::system::error_code ec;
boost::asio::connect(ssl_socket_.lowest_layer(), endpoints, ec);
ssl_socket_.lowest_layer().shutdown(tcp::socket::shutdown_both, ec);
if (ec) {
LOG_ERRO("Socket connect error (%s).", ec.message().c_str());
LOG_WARN("Socket shutdown error (%s)", ec.message().c_str());
return false;
}
return Handshake(host);
}
bool SslSocket::Write(const Payload& payload, boost::system::error_code* ec) {
boost::asio::write(ssl_socket_, payload, *ec);
return !(*ec);
}
bool SslSocket::ReadSome(std::vector<char>* buffer, std::size_t* size,
boost::system::error_code* ec) {
*size = ssl_socket_.read_some(boost::asio::buffer(*buffer), *ec);
return (*size != 0 && !(*ec));
}
void SslSocket::AsyncReadSome(ReadHandler&& handler,
std::vector<char>* buffer) {
ssl_socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler));
return true;
}
bool SslSocket::Close() {
boost::system::error_code ec;
ssl_socket_.lowest_layer().close(ec);
return !ec;
if (ec) {
LOG_WARN("Socket close error (%s)", ec.message().c_str());
return false;
}
bool SslSocket::Handshake(const std::string& host) {
if (ssl_verify_) {
ssl_socket_.set_verify_mode(ssl::verify_peer);
} else {
ssl_socket_.set_verify_mode(ssl::verify_none);
return true;
}
// ssl::host_name_verification has been added since Boost 1.73 to replace
// ssl::rfc2818_verification.
#if BOOST_VERSION < 107300
ssl_socket_.set_verify_callback(ssl::rfc2818_verification(host));
#else
ssl_socket_.set_verify_callback(ssl::host_name_verification(host));
#endif // BOOST_VERSION < 107300
void SslSocket::OnConnect(boost::system::error_code ec,
tcp::endpoint endpoint) {
if (ec) {
connect_handler_(ec, std::move(endpoint));
return;
}
// Use sync API directly since we don't need timeout control.
boost::system::error_code ec;
ssl_socket_.handshake(ssl::stream_base::client, ec);
// Backup endpoint
endpoint_ = std::move(endpoint);
ssl_socket_.async_handshake(ssl::stream_base::client,
[this](boost::system::error_code ec) {
if (ec) {
LOG_ERRO("Handshake error (%s).", ec.message().c_str());
return false;
LOG_ERRO("Handshake error (%s)", ec.message().c_str());
}
return true;
connect_handler_(ec, std::move(endpoint_));
});
}
#endif // WEBCC_ENABLE_SSL

@ -18,24 +18,29 @@ namespace webcc {
class SocketBase {
public:
virtual ~SocketBase() = default;
using Endpoints = boost::asio::ip::tcp::resolver::results_type;
using ConnectHandler = std::function<void(boost::system::error_code,
boost::asio::ip::tcp::endpoint)>;
using WriteHandler =
std::function<void(boost::system::error_code, std::size_t)>;
using ReadHandler =
std::function<void(boost::system::error_code, std::size_t)>;
// TODO: Remove |host|
virtual bool Connect(const std::string& host, const Endpoints& endpoints) = 0;
virtual ~SocketBase() = default;
virtual bool Write(const Payload& payload, boost::system::error_code* ec) = 0;
virtual void AsyncConnect(const std::string& host, const Endpoints& endpoints,
ConnectHandler&& handler) = 0;
virtual bool ReadSome(std::vector<char>* buffer, std::size_t* size,
boost::system::error_code* ec) = 0;
virtual void AsyncWrite(const Payload& payload, WriteHandler&& handler) = 0;
virtual void AsyncReadSome(ReadHandler&& handler,
std::vector<char>* buffer) = 0;
virtual bool Shutdown() = 0;
virtual bool Close() = 0;
};
@ -45,15 +50,15 @@ class Socket : public SocketBase {
public:
explicit Socket(boost::asio::io_context& io_context);
bool Connect(const std::string& host, const Endpoints& endpoints) override;
bool Write(const Payload& payload, boost::system::error_code* ec) override;
void AsyncConnect(const std::string& host, const Endpoints& endpoints,
ConnectHandler&& handler) override;
bool ReadSome(std::vector<char>* buffer, std::size_t* size,
boost::system::error_code* ec) override;
void AsyncWrite(const Payload& payload, WriteHandler&& handler) override;
void AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) override;
bool Shutdown() override;
bool Close() override;
private:
@ -66,29 +71,32 @@ private:
class SslSocket : public SocketBase {
public:
explicit SslSocket(boost::asio::io_context& io_context,
SslSocket(boost::asio::io_context& io_context,
boost::asio::ssl::context& ssl_context,
bool ssl_verify = true);
bool Connect(const std::string& host, const Endpoints& endpoints) override;
void AsyncConnect(const std::string& host, const Endpoints& endpoints,
ConnectHandler&& handler) override;
bool Write(const Payload& payload, boost::system::error_code* ec) override;
bool ReadSome(std::vector<char>* buffer, std::size_t* size,
boost::system::error_code* ec) override;
void AsyncWrite(const Payload& payload, WriteHandler&& handler) override;
void AsyncReadSome(ReadHandler&& handler, std::vector<char>* buffer) override;
bool Shutdown() override;
bool Close() override;
private:
bool Handshake(const std::string& host);
void OnConnect(boost::system::error_code ec,
boost::asio::ip::tcp::endpoint endpoint);
boost::asio::ssl::context ssl_context_;
ConnectHandler connect_handler_;
boost::asio::ip::tcp::endpoint endpoint_;
boost::asio::ssl::stream<boost::asio::ip::tcp::socket> ssl_socket_;
// Verify the certificate of the peer (remote server) or not.
bool ssl_verify_;
bool ssl_verify_ = true;
};
#endif // WEBCC_ENABLE_SSL

Loading…
Cancel
Save