diff --git a/unittest/string_unittest.cc b/unittest/string_unittest.cc index f8652f0..7678dfb 100644 --- a/unittest/string_unittest.cc +++ b/unittest/string_unittest.cc @@ -1,9 +1,5 @@ #include "gtest/gtest.h" -#include - -#include "boost/algorithm/string.hpp" - #include "webcc/string.h" TEST(StringTest, Trim) { diff --git a/webcc/CMakeLists.txt b/webcc/CMakeLists.txt index cbb1f20..85a0747 100644 --- a/webcc/CMakeLists.txt +++ b/webcc/CMakeLists.txt @@ -18,7 +18,7 @@ include(GNUInstallDirs) set(SOURCES base64.cc body.cc - client.cc + client_base.cc client_pool.cc client_session.cc common.cc @@ -45,6 +45,7 @@ set(SOURCES set(HEADERS base64.h body.h + client_base.h client.h client_pool.h client_session.h @@ -75,13 +76,8 @@ set(HEADERS ) if(WEBCC_ENABLE_SSL) - set(SOURCES ${SOURCES} - ssl_socket.cc - ssl_client.cc) - - set(HEADERS ${HEADERS} - ssl_socket.h - ssl_client.h) + set(SOURCES ${SOURCES} ssl_socket.cc) + set(HEADERS ${HEADERS} ssl_socket.h ssl_client.h) endif() if(WEBCC_ENABLE_GZIP) diff --git a/webcc/client.h b/webcc/client.h index 83042c4..250c329 100644 --- a/webcc/client.h +++ b/webcc/client.h @@ -1,162 +1,27 @@ #ifndef WEBCC_CLIENT_H_ #define WEBCC_CLIENT_H_ -#include -#include -#include -#include -#include - -#include "boost/asio/io_context.hpp" -#include "boost/asio/ip/tcp.hpp" -#include "boost/asio/steady_timer.hpp" - -#include "webcc/globals.h" -#include "webcc/request.h" -#include "webcc/response.h" -#include "webcc/response_parser.h" -#include "webcc/socket_base.h" +#include "webcc/client_base.h" +#include "webcc/socket.h" namespace webcc { -class Client; -using ClientPtr = std::shared_ptr; - -// Synchronous HTTP & HTTPS client. -// A request won't return until the response is received or timeout occurs. -class Client { +class Client final : public ClientBase { public: - explicit Client(boost::asio::io_context& io_context); - - Client(const Client&) = delete; - Client& operator=(const Client&) = delete; - - ~Client() = default; - - void set_buffer_size(std::size_t buffer_size) { - if (buffer_size > 0) { - buffer_size_ = buffer_size; - } - } - - void set_connect_timeout(int timeout) { - if (timeout > 0) { - connect_timeout_ = timeout; - } + explicit Client(boost::asio::io_context& io_context) + : ClientBase(io_context) { } - void set_read_timeout(int timeout) { - if (timeout > 0) { - read_timeout_ = timeout; - } - } - - // Set progress callback to be informed about the read progress. - // NOTE: Don't use move semantics because in practice, there is no difference - // between copying and moving an object of a closure type. - // TODO: Support write progress - void set_progress_callback(ProgressCallback callback) { - progress_callback_ = callback; - } - - // Connect, send request, wait until response is received. - Error Request(RequestPtr request, bool stream = false); - - // Close the socket. - void Close(); - - bool connected() const { - return connected_; - } + ~Client() = default; - ResponsePtr response() const { - return response_; +protected: + void CreateSocket() override { + socket_.reset(new Socket{ io_context_ }); } - // Reset response object. - // Used to make sure the response object will released even the client object - // itself will be cached for keep-alive purpose. - void Reset() { - response_.reset(); - response_parser_.Init(nullptr, false); + void Resolve() override { + AsyncResolve("80"); } - -protected: - void DoClose(); - - // TODO: Rename - // TODO: Add class ClientBase ? - virtual void AsyncConnect(); - - void AsyncResolve(string_view default_port); - - void OnResolve(boost::system::error_code ec, - boost::asio::ip::tcp::resolver::results_type endpoints); - - void OnConnect(boost::system::error_code ec, boost::asio::ip::tcp::endpoint); - - void AsyncWrite(); - void OnWrite(boost::system::error_code ec, std::size_t length); - - void AsyncWriteBody(); - void OnWriteBody(boost::system::error_code ec, std::size_t length); - - void HandleWriteError(boost::system::error_code ec); - - void AsyncRead(); - void OnRead(boost::system::error_code ec, std::size_t length); - - void AsyncWaitDeadlineTimer(int seconds); - void OnDeadlineTimer(boost::system::error_code ec); - void StopDeadlineTimer(); - - void FinishRequest(); - -protected: - boost::asio::io_context& io_context_; - - std::unique_ptr socket_; - - boost::asio::ip::tcp::resolver resolver_; - - bool request_finished_ = true; - std::condition_variable request_cv_; - std::mutex request_mutex_; - - RequestPtr request_; - - ResponsePtr response_; - ResponseParser response_parser_; - - // The length already read. - std::size_t length_read_ = 0; - - // The buffer for reading response. - std::vector buffer_; - - // The size of the buffer for reading response. - // 0 means default value will be used. - std::size_t buffer_size_ = kBufferSize; - - // Timeout (seconds) for connecting to server. - // Default as 0 to disable our own control (i.e., deadline_timer_). - int connect_timeout_ = 0; - - // Timeout (seconds) for reading response. - int read_timeout_ = kMaxReadSeconds; - - // Deadline timer for connecting to server. - boost::asio::steady_timer deadline_timer_; - bool deadline_timer_stopped_ = true; - - // Socket connected or not. - bool connected_ = false; - - // Progress callback (optional). - ProgressCallback progress_callback_; - - // Current error. - Error error_; }; } // namespace webcc diff --git a/webcc/client.cc b/webcc/client_base.cc similarity index 75% rename from webcc/client.cc rename to webcc/client_base.cc index 7225efc..b7d21d4 100644 --- a/webcc/client.cc +++ b/webcc/client_base.cc @@ -1,4 +1,4 @@ -#include "webcc/client.h" +#include "webcc/client_base.h" #include "boost/algorithm/string.hpp" @@ -10,13 +10,13 @@ using namespace std::placeholders; namespace webcc { -Client::Client(boost::asio::io_context& io_context) +ClientBase::ClientBase(boost::asio::io_context& io_context) : io_context_(io_context), resolver_(io_context), deadline_timer_(io_context) { } -Error Client::Request(RequestPtr request, bool stream) { +Error ClientBase::Request(RequestPtr request, bool stream) { LOG_VERB("Request begin"); request_finished_ = false; @@ -48,7 +48,8 @@ Error Client::Request(RequestPtr request, bool stream) { } if (!connected_) { - AsyncConnect(); + CreateSocket(); + Resolve(); } else { AsyncWrite(); } @@ -62,14 +63,14 @@ Error Client::Request(RequestPtr request, bool stream) { return error_; } -void Client::Close() { - DoClose(); +void ClientBase::Close() { + CloseSocket(); // Don't call FinishRequest() from here! It will be called in the handler // OnXxx with `error::operation_aborted`. } -void Client::DoClose() { +void ClientBase::CloseSocket() { if (connected_) { connected_ = false; if (socket_) { @@ -87,19 +88,7 @@ void Client::DoClose() { } } -void Client::AsyncConnect() { - if (boost::iequals(request_->url().scheme(), "http")) { - socket_.reset(new Socket{ io_context_ }); - AsyncResolve("80"); - } else { - LOG_ERRO("URL scheme (%s) is not supported", - request_->url().scheme().c_str()); - error_.Set(Error::kSyntaxError, "URL scheme not supported"); - FinishRequest(); - } -} - -void Client::AsyncResolve(string_view default_port) { +void ClientBase::AsyncResolve(string_view default_port) { std::string port = request_->port(); if (port.empty()) { port = ToString(default_port); @@ -109,11 +98,11 @@ void Client::AsyncResolve(string_view default_port) { // The protocol depends on the `host`, both V4 and V6 are supported. resolver_.async_resolve(request_->host(), port, - std::bind(&Client::OnResolve, this, _1, _2)); + std::bind(&ClientBase::OnResolve, this, _1, _2)); } -void Client::OnResolve(boost::system::error_code ec, - tcp::resolver::results_type endpoints) { +void ClientBase::OnResolve(boost::system::error_code ec, + tcp::resolver::results_type endpoints) { if (ec) { LOG_ERRO("Host resolve error (%s)", ec.message().c_str()); error_.Set(Error::kResolveError, "Host resolve error"); @@ -126,17 +115,17 @@ void Client::OnResolve(boost::system::error_code ec, AsyncWaitDeadlineTimer(connect_timeout_); socket_->AsyncConnect(request_->host(), endpoints, - std::bind(&Client::OnConnect, this, _1, _2)); + std::bind(&ClientBase::OnConnect, this, _1, _2)); } -void Client::OnConnect(boost::system::error_code ec, tcp::endpoint) { +void ClientBase::OnConnect(boost::system::error_code ec, tcp::endpoint) { LOG_VERB("On connect"); StopDeadlineTimer(); if (ec) { if (ec == boost::asio::error::operation_aborted) { - // Socket has been closed by OnDeadlineTimer() or DoClose(). + // Socket has been closed by OnDeadlineTimer() or CloseSocket(). LOG_WARN("Connect operation aborted"); } else { LOG_INFO("Connect error"); @@ -156,14 +145,14 @@ void Client::OnConnect(boost::system::error_code ec, tcp::endpoint) { AsyncWrite(); } -void Client::AsyncWrite() { +void ClientBase::AsyncWrite() { LOG_VERB("Request:\n%s", request_->Dump().c_str()); socket_->AsyncWrite(request_->GetPayload(), - std::bind(&Client::OnWrite, this, _1, _2)); + std::bind(&ClientBase::OnWrite, this, _1, _2)); } -void Client::OnWrite(boost::system::error_code ec, std::size_t length) { +void ClientBase::OnWrite(boost::system::error_code ec, std::size_t length) { if (ec) { HandleWriteError(ec); return; @@ -174,11 +163,11 @@ void Client::OnWrite(boost::system::error_code ec, std::size_t length) { AsyncWriteBody(); } -void Client::AsyncWriteBody() { +void ClientBase::AsyncWriteBody() { auto p = request_->body()->NextPayload(true); if (!p.empty()) { - socket_->AsyncWrite(p, std::bind(&Client::OnWriteBody, this, _1, _2)); + socket_->AsyncWrite(p, std::bind(&ClientBase::OnWriteBody, this, _1, _2)); } else { LOG_INFO("Request send"); @@ -190,7 +179,7 @@ void Client::AsyncWriteBody() { } } -void Client::OnWriteBody(boost::system::error_code ec, std::size_t legnth) { +void ClientBase::OnWriteBody(boost::system::error_code ec, std::size_t legnth) { if (ec) { HandleWriteError(ec); return; @@ -200,33 +189,34 @@ void Client::OnWriteBody(boost::system::error_code ec, std::size_t legnth) { AsyncWriteBody(); } -void Client::HandleWriteError(boost::system::error_code ec) { +void ClientBase::HandleWriteError(boost::system::error_code ec) { if (ec == boost::asio::error::operation_aborted) { - // Socket has been closed by OnDeadlineTimer() or DoClose(). + // Socket has been closed by OnDeadlineTimer() or CloseSocket(). LOG_WARN("Write operation aborted"); } else { LOG_ERRO("Socket write error (%s)", ec.message().c_str()); - DoClose(); + CloseSocket(); } error_.Set(Error::kSocketWriteError, "Socket write error"); FinishRequest(); } -void Client::AsyncRead() { - socket_->AsyncReadSome(std::bind(&Client::OnRead, this, _1, _2), &buffer_); +void ClientBase::AsyncRead() { + socket_->AsyncReadSome(std::bind(&ClientBase::OnRead, this, _1, _2), + &buffer_); } -void Client::OnRead(boost::system::error_code ec, std::size_t length) { +void ClientBase::OnRead(boost::system::error_code ec, std::size_t length) { StopDeadlineTimer(); if (ec) { if (ec == boost::asio::error::operation_aborted) { - // Socket has been closed by OnDeadlineTimer() or DoClose(). + // Socket has been closed by OnDeadlineTimer() or CloseSocket(). LOG_WARN("Read operation aborted"); } else { LOG_ERRO("Socket read error (%s)", ec.message().c_str()); - DoClose(); + CloseSocket(); } error_.Set(Error::kSocketReadError, "Socket read error"); @@ -241,7 +231,7 @@ void Client::OnRead(boost::system::error_code ec, std::size_t length) { // Parse the piece of data just read. if (!response_parser_.Parse(buffer_.data(), length)) { LOG_ERRO("Failed to parse the response"); - DoClose(); + CloseSocket(); error_.Set(Error::kParseError, "Response parse error"); FinishRequest(); return; @@ -262,7 +252,7 @@ void Client::OnRead(boost::system::error_code ec, std::size_t length) { if (response_->IsConnectionKeepAlive()) { LOG_INFO("Keep the socket connection alive"); } else { - DoClose(); + CloseSocket(); } // Stop trying to read once all content has been received, because some @@ -277,7 +267,7 @@ void Client::OnRead(boost::system::error_code ec, std::size_t length) { AsyncRead(); } -void Client::AsyncWaitDeadlineTimer(int seconds) { +void ClientBase::AsyncWaitDeadlineTimer(int seconds) { if (seconds <= 0) { deadline_timer_stopped_ = true; return; @@ -288,10 +278,10 @@ void Client::AsyncWaitDeadlineTimer(int seconds) { deadline_timer_stopped_ = false; deadline_timer_.expires_after(std::chrono::seconds(seconds)); - deadline_timer_.async_wait(std::bind(&Client::OnDeadlineTimer, this, _1)); + deadline_timer_.async_wait(std::bind(&ClientBase::OnDeadlineTimer, this, _1)); } -void Client::OnDeadlineTimer(boost::system::error_code ec) { +void ClientBase::OnDeadlineTimer(boost::system::error_code ec) { LOG_VERB("On deadline timer"); deadline_timer_stopped_ = true; @@ -307,7 +297,7 @@ void Client::OnDeadlineTimer(boost::system::error_code ec) { // Cancel the async operations on the socket. // OnXxx() will be called with `error::operation_aborted`. if (connected_) { - DoClose(); + CloseSocket(); } else { socket_->Close(); } @@ -315,7 +305,7 @@ void Client::OnDeadlineTimer(boost::system::error_code ec) { error_.set_timeout(true); } -void Client::StopDeadlineTimer() { +void ClientBase::StopDeadlineTimer() { if (deadline_timer_stopped_) { return; } @@ -332,7 +322,7 @@ void Client::StopDeadlineTimer() { deadline_timer_stopped_ = true; } -void Client::FinishRequest() { +void ClientBase::FinishRequest() { request_mutex_.lock(); if (!request_finished_) { diff --git a/webcc/client_base.h b/webcc/client_base.h new file mode 100644 index 0000000..81e9960 --- /dev/null +++ b/webcc/client_base.h @@ -0,0 +1,164 @@ +#ifndef WEBCC_CLIENT_BASE_H_ +#define WEBCC_CLIENT_BASE_H_ + +#include +#include +#include +#include +#include + +#include "boost/asio/io_context.hpp" +#include "boost/asio/ip/tcp.hpp" +#include "boost/asio/steady_timer.hpp" + +#include "webcc/globals.h" +#include "webcc/request.h" +#include "webcc/response.h" +#include "webcc/response_parser.h" +#include "webcc/socket_base.h" + +namespace webcc { + +class ClientBase { +public: + explicit ClientBase(boost::asio::io_context& io_context); + + ClientBase(const ClientBase&) = delete; + ClientBase& operator=(const ClientBase&) = delete; + + ~ClientBase() = default; + + void set_buffer_size(std::size_t buffer_size) { + if (buffer_size > 0) { + buffer_size_ = buffer_size; + } + } + + void set_connect_timeout(int timeout) { + if (timeout > 0) { + connect_timeout_ = timeout; + } + } + + void set_read_timeout(int timeout) { + if (timeout > 0) { + read_timeout_ = timeout; + } + } + + // Set progress callback to be informed about the read progress. + // NOTE: Don't use move semantics because in practice, there is no difference + // between copying and moving an object of a closure type. + // TODO: Support write progress + void set_progress_callback(ProgressCallback callback) { + progress_callback_ = callback; + } + + // Connect, send request, wait until response is received. + Error Request(RequestPtr request, bool stream = false); + + // Close the connection. + // The async operation on the socket will be canceled. + void Close(); + + bool connected() const { + return connected_; + } + + ResponsePtr response() const { + return response_; + } + + // Reset response object. + // Used to make sure the response object will released even the client object + // itself will be cached for keep-alive purpose. + void Reset() { + response_.reset(); + response_parser_.Init(nullptr, false); + } + +protected: + // Create Socket or SslSocket. + virtual void CreateSocket() = 0; + + // Resolve host. + virtual void Resolve() = 0; + + void CloseSocket(); + + void AsyncResolve(string_view default_port); + + void OnResolve(boost::system::error_code ec, + boost::asio::ip::tcp::resolver::results_type endpoints); + + void OnConnect(boost::system::error_code ec, boost::asio::ip::tcp::endpoint); + + void AsyncWrite(); + void OnWrite(boost::system::error_code ec, std::size_t length); + + void AsyncWriteBody(); + void OnWriteBody(boost::system::error_code ec, std::size_t length); + + void HandleWriteError(boost::system::error_code ec); + + void AsyncRead(); + void OnRead(boost::system::error_code ec, std::size_t length); + + void AsyncWaitDeadlineTimer(int seconds); + void OnDeadlineTimer(boost::system::error_code ec); + void StopDeadlineTimer(); + + void FinishRequest(); + +protected: + boost::asio::io_context& io_context_; + + std::unique_ptr socket_; + + boost::asio::ip::tcp::resolver resolver_; + + bool request_finished_ = true; + std::condition_variable request_cv_; + std::mutex request_mutex_; + + RequestPtr request_; + + ResponsePtr response_; + ResponseParser response_parser_; + + // The length already read. + std::size_t length_read_ = 0; + + // The buffer for reading response. + std::vector buffer_; + + // The size of the buffer for reading response. + // 0 means default value will be used. + std::size_t buffer_size_ = kBufferSize; + + // Timeout (seconds) for connecting to server. + // Default as 0 to disable our own control (i.e., deadline_timer_). + int connect_timeout_ = 0; + + // Timeout (seconds) for reading response. + int read_timeout_ = kMaxReadSeconds; + + // Deadline timer for connecting to server. + boost::asio::steady_timer deadline_timer_; + bool deadline_timer_stopped_ = true; + + // Socket connected or not. + bool connected_ = false; + + // Progress callback (optional). + ProgressCallback progress_callback_; + + // Current error. + Error error_; +}; + +using ClientPtr = std::shared_ptr; + +} // namespace webcc + +#endif // WEBCC_CLIENT_BASE_H_ diff --git a/webcc/ssl_client.cc b/webcc/ssl_client.cc deleted file mode 100644 index 058ddc8..0000000 --- a/webcc/ssl_client.cc +++ /dev/null @@ -1,23 +0,0 @@ -#include "webcc/ssl_client.h" - -#include "boost/algorithm/string.hpp" - -#include "webcc/ssl_socket.h" - -namespace webcc { - -SslClient::SslClient(boost::asio::io_context& io_context, - boost::asio::ssl::context& ssl_context) - : Client(io_context), ssl_context_(ssl_context) { -} - -void SslClient::AsyncConnect() { - if (boost::iequals(request_->url().scheme(), "https")) { - socket_.reset(new SslSocket{ io_context_, ssl_context_ }); - AsyncResolve("443"); - } else { - Client::AsyncConnect(); - } -} - -} // namespace webcc diff --git a/webcc/ssl_client.h b/webcc/ssl_client.h index 0bfdf4d..a416507 100644 --- a/webcc/ssl_client.h +++ b/webcc/ssl_client.h @@ -1,25 +1,34 @@ #ifndef WEBCC_SSL_CLIENT_H_ #define WEBCC_SSL_CLIENT_H_ -#include "webcc/client.h" - #include "boost/asio/ssl/context.hpp" +#include "webcc/client_base.h" +#include "webcc/ssl_socket.h" + #if !WEBCC_ENABLE_SSL #error SSL must be enabled! #endif namespace webcc { -class SslClient final : public Client { +class SslClient final : public ClientBase { public: SslClient(boost::asio::io_context& io_context, - boost::asio::ssl::context& ssl_context); + boost::asio::ssl::context& ssl_context) + : ClientBase(io_context), ssl_context_(ssl_context) { + } ~SslClient() = default; protected: - void AsyncConnect() override; + void CreateSocket() override { + socket_.reset(new SslSocket{ io_context_, ssl_context_ }); + } + + void Resolve() override { + AsyncResolve("443"); + } private: boost::asio::ssl::context& ssl_context_;