This commit is contained in:
Timur A. Fatkhullin 2024-10-22 18:35:53 +03:00
parent 0e937bb2d7
commit addd13d826
2 changed files with 205 additions and 43 deletions

View File

@ -5,6 +5,7 @@
#include <asio/ip/udp.hpp>
#include <asio/local/seq_packet_protocol.hpp>
#include <asio/local/stream_protocol.hpp>
#include <asio/signal_set.hpp>
#include "../adc_device_netserver.h"
#include "../adc_endpoint.h"
@ -67,9 +68,36 @@ public:
// some default endpoint?!!
void start() {}
template <std::ranges::range RST, std::ranges::range RRT>
void setupSignals(const RST& stop_sig_num = std::vector<int>{SIGINT, SIGTERM},
const RRT& restart_sig_num = std::vector<int>{SIGUSR1})
requires(std::convertible_to<std::ranges::range_value_t<RST>, int> &&
std::convertible_to<std::ranges::range_value_t<RRT>, int>)
{
for (const int sig : stop_sig_num) {
_stopSignal.add(sig);
}
_stopSignal.async_wait([this](std::error_code ec, int signo) {
signalReceived(ec, signo);
this->stopAllSessions();
});
for (const int sig : restart_sig_num) {
_restartSignal.add(sig);
}
_restartSignal.async_wait([this](std::error_code ec, int signo) {
signalReceived(ec, signo);
// ?!!!!!!!
});
}
protected:
asio::io_context& _ioContext;
asio::signal_set _stopSignal, _restartSignal;
// demonizing ASIO-related methods
virtual void daemonizePrepare()
{
@ -80,6 +108,8 @@ protected:
{
_ioContext.notify_fork(asio::io_context::fork_child);
}
virtual void signalReceived(std::error_code, int signo) {};
};
} // namespace adc::impl

View File

