rework client classes

master
Chunting Gu 4 years ago
parent d294cda3c1
commit 28ae628d56

@ -1,9 +1,5 @@
#include "gtest/gtest.h" #include "gtest/gtest.h"
#include <vector>
#include "boost/algorithm/string.hpp"
#include "webcc/string.h" #include "webcc/string.h"
TEST(StringTest, Trim) { TEST(StringTest, Trim) {

@ -18,7 +18,7 @@ include(GNUInstallDirs)
set(SOURCES set(SOURCES
base64.cc base64.cc
body.cc body.cc
client.cc client_base.cc
client_pool.cc client_pool.cc
client_session.cc client_session.cc
common.cc common.cc
@ -45,6 +45,7 @@ set(SOURCES
set(HEADERS set(HEADERS
base64.h base64.h
body.h body.h
client_base.h
client.h client.h
client_pool.h client_pool.h
client_session.h client_session.h
@ -75,13 +76,8 @@ set(HEADERS
) )
if(WEBCC_ENABLE_SSL) if(WEBCC_ENABLE_SSL)
set(SOURCES ${SOURCES} set(SOURCES ${SOURCES} ssl_socket.cc)
ssl_socket.cc set(HEADERS ${HEADERS} ssl_socket.h ssl_client.h)
ssl_client.cc)
set(HEADERS ${HEADERS}
ssl_socket.h
ssl_client.h)
endif() endif()
if(WEBCC_ENABLE_GZIP) if(WEBCC_ENABLE_GZIP)

@ -1,162 +1,27 @@
#ifndef WEBCC_CLIENT_H_ #ifndef WEBCC_CLIENT_H_
#define WEBCC_CLIENT_H_ #define WEBCC_CLIENT_H_
#include <condition_variable> #include "webcc/client_base.h"
#include <memory> #include "webcc/socket.h"
#include <mutex>
#include <string>
#include <vector>
#include "boost/asio/io_context.hpp"
#include "boost/asio/ip/tcp.hpp"
#include "boost/asio/steady_timer.hpp"
#include "webcc/globals.h"
#include "webcc/request.h"
#include "webcc/response.h"
#include "webcc/response_parser.h"
#include "webcc/socket_base.h"
namespace webcc { namespace webcc {
class Client; class Client final : public ClientBase {
using ClientPtr = std::shared_ptr<Client>;
// Synchronous HTTP & HTTPS client.
// A request won't return until the response is received or timeout occurs.
class Client {
public: public:
explicit Client(boost::asio::io_context& io_context); explicit Client(boost::asio::io_context& io_context)
: ClientBase(io_context) {
Client(const Client&) = delete;
Client& operator=(const Client&) = delete;
~Client() = default;
void set_buffer_size(std::size_t buffer_size) {
if (buffer_size > 0) {
buffer_size_ = buffer_size;
}
}
void set_connect_timeout(int timeout) {
if (timeout > 0) {
connect_timeout_ = timeout;
}
} }
void set_read_timeout(int timeout) { ~Client() = default;
if (timeout > 0) {
read_timeout_ = timeout;
}
}
// Set progress callback to be informed about the read progress.
// NOTE: Don't use move semantics because in practice, there is no difference
// between copying and moving an object of a closure type.
// TODO: Support write progress
void set_progress_callback(ProgressCallback callback) {
progress_callback_ = callback;
}
// Connect, send request, wait until response is received.
Error Request(RequestPtr request, bool stream = false);
// Close the socket.
void Close();
bool connected() const {
return connected_;
}
ResponsePtr response() const { protected:
return response_; void CreateSocket() override {
socket_.reset(new Socket{ io_context_ });
} }
// Reset response object. void Resolve() override {
// Used to make sure the response object will released even the client object AsyncResolve("80");
// itself will be cached for keep-alive purpose.
void Reset() {
response_.reset();
response_parser_.Init(nullptr, false);
} }
protected:
void DoClose();
// TODO: Rename
// TODO: Add class ClientBase ?
virtual void AsyncConnect();
void AsyncResolve(string_view default_port);
void OnResolve(boost::system::error_code ec,
boost::asio::ip::tcp::resolver::results_type endpoints);
void OnConnect(boost::system::error_code ec, boost::asio::ip::tcp::endpoint);
void AsyncWrite();
void OnWrite(boost::system::error_code ec, std::size_t length);
void AsyncWriteBody();
void OnWriteBody(boost::system::error_code ec, std::size_t length);
void HandleWriteError(boost::system::error_code ec);
void AsyncRead();
void OnRead(boost::system::error_code ec, std::size_t length);
void AsyncWaitDeadlineTimer(int seconds);
void OnDeadlineTimer(boost::system::error_code ec);
void StopDeadlineTimer();
void FinishRequest();
protected:
boost::asio::io_context& io_context_;
std::unique_ptr<SocketBase> socket_;
boost::asio::ip::tcp::resolver resolver_;
bool request_finished_ = true;
std::condition_variable request_cv_;
std::mutex request_mutex_;
RequestPtr request_;
ResponsePtr response_;
ResponseParser response_parser_;
// The length already read.
std::size_t length_read_ = 0;
// The buffer for reading response.
std::vector<char> buffer_;
// The size of the buffer for reading response.
// 0 means default value will be used.
std::size_t buffer_size_ = kBufferSize;
// Timeout (seconds) for connecting to server.
// Default as 0 to disable our own control (i.e., deadline_timer_).
int connect_timeout_ = 0;
// Timeout (seconds) for reading response.
int read_timeout_ = kMaxReadSeconds;
// Deadline timer for connecting to server.
boost::asio::steady_timer deadline_timer_;
bool deadline_timer_stopped_ = true;
// Socket connected or not.
bool connected_ = false;
// Progress callback (optional).
ProgressCallback progress_callback_;
// Current error.
Error error_;
}; };
} // namespace webcc } // namespace webcc

@ -1,4 +1,4 @@
#include "webcc/client.h" #include "webcc/client_base.h"
#include "boost/algorithm/string.hpp" #include "boost/algorithm/string.hpp"
@ -10,13 +10,13 @@ using namespace std::placeholders;
namespace webcc { namespace webcc {
Client::Client(boost::asio::io_context& io_context) ClientBase::ClientBase(boost::asio::io_context& io_context)
: io_context_(io_context), : io_context_(io_context),
resolver_(io_context), resolver_(io_context),
deadline_timer_(io_context) { deadline_timer_(io_context) {
} }
Error Client::Request(RequestPtr request, bool stream) { Error ClientBase::Request(RequestPtr request, bool stream) {
LOG_VERB("Request begin"); LOG_VERB("Request begin");
request_finished_ = false; request_finished_ = false;
@ -48,7 +48,8 @@ Error Client::Request(RequestPtr request, bool stream) {
} }
if (!connected_) { if (!connected_) {
AsyncConnect(); CreateSocket();
Resolve();
} else { } else {
AsyncWrite(); AsyncWrite();
} }
@ -62,14 +63,14 @@ Error Client::Request(RequestPtr request, bool stream) {
return error_; return error_;
} }
void Client::Close() { void ClientBase::Close() {
DoClose(); CloseSocket();
// Don't call FinishRequest() from here! It will be called in the handler // Don't call FinishRequest() from here! It will be called in the handler
// OnXxx with `error::operation_aborted`. // OnXxx with `error::operation_aborted`.
} }
void Client::DoClose() { void ClientBase::CloseSocket() {
if (connected_) { if (connected_) {
connected_ = false; connected_ = false;
if (socket_) { if (socket_) {
@ -87,19 +88,7 @@ void Client::DoClose() {
} }
} }
void Client::AsyncConnect() { void ClientBase::AsyncResolve(string_view default_port) {
if (boost::iequals(request_->url().scheme(), "http")) {
socket_.reset(new Socket{ io_context_ });
AsyncResolve("80");
} else {
LOG_ERRO("URL scheme (%s) is not supported",
request_->url().scheme().c_str());
error_.Set(Error::kSyntaxError, "URL scheme not supported");
FinishRequest();
}
}
void Client::AsyncResolve(string_view default_port) {
std::string port = request_->port(); std::string port = request_->port();
if (port.empty()) { if (port.empty()) {
port = ToString(default_port); port = ToString(default_port);
@ -109,11 +98,11 @@ void Client::AsyncResolve(string_view default_port) {
// The protocol depends on the `host`, both V4 and V6 are supported. // The protocol depends on the `host`, both V4 and V6 are supported.
resolver_.async_resolve(request_->host(), port, resolver_.async_resolve(request_->host(), port,
std::bind(&Client::OnResolve, this, _1, _2)); std::bind(&ClientBase::OnResolve, this, _1, _2));
} }
void Client::OnResolve(boost::system::error_code ec, void ClientBase::OnResolve(boost::system::error_code ec,
tcp::resolver::results_type endpoints) { tcp::resolver::results_type endpoints) {
if (ec) { if (ec) {
LOG_ERRO("Host resolve error (%s)", ec.message().c_str()); LOG_ERRO("Host resolve error (%s)", ec.message().c_str());
error_.Set(Error::kResolveError, "Host resolve error"); error_.Set(Error::kResolveError, "Host resolve error");
@ -126,17 +115,17 @@ void Client::OnResolve(boost::system::error_code ec,
AsyncWaitDeadlineTimer(connect_timeout_); AsyncWaitDeadlineTimer(connect_timeout_);
socket_->AsyncConnect(request_->host(), endpoints, socket_->AsyncConnect(request_->host(), endpoints,
std::bind(&Client::OnConnect, this, _1, _2)); std::bind(&ClientBase::OnConnect, this, _1, _2));
} }
void Client::OnConnect(boost::system::error_code ec, tcp::endpoint) { void ClientBase::OnConnect(boost::system::error_code ec, tcp::endpoint) {
LOG_VERB("On connect"); LOG_VERB("On connect");
StopDeadlineTimer(); StopDeadlineTimer();
if (ec) { if (ec) {
if (ec == boost::asio::error::operation_aborted) { if (ec == boost::asio::error::operation_aborted) {
// Socket has been closed by OnDeadlineTimer() or DoClose(). // Socket has been closed by OnDeadlineTimer() or CloseSocket().
LOG_WARN("Connect operation aborted"); LOG_WARN("Connect operation aborted");
} else { } else {
LOG_INFO("Connect error"); LOG_INFO("Connect error");
@ -156,14 +145,14 @@ void Client::OnConnect(boost::system::error_code ec, tcp::endpoint) {
AsyncWrite(); AsyncWrite();
} }
void Client::AsyncWrite() { void ClientBase::AsyncWrite() {
LOG_VERB("Request:\n%s", request_->Dump().c_str()); LOG_VERB("Request:\n%s", request_->Dump().c_str());
socket_->AsyncWrite(request_->GetPayload(), socket_->AsyncWrite(request_->GetPayload(),
std::bind(&Client::OnWrite, this, _1, _2)); std::bind(&ClientBase::OnWrite, this, _1, _2));
} }
void Client::OnWrite(boost::system::error_code ec, std::size_t length) { void ClientBase::OnWrite(boost::system::error_code ec, std::size_t length) {
if (ec) { if (ec) {
HandleWriteError(ec); HandleWriteError(ec);
return; return;
@ -174,11 +163,11 @@ void Client::OnWrite(boost::system::error_code ec, std::size_t length) {
AsyncWriteBody(); AsyncWriteBody();
} }
void Client::AsyncWriteBody() { void ClientBase::AsyncWriteBody() {
auto p = request_->body()->NextPayload(true); auto p = request_->body()->NextPayload(true);
if (!p.empty()) { if (!p.empty()) {
socket_->AsyncWrite(p, std::bind(&Client::OnWriteBody, this, _1, _2)); socket_->AsyncWrite(p, std::bind(&ClientBase::OnWriteBody, this, _1, _2));
} else { } else {
LOG_INFO("Request send"); LOG_INFO("Request send");
@ -190,7 +179,7 @@ void Client::AsyncWriteBody() {
} }
} }
void Client::OnWriteBody(boost::system::error_code ec, std::size_t legnth) { void ClientBase::OnWriteBody(boost::system::error_code ec, std::size_t legnth) {
if (ec) { if (ec) {
HandleWriteError(ec); HandleWriteError(ec);
return; return;
@ -200,33 +189,34 @@ void Client::OnWriteBody(boost::system::error_code ec, std::size_t legnth) {
AsyncWriteBody(); AsyncWriteBody();
} }
void Client::HandleWriteError(boost::system::error_code ec) { void ClientBase::HandleWriteError(boost::system::error_code ec) {
if (ec == boost::asio::error::operation_aborted) { if (ec == boost::asio::error::operation_aborted) {
// Socket has been closed by OnDeadlineTimer() or DoClose(). // Socket has been closed by OnDeadlineTimer() or CloseSocket().
LOG_WARN("Write operation aborted"); LOG_WARN("Write operation aborted");
} else { } else {
LOG_ERRO("Socket write error (%s)", ec.message().c_str()); LOG_ERRO("Socket write error (%s)", ec.message().c_str());
DoClose(); CloseSocket();
} }
error_.Set(Error::kSocketWriteError, "Socket write error"); error_.Set(Error::kSocketWriteError, "Socket write error");
FinishRequest(); FinishRequest();
} }
void Client::AsyncRead() { void ClientBase::AsyncRead() {
socket_->AsyncReadSome(std::bind(&Client::OnRead, this, _1, _2), &buffer_); socket_->AsyncReadSome(std::bind(&ClientBase::OnRead, this, _1, _2),
&buffer_);
} }
void Client::OnRead(boost::system::error_code ec, std::size_t length) { void ClientBase::OnRead(boost::system::error_code ec, std::size_t length) {
StopDeadlineTimer(); StopDeadlineTimer();
if (ec) { if (ec) {
if (ec == boost::asio::error::operation_aborted) { if (ec == boost::asio::error::operation_aborted) {
// Socket has been closed by OnDeadlineTimer() or DoClose(). // Socket has been closed by OnDeadlineTimer() or CloseSocket().
LOG_WARN("Read operation aborted"); LOG_WARN("Read operation aborted");
} else { } else {
LOG_ERRO("Socket read error (%s)", ec.message().c_str()); LOG_ERRO("Socket read error (%s)", ec.message().c_str());
DoClose(); CloseSocket();
} }
error_.Set(Error::kSocketReadError, "Socket read error"); error_.Set(Error::kSocketReadError, "Socket read error");
@ -241,7 +231,7 @@ void Client::OnRead(boost::system::error_code ec, std::size_t length) {
// Parse the piece of data just read. // Parse the piece of data just read.
if (!response_parser_.Parse(buffer_.data(), length)) { if (!response_parser_.Parse(buffer_.data(), length)) {
LOG_ERRO("Failed to parse the response"); LOG_ERRO("Failed to parse the response");
DoClose(); CloseSocket();
error_.Set(Error::kParseError, "Response parse error"); error_.Set(Error::kParseError, "Response parse error");
FinishRequest(); FinishRequest();
return; return;
@ -262,7 +252,7 @@ void Client::OnRead(boost::system::error_code ec, std::size_t length) {
if (response_->IsConnectionKeepAlive()) { if (response_->IsConnectionKeepAlive()) {
LOG_INFO("Keep the socket connection alive"); LOG_INFO("Keep the socket connection alive");
} else { } else {
DoClose(); CloseSocket();
} }
// Stop trying to read once all content has been received, because some // Stop trying to read once all content has been received, because some
@ -277,7 +267,7 @@ void Client::OnRead(boost::system::error_code ec, std::size_t length) {
AsyncRead(); AsyncRead();
} }
void Client::AsyncWaitDeadlineTimer(int seconds) { void ClientBase::AsyncWaitDeadlineTimer(int seconds) {
if (seconds <= 0) { if (seconds <= 0) {
deadline_timer_stopped_ = true; deadline_timer_stopped_ = true;
return; return;
@ -288,10 +278,10 @@ void Client::AsyncWaitDeadlineTimer(int seconds) {
deadline_timer_stopped_ = false; deadline_timer_stopped_ = false;
deadline_timer_.expires_after(std::chrono::seconds(seconds)); deadline_timer_.expires_after(std::chrono::seconds(seconds));
deadline_timer_.async_wait(std::bind(&Client::OnDeadlineTimer, this, _1)); deadline_timer_.async_wait(std::bind(&ClientBase::OnDeadlineTimer, this, _1));
} }
void Client::OnDeadlineTimer(boost::system::error_code ec) { void ClientBase::OnDeadlineTimer(boost::system::error_code ec) {
LOG_VERB("On deadline timer"); LOG_VERB("On deadline timer");
deadline_timer_stopped_ = true; deadline_timer_stopped_ = true;
@ -307,7 +297,7 @@ void Client::OnDeadlineTimer(boost::system::error_code ec) {
// Cancel the async operations on the socket. // Cancel the async operations on the socket.
// OnXxx() will be called with `error::operation_aborted`. // OnXxx() will be called with `error::operation_aborted`.
if (connected_) { if (connected_) {
DoClose(); CloseSocket();
} else { } else {
socket_->Close(); socket_->Close();
} }
@ -315,7 +305,7 @@ void Client::OnDeadlineTimer(boost::system::error_code ec) {
error_.set_timeout(true); error_.set_timeout(true);
} }
void Client::StopDeadlineTimer() { void ClientBase::StopDeadlineTimer() {
if (deadline_timer_stopped_) { if (deadline_timer_stopped_) {
return; return;
} }
@ -332,7 +322,7 @@ void Client::StopDeadlineTimer() {
deadline_timer_stopped_ = true; deadline_timer_stopped_ = true;
} }
void Client::FinishRequest() { void ClientBase::FinishRequest() {
request_mutex_.lock(); request_mutex_.lock();
if (!request_finished_) { if (!request_finished_) {

@ -0,0 +1,164 @@
#ifndef WEBCC_CLIENT_BASE_H_
#define WEBCC_CLIENT_BASE_H_
#include <condition_variable>
#include <memory>
#include <mutex>
#include <string>
#include <vector>
#include "boost/asio/io_context.hpp"
#include "boost/asio/ip/tcp.hpp"
#include "boost/asio/steady_timer.hpp"
#include "webcc/globals.h"
#include "webcc/request.h"
#include "webcc/response.h"
#include "webcc/response_parser.h"
#include "webcc/socket_base.h"
namespace webcc {
class ClientBase {
public:
explicit ClientBase(boost::asio::io_context& io_context);
ClientBase(const ClientBase&) = delete;
ClientBase& operator=(const ClientBase&) = delete;
~ClientBase() = default;
void set_buffer_size(std::size_t buffer_size) {
if (buffer_size > 0) {
buffer_size_ = buffer_size;
}
}
void set_connect_timeout(int timeout) {
if (timeout > 0) {
connect_timeout_ = timeout;
}
}
void set_read_timeout(int timeout) {
if (timeout > 0) {
read_timeout_ = timeout;
}
}
// Set progress callback to be informed about the read progress.
// NOTE: Don't use move semantics because in practice, there is no difference
// between copying and moving an object of a closure type.
// TODO: Support write progress
void set_progress_callback(ProgressCallback callback) {
progress_callback_ = callback;
}
// Connect, send request, wait until response is received.
Error Request(RequestPtr request, bool stream = false);
// Close the connection.
// The async operation on the socket will be canceled.
void Close();
bool connected() const {
return connected_;
}
ResponsePtr response() const {
return response_;
}
// Reset response object.
// Used to make sure the response object will released even the client object
// itself will be cached for keep-alive purpose.
void Reset() {
response_.reset();
response_parser_.Init(nullptr, false);
}
protected:
// Create Socket or SslSocket.
virtual void CreateSocket() = 0;
// Resolve host.
virtual void Resolve() = 0;
void CloseSocket();
void AsyncResolve(string_view default_port);
void OnResolve(boost::system::error_code ec,
boost::asio::ip::tcp::resolver::results_type endpoints);
void OnConnect(boost::system::error_code ec, boost::asio::ip::tcp::endpoint);
void AsyncWrite();
void OnWrite(boost::system::error_code ec, std::size_t length);
void AsyncWriteBody();
void OnWriteBody(boost::system::error_code ec, std::size_t length);
void HandleWriteError(boost::system::error_code ec);
void AsyncRead();
void OnRead(boost::system::error_code ec, std::size_t length);
void AsyncWaitDeadlineTimer(int seconds);
void OnDeadlineTimer(boost::system::error_code ec);
void StopDeadlineTimer();
void FinishRequest();
protected:
boost::asio::io_context& io_context_;
std::unique_ptr<SocketBase> socket_;
boost::asio::ip::tcp::resolver resolver_;
bool request_finished_ = true;
std::condition_variable request_cv_;
std::mutex request_mutex_;
RequestPtr request_;
ResponsePtr response_;
ResponseParser response_parser_;
// The length already read.
std::size_t length_read_ = 0;
// The buffer for reading response.
std::vector<char> buffer_;
// The size of the buffer for reading response.
// 0 means default value will be used.
std::size_t buffer_size_ = kBufferSize;
// Timeout (seconds) for connecting to server.
// Default as 0 to disable our own control (i.e., deadline_timer_).
int connect_timeout_ = 0;
// Timeout (seconds) for reading response.
int read_timeout_ = kMaxReadSeconds;
// Deadline timer for connecting to server.
boost::asio::steady_timer deadline_timer_;
bool deadline_timer_stopped_ = true;
// Socket connected or not.
bool connected_ = false;
// Progress callback (optional).
ProgressCallback progress_callback_;
// Current error.
Error error_;
};
using ClientPtr = std::shared_ptr<ClientBase>;
} // namespace webcc
#endif // WEBCC_CLIENT_BASE_H_

@ -1,23 +0,0 @@
#include "webcc/ssl_client.h"
#include "boost/algorithm/string.hpp"
#include "webcc/ssl_socket.h"
namespace webcc {
SslClient::SslClient(boost::asio::io_context& io_context,
boost::asio::ssl::context& ssl_context)
: Client(io_context), ssl_context_(ssl_context) {
}
void SslClient::AsyncConnect() {
if (boost::iequals(request_->url().scheme(), "https")) {
socket_.reset(new SslSocket{ io_context_, ssl_context_ });
AsyncResolve("443");
} else {
Client::AsyncConnect();
}
}
} // namespace webcc

@ -1,25 +1,34 @@
#ifndef WEBCC_SSL_CLIENT_H_ #ifndef WEBCC_SSL_CLIENT_H_
#define WEBCC_SSL_CLIENT_H_ #define WEBCC_SSL_CLIENT_H_
#include "webcc/client.h"
#include "boost/asio/ssl/context.hpp" #include "boost/asio/ssl/context.hpp"
#include "webcc/client_base.h"
#include "webcc/ssl_socket.h"
#if !WEBCC_ENABLE_SSL #if !WEBCC_ENABLE_SSL
#error SSL must be enabled! #error SSL must be enabled!
#endif #endif
namespace webcc { namespace webcc {
class SslClient final : public Client { class SslClient final : public ClientBase {
public: public:
SslClient(boost::asio::io_context& io_context, SslClient(boost::asio::io_context& io_context,
boost::asio::ssl::context& ssl_context); boost::asio::ssl::context& ssl_context)
: ClientBase(io_context), ssl_context_(ssl_context) {
}
~SslClient() = default; ~SslClient() = default;
protected: protected:
void AsyncConnect() override; void CreateSocket() override {
socket_.reset(new SslSocket{ io_context_, ssl_context_ });
}
void Resolve() override {
AsyncResolve("443");
}
private: private:
boost::asio::ssl::context& ssl_context_; boost::asio::ssl::context& ssl_context_;

Loading…
Cancel
Save