diff --git a/autotest/client_autotest/client_autotest.cc b/autotest/client_autotest/client_autotest.cc index ac62415..cebca36 100644 --- a/autotest/client_autotest/client_autotest.cc +++ b/autotest/client_autotest/client_autotest.cc @@ -418,7 +418,7 @@ TEST(ClientTest, KeepAlive) { // Keep-Alive by default. auto r = session.Send(webcc::RequestBuilder{}.Get(url)()); - EXPECT_TRUE(boost::iequals(r->GetHeader("Connection"), "Keep-alive")); + EXPECT_TRUE(boost::iequals(r->GetHeader("Connection"), "Keep-Alive")); // Close by setting Connection header directly. r = session.Send(webcc::RequestBuilder{}.Get(url). @@ -435,7 +435,7 @@ TEST(ClientTest, KeepAlive) { // Keep-Alive explicitly by using request builder. r = session.Send(webcc::RequestBuilder{}.Get(url).KeepAlive(true)()); - EXPECT_TRUE(boost::iequals(r->GetHeader("Connection"), "Keep-alive")); + EXPECT_TRUE(boost::iequals(r->GetHeader("Connection"), "Keep-Alive")); } catch (const webcc::Error& error) { std::cerr << error << std::endl; diff --git a/webcc/CMakeLists.txt b/webcc/CMakeLists.txt index 653c6ea..cbb1f20 100644 --- a/webcc/CMakeLists.txt +++ b/webcc/CMakeLists.txt @@ -15,15 +15,78 @@ configure_file( # Adhere to GNU filesystem layout conventions. include(GNUInstallDirs) -file(GLOB SOURCES RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}/*.cc) +set(SOURCES + base64.cc + body.cc + client.cc + client_pool.cc + client_session.cc + common.cc + connection.cc + connection_pool.cc + globals.cc + logger.cc + message.cc + parser.cc + request.cc + request_builder.cc + request_parser.cc + response.cc + response_builder.cc + response_parser.cc + router.cc + server.cc + socket.cc + string.cc + url.cc + utility.cc + ) + +set(HEADERS + base64.h + body.h + client.h + client_pool.h + client_session.h + common.h + connection.h + connection_pool.h + fs.h + globals.h + logger.h + message.h + parser.h + queue.h + request.h + request_builder.h + request_parser.h + response.h + response_builder.h + response_parser.h + router.h + server.h + socket_base.h + socket.h + string.h + url.h + utility.h + version.h + view.h + ) -file(GLOB HEADERS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} - ${CMAKE_CURRENT_SOURCE_DIR}/*.h) +if(WEBCC_ENABLE_SSL) + set(SOURCES ${SOURCES} + ssl_socket.cc + ssl_client.cc) -if(NOT WEBCC_ENABLE_GZIP) - list(REMOVE_ITEM SOURCES "gzip.cc") - list(REMOVE_ITEM HEADERS "gzip.h") + set(HEADERS ${HEADERS} + ssl_socket.h + ssl_client.h) +endif() + +if(WEBCC_ENABLE_GZIP) + set(SOURCES ${SOURCES} "gzip.cc") + set(HEADERS ${HEADERS} "gzip.h") endif() set(CMAKE_DEBUG_POSTFIX "d" CACHE STRING "Add a postfix to the debug library") diff --git a/webcc/client.cc b/webcc/client.cc index db8cf10..7225efc 100644 --- a/webcc/client.cc +++ b/webcc/client.cc @@ -3,32 +3,19 @@ #include "boost/algorithm/string.hpp" #include "webcc/logger.h" +#include "webcc/socket.h" using boost::asio::ip::tcp; using namespace std::placeholders; namespace webcc { -#if WEBCC_ENABLE_SSL - -Client::Client(boost::asio::io_context& io_context, - boost::asio::ssl::context& ssl_context) - : io_context_(io_context), - ssl_context_(ssl_context), - resolver_(io_context), - deadline_timer_(io_context) { -} - -#else - Client::Client(boost::asio::io_context& io_context) : io_context_(io_context), resolver_(io_context), deadline_timer_(io_context) { } -#endif // WEBCC_ENABLE_SSL - Error Client::Request(RequestPtr request, bool stream) { LOG_VERB("Request begin"); @@ -104,21 +91,12 @@ void Client::AsyncConnect() { if (boost::iequals(request_->url().scheme(), "http")) { socket_.reset(new Socket{ io_context_ }); AsyncResolve("80"); - return; - } - -#if WEBCC_ENABLE_SSL - if (boost::iequals(request_->url().scheme(), "https")) { - socket_.reset(new SslSocket{ io_context_, ssl_context_ }); - AsyncResolve("443"); - return; + } else { + LOG_ERRO("URL scheme (%s) is not supported", + request_->url().scheme().c_str()); + error_.Set(Error::kSyntaxError, "URL scheme not supported"); + FinishRequest(); } -#endif // WEBCC_ENABLE_SSL - - LOG_ERRO("URL scheme (%s) is not supported", - request_->url().scheme().c_str()); - error_.Set(Error::kSyntaxError, "URL scheme not supported"); - FinishRequest(); } void Client::AsyncResolve(string_view default_port) { diff --git a/webcc/client.h b/webcc/client.h index 4da2e16..83042c4 100644 --- a/webcc/client.h +++ b/webcc/client.h @@ -15,21 +15,18 @@ #include "webcc/request.h" #include "webcc/response.h" #include "webcc/response_parser.h" -#include "webcc/socket.h" +#include "webcc/socket_base.h" namespace webcc { +class Client; +using ClientPtr = std::shared_ptr; + // Synchronous HTTP & HTTPS client. // A request won't return until the response is received or timeout occurs. class Client { public: - // TODO -#if WEBCC_ENABLE_SSL - Client(boost::asio::io_context& io_context, - boost::asio::ssl::context& ssl_context); -#else explicit Client(boost::asio::io_context& io_context); -#endif Client(const Client&) = delete; Client& operator=(const Client&) = delete; @@ -84,10 +81,12 @@ public: response_parser_.Init(nullptr, false); } -private: +protected: void DoClose(); - void AsyncConnect(); + // TODO: Rename + // TODO: Add class ClientBase ? + virtual void AsyncConnect(); void AsyncResolve(string_view default_port); @@ -113,13 +112,9 @@ private: void FinishRequest(); -private: +protected: boost::asio::io_context& io_context_; -#if WEBCC_ENABLE_SSL - boost::asio::ssl::context& ssl_context_; -#endif - std::unique_ptr socket_; boost::asio::ip::tcp::resolver resolver_; @@ -164,8 +159,6 @@ private: Error error_; }; -using ClientPtr = std::shared_ptr; - } // namespace webcc #endif // WEBCC_CLIENT_H_ diff --git a/webcc/client_session.cc b/webcc/client_session.cc index 5f90c9d..6fc7ee0 100644 --- a/webcc/client_session.cc +++ b/webcc/client_session.cc @@ -21,8 +21,14 @@ #include "webcc/url.h" #include "webcc/utility.h" +#if WEBCC_ENABLE_SSL +#include "webcc/ssl_client.h" +#endif + namespace webcc { +// ----------------------------------------------------------------------------- + #if WEBCC_ENABLE_SSL #if (defined(_WIN32) || defined(_WIN64)) @@ -73,21 +79,11 @@ static bool UseSystemCertificateStore(SSL_CTX* ssl_ctx) { #endif // defined(_WIN32) || defined(_WIN64) #endif // WEBCC_ENABLE_SSL +// ----------------------------------------------------------------------------- + ClientSession::ClientSession(std::size_t buffer_size) : work_guard_(boost::asio::make_work_guard(io_context_)), -#if WEBCC_ENABLE_SSL - ssl_context_(boost::asio::ssl::context::sslv23_client), -#endif buffer_size_(buffer_size) { -#if WEBCC_ENABLE_SSL -#if (defined(_WIN32) || defined(_WIN64)) - UseSystemCertificateStore(ssl_context_.native_handle()); -#else - // Use the default paths for finding CA certificates. - ssl_context_.set_default_verify_paths(); -#endif // defined(_WIN32) || defined(_WIN64) -#endif // WEBCC_ENABLE_SSL - InitHeaders(); Start(); @@ -95,6 +91,12 @@ ClientSession::ClientSession(std::size_t buffer_size) ClientSession::~ClientSession() { Stop(); + +#if WEBCC_ENABLE_SSL + if (ssl_context_ != nullptr) { + delete ssl_context_; + } +#endif // WEBCC_ENABLE_SSL } void ClientSession::Start() { @@ -195,10 +197,6 @@ ResponsePtr ClientSession::Send(RequestPtr request, bool stream, throw Error{ Error::kStateError, "Loop is not running" }; } - if (!CheckUrlScheme(request)) { - throw Error{ Error::kSyntaxError, "Invalid URL scheme" }; - } - for (auto& h : headers_.data()) { if (!request->HasHeader(h.first)) { request->SetHeader(h.first, h.second); @@ -235,20 +233,42 @@ void ClientSession::InitHeaders() { headers_.Set(headers::kConnection, "Keep-Alive"); } -bool ClientSession::CheckUrlScheme(RequestPtr request) { - if (boost::iequals(request->url().scheme(), "http")) { - return true; +ClientPtr ClientSession::CreateClient(const std::string& url_scheme) { + if (boost::iequals(url_scheme, "http")) { + return std::make_shared(io_context_); } #if WEBCC_ENABLE_SSL - if (boost::iequals(request->url().scheme(), "https")) { - return true; + if (boost::iequals(url_scheme, "https")) { + CreateSslContext(); // If it's not created yet + return std::make_shared(io_context_, *ssl_context_); } #endif // WEBCC_ENABLE_SSL - return false; + return {}; +} + +#if WEBCC_ENABLE_SSL + +void ClientSession::CreateSslContext() { + if (ssl_context_ != nullptr) { + return; + } + + namespace ssl = boost::asio::ssl; + + ssl_context_ = new ssl::context{ ssl::context::sslv23_client }; + +#if (defined(_WIN32) || defined(_WIN64)) + UseSystemCertificateStore(ssl_context_->native_handle()); +#else + // Use the default paths for finding CA certificates. + ssl_context_->set_default_verify_paths(); +#endif // defined(_WIN32) || defined(_WIN64) } +#endif // WEBCC_ENABLE_SSL + ResponsePtr ClientSession::DoSend(RequestPtr request, bool stream, ProgressCallback callback) { const ClientPool::Key key{ request->url() }; @@ -259,11 +279,10 @@ ResponsePtr ClientSession::DoSend(RequestPtr request, bool stream, ClientPtr client = pool_.Get(key); if (!client) { -#if WEBCC_ENABLE_SSL - client.reset(new Client{ io_context_, ssl_context_ }); -#else - client.reset(new Client{ io_context_ }); -#endif // WEBCC_ENABLE_SSL + client = CreateClient(request->url().scheme()); + if (!client) { + throw Error{ Error::kSyntaxError, "Invalid URL scheme" }; + } reuse = false; } else { LOG_VERB("Reuse an existing connection"); diff --git a/webcc/client_session.h b/webcc/client_session.h index ffd9c29..ce5eee6 100644 --- a/webcc/client_session.h +++ b/webcc/client_session.h @@ -13,6 +13,10 @@ #include "webcc/request_builder.h" #include "webcc/response.h" +#if WEBCC_ENABLE_SSL +#include "boost/asio/ssl/context.hpp" +#endif + namespace webcc { // Client session provides connection-pooling, configuration and more. @@ -101,8 +105,13 @@ public: private: void InitHeaders(); - // Check if the scheme of the request is valid. - bool CheckUrlScheme(RequestPtr request); + // Create a client object according to the URL scheme. + ClientPtr CreateClient(const std::string& url_scheme); + +#if WEBCC_ENABLE_SSL + // Create SSL context if it's not created. + void CreateSslContext(); +#endif // WEBCC_ENABLE_SSL ResponsePtr DoSend(RequestPtr request, bool stream, ProgressCallback callback); @@ -117,7 +126,8 @@ private: boost::asio::executor_work_guard work_guard_; #if WEBCC_ENABLE_SSL - boost::asio::ssl::context ssl_context_; + // SSL context is lazily created on the first HTTPS request. + boost::asio::ssl::context* ssl_context_ = nullptr; #endif // Is Asio loop running? diff --git a/webcc/socket.cc b/webcc/socket.cc index a58a525..ceb46e1 100644 --- a/webcc/socket.cc +++ b/webcc/socket.cc @@ -2,18 +2,14 @@ #include "boost/asio/connect.hpp" #include "boost/asio/read.hpp" -#include "boost/asio/ssl.hpp" #include "boost/asio/write.hpp" #include "webcc/logger.h" using boost::asio::ip::tcp; -using namespace std::placeholders; namespace webcc { -// ----------------------------------------------------------------------------- - Socket::Socket(boost::asio::io_context& io_context) : socket_(io_context) { } @@ -54,119 +50,4 @@ bool Socket::Close() { return true; } -// ----------------------------------------------------------------------------- - -#if WEBCC_ENABLE_SSL - -namespace ssl = boost::asio::ssl; - -SslSocket::SslSocket(boost::asio::io_context& io_context, - ssl::context& ssl_context) - : ssl_stream_(io_context, ssl_context) { -} - -void SslSocket::AsyncConnect(const std::string& host, - const Endpoints& endpoints, - ConnectHandler&& handler) { - connect_handler_ = std::move(handler); - - // Set SNI (server name indication) host name. - // Many hosts need this to handshake successfully (e.g., google.com). - // Inspired by Boost.Beast. - if (!SSL_set_tlsext_host_name(ssl_stream_.native_handle(), host.c_str())) { - // TODO: Call ERR_get_error() to get error. - LOG_ERRO("Failed to set SNI host name for SSL"); - } - - // Modes `ssl::verify_fail_if_no_peer_cert` and `ssl::verify_client_once` are - // for server only. `ssl::verify_none` is not secure. - // See: https://stackoverflow.com/a/12621528 - ssl_stream_.set_verify_mode(ssl::verify_peer); - - // ssl::host_name_verification has been added since Boost 1.73 to replace - // ssl::rfc2818_verification. -#if BOOST_VERSION < 107300 - ssl_stream_.set_verify_callback(ssl::rfc2818_verification(host)); -#else - ssl_stream_.set_verify_callback(ssl::host_name_verification(host)); -#endif // BOOST_VERSION < 107300 - - boost::asio::async_connect(ssl_stream_.lowest_layer(), endpoints, - std::bind(&SslSocket::OnConnect, this, _1, _2)); -} - -void SslSocket::AsyncWrite(const Payload& payload, WriteHandler&& handler) { - boost::asio::async_write(ssl_stream_, payload, std::move(handler)); -} - -void SslSocket::AsyncReadSome(ReadHandler&& handler, - std::vector* buffer) { - ssl_stream_.async_read_some(boost::asio::buffer(*buffer), std::move(handler)); -} - -bool SslSocket::Shutdown() { - boost::system::error_code ec; - - ssl_stream_.lowest_layer().cancel(ec); - - // Shutdown SSL - // TODO: Use async_shutdown()? - ssl_stream_.shutdown(ec); - - if (ec == boost::asio::error::eof) { - // See: https://stackoverflow.com/a/25703699 - ec = {}; - } - - if (ec) { - LOG_WARN("SSL shutdown error (%s)", ec.message().c_str()); - return false; - } - - // Shutdown TCP - // TODO: Not sure if this is necessary? - ssl_stream_.lowest_layer().shutdown(tcp::socket::shutdown_both, ec); - - if (ec) { - LOG_WARN("Socket shutdown error (%s)", ec.message().c_str()); - return false; - } - - return true; -} - -bool SslSocket::Close() { - boost::system::error_code ec; - ssl_stream_.lowest_layer().close(ec); - - if (ec) { - LOG_WARN("Socket close error (%s)", ec.message().c_str()); - return false; - } - - return true; -} - -void SslSocket::OnConnect(boost::system::error_code ec, - tcp::endpoint endpoint) { - if (ec) { - connect_handler_(ec, std::move(endpoint)); - return; - } - - // Backup endpoint - endpoint_ = std::move(endpoint); - - ssl_stream_.async_handshake(ssl::stream_base::client, - [this](boost::system::error_code ec) { - if (ec) { - LOG_ERRO("Handshake error (%s)", ec.message().c_str()); - } - - connect_handler_(ec, std::move(endpoint_)); - }); -} - -#endif // WEBCC_ENABLE_SSL - } // namespace webcc diff --git a/webcc/socket.h b/webcc/socket.h index e6dad79..b7ad033 100644 --- a/webcc/socket.h +++ b/webcc/socket.h @@ -1,56 +1,10 @@ #ifndef WEBCC_SOCKET_H_ #define WEBCC_SOCKET_H_ -#include - -#include "boost/asio/ip/tcp.hpp" - -#include "webcc/config.h" -#include "webcc/request.h" - -#if WEBCC_ENABLE_SSL -#include "boost/asio/ssl.hpp" -#endif // WEBCC_ENABLE_SSL +#include "webcc/socket_base.h" namespace webcc { -// ----------------------------------------------------------------------------- - -class SocketBase { -public: - using Endpoints = boost::asio::ip::tcp::resolver::results_type; - - using ConnectHandler = std::function; - - using WriteHandler = - std::function; - - using ReadHandler = - std::function; - - SocketBase() = default; - - SocketBase(const SocketBase&) = delete; - SocketBase& operator=(const SocketBase&) = delete; - - virtual ~SocketBase() = default; - - virtual void AsyncConnect(const std::string& host, const Endpoints& endpoints, - ConnectHandler&& handler) = 0; - - virtual void AsyncWrite(const Payload& payload, WriteHandler&& handler) = 0; - - virtual void AsyncReadSome(ReadHandler&& handler, - std::vector* buffer) = 0; - - virtual bool Shutdown() = 0; - - virtual bool Close() = 0; -}; - -// ----------------------------------------------------------------------------- - class Socket : public SocketBase { public: explicit Socket(boost::asio::io_context& io_context); @@ -70,38 +24,6 @@ private: boost::asio::ip::tcp::socket socket_; }; -// ----------------------------------------------------------------------------- - -#if WEBCC_ENABLE_SSL - -class SslSocket : public SocketBase { -public: - SslSocket(boost::asio::io_context& io_context, - boost::asio::ssl::context& ssl_context); - - void AsyncConnect(const std::string& host, const Endpoints& endpoints, - ConnectHandler&& handler) override; - - void AsyncWrite(const Payload& payload, WriteHandler&& handler) override; - - void AsyncReadSome(ReadHandler&& handler, std::vector* buffer) override; - - bool Shutdown() override; - - bool Close() override; - -private: - void OnConnect(boost::system::error_code ec, - boost::asio::ip::tcp::endpoint endpoint); - - ConnectHandler connect_handler_; - boost::asio::ip::tcp::endpoint endpoint_; - - boost::asio::ssl::stream ssl_stream_; -}; - -#endif // WEBCC_ENABLE_SSL - } // namespace webcc #endif // WEBCC_SOCKET_H_ diff --git a/webcc/socket_base.h b/webcc/socket_base.h new file mode 100644 index 0000000..8f0072b --- /dev/null +++ b/webcc/socket_base.h @@ -0,0 +1,45 @@ +#ifndef WEBCC_SOCKET_BASE_H_ +#define WEBCC_SOCKET_BASE_H_ + +#include "boost/asio/ip/tcp.hpp" + +#include "webcc/globals.h" + +namespace webcc { + +class SocketBase { +public: + using Endpoints = boost::asio::ip::tcp::resolver::results_type; + + using ConnectHandler = std::function; + + using WriteHandler = + std::function; + + using ReadHandler = + std::function; + + SocketBase() = default; + + SocketBase(const SocketBase&) = delete; + SocketBase& operator=(const SocketBase&) = delete; + + virtual ~SocketBase() = default; + + virtual void AsyncConnect(const std::string& host, const Endpoints& endpoints, + ConnectHandler&& handler) = 0; + + virtual void AsyncWrite(const Payload& payload, WriteHandler&& handler) = 0; + + virtual void AsyncReadSome(ReadHandler&& handler, + std::vector* buffer) = 0; + + virtual bool Shutdown() = 0; + + virtual bool Close() = 0; +}; + +} // namespace webcc + +#endif // WEBCC_SOCKET_BASE_H_ diff --git a/webcc/ssl_client.cc b/webcc/ssl_client.cc new file mode 100644 index 0000000..058ddc8 --- /dev/null +++ b/webcc/ssl_client.cc @@ -0,0 +1,23 @@ +#include "webcc/ssl_client.h" + +#include "boost/algorithm/string.hpp" + +#include "webcc/ssl_socket.h" + +namespace webcc { + +SslClient::SslClient(boost::asio::io_context& io_context, + boost::asio::ssl::context& ssl_context) + : Client(io_context), ssl_context_(ssl_context) { +} + +void SslClient::AsyncConnect() { + if (boost::iequals(request_->url().scheme(), "https")) { + socket_.reset(new SslSocket{ io_context_, ssl_context_ }); + AsyncResolve("443"); + } else { + Client::AsyncConnect(); + } +} + +} // namespace webcc diff --git a/webcc/ssl_client.h b/webcc/ssl_client.h new file mode 100644 index 0000000..0bfdf4d --- /dev/null +++ b/webcc/ssl_client.h @@ -0,0 +1,30 @@ +#ifndef WEBCC_SSL_CLIENT_H_ +#define WEBCC_SSL_CLIENT_H_ + +#include "webcc/client.h" + +#include "boost/asio/ssl/context.hpp" + +#if !WEBCC_ENABLE_SSL +#error SSL must be enabled! +#endif + +namespace webcc { + +class SslClient final : public Client { +public: + SslClient(boost::asio::io_context& io_context, + boost::asio::ssl::context& ssl_context); + + ~SslClient() = default; + +protected: + void AsyncConnect() override; + +private: + boost::asio::ssl::context& ssl_context_; +}; + +} // namespace webcc + +#endif // WEBCC_SSL_CLIENT_H_ diff --git a/webcc/ssl_socket.cc b/webcc/ssl_socket.cc new file mode 100644 index 0000000..84a9258 --- /dev/null +++ b/webcc/ssl_socket.cc @@ -0,0 +1,123 @@ +#include "webcc/ssl_socket.h" + +#include "boost/asio/connect.hpp" +#include "boost/asio/read.hpp" +#include "boost/asio/write.hpp" + +#include "webcc/logger.h" + +using namespace std::placeholders; + +using boost::asio::ip::tcp; +namespace ssl = boost::asio::ssl; + +namespace webcc { + +SslSocket::SslSocket(boost::asio::io_context& io_context, + ssl::context& ssl_context) + : ssl_stream_(io_context, ssl_context) { +} + +void SslSocket::AsyncConnect(const std::string& host, + const Endpoints& endpoints, + ConnectHandler&& handler) { + connect_handler_ = std::move(handler); + + // Set SNI (server name indication) host name. + // Many hosts need this to handshake successfully (e.g., google.com). + // Inspired by Boost.Beast. + if (!SSL_set_tlsext_host_name(ssl_stream_.native_handle(), host.c_str())) { + // TODO: Call ERR_get_error() to get error. + LOG_ERRO("Failed to set SNI host name for SSL"); + } + + // Modes `ssl::verify_fail_if_no_peer_cert` and `ssl::verify_client_once` are + // for server only. `ssl::verify_none` is not secure. + // See: https://stackoverflow.com/a/12621528 + ssl_stream_.set_verify_mode(ssl::verify_peer); + + // ssl::host_name_verification has been added since Boost 1.73 to replace + // ssl::rfc2818_verification. +#if BOOST_VERSION < 107300 + ssl_stream_.set_verify_callback(ssl::rfc2818_verification(host)); +#else + ssl_stream_.set_verify_callback(ssl::host_name_verification(host)); +#endif // BOOST_VERSION < 107300 + + boost::asio::async_connect(ssl_stream_.lowest_layer(), endpoints, + std::bind(&SslSocket::OnConnect, this, _1, _2)); +} + +void SslSocket::AsyncWrite(const Payload& payload, WriteHandler&& handler) { + boost::asio::async_write(ssl_stream_, payload, std::move(handler)); +} + +void SslSocket::AsyncReadSome(ReadHandler&& handler, + std::vector* buffer) { + ssl_stream_.async_read_some(boost::asio::buffer(*buffer), std::move(handler)); +} + +bool SslSocket::Shutdown() { + boost::system::error_code ec; + + ssl_stream_.lowest_layer().cancel(ec); + + // Shutdown SSL + // TODO: Use async_shutdown()? + ssl_stream_.shutdown(ec); + + if (ec == boost::asio::error::eof) { + // See: https://stackoverflow.com/a/25703699 + ec = {}; + } + + if (ec) { + LOG_WARN("SSL shutdown error (%s)", ec.message().c_str()); + return false; + } + + // Shutdown TCP + // TODO: Not sure if this is necessary? + ssl_stream_.lowest_layer().shutdown(tcp::socket::shutdown_both, ec); + + if (ec) { + LOG_WARN("Socket shutdown error (%s)", ec.message().c_str()); + return false; + } + + return true; +} + +bool SslSocket::Close() { + boost::system::error_code ec; + ssl_stream_.lowest_layer().close(ec); + + if (ec) { + LOG_WARN("Socket close error (%s)", ec.message().c_str()); + return false; + } + + return true; +} + +void SslSocket::OnConnect(boost::system::error_code ec, + tcp::endpoint endpoint) { + if (ec) { + connect_handler_(ec, std::move(endpoint)); + return; + } + + // Backup endpoint + endpoint_ = std::move(endpoint); + + ssl_stream_.async_handshake(ssl::stream_base::client, + [this](boost::system::error_code ec) { + if (ec) { + LOG_ERRO("Handshake error (%s)", ec.message().c_str()); + } + + connect_handler_(ec, std::move(endpoint_)); + }); +} + +} // namespace webcc diff --git a/webcc/ssl_socket.h b/webcc/ssl_socket.h new file mode 100644 index 0000000..9910419 --- /dev/null +++ b/webcc/ssl_socket.h @@ -0,0 +1,42 @@ +#ifndef WEBCC_SSL_SOCKET_H_ +#define WEBCC_SSL_SOCKET_H_ + +#include "webcc/socket_base.h" + +#include "boost/asio/ssl.hpp" + +#if !WEBCC_ENABLE_SSL +#error SSL must be enabled! +#endif + +namespace webcc { + +class SslSocket : public SocketBase { +public: + SslSocket(boost::asio::io_context& io_context, + boost::asio::ssl::context& ssl_context); + + void AsyncConnect(const std::string& host, const Endpoints& endpoints, + ConnectHandler&& handler) override; + + void AsyncWrite(const Payload& payload, WriteHandler&& handler) override; + + void AsyncReadSome(ReadHandler&& handler, std::vector* buffer) override; + + bool Shutdown() override; + + bool Close() override; + +private: + void OnConnect(boost::system::error_code ec, + boost::asio::ip::tcp::endpoint endpoint); + + ConnectHandler connect_handler_; + boost::asio::ip::tcp::endpoint endpoint_; + + boost::asio::ssl::stream ssl_stream_; +}; + +} // namespace webcc + +#endif // WEBCC_SSL_SOCKET_H_