diff --git a/net/asio/adc_device_netserver_asio.h b/net/asio/adc_device_netserver_asio.h index 2bf0bf2..bb6ab00 100644 --- a/net/asio/adc_device_netserver_asio.h +++ b/net/asio/adc_device_netserver_asio.h @@ -5,6 +5,7 @@ #include #include #include +#include #include "../adc_device_netserver.h" #include "../adc_endpoint.h" @@ -67,9 +68,36 @@ public: // some default endpoint?!! void start() {} + template + void setupSignals(const RST& stop_sig_num = std::vector{SIGINT, SIGTERM}, + const RRT& restart_sig_num = std::vector{SIGUSR1}) + requires(std::convertible_to, int> && + std::convertible_to, int>) + { + for (const int sig : stop_sig_num) { + _stopSignal.add(sig); + } + + _stopSignal.async_wait([this](std::error_code ec, int signo) { + signalReceived(ec, signo); + this->stopAllSessions(); + }); + + for (const int sig : restart_sig_num) { + _restartSignal.add(sig); + } + + _restartSignal.async_wait([this](std::error_code ec, int signo) { + signalReceived(ec, signo); + // ?!!!!!!! + }); + } + protected: asio::io_context& _ioContext; + asio::signal_set _stopSignal, _restartSignal; + // demonizing ASIO-related methods virtual void daemonizePrepare() { @@ -80,6 +108,8 @@ protected: { _ioContext.notify_fork(asio::io_context::fork_child); } + + virtual void signalReceived(std::error_code, int signo) {}; }; } // namespace adc::impl diff --git a/net/asio/adc_netservice_asio.h b/net/asio/adc_netservice_asio.h index 574cbf3..f8e8d5c 100644 --- a/net/asio/adc_netservice_asio.h +++ b/net/asio/adc_netservice_asio.h @@ -135,7 +135,7 @@ public: // typedefs from transport protocol using socket_t = typename TRANSPORT_PROTOT::socket; - using sock_stream_t = socket_t&; + using tls_stream_t = std::nullptr_t; // acceptor class acceptor_t @@ -147,12 +147,9 @@ public: : _ioContext(io_ctx), _endpoint(), _socket(_ioContext), _acceptor(_ioContext) { } - acceptor_t(asio::io_context& io_ctx, const AdcNetServiceASIOBase::endpoint_t& endpoint) : acceptor_t(io_ctx) + acceptor_t(asio::io_context& io_ctx, const AdcNetServiceASIOBase::endpoint_t& endpoint) + : _ioContext(io_ctx), _endpoint(endpoint), _socket(_ioContext), _acceptor(_ioContext, endpoint) { - if (_endpoint != endpoint) { - _endpoint = endpoint; - _acceptor = _acceptor_t(_ioContext, _endpoint); - } } @@ -169,32 +166,50 @@ public: traits::adc_time_duration_c DT = decltype(DEFAULT_ACCEPT_TIMEOUT)> auto asyncAccept(TokenT&& token, const DT& timeout = DEFAULT_ACCEPT_TIMEOUT) { + enum { start_state, handshake_state, stop_state }; + // no acceptor for UDP-sockets if constexpr (std::is_null_pointer_v<_acceptor_t>) { static_assert(false, "INVALID TRANSPORT PROTOCOL TYPE!"); } + netservice_t srv{_ioContext}; + srv._socket = AdcNetServiceASIOBase::socket_t{_ioContext}; + _socket = AdcNetServiceASIOBase::socket_t{_ioContext}; auto timer = getDeadlineTimer(_acceptor, timeout); // return asio::async_compose( return asio::async_compose( - [timer = std::move(timer), start = true, this](auto& self, std::error_code ec = {}) mutable { + [timer = std::move(timer), srv = std::move(srv), state = start_state, this]( + auto& self, std::error_code ec = {}) mutable { if (!ec) { - if (start) { - start = false; - try { - if (!_acceptor.is_open() || (_acceptor.local_endpoint() != _endpoint)) { - _acceptor = _acceptor_t(_ioContext, _endpoint); + switch (state) { + case start_state: + state = handshake_state; + try { + if (!_acceptor.is_open() || (_acceptor.local_endpoint() != _endpoint)) { + _acceptor = _acceptor_t(_ioContext, _endpoint); + } + } catch (std::system_error err) { + timer->cancel(); + self.complete(err.code(), netservice_t{_ioContext}); + // self.complete(err.code(), std::make_shared(_ioContext)); + return; } - } catch (std::system_error err) { - timer->cancel(); - self.complete(err.code(), netservice_t{_ioContext}); - // self.complete(err.code(), std::make_shared(_ioContext)); - return; - } - return _acceptor.async_accept(_socket, std::move(self)); + return _acceptor.async_accept(_socket, std::move(self)); + break; + case handshake_state: + state = stop_state; + + handshake(); + break; + case stop_state: + finalize(); + break; + default: + break; } } @@ -245,6 +260,9 @@ public: typename TRANSPORT_PROTOT::acceptor>; _acceptor_t _acceptor; + + virtual void handshake() {} + virtual void finalize() {} }; @@ -270,7 +288,21 @@ public: // NOTE: CANNOT MOVE asio::streambuf CORRECTLY?!!! - AdcNetServiceASIOBase(AdcNetServiceASIOBase&& other) = default; + // AdcNetServiceASIOBase(AdcNetServiceASIOBase&& other) = default; + AdcNetServiceASIOBase(AdcNetServiceASIOBase&& other) + : _ioContext(other._ioContext), + _receiveStrand(std::move(other._receiveStrand)), + _receiveQueue(), + _socket(std::move(other._socket)), + _streamBuffer() + + { + _receiveQueue = std::move(_receiveQueue); + auto bytes = asio::buffer_copy(_streamBuffer.prepare(other._streamBuffer.size()), other._streamBuffer.data()); + _streamBuffer.commit(bytes); + } + + // AdcNetServiceASIOBase(AdcNetServiceASIOBase&& other) = delete; AdcNetServiceASIOBase(const AdcNetServiceASIOBase&) = delete; // no copy constructor! @@ -280,7 +312,21 @@ public: AdcNetServiceASIOBase& operator=(const AdcNetServiceASIOBase&) = delete; // AdcNetServiceASIOBase& operator=(AdcNetServiceASIOBase&& other) = delete; - AdcNetServiceASIOBase& operator=(AdcNetServiceASIOBase&& other) = default; + // AdcNetServiceASIOBase& operator=(AdcNetServiceASIOBase&& other) = default; + AdcNetServiceASIOBase& operator=(AdcNetServiceASIOBase&& other) + { + _ioContext = other._ioContext; + _receiveStrand = std::move(other._receiveStrand); + _receiveQueue = std::move(_receiveQueue); + _socket = std::move(other._socket); + + _streamBuffer.consume(_streamBuffer.size()); + + auto bytes = asio::buffer_copy(_streamBuffer.prepare(other._streamBuffer.size()), other._streamBuffer.data()); + _streamBuffer.commit(bytes); + + return *this; + }; constexpr netservice_ident_t ident() const @@ -368,7 +414,7 @@ public: template auto asyncReceive(TokenT&& token, const TimeoutT& timeout = DEFAULT_RECEIVE_TIMEOUT) { - static asio::streambuf _streamBuffer; + // static asio::streambuf _streamBuffer; // check completion token signature and deduce message type // if constexpr (!adc_asio_special_comp_token_c && !is_async_ctx_t) { @@ -524,6 +570,7 @@ public: return ftr.get(); } + // one still may receive messages from queue! std::error_code close() { std::error_code ec; @@ -538,11 +585,30 @@ public: /* additional ASIO-related methods */ - void clear() + void clearRcvQueue() { // clear receiving messages queue // NOTE: there is no racing condition here since using asio::strand! - asio::post(_receiveStrand, [this]() { _receiveQueue = {}; }); + asio::post(_receiveStrand, [this]() { + // + _receiveQueue = {}; + }); + } + + void clearRcvBuff() + { + asio::post(_receiveStrand, [this]() { + // + _streamBuffer.consume(_streamBuffer.size()); + }); + } + + void clearRcvData() + { + asio::post(_receiveStrand, [this]() { + _receiveQueue = {}; + _streamBuffer.consume(_streamBuffer.size()); + }); } void setShutdownType(asio::socket_base::shutdown_type shutdown_type) @@ -569,8 +635,9 @@ protected: asio::io_context::strand _receiveStrand; socket_t _socket; + tls_stream_t _tlsStream; - // asio::streambuf _streamBuffer; + asio::streambuf _streamBuffer; std::queue> _receiveQueue; @@ -622,10 +689,12 @@ template SESSION_PROTOT, traits::adc_output_char_range RMSGT = std::vector> // used only for inner storing of message byte sequence -class AdcNetServiceASIOTLS : public TRANSPORT_PROTOT +class AdcNetServiceASIOTLS : public AdcNetServiceASIOBase { + typedef AdcNetServiceASIOBase service_base_t; + public: - using socket_t = typename TRANSPORT_PROTOT::socket; + using typename service_base_t::socket_t; typedef asio::ssl::stream tls_stream_t; // TLS certificate attributes comparison function: @@ -638,39 +707,101 @@ public: // reimplement acceptor class - class acceptor_t : public AdcNetServiceASIOBase::acceptor_t + class acceptor_t { - using base_t = AdcNetServiceASIOBase; - public: + static constexpr std::chrono::duration DEFAULT_ACCEPT_TIMEOUT = std::chrono::seconds::max(); + typedef AdcNetServiceASIOTLS netservice_t; typedef std::shared_ptr sptr_netservice_t; - using base_t::acceptor_t::acceptor_t; - - typedef std::function async_accept_callback_t; - - template TokenT, - traits::adc_time_duration_c DT = decltype(base_t::acceptor_t::DEFAULT_ACCEPT_TIMEOUT)> - auto asyncAccept(TokenT&& token, const DT& timeout = base_t::acceptor_t::DEFAULT_ACCEPT_TIMEOUT) + acceptor_t(asio::io_context& io_ctx, asio::ssl::context tls_context) + : _ioContext(io_ctx), _endpoint(), _socket(_ioContext), _acceptor(_ioContext) { - enum { starting, handshaking, finishing }; + } + acceptor_t(asio::io_context& io_ctx, const service_base_t::endpoint_t& endpoint, asio::ssl::context tls_context) + : _ioContext(io_ctx), _endpoint(endpoint), _socket(_ioContext), _acceptor(_ioContext, endpoint) + { + } - this->_socket = base_t::socket_t(this->_ioContext); + typedef std::function async_accept_callback_t; + + template TokenT, + traits::adc_time_duration_c DT = decltype(DEFAULT_ACCEPT_TIMEOUT)> + auto asyncAccept(TokenT&& token, const DT& timeout = DEFAULT_ACCEPT_TIMEOUT) + { + enum { start, handshake, stop }; + + this->_socket = AdcNetServiceASIOTLS::socket_t(this->_ioContext); auto timer = getDeadlineTimer(this->_acceptor, timeout); + + netservice_t srv(_ioContext); + + return asio::async_compose( + [timer = std::move(timer), srv = std::move(srv), state = start, this](auto& self, + std::error_code ec = {}) mutable { + if (!ec) { + switch (state) { + case start: + state = handshake; + try { + if (!_acceptor.is_open() || (_acceptor.local_endpoint() != _endpoint)) { + _acceptor = _acceptor_t(_ioContext, _endpoint); + } + } catch (std::system_error err) { + timer->cancel(); + self.complete(err.code(), netservice_t{_ioContext}); + // self.complete(err.code(), std::make_shared(_ioContext)); + return; + } + + return _acceptor.async_accept(_socket, std::move(self)); + break; + case handshake: + state = stop; + srv._socket = std::move(_socket); + srv._tlsStream = asio::ssl::stream(srv._socket, _tlsContext); + return srv._tlsStream.async_handshake(asio::ssl::stream_base::server, std::move(self)); + default: + break; + } + } + + if (isTimeout(timer, ec)) { + ec = std::make_error_code(std::errc::timed_out); + } else { // an error occured in async_accept od async_handshake + timer->cancel(); + } + + self.complete(ec, std::move(srv)); + }, + token, this->_ioContext); } - template TokenT, - traits::adc_time_duration_c DT = decltype(base_t::acceptor_t::DEFAULT_ACCEPT_TIMEOUT)> - auto asyncAccept(const base_t::endpoint_t& endpoint, + template TokenT, + traits::adc_time_duration_c DT = decltype(DEFAULT_ACCEPT_TIMEOUT)> + auto asyncAccept(const AdcNetServiceASIOTLS::endpoint_t& endpoint, TokenT&& token, - const DT& timeout = base_t::acceptor_t::DEFAULT_ACCEPT_TIMEOUT) + const DT& timeout = DEFAULT_ACCEPT_TIMEOUT) { this->_endpoint = endpoint; return asyncAccept(std::forward(token), timeout); } + + protected: + asio::io_context& _ioContext; + AdcNetServiceASIOTLS::endpoint_t _endpoint; + AdcNetServiceASIOTLS::socket_t _socket; + asio::ssl::context& _tlsContext; + + using _acceptor_t = std::conditional_t< + std::derived_from>, + std::nullptr_t, // there is no acceptor + typename TRANSPORT_PROTOT::acceptor>; + + _acceptor_t _acceptor; }; @@ -690,6 +821,7 @@ public: } protected: + tls_stream_t _tlsStream; asio::ssl::context _tlsContext; asio::ssl::verify_mode _tlsPeerVerifyMode; std::string _tlsCertFingerprintDigest;