diff --git a/autotest/CMakeLists.txt b/autotest/CMakeLists.txt index cb3f7ec..26d0e56 100644 --- a/autotest/CMakeLists.txt +++ b/autotest/CMakeLists.txt @@ -1,24 +1,2 @@ -# Automation test - -set(AT_SRCS - client_autotest.cc - client_timeout_autotest.cc - main.cc - ) - -# Common libraries to link. -set(AT_LIBS - webcc - jsoncpp - GTest::GTest) - -if(UNIX) - # Add `-ldl` for Linux to avoid "undefined reference to `dlopen'". - set(AT_LIBS ${AT_LIBS} ${CMAKE_DL_LIBS}) -endif() - -set(AT_TARGET_NAME autotest) - -add_executable(${AT_TARGET_NAME} ${AT_SRCS}) -target_link_libraries(${AT_TARGET_NAME} ${AT_LIBS}) -set_target_properties(${AT_TARGET_NAME} PROPERTIES FOLDER "Tests") +add_subdirectory(client_autotest) +add_subdirectory(client_timeout_autotest) diff --git a/autotest/client_autotest/CMakeLists.txt b/autotest/client_autotest/CMakeLists.txt new file mode 100644 index 0000000..70b88af --- /dev/null +++ b/autotest/client_autotest/CMakeLists.txt @@ -0,0 +1,24 @@ +set(SRCS + client_autotest.cc + main.cc + ) + +# Common libraries to link. +set(LIBS + webcc + jsoncpp + GTest::GTest + ) + +if(UNIX) + # Add `-ldl` for Linux to avoid "undefined reference to `dlopen'". + set(LIBS ${LIBS} ${CMAKE_DL_LIBS}) +endif() + +set(TARGET_NAME client_autotest) + +add_executable(${TARGET_NAME} ${SRCS}) + +target_link_libraries(${TARGET_NAME} ${LIBS}) + +set_target_properties(${TARGET_NAME} PROPERTIES FOLDER "Tests") diff --git a/autotest/client_autotest.cc b/autotest/client_autotest/client_autotest.cc similarity index 100% rename from autotest/client_autotest.cc rename to autotest/client_autotest/client_autotest.cc diff --git a/autotest/main.cc b/autotest/client_autotest/main.cc similarity index 100% rename from autotest/main.cc rename to autotest/client_autotest/main.cc diff --git a/autotest/client_timeout_autotest/CMakeLists.txt b/autotest/client_timeout_autotest/CMakeLists.txt new file mode 100644 index 0000000..cc0c159 --- /dev/null +++ b/autotest/client_timeout_autotest/CMakeLists.txt @@ -0,0 +1,22 @@ +set(SRCS + client_timeout_autotest.cc + main.cc + ) + +# Common libraries to link. +set(LIBS + webcc + jsoncpp + GTest::GTest + ) + +if(UNIX) + # Add `-ldl` for Linux to avoid "undefined reference to `dlopen'". + set(LIBS ${LIBS} ${CMAKE_DL_LIBS}) +endif() + +set(TARGET_NAME client_timeout_autotest) + +add_executable(${TARGET_NAME} ${SRCS}) +target_link_libraries(${TARGET_NAME} ${LIBS}) +set_target_properties(${TARGET_NAME} PROPERTIES FOLDER "Tests") diff --git a/autotest/client_timeout_autotest.cc b/autotest/client_timeout_autotest/client_timeout_autotest.cc similarity index 95% rename from autotest/client_timeout_autotest.cc rename to autotest/client_timeout_autotest/client_timeout_autotest.cc index 7c2a87b..6f8f009 100644 --- a/autotest/client_timeout_autotest.cc +++ b/autotest/client_timeout_autotest/client_timeout_autotest.cc @@ -9,7 +9,7 @@ namespace { -const char* kData = "Hello, World!"; +const char* const kData = "Hello, World!"; const std::uint16_t kPort = 8080; @@ -83,8 +83,8 @@ TEST_F(ClientTimeoutTest, NoTimeout) { TEST_F(ClientTimeoutTest, Timeout) { webcc::ClientSession session; - // Change timeout to 1s. - session.set_timeout(1); + // Change read timeout to 1s. + session.set_read_timeout(1); webcc::ResponsePtr r; bool timeout = false; diff --git a/autotest/client_timeout_autotest/main.cc b/autotest/client_timeout_autotest/main.cc new file mode 100644 index 0000000..7ee7577 --- /dev/null +++ b/autotest/client_timeout_autotest/main.cc @@ -0,0 +1,11 @@ +#include "gtest/gtest.h" + +#include "webcc/logger.h" + +int main(int argc, char* argv[]) { + // Set webcc::LOG_CONSOLE to enable logging. + WEBCC_LOG_INIT("", 0); + + ::testing::InitGoogleTest(&argc, argv); + return RUN_ALL_TESTS(); +} diff --git a/examples/client_basics.cc b/examples/client_basics.cc index bfe6ee5..93cea12 100644 --- a/examples/client_basics.cc +++ b/examples/client_basics.cc @@ -8,7 +8,8 @@ int main() { WEBCC_LOG_INIT("", webcc::LOG_CONSOLE); webcc::ClientSession session; - + session.set_connect_timeout(5); + session.set_read_timeout(5); session.Accept("application/json"); webcc::ResponsePtr r; diff --git a/examples/concurrency_test.cc b/examples/concurrency_test.cc index 6fa70d9..30f1339 100644 --- a/examples/concurrency_test.cc +++ b/examples/concurrency_test.cc @@ -30,7 +30,7 @@ int main(int argc, const char* argv[]) { threads.emplace_back([&url]() { // NOTE: Each thread has its own client session. webcc::ClientSession session; - session.set_timeout(180); + session.set_read_timeout(180); try { LOG_USER("Start"); diff --git a/examples/file_downloader.cc b/examples/file_downloader.cc index 26e91ba..965c891 100644 --- a/examples/file_downloader.cc +++ b/examples/file_downloader.cc @@ -27,7 +27,10 @@ int main(int argc, char* argv[]) { webcc::ClientSession session; try { - auto r = session.Send(webcc::RequestBuilder{}.Get(url)(), true); + auto r = session.Send(webcc::RequestBuilder{}.Get(url)(), true, + [](std::size_t length, std::size_t total_length) { + std::cout << "Progress " << length << " / " << total_length << std::endl; + }); if (auto file_body = r->file_body()) { file_body->Move(path); diff --git a/webcc/client.cc b/webcc/client.cc index 981a630..77eae82 100644 --- a/webcc/client.cc +++ b/webcc/client.cc @@ -3,23 +3,40 @@ #include "webcc/logger.h" using boost::asio::ip::tcp; +using namespace std::placeholders; namespace webcc { -Client::Client() - : timer_(io_context_), - ssl_verify_(true), - buffer_size_(kBufferSize), - timeout_(kMaxReadSeconds), - closed_(false), - timer_canceled_(false) { +#if WEBCC_ENABLE_SSL + +Client::Client(boost::asio::io_context& io_context, + boost::asio::ssl::context& ssl_context) + : io_context_(io_context), + ssl_context_(ssl_context), + resolver_(io_context), + deadline_timer_(io_context) { + } -Error Client::Request(RequestPtr request, bool connect, bool stream) { - closed_ = false; - timer_canceled_ = false; +#else + +Client::Client(boost::asio::io_context& io_context) + : io_context_(io_context), + resolver_(io_context), + deadline_timer_(io_context) { +} + +#endif // WEBCC_ENABLE_SSL + +Error Client::Request(RequestPtr request, bool stream) { + LOG_VERB("Request begin"); + + request_finished_ = false; error_ = Error{}; + request_ = request; + + length_read_ = 0; response_.reset(new Response{}); response_parser_.Init(response_.get(), stream); @@ -35,246 +52,312 @@ Error Client::Request(RequestPtr request, bool connect, bool stream) { // have Content-Length; // - If request.Accept-Encoding is "identity", the response will have // Content-Length. - if (request->method() == methods::kHead) { + if (request_->method() == methods::kHead) { response_parser_.set_ignore_body(true); } else { // Reset in case the connection is persistent. response_parser_.set_ignore_body(false); } - io_context_.restart(); - - if (connect) { - // No existing socket connection was specified, create a new one. - Connect(request); - - if (error_) { - return error_; - } + if (!connected_) { + AsyncConnect(); + } else { + AsyncWrite(); } - WriteRequest(request); - - if (error_) { - return error_; - } + // Wait for the request to be finished. + std::unique_lock response_lock{ request_mutex_ }; + request_cv_.wait(response_lock, [=] { return request_finished_; }); - ReadResponse(); + LOG_VERB("Request end"); return error_; } void Client::Close() { - if (closed_) { + if (!connected_) { + //resolver_.cancel(); // TODO + if (socket_) { + // Cancel any async operations on the socket. + LOG_VERB("Close socket"); + socket_->Close(); + // Make sure the current request, if any, could be finished. + FinishRequest(); + } return; } - closed_ = true; + connected_ = false; - LOG_INFO("Close socket..."); + if (socket_) { + LOG_INFO("Shutdown & close socket"); + socket_->Shutdown(); + socket_->Close(); + // Make sure the current request, if any, could be finished. + FinishRequest(); + } - socket_->Close(); + LOG_INFO("Socket closed"); } -void Client::Connect(RequestPtr request) { - if (request->url().scheme() == "https") { +void Client::AsyncConnect() { + if (request_->url().scheme() == "https") { #if WEBCC_ENABLE_SSL - socket_.reset(new SslSocket{ io_context_, ssl_verify_ }); - DoConnect(request, "443"); + socket_.reset(new SslSocket{ io_context_, ssl_context_, ssl_verify_ }); + AsyncResolve("443"); #else LOG_ERRO("SSL/HTTPS support is not enabled."); error_.Set(Error::kSyntaxError, "SSL/HTTPS is not supported"); + FinishRequest(); #endif // WEBCC_ENABLE_SSL } else { socket_.reset(new Socket{ io_context_ }); - DoConnect(request, "80"); + AsyncResolve("80"); } } -void Client::DoConnect(RequestPtr request, const std::string& default_port) { - tcp::resolver resolver(io_context_); - - std::string port = request->port(); +void Client::AsyncResolve(const std::string& default_port) { + std::string port = request_->port(); if (port.empty()) { port = default_port; } - LOG_VERB("Resolve host (%s)...", request->host().c_str()); - - boost::system::error_code ec; + LOG_VERB("Resolve host (%s)", request_->host().c_str()); // The protocol depends on the `host`, both V4 and V6 are supported. - auto endpoints = resolver.resolve(request->host(), port, ec); + resolver_.async_resolve(request_->host(), port, + std::bind(&Client::OnResolve, this, _1, _2)); +} +void Client::OnResolve(boost::system::error_code ec, + tcp::resolver::results_type endpoints) { if (ec) { - LOG_ERRO("Host resolve error (%s): %s, %s.", ec.message().c_str(), - request->host().c_str(), port.c_str()); + LOG_ERRO("Host resolve error (%s)", ec.message().c_str()); error_.Set(Error::kResolveError, "Host resolve error"); + FinishRequest(); return; } - LOG_VERB("Connect to server..."); + LOG_VERB("Connect socket"); - // Use sync API directly since we don't need timeout control. + AsyncWaitDeadlineTimer(connect_timeout_); - if (!socket_->Connect(request->host(), endpoints)) { - error_.Set(Error::kConnectError, "Endpoint connect error"); - Close(); + socket_->AsyncConnect(request_->host(), endpoints, + std::bind(&Client::OnConnect, this, _1, _2)); +} + +void Client::OnConnect(boost::system::error_code ec, tcp::endpoint) { + LOG_VERB("On connect"); + + StopDeadlineTimer(); + + if (ec) { + if (ec == boost::asio::error::operation_aborted) { + // Socket has been closed by OnDeadlineTimer() or Close(). + LOG_WARN("Connect operation aborted"); + } else { + LOG_INFO("Connect error"); + // No need to close socket since no async operation is on it. + // socket_->Close(); + } + + error_.Set(Error::kConnectError, "Socket connect error"); + FinishRequest(); return; } - LOG_VERB("Socket connected."); + LOG_INFO("Socket connected"); + + connected_ = true; + + AsyncWrite(); } -void Client::WriteRequest(RequestPtr request) { - LOG_VERB("HTTP request:\n%s", request->Dump().c_str()); +void Client::AsyncWrite() { + LOG_VERB("Request:\n%s", request_->Dump().c_str()); - // NOTE: - // It doesn't make much sense to set a timeout for socket write. - // I find that it's almost impossible to simulate a situation in the server - // side to test this timeout. + socket_->AsyncWrite(request_->GetPayload(), + std::bind(&Client::OnWrite, this, _1, _2)); +} - // Use sync API directly since we don't need timeout control. +void Client::OnWrite(boost::system::error_code ec, std::size_t length) { + if (ec) { + HandleWriteError(ec); + return; + } - boost::system::error_code ec; + request_->body()->InitPayload(); - if (socket_->Write(request->GetPayload(), &ec)) { - // Write request body. - auto body = request->body(); - body->InitPayload(); - for (auto p = body->NextPayload(true); !p.empty(); - p = body->NextPayload(true)) { - if (!socket_->Write(p, &ec)) { - break; - } - } + AsyncWriteBody(); +} + +void Client::AsyncWriteBody() { + auto p = request_->body()->NextPayload(true); + + if (!p.empty()) { + socket_->AsyncWrite(p, std::bind(&Client::OnWriteBody, this, _1, _2)); + } else { + LOG_INFO("Request send"); + + // Start the read deadline timer. + AsyncWaitDeadlineTimer(read_timeout_); + + // Start to read response. + AsyncRead(); } +} +void Client::OnWriteBody(boost::system::error_code ec, std::size_t legnth) { if (ec) { - LOG_ERRO("Socket write error (%s).", ec.message().c_str()); - Close(); - error_.Set(Error::kSocketWriteError, "Socket write error"); + HandleWriteError(ec); + return; } - LOG_INFO("Request sent."); + // Continue to write the next payload of body. + AsyncWriteBody(); } -void Client::ReadResponse() { - LOG_VERB("Read response (timeout: %ds)...", timeout_); - - DoReadResponse(); - - if (!error_) { - LOG_VERB("HTTP response:\n%s", response_->Dump().c_str()); +void Client::HandleWriteError(boost::system::error_code ec) { + if (ec == boost::asio::error::operation_aborted) { + // Socket has been closed by OnDeadlineTimer() or Close(). + LOG_WARN("Write operation aborted"); + } else { + LOG_ERRO("Socket write error (%s)", ec.message().c_str()); + Close(); } + + error_.Set(Error::kSocketWriteError, "Socket write error"); + FinishRequest(); } -void Client::DoReadResponse() { - boost::system::error_code ec = boost::asio::error::would_block; - std::size_t length = 0; +void Client::AsyncRead() { + socket_->AsyncReadSome(std::bind(&Client::OnRead, this, _1, _2), &buffer_); +} - // The read handler. - auto handler = [&ec, &length](boost::system::error_code inner_ec, - std::size_t inner_length) { - ec = inner_ec; - length = inner_length; - }; +void Client::OnRead(boost::system::error_code ec, std::size_t length) { + StopDeadlineTimer(); - while (true) { - ec = boost::asio::error::would_block; - length = 0; + if (ec) { + if (ec == boost::asio::error::operation_aborted) { + // Socket has been closed by OnDeadlineTimer() or Close(). + LOG_WARN("Read operation aborted"); + } else { + LOG_ERRO("Socket read error (%s)", ec.message().c_str()); + Close(); + } - socket_->AsyncReadSome(std::move(handler), &buffer_); + error_.Set(Error::kSocketReadError, "Socket read error"); + FinishRequest(); + return; + } - // Start the timer. - DoWaitTimer(); + length_read_ += length; - // Block until the asynchronous operation has completed. - do { - io_context_.run_one(); - } while (ec == boost::asio::error::would_block); + LOG_INFO("Read length: %u", length); - // Stop the timer. - CancelTimer(); + // Parse the piece of data just read. + if (!response_parser_.Parse(buffer_.data(), length)) { + LOG_ERRO("Failed to parse the response"); + Close(); + error_.Set(Error::kParseError, "Response parse error"); + FinishRequest(); + return; + } - // The error normally is caused by timeout. See OnTimer(). - if (ec || length == 0) { - Close(); - error_.Set(Error::kSocketReadError, "Socket read error"); - LOG_ERRO("Socket read error (%s).", ec.message().c_str()); - break; + // Inform progress callback if it's specified. + if (progress_callback_) { + if (response_parser_.header_ended()) { + // NOTE: Need to get rid of the header length. + progress_callback_(length_read_ - response_parser_.header_length(), + response_parser_.content_length()); } + } - LOG_INFO("Read data, length: %u.", length); + if (response_parser_.finished()) { + LOG_VERB("Response:\n%s", response_->Dump().c_str()); - // Parse the piece of data just read. - if (!response_parser_.Parse(buffer_.data(), length)) { + if (response_->IsConnectionKeepAlive()) { + LOG_INFO("Keep the socket connection alive"); + } else { Close(); - error_.Set(Error::kParseError, "HTTP parse error"); - LOG_ERRO("Failed to parse the HTTP response."); - break; } - if (response_parser_.finished()) { - // Stop trying to read once all content has been received, because - // some servers will block extra call to read_some(). + // Stop trying to read once all content has been received, because some + // servers will block extra call to read_some(). - if (response_->IsConnectionKeepAlive()) { - // Close the timer but keep the socket connection. - LOG_INFO("Keep the socket connection alive."); - } else { - Close(); - } - - // Stop reading. - LOG_INFO("Finished to read the HTTP response."); - break; - } + LOG_INFO("Finished to read the response"); + FinishRequest(); + return; } + + // Continue to read the response. + AsyncRead(); } -void Client::DoWaitTimer() { - LOG_VERB("Wait timer asynchronously."); - timer_.expires_after(std::chrono::seconds(timeout_)); - timer_.async_wait(std::bind(&Client::OnTimer, this, std::placeholders::_1)); +void Client::AsyncWaitDeadlineTimer(int seconds) { + if (seconds <= 0) { + deadline_timer_stopped_ = true; + return; + } + + LOG_VERB("Async wait deadline timer"); + + deadline_timer_stopped_ = false; + + deadline_timer_.expires_after(std::chrono::seconds(seconds)); + deadline_timer_.async_wait(std::bind(&Client::OnDeadlineTimer, this, _1)); } -void Client::OnTimer(boost::system::error_code ec) { - LOG_VERB("On timer."); +void Client::OnDeadlineTimer(boost::system::error_code ec) { + LOG_VERB("On deadline timer"); + + deadline_timer_stopped_ = true; - // timer_.cancel() was called. + // deadline_timer_.cancel() was called. if (ec == boost::asio::error::operation_aborted) { - LOG_VERB("Timer canceled."); + LOG_VERB("Deadline timer canceled"); return; } - if (closed_) { - LOG_VERB("Socket has been closed."); - return; - } + LOG_WARN("Timeout"); - if (timer_.expiry() <= boost::asio::steady_timer::clock_type::now()) { - // The deadline has passed. The socket is closed so that any outstanding - // asynchronous operations are canceled. - LOG_WARN("HTTP client timed out."); - error_.set_timeout(true); + // Cancel the async operations on the socket. + // OnXxx() will be called with `error::operation_aborted`. + if (connected_) { Close(); - return; + } else { + socket_->Close(); } - // Put the actor back to sleep. - DoWaitTimer(); + error_.set_timeout(true); } -void Client::CancelTimer() { - if (timer_canceled_) { +void Client::StopDeadlineTimer() { + if (deadline_timer_stopped_) { return; } - LOG_INFO("Cancel timer..."); - timer_.cancel(); + LOG_INFO("Cancel deadline timer"); + + try { + // Cancel the async wait operation on this timer. + deadline_timer_.cancel(); + } catch (const boost::system::system_error&) { + } + + deadline_timer_stopped_ = true; +} - timer_canceled_ = true; +void Client::FinishRequest() { + { + std::lock_guard lock{ request_mutex_ }; + if (!request_finished_) { + request_finished_ = true; + } else { + return; + } + } + request_cv_.notify_one(); } } // namespace webcc diff --git a/webcc/client.h b/webcc/client.h index e75ce7e..14997c5 100644 --- a/webcc/client.h +++ b/webcc/client.h @@ -1,8 +1,9 @@ #ifndef WEBCC_CLIENT_H_ #define WEBCC_CLIENT_H_ -#include +#include #include +#include #include #include @@ -19,18 +20,22 @@ namespace webcc { // Synchronous HTTP & HTTPS client. -// In synchronous mode, a request won't return until the response is received -// or timeout occurs. -// Please don't use the same client object in multiple threads. +// A request won't return until the response is received or timeout occurs. class Client { public: - Client(); - - ~Client() = default; + // TODO +#if WEBCC_ENABLE_SSL + Client(boost::asio::io_context& io_context, + boost::asio::ssl::context& ssl_context); +#else + explicit Client(boost::asio::io_context& io_context); +#endif Client(const Client&) = delete; Client& operator=(const Client&) = delete; + ~Client() = default; + void set_ssl_verify(bool ssl_verify) { ssl_verify_ = ssl_verify; } @@ -41,19 +46,35 @@ public: } } - // Set the timeout (in seconds) for reading response. - void set_timeout(int timeout) { + void set_connect_timeout(int timeout) { + if (timeout > 0) { + connect_timeout_ = timeout; + } + } + + void set_read_timeout(int timeout) { if (timeout > 0) { - timeout_ = timeout; + read_timeout_ = timeout; } } - // Connect to server, send request, wait until response is received. - Error Request(RequestPtr request, bool connect = true, bool stream = false); + // Set progress callback to be informed about the read progress. + // TODO: What about write? + // TODO: std::move? + void set_progress_callback(ProgressCallback callback) { + progress_callback_ = callback; + } + + // Connect, send request, wait until response is received. + Error Request(RequestPtr request, bool stream = false); // Close the socket. void Close(); + bool connected() const { + return connected_; + } + ResponsePtr response() const { return response_; } @@ -66,57 +87,82 @@ public: response_parser_.Init(nullptr, false); } - bool closed() const { - return closed_; - } - private: - void Connect(RequestPtr request); + void AsyncConnect(); + + void AsyncResolve(const std::string& default_port); - void DoConnect(RequestPtr request, const std::string& default_port); + void OnResolve(boost::system::error_code ec, + boost::asio::ip::tcp::resolver::results_type endpoints); - void WriteRequest(RequestPtr request); + void OnConnect(boost::system::error_code ec, boost::asio::ip::tcp::endpoint); - void ReadResponse(); + void AsyncWrite(); + void OnWrite(boost::system::error_code ec, std::size_t length); - void DoReadResponse(); + void AsyncWriteBody(); + void OnWriteBody(boost::system::error_code ec, std::size_t length); - void DoWaitTimer(); - void OnTimer(boost::system::error_code ec); + void HandleWriteError(boost::system::error_code ec); - // Cancel any async-operations waiting on the timer. - void CancelTimer(); + void AsyncRead(); + void OnRead(boost::system::error_code ec, std::size_t length); + + void AsyncWaitDeadlineTimer(int seconds); + void OnDeadlineTimer(boost::system::error_code ec); + void StopDeadlineTimer(); + + void FinishRequest(); private: - boost::asio::io_context io_context_; + boost::asio::io_context& io_context_; + +#if WEBCC_ENABLE_SSL + boost::asio::ssl::context& ssl_context_; +#endif - // Socket connection. std::unique_ptr socket_; + boost::asio::ip::tcp::resolver resolver_; + + bool request_finished_ = true; + std::condition_variable request_cv_; + std::mutex request_mutex_; + + RequestPtr request_; + ResponsePtr response_; ResponseParser response_parser_; - // Timer for the timeout control. - boost::asio::steady_timer timer_; + // The length already read. + std::size_t length_read_ = 0; // The buffer for reading response. std::vector buffer_; // Verify the certificate of the peer or not (for HTTPS). - bool ssl_verify_; + bool ssl_verify_ = true; // The size of the buffer for reading response. // 0 means default value will be used. - std::size_t buffer_size_; + std::size_t buffer_size_ = kBufferSize; + + // Timeout (seconds) for connecting to server. + // Default as 0 to disable our own control (i.e., deadline_timer_). + int connect_timeout_ = 0; + + // Timeout (seconds) for reading response. + int read_timeout_ = kMaxReadSeconds; - // Timeout (seconds) for receiving response. - int timeout_; + // Deadline timer for connecting to server. + boost::asio::steady_timer deadline_timer_; + bool deadline_timer_stopped_ = true; - // Connection closed. - bool closed_; + // Socket connected or not. + bool connected_ = false; - // Deadline timer canceled. - bool timer_canceled_; + // Progress callback (optional). + ProgressCallback progress_callback_; Error error_; }; diff --git a/webcc/client_pool.cc b/webcc/client_pool.cc index 90e3c2f..cdf2d0f 100644 --- a/webcc/client_pool.cc +++ b/webcc/client_pool.cc @@ -6,7 +6,7 @@ namespace webcc { ClientPool::~ClientPool() { if (!clients_.empty()) { - LOG_INFO("Close socket for all (%u) connections in the pool.", + LOG_INFO("Close socket for all (%u) connections in the pool", clients_.size()); for (auto& pair : clients_) { @@ -21,21 +21,21 @@ ClientPtr ClientPool::Get(const Key& key) const { if (it != clients_.end()) { return it->second; } else { - return ClientPtr{}; + return {}; } } void ClientPool::Add(const Key& key, ClientPtr client) { clients_[key] = client; - LOG_INFO("Added connection to pool (%s, %s, %s).", + LOG_INFO("Connection added to pool (%s, %s, %s)", key.scheme.c_str(), key.host.c_str(), key.port.c_str()); } void ClientPool::Remove(const Key& key) { clients_.erase(key); - LOG_INFO("Removed connection from pool (%s, %s, %s).", + LOG_INFO("Connection removed from pool (%s, %s, %s)", key.scheme.c_str(), key.host.c_str(), key.port.c_str()); } diff --git a/webcc/client_session.cc b/webcc/client_session.cc index 4472347..a8e187c 100644 --- a/webcc/client_session.cc +++ b/webcc/client_session.cc @@ -1,16 +1,123 @@ #include "webcc/client_session.h" +#include + #include "webcc/base64.h" #include "webcc/logger.h" #include "webcc/url.h" #include "webcc/utility.h" + +namespace ssl = boost::asio::ssl; namespace webcc { -ClientSession::ClientSession(int timeout, bool ssl_verify, - std::size_t buffer_size) - : timeout_(timeout), ssl_verify_(ssl_verify), buffer_size_(buffer_size) { +#if WEBCC_ENABLE_SSL +#if (defined(_WIN32) || defined(_WIN64)) + +// Let OpenSSL on Windows use the system certificate store +// 1. Load your certificate (in PCCERT_CONTEXT structure) from Windows Cert +// store using Crypto APIs. +// 2. Get encrypted content of it in binary format as it is. +// [PCCERT_CONTEXT->pbCertEncoded]. +// 3. Parse this binary buffer into X509 certificate Object using OpenSSL's +// d2i_X509() method. +// 4. Get handle to OpenSSL's trust store using SSL_CTX_get_cert_store() +// method. +// 5. Load above parsed X509 certificate into this trust store using +// X509_STORE_add_cert() method. +// 6. You are done! +// NOTES: Enum Windows store with "ROOT" (not "CA"). +// See: https://stackoverflow.com/a/11763389/6825348 + +static bool UseSystemCertificateStore(SSL_CTX* ssl_ctx) { + // NOTE: Cannot use nullptr to replace NULL. + HCERTSTORE cert_store = ::CertOpenSystemStoreW(NULL, L"ROOT"); + if (cert_store == nullptr) { + LOG_ERRO("Cannot open Windows system certificate store."); + return false; + } + + X509_STORE* x509_store = SSL_CTX_get_cert_store(ssl_ctx); + PCCERT_CONTEXT cert_context = nullptr; + + while (cert_context = CertEnumCertificatesInStore(cert_store, cert_context)) { + auto in = (const unsigned char**)&cert_context->pbCertEncoded; + X509* x509 = d2i_X509(nullptr, in, cert_context->cbCertEncoded); + + if (x509 != nullptr) { + if (X509_STORE_add_cert(x509_store, x509) == 0) { + LOG_ERRO("Cannot add Windows root certificate."); + } + + X509_free(x509); + } + } + + CertFreeCertificateContext(cert_context); + CertCloseStore(cert_store, 0); + return true; +} + +#endif // defined(_WIN32) || defined(_WIN64) +#endif // WEBCC_ENABLE_SSL + +ClientSession::ClientSession(bool ssl_verify, std::size_t buffer_size) + : work_guard_(boost::asio::make_work_guard(io_context_)), +#if WEBCC_ENABLE_SSL + ssl_context_(ssl::context::sslv23), +#endif + ssl_verify_(ssl_verify), buffer_size_(buffer_size) { +#if WEBCC_ENABLE_SSL +#if (defined(_WIN32) || defined(_WIN64)) + // if (ssl_verify_) { + UseSystemCertificateStore(ssl_context_.native_handle()); + // } +#else + // Use the default paths for finding CA certificates. + ssl_context_.set_default_verify_paths(); +#endif // defined(_WIN32) || defined(_WIN64) +#endif // WEBCC_ENABLE_SSL + InitHeaders(); + + Start(); +} + +ClientSession::~ClientSession() { + Stop(); +} + +void ClientSession::Start() { + if (started_) { + return; + } + + started_ = true; + + io_context_.restart(); + + // Run the io context off in its own thread so that it operates completely + // asynchronously with respect to the rest of the program. + + io_thread_.reset(new std::thread{ [this]() { io_context_.run(); }}); + + LOG_INFO("Loop is now running"); +} + +void ClientSession::Stop() { + if (!started_) { + return; + } + + Cancel(); + + io_context_.stop(); + + io_thread_->join(); + + LOG_INFO("Loop stopped"); + + started_ = false; } void ClientSession::Accept(const std::string& content_types) { @@ -68,9 +175,16 @@ void ClientSession::AuthToken(const std::string& token) { return Auth("Token", token); } -ResponsePtr ClientSession::Send(RequestPtr request, bool stream) { +ResponsePtr ClientSession::Send(RequestPtr request, bool stream, + ProgressCallback callback) { assert(request); + std::lock_guard lock{ mutex_ }; + + if (!started_) { + throw Error{ Error::kStateError, "Loop is not running" }; + } + for (auto& h : headers_.data()) { if (!request->HasHeader(h.first)) { request->SetHeader(h.first, h.second); @@ -84,7 +198,13 @@ ResponsePtr ClientSession::Send(RequestPtr request, bool stream) { request->Prepare(); - return DoSend(request, stream); + return DoSend(request, stream, callback); +} + +void ClientSession::Cancel() { + if (client_) { + client_->Close(); + } } void ClientSession::InitHeaders() { @@ -99,34 +219,40 @@ void ClientSession::InitHeaders() { headers_.Set(headers::kConnection, "Keep-Alive"); } -ResponsePtr ClientSession::DoSend(RequestPtr request, bool stream) { +ResponsePtr ClientSession::DoSend(RequestPtr request, bool stream, + ProgressCallback callback) { const ClientPool::Key key{ request->url() }; // Reuse a pooled connection. bool reuse = false; ClientPtr client = pool_.Get(key); + if (!client) { - client.reset(new Client{}); +#if WEBCC_ENABLE_SSL + client.reset(new Client{ io_context_, ssl_context_ }); +#else + client.reset(new Client{ io_context_ }); +#endif // WEBCC_ENABLE_SSL reuse = false; } else { - LOG_VERB("Reuse an existing connection."); + LOG_VERB("Reuse an existing connection"); reuse = true; } client->set_ssl_verify(ssl_verify_); client->set_buffer_size(buffer_size_); - client->set_timeout(timeout_); + client->set_connect_timeout(connect_timeout_); + client->set_read_timeout(read_timeout_); - Error error = client->Request(request, !reuse, stream); + client->set_progress_callback(callback); - if (error) { - if (reuse && error.code() == Error::kSocketWriteError) { - LOG_WARN("Cannot send request with the reused connection. " - "The server must have closed it, reconnect and try again."); - error = client->Request(request, true, stream); - } - } + // Save current client for cancel. + client_ = client; + + Error error = client->Request(request, stream); + + client_.reset(); if (error) { // Remove the failed connection from pool. @@ -139,11 +265,11 @@ ResponsePtr ClientSession::DoSend(RequestPtr request, bool stream) { // Update connection pool. if (reuse) { - if (client->closed()) { + if (!client->connected()) { pool_.Remove(key); } } else { - if (!client->closed()) { + if (client->connected()) { pool_.Add(key, client); } } diff --git a/webcc/client_session.h b/webcc/client_session.h index b420c36..cc5af47 100644 --- a/webcc/client_session.h +++ b/webcc/client_session.h @@ -1,9 +1,14 @@ #ifndef WEBCC_CLIENT_SESSION_H_ #define WEBCC_CLIENT_SESSION_H_ +#include +#include #include +#include #include +#include "boost/asio/io_context.hpp" + #include "webcc/client_pool.h" #include "webcc/request_builder.h" #include "webcc/response.h" @@ -11,18 +16,32 @@ namespace webcc { // HTTP requests session providing connection-pooling, configuration and more. -// A session shouldn't be shared by multiple threads. Please create a new -// session for each thread instead. +// NOTE: If a session is shared by multiple threads, the requests sent through +// it will be serialized by using a mutex. class ClientSession { public: - explicit ClientSession(int timeout = 0, bool ssl_verify = true, - std::size_t buffer_size = 0); + explicit ClientSession(bool ssl_verify = true, std::size_t buffer_size = 0); + + ~ClientSession(); + + // Start Asio loop in a thread. + // You don't have to call Start() manually because it's called by the + // constructor. + void Start(); + + // Stop Asio loop. + // You can call Start() to resume the loop. + void Stop(); - ~ClientSession() = default; + void set_connect_timeout(int timeout) { + if (timeout > 0) { + connect_timeout_ = timeout; + } + } - void set_timeout(int timeout) { + void set_read_timeout(int timeout) { if (timeout > 0) { - timeout_ = timeout; + read_timeout_ = timeout; } } @@ -73,14 +92,35 @@ public: // the response body will be FileBody, and you can easily move the temp file // to another path with FileBody::Move(). So, |stream| is really useful for // downloading files (JPEG, etc.) or saving memory for huge data responses. - ResponsePtr Send(RequestPtr request, bool stream = false); + ResponsePtr Send(RequestPtr request, bool stream = false, + ProgressCallback callback = {}); + + // Cancel any in-progress connecting, writing or reading. + void Cancel(); private: void InitHeaders(); - ResponsePtr DoSend(RequestPtr request, bool stream); + ResponsePtr DoSend(RequestPtr request, bool stream, + ProgressCallback callback); private: + boost::asio::io_context io_context_; + + // The thread to run Asio loop. + std::unique_ptr io_thread_; + + using ExecutorType = boost::asio::io_context::executor_type; + boost::asio::executor_work_guard work_guard_; + + // TODO +#if WEBCC_ENABLE_SSL + boost::asio::ssl::context ssl_context_; +#endif + + // Is Asio loop running? + bool started_ = false; + // The media (or MIME) type of `Content-Type` header. // E.g., "application/json". std::string media_type_; @@ -92,18 +132,27 @@ private: // Additional headers for each request. Headers headers_; - // Timeout in seconds for receiving response. - int timeout_; + // Timeout (seconds) for connecting to server. + int connect_timeout_ = 0; + + // Timeout (seconds) for reading response. + int read_timeout_ = 0; // Verify the certificate of the peer or not. - bool ssl_verify_; + bool ssl_verify_ = true; // The size of the buffer for reading response. // 0 means default value will be used. std::size_t buffer_size_; - // Pool for Keep-Alive client connections. + // Keep-Alive client connections. ClientPool pool_; + + // Current requested client. + ClientPtr client_; + + // The mutex to guard the request. + std::mutex mutex_; }; } // namespace webcc diff --git a/webcc/connection.cc b/webcc/connection.cc index 7538e90..b9cba51 100644 --- a/webcc/connection.cc +++ b/webcc/connection.cc @@ -32,7 +32,7 @@ void Connection::Start() { } void Connection::Close() { - LOG_INFO("Shutdown socket..."); + LOG_INFO("Shutdown socket"); // Initiate graceful connection closure. // Socket close VS. shutdown: @@ -41,17 +41,17 @@ void Connection::Close() { socket_.shutdown(tcp::socket::shutdown_both, ec); if (ec) { - LOG_WARN("Socket shutdown error (%s).", ec.message().c_str()); + LOG_WARN("Socket shutdown error (%s)", ec.message().c_str()); ec.clear(); // Don't return, try to close the socket anywhere. } - LOG_INFO("Close socket..."); + LOG_INFO("Close socket"); socket_.close(ec); if (ec) { - LOG_ERRO("Socket close error (%s).", ec.message().c_str()); + LOG_ERRO("Socket close error (%s)", ec.message().c_str()); } } @@ -92,13 +92,13 @@ void Connection::DoRead() { void Connection::OnRead(boost::system::error_code ec, std::size_t length) { if (ec) { if (ec == boost::asio::error::eof) { - LOG_INFO("Socket read EOF (%s).", ec.message().c_str()); + LOG_INFO("Socket read EOF (%s)", ec.message().c_str()); } else if (ec == boost::asio::error::operation_aborted) { // The socket of this connection has been closed. // This happens, e.g., when the server was stopped by a signal (Ctrl-C). - LOG_WARN("Socket operation aborted (%s).", ec.message().c_str()); + LOG_WARN("Socket operation aborted (%s)", ec.message().c_str()); } else { - LOG_ERRO("Socket read error (%s).", ec.message().c_str()); + LOG_ERRO("Socket read error (%s)", ec.message().c_str()); } // Don't try to send any response back. @@ -111,7 +111,7 @@ void Connection::OnRead(boost::system::error_code ec, std::size_t length) { } if (!request_parser_.Parse(buffer_.data(), length)) { - LOG_ERRO("Failed to parse HTTP request."); + LOG_ERRO("Failed to parse request"); // Send Bad Request (400) to the client and no Keep-Alive. SendResponse(Status::kBadRequest, true); // Close the socket connection. @@ -125,7 +125,7 @@ void Connection::OnRead(boost::system::error_code ec, std::size_t length) { return; } - LOG_VERB("HTTP request:\n%s", request_->Dump().c_str()); + LOG_VERB("Request:\n%s", request_->Dump().c_str()); // Enqueue this connection once the request has been read. // Some worker thread will handle the request later. @@ -133,7 +133,7 @@ void Connection::OnRead(boost::system::error_code ec, std::size_t length) { } void Connection::DoWrite() { - LOG_VERB("HTTP response:\n%s", response_->Dump().c_str()); + LOG_VERB("Response:\n%s", response_->Dump().c_str()); // Firstly, write the headers. boost::asio::async_write(socket_, response_->GetPayload(), @@ -177,11 +177,11 @@ void Connection::OnWriteBody(boost::system::error_code ec, std::size_t length) { } void Connection::OnWriteOK() { - LOG_INFO("Response has been sent back."); + LOG_INFO("Response has been sent back"); if (request_->IsConnectionKeepAlive()) { - LOG_INFO("The client asked for a keep-alive connection."); - LOG_INFO("Continue to read the next request..."); + LOG_INFO("The client asked for a keep-alive connection"); + LOG_INFO("Continue to read the next request"); Start(); } else { pool_->Close(shared_from_this()); @@ -189,7 +189,7 @@ void Connection::OnWriteOK() { } void Connection::OnWriteError(boost::system::error_code ec) { - LOG_ERRO("Socket write error (%s).", ec.message().c_str()); + LOG_ERRO("Socket write error (%s)", ec.message().c_str()); if (ec != boost::asio::error::operation_aborted) { pool_->Close(shared_from_this()); diff --git a/webcc/connection_pool.cc b/webcc/connection_pool.cc index e99868c..e6e0c2e 100644 --- a/webcc/connection_pool.cc +++ b/webcc/connection_pool.cc @@ -5,11 +5,11 @@ namespace webcc { void ConnectionPool::Start(ConnectionPtr c) { - LOG_VERB("Starting connection..."); + LOG_VERB("Start connection"); { // Lock the container only. - std::lock_guard lock(mutex_); + std::lock_guard lock{ mutex_ }; connections_.insert(c); } @@ -17,11 +17,11 @@ void ConnectionPool::Start(ConnectionPtr c) { } void ConnectionPool::Close(ConnectionPtr c) { - LOG_VERB("Closing connection..."); + LOG_VERB("Close connection"); { // Lock the container only. - std::lock_guard lock(mutex_); + std::lock_guard lock{ mutex_ }; connections_.erase(c); } @@ -30,10 +30,10 @@ void ConnectionPool::Close(ConnectionPtr c) { void ConnectionPool::Clear() { // Lock all since we are going to stop anyway. - std::lock_guard lock(mutex_); + std::lock_guard lock{ mutex_ }; if (!connections_.empty()) { - LOG_VERB("Closing all (%u) connections...", connections_.size()); + LOG_VERB("Close all (%u) connections", connections_.size()); for (auto& c : connections_) { c->Close(); } diff --git a/webcc/globals.h b/webcc/globals.h index 0c636df..afa554b 100644 --- a/webcc/globals.h +++ b/webcc/globals.h @@ -3,6 +3,7 @@ #include #include +#include #include #include #include @@ -23,11 +24,14 @@ using UrlArgs = std::vector; using Payload = std::vector; +using ProgressCallback = + std::function; + // ----------------------------------------------------------------------------- const char* const kCRLF = "\r\n"; -const std::size_t kInvalidLength = std::string::npos; +const std::size_t kInvalidLength = -1; // Default timeout for reading response. const int kMaxReadSeconds = 30; @@ -150,6 +154,7 @@ public: enum Code { kUnknownError = -1, kOK = 0, + kStateError, kSyntaxError, kResolveError, kConnectError, @@ -166,7 +171,7 @@ public: } // Note that `noexcept` is required by GCC. - const char* what() const noexcept override{ + const char* what() const noexcept override { return message_.c_str(); } diff --git a/webcc/logger.cc b/webcc/logger.cc index a10439e..22bbf85 100644 --- a/webcc/logger.cc +++ b/webcc/logger.cc @@ -45,7 +45,12 @@ static FILE* FOpen(const bfs::path& path, bool overwrite) { } struct Logger { - Logger() : file(nullptr), modes(0) { + Logger() = default; + + ~Logger() { + if (file != nullptr) { + fclose(file); + } } void Init(const bfs::path& path, int _modes) { @@ -57,14 +62,8 @@ struct Logger { } } - ~Logger() { - if (file != nullptr) { - fclose(file); - } - } - - FILE* file; - int modes; + FILE* file = nullptr; + int modes = 0; std::mutex mutex; }; @@ -161,10 +160,10 @@ static bfs::path InitLogPath(const bfs::path& dir) { return bfs::current_path() / WEBCC_LOG_FILE_NAME; } - if (!bfs::exists(dir) || !bfs::is_directory(dir)) { - boost::system::error_code ec; + boost::system::error_code ec; + if (!bfs::exists(dir, ec) || !bfs::is_directory(dir, ec)) { if (!bfs::create_directories(dir, ec) || ec) { - return bfs::path{}; + return {}; } } @@ -216,7 +215,7 @@ void Log(int level, const char* file, int line, const char* format, ...) { va_list args; va_start(args, format); - fprintf(g_logger.file, "%s, %s, %7s, %20s, %4d, ", + fprintf(g_logger.file, "%s, %s, %7s, %25s, %4d, ", timestamp.c_str(), kLevelNames[level], thread_id.c_str(), file, line); @@ -239,12 +238,12 @@ void Log(int level, const char* file, int line, const char* format, ...) { if (g_terminal_has_color) { if (level < WEBCC_WARN) { - fprintf(stderr, "%s%s, %s, %7s, %20s, %4d, ", + fprintf(stderr, "%s%s, %s, %7s, %25s, %4d, ", TERM_RESET, timestamp.c_str(), kLevelNames[level], thread_id.c_str(), file, line); } else { - fprintf(stderr, "%s%s%s, %s, %7s, %20s, %4d, ", + fprintf(stderr, "%s%s%s, %s, %7s, %25s, %4d, ", TERM_RESET, level == WEBCC_WARN ? TERM_YELLOW : TERM_RED, timestamp.c_str(), kLevelNames[level], thread_id.c_str(), @@ -255,7 +254,7 @@ void Log(int level, const char* file, int line, const char* format, ...) { fprintf(stderr, "%s\n", TERM_RESET); } else { - fprintf(stderr, "%s, %s, %7s, %20s, %4d, ", + fprintf(stderr, "%s, %s, %7s, %25s, %4d, ", timestamp.c_str(), kLevelNames[level], thread_id.c_str(), file, line); diff --git a/webcc/message.cc b/webcc/message.cc index c231e79..d940ba0 100644 --- a/webcc/message.cc +++ b/webcc/message.cc @@ -8,7 +8,7 @@ namespace webcc { -Message::Message() : body_(new Body{}), content_length_(kInvalidLength) { +Message::Message() : body_(new Body{}) { } void Message::SetBody(BodyPtr body, bool set_length) { diff --git a/webcc/message.h b/webcc/message.h index d2cf3b9..73f5dbe 100644 --- a/webcc/message.h +++ b/webcc/message.h @@ -118,7 +118,7 @@ protected: std::string start_line_; - std::size_t content_length_; + std::size_t content_length_ = kInvalidLength; }; } // namespace webcc diff --git a/webcc/parser.cc b/webcc/parser.cc index c6807bf..dbbdab5 100644 --- a/webcc/parser.cc +++ b/webcc/parser.cc @@ -49,14 +49,14 @@ bool StringBodyHandler::Finish() { auto body = std::make_shared(std::move(content_), IsCompressed()); #if WEBCC_ENABLE_GZIP - LOG_INFO("Decompress the HTTP content..."); + LOG_INFO("Decompress the HTTP content"); if (!body->Decompress()) { - LOG_ERRO("Cannot decompress the HTTP content!"); + LOG_ERRO("Cannot decompress the HTTP content"); return false; } #else if (body->compressed()) { - LOG_WARN("Compressed HTTP content remains untouched."); + LOG_WARN("Compressed HTTP content remains untouched"); } #endif // WEBCC_ENABLE_GZIP @@ -79,7 +79,7 @@ bool FileBodyHandler::OpenFile() { temp_path_.string().c_str()); } catch (const bfs::filesystem_error&) { - LOG_ERRO("Failed to generate temp path for streaming."); + LOG_ERRO("Failed to generate temp path for streaming"); return false; } @@ -132,6 +132,8 @@ bool Parser::Parse(const char* data, std::size_t length) { return ParseContent(data, length); } + header_length_ += length; + // Append the new data to the pending data. pending_data_.append(data, length); @@ -140,11 +142,13 @@ bool Parser::Parse(const char* data, std::size_t length) { } if (!header_ended_) { - LOG_INFO("HTTP headers will continue in next read."); + LOG_INFO("HTTP headers will continue in next read"); return true; } - LOG_INFO("HTTP headers just ended."); + LOG_INFO("HTTP headers just ended"); + + header_length_ -= pending_data_.size(); if (!OnHeadersEnd()) { // Only request parser can reach here when no view matches the request. @@ -170,6 +174,7 @@ void Parser::Reset() { stream_ = false; pending_data_.clear(); + header_length_ = 0; content_length_ = kInvalidLength; content_type_.Reset(); @@ -336,7 +341,7 @@ bool Parser::ParseChunkedContent(const char* data, std::size_t length) { return false; } - LOG_VERB("Chunk size: %u.", chunk_size_); + LOG_VERB("Chunk size: %u", chunk_size_); } if (chunk_size_ == 0) { @@ -378,14 +383,14 @@ bool Parser::ParseChunkedContent(const char* data, std::size_t length) { } bool Parser::ParseChunkSize() { - LOG_VERB("Parse chunk size."); + LOG_VERB("Parse chunk size"); std::string line; if (!GetNextLine(0, &line, true)) { return true; } - LOG_VERB("Chunk size line: [%s].", line.c_str()); + LOG_VERB("Chunk size line: [%s]", line.c_str()); std::string hex_str; // e.g., "cf0" (3312) @@ -397,7 +402,7 @@ bool Parser::ParseChunkSize() { } if (!to_size_t(hex_str, 16, &chunk_size_)) { - LOG_ERRO("Invalid chunk-size: %s.", hex_str.c_str()); + LOG_ERRO("Invalid chunk-size: %s", hex_str.c_str()); return false; } diff --git a/webcc/parser.h b/webcc/parser.h index 973065e..ca9bd01 100644 --- a/webcc/parser.h +++ b/webcc/parser.h @@ -94,17 +94,35 @@ private: class Parser { public: Parser(); - virtual ~Parser() = default; Parser(const Parser&) = delete; Parser& operator=(const Parser&) = delete; + virtual ~Parser() = default; + void Init(Message* message); bool finished() const { return finished_; } + // If the headers part has been parsed or not. + bool header_ended() const { + return header_ended_; + } + + // Get the length of the headers part. + // Available after the headers have been parsed (see header_ended()). + std::size_t header_length() const { + return header_length_; + } + + // The content length parsed from `Content-Length` header. + // kInvalidLength if the content is chunked. + std::size_t content_length() const { + return content_length_; + } + bool Parse(const char* data, std::size_t length); protected: @@ -144,25 +162,28 @@ protected: bool Finish(); protected: - Message* message_; + Message* message_ = nullptr; std::unique_ptr body_handler_; // Data streaming or not. - bool stream_; + bool stream_ = false; // Data waiting to be parsed. std::string pending_data_; + // The length of the headers part. + std::size_t header_length_ = 0; + // Temporary data and helper flags for parsing. - std::size_t content_length_; + std::size_t content_length_ = kInvalidLength; ContentType content_type_; - bool start_line_parsed_; - bool content_length_parsed_; - bool header_ended_; - bool chunked_; - std::size_t chunk_size_; - bool finished_; + bool start_line_parsed_ = false; + bool content_length_parsed_ = false; + bool header_ended_ = false; + bool chunked_ = false; + std::size_t chunk_size_ = kInvalidLength; + bool finished_ = false; }; } // namespace webcc diff --git a/webcc/server.cc b/webcc/server.cc index c8c759f..6ccd156 100644 --- a/webcc/server.cc +++ b/webcc/server.cc @@ -23,9 +23,6 @@ Server::Server(boost::asio::ip::tcp protocol, std::uint16_t port, : protocol_(protocol), port_(port), doc_root_(doc_root), - buffer_size_(kBufferSize), - file_chunk_size_(1024), - running_(false), acceptor_(io_context_), signals_(io_context_) { AddSignals(); @@ -35,12 +32,12 @@ void Server::Run(std::size_t workers, std::size_t loops) { assert(workers > 0); { - std::lock_guard lock(state_mutex_); + std::lock_guard lock{ state_mutex_ }; assert(worker_threads_.empty()); if (IsRunning()) { - LOG_WARN("Server is already running."); + LOG_WARN("Server is already running"); return; } @@ -48,11 +45,11 @@ void Server::Run(std::size_t workers, std::size_t loops) { io_context_.restart(); if (!Listen(port_)) { - LOG_ERRO("Server is NOT going to run."); + LOG_ERRO("Server is NOT going to run"); return; } - LOG_INFO("Server is going to run..."); + LOG_INFO("Server is going to run"); AsyncWaitSignals(); @@ -70,7 +67,7 @@ void Server::Run(std::size_t workers, std::size_t loops) { // asynchronous operation outstanding: the asynchronous accept call waiting // for new incoming connections. - LOG_INFO("Loop is running in %u thread(s).", loops); + LOG_INFO("Loop is running in %u thread(s)", loops); if (loops == 1) { // Run the loop in current thread. @@ -88,7 +85,7 @@ void Server::Run(std::size_t workers, std::size_t loops) { } void Server::Stop() { - std::lock_guard lock(state_mutex_); + std::lock_guard lock{ state_mutex_ }; DoStop(); } @@ -112,7 +109,7 @@ void Server::AsyncWaitSignals() { // The server is stopped by canceling all outstanding asynchronous // operations. Once all operations have finished the io_context::run() // call will exit. - LOG_INFO("On signal %d, stopping the server...", signo); + LOG_INFO("On signal %d, stop the server", signo); DoStop(); }); @@ -121,12 +118,12 @@ void Server::AsyncWaitSignals() { bool Server::Listen(std::uint16_t port) { boost::system::error_code ec; - tcp::endpoint endpoint(protocol_, port); + tcp::endpoint endpoint{ protocol_, port }; // Open the acceptor. acceptor_.open(endpoint.protocol(), ec); if (ec) { - LOG_ERRO("Acceptor open error (%s).", ec.message().c_str()); + LOG_ERRO("Acceptor open error (%s)", ec.message().c_str()); return false; } @@ -141,7 +138,7 @@ bool Server::Listen(std::uint16_t port) { // Bind to the server address. acceptor_.bind(endpoint, ec); if (ec) { - LOG_ERRO("Acceptor bind error (%s).", ec.message().c_str()); + LOG_ERRO("Acceptor bind error (%s)", ec.message().c_str()); return false; } @@ -150,7 +147,7 @@ bool Server::Listen(std::uint16_t port) { // has not started to accept the connection yet. acceptor_.listen(boost::asio::socket_base::max_listen_connections, ec); if (ec) { - LOG_ERRO("Acceptor listen error (%s).", ec.message().c_str()); + LOG_ERRO("Acceptor listen error (%s)", ec.message().c_str()); return false; } @@ -167,9 +164,10 @@ void Server::AsyncAccept() { } if (!ec) { - LOG_INFO("Accepted a connection."); + LOG_INFO("Accepted a connection"); using namespace std::placeholders; + auto view_matcher = std::bind(&Server::MatchViewOrStatic, this, _1, _2, _3); @@ -205,16 +203,16 @@ void Server::DoStop() { } void Server::WorkerRoutine() { - LOG_INFO("Worker is running."); + LOG_INFO("Worker is running"); for (;;) { auto connection = queue_.PopOrWait(); if (!connection) { - LOG_INFO("Worker is going to stop."); + LOG_INFO("Worker is going to stop"); // For stopping next worker. - queue_.Push(ConnectionPtr()); + queue_.Push({}); // Stop this worker. break; @@ -225,13 +223,13 @@ void Server::WorkerRoutine() { } void Server::StopWorkers() { - LOG_INFO("Stopping workers..."); + LOG_INFO("Stop workers"); // Clear/drop pending connections. // The connections will be closed later (see DoStop). // Alternatively, we can wait for the pending connections to be handled. if (queue_.Size() != 0) { - LOG_INFO("Clear pending connections..."); + LOG_INFO("Clear pending connections"); queue_.Clear(); } @@ -252,7 +250,7 @@ void Server::StopWorkers() { // last worker thread. queue_.Clear(); - LOG_INFO("All workers have been stopped."); + LOG_INFO("Workers stopped"); } void Server::Handle(ConnectionPtr connection) { @@ -307,7 +305,9 @@ bool Server::MatchViewOrStatic(const std::string& method, // Try to match a static file. if (method == methods::kGet && !doc_root_.empty()) { bfs::path path = doc_root_ / url; - if (!bfs::is_directory(path) && bfs::exists(path)) { + + boost::system::error_code ec; + if (!bfs::is_directory(path, ec) && bfs::exists(path, ec)) { return true; } } @@ -319,7 +319,7 @@ ResponsePtr Server::ServeStatic(RequestPtr request) { assert(request->method() == methods::kGet); if (doc_root_.empty()) { - LOG_INFO("The doc root was not specified."); + LOG_INFO("The doc root was not specified"); return {}; } @@ -340,7 +340,7 @@ ResponsePtr Server::ServeStatic(RequestPtr request) { return response; } catch (const Error& error) { - LOG_ERRO("File error: %s.", error.message().c_str()); + LOG_ERRO("File error: %s", error.message().c_str()); return {}; } } diff --git a/webcc/server.h b/webcc/server.h index c4d6cf1..5615a1e 100644 --- a/webcc/server.h +++ b/webcc/server.h @@ -24,11 +24,11 @@ public: Server(boost::asio::ip::tcp protocol, std::uint16_t port, const boost::filesystem::path& doc_root = {}); - ~Server() = default; - Server(const Server&) = delete; Server& operator=(const Server&) = delete; + ~Server() = default; + void set_buffer_size(std::size_t buffer_size) { if (buffer_size > 0) { buffer_size_ = buffer_size; @@ -105,20 +105,20 @@ private: boost::asio::ip::tcp protocol_; // Port number. - std::uint16_t port_; + std::uint16_t port_ = 0; // The directory with the static files to be served. boost::filesystem::path doc_root_; // The size of the buffer for reading request. - std::size_t buffer_size_; + std::size_t buffer_size_ = kBufferSize; // The size of the chunk loaded into memory each time when serving a // static file. - std::size_t file_chunk_size_; + std::size_t file_chunk_size_ = 1024; // Is the server running? - bool running_; + bool running_ = false; // The mutex for guarding the state of the server. std::mutex state_mutex_; diff --git a/webcc/socket.cc b/webcc/socket.cc index 86577f8..f2c8ae5 100644 --- a/webcc/socket.cc +++ b/webcc/socket.cc @@ -3,9 +3,9 @@ #if WEBCC_ENABLE_SSL #if (defined(_WIN32) || defined(_WIN64)) -#include -#include #include +#include +#include #include "openssl/x509.h" @@ -18,6 +18,9 @@ #include "webcc/logger.h" +using boost::asio::ip::tcp; +using namespace std::placeholders; + namespace webcc { // ----------------------------------------------------------------------------- @@ -25,48 +28,37 @@ namespace webcc { Socket::Socket(boost::asio::io_context& io_context) : socket_(io_context) { } -bool Socket::Connect(const std::string& /*host*/, const Endpoints& endpoints) { - boost::system::error_code ec; - boost::asio::connect(socket_, endpoints, ec); - - if (ec) { - LOG_ERRO("Socket connect error (%s).", ec.message().c_str()); - return false; - } - - return true; -} - -bool Socket::Write(const Payload& payload, boost::system::error_code* ec) { - boost::asio::write(socket_, payload, *ec); - return !(*ec); +void Socket::AsyncConnect(const std::string& host, const Endpoints& endpoints, + ConnectHandler&& handler) { + boost::asio::async_connect(socket_, endpoints, std::move(handler)); } -bool Socket::ReadSome(std::vector* buffer, std::size_t* size, - boost::system::error_code* ec) { - *size = socket_.read_some(boost::asio::buffer(*buffer), *ec); - return (*size != 0 && !(*ec)); +void Socket::AsyncWrite(const Payload& payload, WriteHandler&& handler) { + boost::asio::async_write(socket_, payload, std::move(handler)); } void Socket::AsyncReadSome(ReadHandler&& handler, std::vector* buffer) { socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler)); } -bool Socket::Close() { +bool Socket::Shutdown() { boost::system::error_code ec; - - socket_.shutdown(boost::asio::ip::tcp::socket::shutdown_both, ec); + socket_.shutdown(tcp::socket::shutdown_both, ec); if (ec) { - LOG_WARN("Socket shutdown error (%s).", ec.message().c_str()); - ec.clear(); - // Don't return, try to close the socket anywhere. + LOG_WARN("Socket shutdown error (%s)", ec.message().c_str()); + return false; } + return true; +} + +bool Socket::Close() { + boost::system::error_code ec; socket_.close(ec); if (ec) { - LOG_WARN("Socket close error (%s).", ec.message().c_str()); + LOG_WARN("Socket close error (%s)", ec.message().c_str()); return false; } @@ -77,91 +69,40 @@ bool Socket::Close() { #if WEBCC_ENABLE_SSL -#if (defined(_WIN32) || defined(_WIN64)) - -// Let OpenSSL on Windows use the system certificate store -// 1. Load your certificate (in PCCERT_CONTEXT structure) from Windows Cert -// store using Crypto APIs. -// 2. Get encrypted content of it in binary format as it is. -// [PCCERT_CONTEXT->pbCertEncoded]. -// 3. Parse this binary buffer into X509 certificate Object using OpenSSL's -// d2i_X509() method. -// 4. Get handle to OpenSSL's trust store using SSL_CTX_get_cert_store() -// method. -// 5. Load above parsed X509 certificate into this trust store using -// X509_STORE_add_cert() method. -// 6. You are done! -// NOTES: Enum Windows store with "ROOT" (not "CA"). -// See: https://stackoverflow.com/a/11763389/6825348 - -static bool UseSystemCertificateStore(SSL_CTX* ssl_ctx) { - // NOTE: Cannot use nullptr to replace NULL. - HCERTSTORE cert_store = ::CertOpenSystemStoreW(NULL, L"ROOT"); - if (cert_store == nullptr) { - LOG_ERRO("Cannot open Windows system certificate store."); - return false; - } - - X509_STORE* x509_store = SSL_CTX_get_cert_store(ssl_ctx); - PCCERT_CONTEXT cert_context = nullptr; - - while (cert_context = CertEnumCertificatesInStore(cert_store, cert_context)) { - auto in = (const unsigned char**)&cert_context->pbCertEncoded; - X509* x509 = d2i_X509(nullptr, in, cert_context->cbCertEncoded); - - if (x509 != nullptr) { - if (X509_STORE_add_cert(x509_store, x509) == 0) { - LOG_ERRO("Cannot add Windows root certificate."); - } - - X509_free(x509); - } - } - - CertFreeCertificateContext(cert_context); - CertCloseStore(cert_store, 0); - return true; -} - -#endif // defined(_WIN32) || defined(_WIN64) - namespace ssl = boost::asio::ssl; -SslSocket::SslSocket(boost::asio::io_context& io_context, bool ssl_verify) - : ssl_context_(ssl::context::sslv23), - ssl_socket_(io_context, ssl_context_), +SslSocket::SslSocket(boost::asio::io_context& io_context, + ssl::context& ssl_context, bool ssl_verify) + : ssl_socket_(io_context, ssl_context), ssl_verify_(ssl_verify) { -#if (defined(_WIN32) || defined(_WIN64)) - if (ssl_verify_) { - UseSystemCertificateStore(ssl_context_.native_handle()); - } -#else - // Use the default paths for finding CA certificates. - ssl_context_.set_default_verify_paths(); -#endif // defined(_WIN32) || defined(_WIN64) } -bool SslSocket::Connect(const std::string& host, const Endpoints& endpoints) { - boost::system::error_code ec; - boost::asio::connect(ssl_socket_.lowest_layer(), endpoints, ec); +void SslSocket::AsyncConnect(const std::string& host, + const Endpoints& endpoints, + ConnectHandler&& handler) { + connect_handler_ = std::move(handler); - if (ec) { - LOG_ERRO("Socket connect error (%s).", ec.message().c_str()); - return false; + + if (ssl_verify_) { + ssl_socket_.set_verify_mode(ssl::verify_peer); + } else { + ssl_socket_.set_verify_mode(ssl::verify_none); } - return Handshake(host); -} + // ssl::host_name_verification has been added since Boost 1.73 to replace + // ssl::rfc2818_verification. +#if BOOST_VERSION < 107300 + ssl_socket_.set_verify_callback(ssl::rfc2818_verification(host)); +#else + ssl_socket_.set_verify_callback(ssl::host_name_verification(host)); +#endif // BOOST_VERSION < 107300 -bool SslSocket::Write(const Payload& payload, boost::system::error_code* ec) { - boost::asio::write(ssl_socket_, payload, *ec); - return !(*ec); + boost::asio::async_connect(ssl_socket_.lowest_layer(), endpoints, + std::bind(&SslSocket::OnConnect, this, _1, _2)); } -bool SslSocket::ReadSome(std::vector* buffer, std::size_t* size, - boost::system::error_code* ec) { - *size = ssl_socket_.read_some(boost::asio::buffer(*buffer), *ec); - return (*size != 0 && !(*ec)); +void SslSocket::AsyncWrite(const Payload& payload, WriteHandler&& handler) { + boost::asio::async_write(ssl_socket_, payload, std::move(handler)); } void SslSocket::AsyncReadSome(ReadHandler&& handler, @@ -169,39 +110,50 @@ void SslSocket::AsyncReadSome(ReadHandler&& handler, ssl_socket_.async_read_some(boost::asio::buffer(*buffer), std::move(handler)); } -bool SslSocket::Close() { +bool SslSocket::Shutdown() { boost::system::error_code ec; - ssl_socket_.lowest_layer().close(ec); - return !ec; -} + ssl_socket_.lowest_layer().shutdown(tcp::socket::shutdown_both, ec); -bool SslSocket::Handshake(const std::string& host) { - if (ssl_verify_) { - ssl_socket_.set_verify_mode(ssl::verify_peer); - } else { - ssl_socket_.set_verify_mode(ssl::verify_none); + if (ec) { + LOG_WARN("Socket shutdown error (%s)", ec.message().c_str()); + return false; } - // ssl::host_name_verification has been added since Boost 1.73 to replace - // ssl::rfc2818_verification. -#if BOOST_VERSION < 107300 - ssl_socket_.set_verify_callback(ssl::rfc2818_verification(host)); -#else - ssl_socket_.set_verify_callback(ssl::host_name_verification(host)); -#endif // BOOST_VERSION < 107300 + return true; +} - // Use sync API directly since we don't need timeout control. +bool SslSocket::Close() { boost::system::error_code ec; - ssl_socket_.handshake(ssl::stream_base::client, ec); + ssl_socket_.lowest_layer().close(ec); if (ec) { - LOG_ERRO("Handshake error (%s).", ec.message().c_str()); + LOG_WARN("Socket close error (%s)", ec.message().c_str()); return false; } return true; } +void SslSocket::OnConnect(boost::system::error_code ec, + tcp::endpoint endpoint) { + if (ec) { + connect_handler_(ec, std::move(endpoint)); + return; + } + + // Backup endpoint + endpoint_ = std::move(endpoint); + + ssl_socket_.async_handshake(ssl::stream_base::client, + [this](boost::system::error_code ec) { + if (ec) { + LOG_ERRO("Handshake error (%s)", ec.message().c_str()); + } + + connect_handler_(ec, std::move(endpoint_)); + }); +} + #endif // WEBCC_ENABLE_SSL } // namespace webcc diff --git a/webcc/socket.h b/webcc/socket.h index 0c628d2..1982c51 100644 --- a/webcc/socket.h +++ b/webcc/socket.h @@ -18,24 +18,29 @@ namespace webcc { class SocketBase { public: - virtual ~SocketBase() = default; - using Endpoints = boost::asio::ip::tcp::resolver::results_type; + using ConnectHandler = std::function; + + using WriteHandler = + std::function; + using ReadHandler = std::function; - // TODO: Remove |host| - virtual bool Connect(const std::string& host, const Endpoints& endpoints) = 0; + virtual ~SocketBase() = default; - virtual bool Write(const Payload& payload, boost::system::error_code* ec) = 0; + virtual void AsyncConnect(const std::string& host, const Endpoints& endpoints, + ConnectHandler&& handler) = 0; - virtual bool ReadSome(std::vector* buffer, std::size_t* size, - boost::system::error_code* ec) = 0; + virtual void AsyncWrite(const Payload& payload, WriteHandler&& handler) = 0; virtual void AsyncReadSome(ReadHandler&& handler, std::vector* buffer) = 0; + virtual bool Shutdown() = 0; + virtual bool Close() = 0; }; @@ -45,15 +50,15 @@ class Socket : public SocketBase { public: explicit Socket(boost::asio::io_context& io_context); - bool Connect(const std::string& host, const Endpoints& endpoints) override; - - bool Write(const Payload& payload, boost::system::error_code* ec) override; + void AsyncConnect(const std::string& host, const Endpoints& endpoints, + ConnectHandler&& handler) override; - bool ReadSome(std::vector* buffer, std::size_t* size, - boost::system::error_code* ec) override; + void AsyncWrite(const Payload& payload, WriteHandler&& handler) override; void AsyncReadSome(ReadHandler&& handler, std::vector* buffer) override; + bool Shutdown() override; + bool Close() override; private: @@ -66,29 +71,32 @@ private: class SslSocket : public SocketBase { public: - explicit SslSocket(boost::asio::io_context& io_context, - bool ssl_verify = true); + SslSocket(boost::asio::io_context& io_context, + boost::asio::ssl::context& ssl_context, + bool ssl_verify = true); - bool Connect(const std::string& host, const Endpoints& endpoints) override; + void AsyncConnect(const std::string& host, const Endpoints& endpoints, + ConnectHandler&& handler) override; - bool Write(const Payload& payload, boost::system::error_code* ec) override; - - bool ReadSome(std::vector* buffer, std::size_t* size, - boost::system::error_code* ec) override; + void AsyncWrite(const Payload& payload, WriteHandler&& handler) override; void AsyncReadSome(ReadHandler&& handler, std::vector* buffer) override; + bool Shutdown() override; + bool Close() override; private: - bool Handshake(const std::string& host); + void OnConnect(boost::system::error_code ec, + boost::asio::ip::tcp::endpoint endpoint); - boost::asio::ssl::context ssl_context_; + ConnectHandler connect_handler_; + boost::asio::ip::tcp::endpoint endpoint_; boost::asio::ssl::stream ssl_socket_; // Verify the certificate of the peer (remote server) or not. - bool ssl_verify_; + bool ssl_verify_ = true; }; #endif // WEBCC_ENABLE_SSL