@ -135,7 +135,7 @@ public:
// typedefs from transport protocol
using socket_t = typename TRANSPORT_PROTOT::socket;
using sock_stream_t = socket_t&;
using tls_stream_t = std::nullptr_t;
// acceptor
class acceptor_t
@ -147,12 +147,9 @@ public:
: _ioContext(io_ctx), _endpoint(), _socket(_ioContext), _acceptor(_ioContext)
{
}
acceptor_t(asio::io_context& io_ctx, const AdcNetServiceASIOBase::endpoint_t& endpoint) : acceptor_t(io_ctx)
acceptor_t(asio::io_context& io_ctx, const AdcNetServiceASIOBase::endpoint_t& endpoint)
: _ioContext(io_ctx), _endpoint(endpoint), _socket(_ioContext), _acceptor(_ioContext, endpoint)
{
if (_endpoint != endpoint) {
_endpoint = endpoint;
_acceptor = _acceptor_t(_ioContext, _endpoint);
}
}
@ -169,32 +166,50 @@ public:
traits::adc_time_duration_c DT = decltype(DEFAULT_ACCEPT_TIMEOUT)>
auto asyncAccept(TokenT&& token, const DT& timeout = DEFAULT_ACCEPT_TIMEOUT)
{
enum { start_state, handshake_state, stop_state };
// no acceptor for UDP-sockets
if constexpr (std::is_null_pointer_v<_acceptor_t>) {
static_assert(false, "INVALID TRANSPORT PROTOCOL TYPE!");
}
netservice_t srv{_ioContext};
srv._socket = AdcNetServiceASIOBase::socket_t{_ioContext};
_socket = AdcNetServiceASIOBase::socket_t{_ioContext};
auto timer = getDeadlineTimer(_acceptor, timeout);
// return asio::async_compose<TokenT, void(std::error_code, sptr_netservice_t)>(
return asio::async_compose<TokenT, void(std::error_code, netservice_t)>(
[timer = std::move(timer), start = true, this](auto& self, std::error_code ec = {}) mutable {
[timer = std::move(timer), srv = std::move(srv), state = start_state, this](
auto& self, std::error_code ec = {}) mutable {
if (!ec) {
if (start) {
start = false;
try {
if (!_acceptor.is_open() || (_acceptor.local_endpoint() != _endpoint)) {
_acceptor = _acceptor_t(_ioContext, _endpoint);
switch (state) {
case start_state:
state = handshake_state;
try {
if (!_acceptor.is_open() || (_acceptor.local_endpoint() != _endpoint)) {
_acceptor = _acceptor_t(_ioContext, _endpoint);
}
} catch (std::system_error err) {
timer->cancel();
self.complete(err.code(), netservice_t{_ioContext});
// self.complete(err.code(), std::make_shared<netservice_t>(_ioContext));
return;
}
} catch (std::system_error err) {
timer->cancel();
self.complete(err.code(), netservice_t{_ioContext});
// self.complete(err.code(), std::make_shared<netservice_t>(_ioContext));
return;
}
return _acceptor.async_accept(_socket, std::move(self));
return _acceptor.async_accept(_socket, std::move(self));
break;
case handshake_state:
state = stop_state;
handshake();
break;
case stop_state:
finalize();
break;
default:
break;
}
}
@ -245,6 +260,9 @@ public:
typename TRANSPORT_PROTOT::acceptor>;
_acceptor_t _acceptor;
virtual void handshake() {}
virtual void finalize() {}
};
@ -270,7 +288,21 @@ public:
// NOTE: CANNOT MOVE asio::streambuf CORRECTLY?!!!
AdcNetServiceASIOBase(AdcNetServiceASIOBase&& other) = default;
// AdcNetServiceASIOBase(AdcNetServiceASIOBase&& other) = default;
AdcNetServiceASIOBase(AdcNetServiceASIOBase&& other)
: _ioContext(other._ioContext),
_receiveStrand(std::move(other._receiveStrand)),
_receiveQueue(),
_socket(std::move(other._socket)),
_streamBuffer()
{
_receiveQueue = std::move(_receiveQueue);
auto bytes = asio::buffer_copy(_streamBuffer.prepare(other._streamBuffer.size()), other._streamBuffer.data());
_streamBuffer.commit(bytes);
}
// AdcNetServiceASIOBase(AdcNetServiceASIOBase&& other) = delete;
AdcNetServiceASIOBase(const AdcNetServiceASIOBase&) = delete; // no copy constructor!
@ -280,7 +312,21 @@ public:
AdcNetServiceASIOBase& operator=(const AdcNetServiceASIOBase&) = delete;
// AdcNetServiceASIOBase& operator=(AdcNetServiceASIOBase&& other) = delete;
AdcNetServiceASIOBase& operator=(AdcNetServiceASIOBase&& other) = default;
// AdcNetServiceASIOBase& operator=(AdcNetServiceASIOBase&& other) = default;
AdcNetServiceASIOBase& operator=(AdcNetServiceASIOBase&& other)
{
_ioContext = other._ioContext;
_receiveStrand = std::move(other._receiveStrand);
_receiveQueue = std::move(_receiveQueue);
_socket = std::move(other._socket);
_streamBuffer.consume(_streamBuffer.size());
auto bytes = asio::buffer_copy(_streamBuffer.prepare(other._streamBuffer.size()), other._streamBuffer.data());
_streamBuffer.commit(bytes);
return *this;
};
constexpr netservice_ident_t ident() const
@ -368,7 +414,7 @@ public:
template <typename TokenT, traits::adc_time_duration_c TimeoutT = decltype(DEFAULT_RECEIVE_TIMEOUT)>
auto asyncReceive(TokenT&& token, const TimeoutT& timeout = DEFAULT_RECEIVE_TIMEOUT)
{
static asio::streambuf _streamBuffer;
// static asio::streambuf _streamBuffer;
// check completion token signature and deduce message type
// if constexpr (!adc_asio_special_comp_token_c<TokenT> && !is_async_ctx_t) {
@ -524,6 +570,7 @@ public:
return ftr.get();
}
// one still may receive messages from queue!
std::error_code close()
{
std::error_code ec;
@ -538,11 +585,30 @@ public:
/* additional ASIO-related methods */
void clear()
void clearRcvQueue()
{
// clear receiving messages queue
// NOTE: there is no racing condition here since using asio::strand!
asio::post(_receiveStrand, [this]() { _receiveQueue = {}; });
asio::post(_receiveStrand, [this]() {
//
_receiveQueue = {};
});
}
void clearRcvBuff()
{
asio::post(_receiveStrand, [this]() {
//
_streamBuffer.consume(_streamBuffer.size());
});
}
void clearRcvData()
{
asio::post(_receiveStrand, [this]() {
_receiveQueue = {};
_streamBuffer.consume(_streamBuffer.size());
});
}
void setShutdownType(asio::socket_base::shutdown_type shutdown_type)
@ -569,8 +635,9 @@ protected:
asio::io_context::strand _receiveStrand;
socket_t _socket;
tls_stream_t _tlsStream;
// asio::streambuf _streamBuffer;
asio::streambuf _streamBuffer;
std::queue<std::vector<char>> _receiveQueue;
@ -622,10 +689,12 @@ template <adc_asio_tls_transport_proto_c TRANSPORT_PROTOT,
interfaces::adc_netsession_proto_c<std::string_view> SESSION_PROTOT,
traits::adc_output_char_range RMSGT =
std::vector<char>> // used only for inner storing of message byte sequence
class AdcNetServiceASIOTLS : public TRANSPORT_PROTOT
class AdcNetServiceASIOTLS : public AdcNetServiceASIOBase<TRANSPORT_PROTOT, SESSION_PROTOT, RMSGT>
{
typedef AdcNetServiceASIOBase<TRANSPORT_PROTOT, SESSION_PROTOT, RMSGT> service_base_t;
public:
using socket_t = typename TRANSPORT_PROTOT::socket;
using typename service_base_t::socket_t;
typedef asio::ssl::stream<socket_t> tls_stream_t;
// TLS certificate attributes comparison function:
@ -638,39 +707,101 @@ public:
// reimplement acceptor class
class acceptor_t : public AdcNetServiceASIOBase<TRANSPORT_PROTOT, SESSION_PROTOT, RMSGT>::acceptor_t
class acceptor_t
{
using base_t = AdcNetServiceASIOBase<TRANSPORT_PROTOT, SESSION_PROTOT, RMSGT>;
public:
static constexpr std::chrono::duration DEFAULT_ACCEPT_TIMEOUT = std::chrono::seconds::max();
typedef AdcNetServiceASIOTLS netservice_t;
typedef std::shared_ptr<netservice_t> sptr_netservice_t;
using base_t::acceptor_t::acceptor_t;
typedef std::function<void(std::error_code, sptr_netservice_t)> async_accept_callback_t;
template <asio::completion_token_for<void(std::error_code, sptr_netservice_t)> TokenT,
traits::adc_time_duration_c DT = decltype(base_t::acceptor_t::DEFAULT_ACCEPT_TIMEOUT)>
auto asyncAccept(TokenT&& token, const DT& timeout = base_t::acceptor_t::DEFAULT_ACCEPT_TIMEOUT)
acceptor_t(asio::io_context& io_ctx, asio::ssl::context tls_context)
: _ioContext(io_ctx), _endpoint(), _socket(_ioContext), _acceptor(_ioContext)
{
enum { starting, handshaking, finishing };
}
acceptor_t(asio::io_context& io_ctx, const service_base_t::endpoint_t& endpoint, asio::ssl::context tls_context)
: _ioContext(io_ctx), _endpoint(endpoint), _socket(_ioContext), _acceptor(_ioContext, endpoint)
{
}
this->_socket = base_t::socket_t(this->_ioContext);
typedef std::function<void(std::error_code, netservice_t)> async_accept_callback_t;
template <asio::completion_token_for<void(std::error_code, netservice_t)> TokenT,
traits::adc_time_duration_c DT = decltype(DEFAULT_ACCEPT_TIMEOUT)>
auto asyncAccept(TokenT&& token, const DT& timeout = DEFAULT_ACCEPT_TIMEOUT)
{
enum { start, handshake, stop };
this->_socket = AdcNetServiceASIOTLS::socket_t(this->_ioContext);
auto timer = getDeadlineTimer(this->_acceptor, timeout);
netservice_t srv(_ioContext);
return asio::async_compose<TokenT, void(std::error_code, netservice_t)>(
[timer = std::move(timer), srv = std::move(srv), state = start, this](auto& self,
std::error_code ec = {}) mutable {
if (!ec) {
switch (state) {
case start:
state = handshake;
try {
if (!_acceptor.is_open() || (_acceptor.local_endpoint() != _endpoint)) {
_acceptor = _acceptor_t(_ioContext, _endpoint);
}
} catch (std::system_error err) {
timer->cancel();
self.complete(err.code(), netservice_t{_ioContext});
// self.complete(err.code(), std::make_shared<netservice_t>(_ioContext));
return;
}
return _acceptor.async_accept(_socket, std::move(self));
break;
case handshake:
state = stop;
srv._socket = std::move(_socket);
srv._tlsStream = asio::ssl::stream(srv._socket, _tlsContext);
return srv._tlsStream.async_handshake(asio::ssl::stream_base::server, std::move(self));
default:
break;
}
}
if (isTimeout(timer, ec)) {
ec = std::make_error_code(std::errc::timed_out);
} else { // an error occured in async_accept od async_handshake
timer->cancel();
}
self.complete(ec, std::move(srv));
},
token, this->_ioContext);
}
template <asio::completion_token_for<void(std::error_code, sptr_netservice_t)> TokenT,
traits::adc_time_duration_c DT = decltype(base_t::acceptor_t::DEFAULT_ACCEPT_TIMEOUT)>
auto asyncAccept(const base_t::endpoint_t& endpoint,
template <asio::completion_token_for<void(std::error_code, netservice_t)> TokenT,
traits::adc_time_duration_c DT = decltype(DEFAULT_ACCEPT_TIMEOUT)>
auto asyncAccept(const AdcNetServiceASIOTLS::endpoint_t& endpoint,
TokenT&& token,
const DT& timeout = base_t::acceptor_t::DEFAULT_ACCEPT_TIMEOUT)
const DT& timeout = DEFAULT_ACCEPT_TIMEOUT)
{
this->_endpoint = endpoint;
return asyncAccept(std::forward<TokenT>(token), timeout);
}
protected:
asio::io_context& _ioContext;
AdcNetServiceASIOTLS::endpoint_t _endpoint;
AdcNetServiceASIOTLS::socket_t _socket;
asio::ssl::context& _tlsContext;
using _acceptor_t = std::conditional_t<
std::derived_from<socket_t, asio::basic_datagram_socket<typename socket_t::protocol_type>>,
std::nullptr_t, // there is no acceptor
typename TRANSPORT_PROTOT::acceptor>;
_acceptor_t _acceptor;
};
@ -690,6 +821,7 @@ public:
}
protected:
tls_stream_t _tlsStream;
asio::ssl::context _tlsContext;
asio::ssl::verify_mode _tlsPeerVerifyMode;
std::string _tlsCertFingerprintDigest;