diff --git a/autotest/client_autotest/client_autotest.cc b/autotest/client_autotest/client_autotest.cc index dadfe83..400f8e7 100644 --- a/autotest/client_autotest/client_autotest.cc +++ b/autotest/client_autotest/client_autotest.cc @@ -297,7 +297,8 @@ TEST(ClientTest, Post) { static webcc::fs::path GenerateTempFile(const std::string& data) { try { - webcc::fs::path path = webcc::fs::temp_directory_path() / webcc::RandomString(10); + webcc::fs::path path = + webcc::fs::temp_directory_path() / webcc::RandomString(10); webcc::fs::ofstream ofs; ofs.open(path, std::ios::binary); diff --git a/autotest/client_timeout_autotest/client_timeout_autotest.cc b/autotest/client_timeout_autotest/client_timeout_autotest.cc index 6f8f009..7207d84 100644 --- a/autotest/client_timeout_autotest/client_timeout_autotest.cc +++ b/autotest/client_timeout_autotest/client_timeout_autotest.cc @@ -101,3 +101,28 @@ TEST_F(ClientTimeoutTest, Timeout) { EXPECT_TRUE(!r); EXPECT_TRUE(timeout); } + +// Test ClientSession::Cancel() +TEST_F(ClientTimeoutTest, SessionCancel) { + webcc::ClientSession session; + session.set_read_timeout(30); + + bool canceled = false; + + // Create a thread to cancel the session after 3 seconds. + std::thread t{ [&session, &canceled]() { + std::this_thread::sleep_for(std::chrono::seconds(3)); + canceled = session.Cancel(); + } }; + + // Send a request and ask the server to sleep 5 seconds before reply. + try { + auto r = session.Send(WEBCC_GET("http://localhost/sleep/5").Port(kPort)()); + } catch (const webcc::Error&) { + } + + t.join(); + + // The request should be canceled. + EXPECT_TRUE(canceled); +} diff --git a/webcc/client.cc b/webcc/client.cc index 77eae82..4ddd25c 100644 --- a/webcc/client.cc +++ b/webcc/client.cc @@ -15,7 +15,6 @@ Client::Client(boost::asio::io_context& io_context, ssl_context_(ssl_context), resolver_(io_context), deadline_timer_(io_context) { - } #else @@ -75,29 +74,28 @@ Error Client::Request(RequestPtr request, bool stream) { } void Client::Close() { - if (!connected_) { - //resolver_.cancel(); // TODO + DoClose(); + + // Don't call FinishRequest() from here! It will be called in the handler + // OnXxx with `error::operation_aborted`. +} + +void Client::DoClose() { + if (connected_) { + connected_ = false; if (socket_) { - // Cancel any async operations on the socket. - LOG_VERB("Close socket"); + LOG_VERB("Shutdown & close socket"); + socket_->Shutdown(); + socket_->Close(); + } + LOG_INFO("Socket closed"); + } else { + // TODO: resolver_.cancel() ? + if (socket_) { + LOG_INFO("Close socket"); socket_->Close(); - // Make sure the current request, if any, could be finished. - FinishRequest(); } - return; - } - - connected_ = false; - - if (socket_) { - LOG_INFO("Shutdown & close socket"); - socket_->Shutdown(); - socket_->Close(); - // Make sure the current request, if any, could be finished. - FinishRequest(); } - - LOG_INFO("Socket closed"); } void Client::AsyncConnect() { @@ -153,7 +151,7 @@ void Client::OnConnect(boost::system::error_code ec, tcp::endpoint) { if (ec) { if (ec == boost::asio::error::operation_aborted) { - // Socket has been closed by OnDeadlineTimer() or Close(). + // Socket has been closed by OnDeadlineTimer() or DoClose(). LOG_WARN("Connect operation aborted"); } else { LOG_INFO("Connect error"); @@ -219,11 +217,11 @@ void Client::OnWriteBody(boost::system::error_code ec, std::size_t legnth) { void Client::HandleWriteError(boost::system::error_code ec) { if (ec == boost::asio::error::operation_aborted) { - // Socket has been closed by OnDeadlineTimer() or Close(). + // Socket has been closed by OnDeadlineTimer() or DoClose(). LOG_WARN("Write operation aborted"); } else { LOG_ERRO("Socket write error (%s)", ec.message().c_str()); - Close(); + DoClose(); } error_.Set(Error::kSocketWriteError, "Socket write error"); @@ -239,11 +237,11 @@ void Client::OnRead(boost::system::error_code ec, std::size_t length) { if (ec) { if (ec == boost::asio::error::operation_aborted) { - // Socket has been closed by OnDeadlineTimer() or Close(). + // Socket has been closed by OnDeadlineTimer() or DoClose(). LOG_WARN("Read operation aborted"); } else { LOG_ERRO("Socket read error (%s)", ec.message().c_str()); - Close(); + DoClose(); } error_.Set(Error::kSocketReadError, "Socket read error"); @@ -258,7 +256,7 @@ void Client::OnRead(boost::system::error_code ec, std::size_t length) { // Parse the piece of data just read. if (!response_parser_.Parse(buffer_.data(), length)) { LOG_ERRO("Failed to parse the response"); - Close(); + DoClose(); error_.Set(Error::kParseError, "Response parse error"); FinishRequest(); return; @@ -279,7 +277,7 @@ void Client::OnRead(boost::system::error_code ec, std::size_t length) { if (response_->IsConnectionKeepAlive()) { LOG_INFO("Keep the socket connection alive"); } else { - Close(); + DoClose(); } // Stop trying to read once all content has been received, because some @@ -324,7 +322,7 @@ void Client::OnDeadlineTimer(boost::system::error_code ec) { // Cancel the async operations on the socket. // OnXxx() will be called with `error::operation_aborted`. if (connected_) { - Close(); + DoClose(); } else { socket_->Close(); } @@ -342,22 +340,24 @@ void Client::StopDeadlineTimer() { try { // Cancel the async wait operation on this timer. deadline_timer_.cancel(); - } catch (const boost::system::system_error&) { + } catch (const boost::system::system_error& e) { + LOG_ERRO("Deadline timer cancel error: %s", e.what()); } deadline_timer_stopped_ = true; } void Client::FinishRequest() { - { - std::lock_guard lock{ request_mutex_ }; - if (!request_finished_) { - request_finished_ = true; - } else { - return; - } + request_mutex_.lock(); + + if (!request_finished_) { + request_finished_ = true; + + request_mutex_.unlock(); + request_cv_.notify_one(); + } else { + request_mutex_.unlock(); } - request_cv_.notify_one(); } } // namespace webcc diff --git a/webcc/client.h b/webcc/client.h index 14997c5..f94cb9b 100644 --- a/webcc/client.h +++ b/webcc/client.h @@ -88,6 +88,8 @@ public: } private: + void DoClose(); + void AsyncConnect(); void AsyncResolve(const std::string& default_port); diff --git a/webcc/client_session.cc b/webcc/client_session.cc index 6967818..83309ad 100644 --- a/webcc/client_session.cc +++ b/webcc/client_session.cc @@ -211,10 +211,12 @@ ResponsePtr ClientSession::Send(RequestPtr request, bool stream, return DoSend(request, stream, callback); } -void ClientSession::Cancel() { +bool ClientSession::Cancel() { if (client_) { client_->Close(); + return true; } + return false; } void ClientSession::InitHeaders() { diff --git a/webcc/client_session.h b/webcc/client_session.h index cc5af47..a545a21 100644 --- a/webcc/client_session.h +++ b/webcc/client_session.h @@ -96,7 +96,9 @@ public: ProgressCallback callback = {}); // Cancel any in-progress connecting, writing or reading. - void Cancel(); + // Return if any client object has been closed. + // It could be used to exit the program as soon as possible. + bool Cancel(); private: void InitHeaders(); diff --git a/webcc/globals.h b/webcc/globals.h index 5421c65..7175898 100644 --- a/webcc/globals.h +++ b/webcc/globals.h @@ -109,6 +109,7 @@ const char* const kAuthorization = "Authorization"; const char* const kContentType = "Content-Type"; const char* const kContentLength = "Content-Length"; const char* const kContentEncoding = "Content-Encoding"; +const char* const kContentMD5 = "Content-MD5"; const char* const kContentDisposition = "Content-Disposition"; const char* const kConnection = "Connection"; const char* const kTransferEncoding = "Transfer-Encoding"; @@ -168,7 +169,7 @@ public: }; public: - Error(Code code = kOK, const std::string& message = "") + explicit Error(Code code = kOK, const std::string& message = "") : code_(code), message_(message), timeout_(false) { }