diff --git a/net/asio/adc_netservice_asio.h b/net/asio/adc_netservice_asio.h index 718e1d6..faac8bd 100644 --- a/net/asio/adc_netservice_asio.h +++ b/net/asio/adc_netservice_asio.h @@ -683,6 +683,7 @@ protected: #ifdef USE_OPENSSL_WITH_ASIO +/* template SESSION_PROTOT, traits::adc_output_char_range RMSGT = @@ -885,7 +886,7 @@ protected: return preverified_ok; } }; - +*/ #endif diff --git a/net/asio/adc_netsrv_asio.h b/net/asio/adc_netsrv_asio.h new file mode 100644 index 0000000..082935e --- /dev/null +++ b/net/asio/adc_netsrv_asio.h @@ -0,0 +1,739 @@ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include +#include + +#ifdef USE_OPENSSL_WITH_ASIO + +#include +#include + +#endif + + +#include "../../common/adc_traits.h" +#include "../adc_net_concepts.h" +#include "../adc_netproto.h" + +namespace adc::traits +{ + +// special ASIO-related template specializations + +template <> +struct adc_func_traits> { + using ret_t = std::nullptr_t; + using args_t = std::tuple; + using arg1_t = std::nullptr_t; + static constexpr size_t arity = 0; +}; + +template <> +struct adc_func_traits> { + using ret_t = std::nullptr_t; + using args_t = std::tuple; + using arg1_t = std::nullptr_t; + static constexpr size_t arity = 0; +}; + +template <> +struct adc_func_traits { + using ret_t = std::nullptr_t; + using args_t = std::tuple; + using arg1_t = std::nullptr_t; + static constexpr size_t arity = 0; +}; + +} // namespace adc::traits + +namespace adc::impl +{ + +template +concept adc_asio_transport_proto_c = + std::derived_from || std::derived_from || + std::derived_from || std::derived_from || + std::derived_from; + + +template +concept adc_asio_tls_transport_proto_c = + std::derived_from || std::derived_from || + std::derived_from; + +template +concept adc_asio_stream_transport_proto_c = + std::derived_from || std::derived_from; + + +template +concept adc_asio_is_future = requires { + // [](std::type_identity>) {}(std::type_identity>()); + [](std::type_identity>) { + }(std::type_identity>{}); +}; + +template +concept adc_asio_is_awaitable = requires { + [](std::type_identity>) { + }(std::type_identity>{}); +}; + + +template +concept adc_asio_special_comp_token_c = + adc_asio_is_future || adc_asio_is_awaitable || std::same_as, asio::deferred_t>; + + +namespace details +{ + + +// template +template +class AdcAcceptorASIO +{ +public: + using netservice_t = SRVT; + + // deduce needed types + using transport_proto_t = typename SRVT::endpoint_t::protocol_type; + using socket_t = typename SRVT::endpoint_t::protocol_type::socket; + using acceptor_t = std::conditional_t>, + std::nullptr_t, // there is no acceptor + typename transport_proto_t::acceptor>; + + static constexpr std::chrono::duration DEFAULT_ACCEPT_TIMEOUT = std::chrono::seconds::max(); + + AdcAcceptorASIO(asio::io_context& io_ctx, const netservice_t::endpoint_t& endpoint) + : _ioContext(io_ctx), _acceptor(io_ctx, endpoint) + { + } + + AdcAcceptorASIO(const AdcAcceptorASIO& other) + : _ioContext(other._ioContext), _acceptor(std::move(other._acceptor)) { + + }; + + template TokenT, + traits::adc_time_duration_c DT = decltype(DEFAULT_ACCEPT_TIMEOUT)> + auto asyncAccept(TokenT&& token, const DT& timeout = DEFAULT_ACCEPT_TIMEOUT) + { + if constexpr (std::is_null_pointer_v) { + static_assert(false, "INVALID TRANSPORT PROTOCOL TYPE!"); + } + + enum { starting, native_accept, post_accept, finishing }; + + auto timer = netservice_t::getDeadlineTimer(_acceptor, timeout); + + // auto srv = std::make_shared(_ioContext); + auto srv = std::make_unique(_ioContext); + + // return asio::async_compose( + // asyncAcceptImplementation{this, _acceptor, std::move(timer), srv, AdcAcceptorASIO::starting}, token, + // _ioContext); + return asio::async_compose( + [timer = std::move(timer), srv = std::move(srv), state = AdcAcceptorASIO::starting, this]( + auto& self, std::error_code ec = {}) mutable { + if (!ec) { + switch (state) { + case starting: + // _starting(srv, state, self); + _starting(srv, state, std::move(self)); + break; + case native_accept: + // _native_accept(srv, state, self); + _native_accept(srv, state, std::move(self)); + break; + case post_accept: + // _post_accept(srv, state, self); + _post_accept(srv, state, std::move(self)); + break; + case finishing: + // _finishing(srv, state, self); + _finishing(srv, state, std::move(self)); + break; + default: + break; + } + } + + if (netservice_t::isTimeout(timer, ec)) { + ec = std::make_error_code(std::errc::timed_out); + } else { // an error occured in async_accept + timer->cancel(); + } + + self.complete(ec, std::move(*srv)); + }, + token, _ioContext); + } + + template + auto accept(const DT& timeout = DEFAULT_ACCEPT_TIMEOUT) + { + auto f = asyncAccept(asio::use_future, timeout); + + return f.get(); + } + +protected: + asio::io_context& _ioContext; + acceptor_t _acceptor; + + enum state_t { starting, native_accept, post_accept, finishing }; + + // using self_t = std::function; + using self_t = std::function; + + + + typedef std::function&, state_t&, self_t)> stage_func_t; + // typedef std::function&, state_t&, self_t)> stage_func_t; + + stage_func_t _starting = [this](auto&, state_t& state, self_t self) mutable { + state = native_accept; + // asio::post(_ioContext, std::bind([](auto, auto) {}, std::move(self), std::error_code{})); + }; + + stage_func_t _native_accept = [this](auto& srv, state_t& state, self_t self) mutable { + state = post_accept; + _acceptor.async_accept(srv->_socket, std::move(self)); + }; + + stage_func_t _post_accept = [this](auto&, state_t& state, self_t self) mutable { state = finishing; }; + + stage_func_t _finishing = [](auto&, state_t&, self_t) mutable {}; + + /* + struct asyncAcceptImplementation { + AdcAcceptorASIO* acp; + acceptor_t& _acceptor; + std::shared_ptr timer; + std::shared_ptr srv; + // std::unique_ptr timer; + // std::unique_ptr srv; + state_t state; + + asyncAcceptImplementation(AdcAcceptorASIO* a, + acceptor_t& ar, + std::shared_ptr tm, + std::shared_ptr s, + state_t st) + : acp(a), _acceptor(ar), timer(tm), srv(s), state(st) + { + } + + asyncAcceptImplementation(const asyncAcceptImplementation& other) : _acceptor(other._acceptor) + { + acp = other.acp; + timer = other.timer; + srv = other.srv; + state = other.state; + } + + void operator()(auto& self, std::error_code ec = {}) + { + if (!ec) { + switch (state) { + case starting: + acp->_starting(srv, state, self); + break; + case native_accept: + acp->_native_accept(srv, state, self); + break; + case post_accept: + acp->_post_accept(srv, state, self); + break; + case finishing: + acp->_finishing(srv, state, self); + break; + default: + break; + } + } + + if (netservice_t::isTimeout(timer, ec)) { + ec = std::make_error_code(std::errc::timed_out); + } else { // an error occured in async_accept + timer->cancel(); + } + + self.complete(ec, std::move(*srv)); + } + }; + */ +}; + +} // namespace details + +template SESSION_PROTOT, + traits::adc_output_char_range RMSGT = + std::vector> // used only for inner storing of message byte sequence +class AdcNetServiceASIOBase : public SESSION_PROTOT +{ +public: + friend details::AdcAcceptorASIO; + + // typedefs to satisfy 'adc_netservice_c' concept + typedef std::string_view netservice_ident_t; + + typedef std::vector send_msg_t; // in general, only one of several possible + typedef RMSGT recv_msg_t; // in general, only one of several possible (see class template arguments declaration) + typedef traits::adc_common_duration_t timeout_t; + using endpoint_t = typename TRANSPORT_PROTOT::endpoint; + + // typedefs for completion tokens (callbacks, required by 'adc_netservice_c' concept) + typedef std::function async_connect_callback_t; + typedef std::function async_send_callback_t; + typedef std::function async_receive_callback_t; + + + // typedefs from transport protocol + using socket_t = typename TRANSPORT_PROTOT::socket; + + typedef details::AdcAcceptorASIO acceptor_t; + + static constexpr std::chrono::duration DEFAULT_CONNECT_TIMEOUT = std::chrono::seconds(10); + static constexpr std::chrono::duration DEFAULT_SEND_TIMEOUT = std::chrono::seconds(5); + static constexpr std::chrono::duration DEFAULT_RECEIVE_TIMEOUT = std::chrono::seconds(5); + + AdcNetServiceASIOBase(asio::io_context& ctx) + : SESSION_PROTOT(), _ioContext(ctx), _receiveStrand(_ioContext), _socket(_ioContext), _receiveQueue() + { + } + + + AdcNetServiceASIOBase(socket_t socket) + : SESSION_PROTOT(), + _ioContext(static_cast(socket.get_executor().context())), + _socket(std::move(socket)), + _receiveStrand(_ioContext), + _receiveQueue() + { + } + + + // NOTE: CANNOT MOVE asio::streambuf CORRECTLY?!!! + // AdcNetServiceASIOBase(AdcNetServiceASIOBase&& other) = default; + AdcNetServiceASIOBase(AdcNetServiceASIOBase&& other) + : _ioContext(other._ioContext), + _receiveStrand(std::move(other._receiveStrand)), + _socket(std::move(other._socket)), + _streamBuffer(), + _receiveQueue(std::move(other._receiveQueue)) + + { + auto bytes = asio::buffer_copy(_streamBuffer.prepare(other._streamBuffer.size()), other._streamBuffer.data()); + _streamBuffer.commit(bytes); + } + + + // AdcNetServiceASIOBase(AdcNetServiceASIOBase&& other) = delete; + AdcNetServiceASIOBase(const AdcNetServiceASIOBase&) = delete; // no copy constructor! + + virtual ~AdcNetServiceASIOBase() {} + + + AdcNetServiceASIOBase& operator=(const AdcNetServiceASIOBase&) = delete; + + // AdcNetServiceASIOBase& operator=(AdcNetServiceASIOBase&& other) = delete; + // AdcNetServiceASIOBase& operator=(AdcNetServiceASIOBase&& other) = default; + AdcNetServiceASIOBase& operator=(AdcNetServiceASIOBase&& other) + { + _ioContext = other._ioContext; + _receiveStrand = std::move(other._receiveStrand); + _receiveQueue = std::move(other._receiveQueue); + _socket = std::move(other._socket); + + _streamBuffer.consume(_streamBuffer.size()); + + auto bytes = asio::buffer_copy(_streamBuffer.prepare(other._streamBuffer.size()), other._streamBuffer.data()); + _streamBuffer.commit(bytes); + + return *this; + }; + + + constexpr netservice_ident_t ident() const + { + return _ident; + } + + + /* asynchronuos methods */ + + template TokenT, + traits::adc_time_duration_c TimeoutT = decltype(DEFAULT_CONNECT_TIMEOUT)> + auto asyncConnect(const endpoint_t& endpoint, TokenT&& token, const TimeoutT& timeout = DEFAULT_CONNECT_TIMEOUT) + { + auto timer = getDeadlineTimer(_socket, timeout); + + return asio::async_compose( + [start = true, endpoint, timer = std::move(timer), this](auto& self, std::error_code ec = {}) mutable { + if (!ec) { + if (start) { + start = false; + return _socket.async_connect(endpoint, std::move(self)); + } + } + + if (isTimeout(timer, ec)) { + ec = std::make_error_code(std::errc::timed_out); + } else { // an error occured in async_connect + timer->cancel(); + } + + self.complete(ec); + }, + token, _socket); + } + + + template TokenT, + traits::adc_time_duration_c TimeoutT = decltype(DEFAULT_SEND_TIMEOUT)> + auto asyncSend(const MessageT& msg, TokenT&& token, const TimeoutT& timeout = DEFAULT_SEND_TIMEOUT) + { + // create buffer sequence of sending session protocol representation of the input message + std::vector buff_seq; + // std::ranges::for_each(this->toProto(msg), [&buff_seq](const auto& el) { buff_seq.emplace_back(el); }); + std::ranges::for_each(this->toProto(msg), + [&buff_seq](const auto& el) { buff_seq.emplace_back(el.data(), el.size()); }); + + auto timer = getDeadlineTimer(_socket, timeout); + + return asio::async_compose( + [start = true, buff_seq = std::move(buff_seq), timer = std::move(timer), this]( + auto& self, std::error_code ec = {}, size_t = 0) mutable { + if (!ec) { + if (start) { + start = false; + if constexpr (std::derived_from>) { + return asio::async_write(_socket, buff_seq, std::move(self)); + } else if constexpr (std::derived_from>) { + return _socket.async_send(buff_seq, std::move(self)); + } else if constexpr (std::derived_from>) { + return _socket.async_send(buff_seq, std::move(self)); + } else { + static_assert(false, "UNKNOWN ASIO-LIBRARY SOCKET TYPE!!!"); + } + } + } + + if (isTimeout(timer, ec)) { + ec = std::make_error_code(std::errc::timed_out); + } else { // an error occured in async_write/async_send + timer->cancel(); + } + + self.complete(ec); + }, + token, _socket); + } + + + template + auto asyncReceive(TokenT&& token, const TimeoutT& timeout = DEFAULT_RECEIVE_TIMEOUT) + { + // static asio::streambuf _streamBuffer; + + // check completion token signature and deduce message type + // if constexpr (!adc_asio_special_comp_token_c && !is_async_ctx_t) { + if constexpr (!adc_asio_special_comp_token_c) { + static_assert(traits::adc_func_traits::arity == 2, "INVALID COMPLETION TOKEN SIGNATURE!"); + static_assert(std::is_same_v>, std::error_code>, + "INVALID COMPLETION TOKEN SIGNATURE!"); + static_assert(traits::adc_output_char_range< + std::tuple_element_t<1, typename traits::adc_func_traits::args_t>>, + "INVALID COMPLETION TOKEN SIGNATURE!"); + } + + using msg_t = std::conditional_t< + // adc_asio_special_comp_token_c || is_async_ctx_t, RMSGT, + adc_asio_special_comp_token_c, RMSGT, + std::remove_cvref_t::args_t>>>; + + auto out_flags = std::make_shared(); + + auto timer = getDeadlineTimer(_socket, timeout); + + return asio::async_compose( + [out_flags, do_read = true, timer = std::move(timer), this](auto& self, std::error_code ec = {}, + size_t nbytes = 0) mutable { + msg_t msg; + + if (!ec) { + if (do_read) { + do_read = false; + if (_receiveQueue.size()) { // return message from queue + timer->cancel(); + auto imsg = _receiveQueue.front(); + _receiveQueue.pop(); + if constexpr (std::is_same_v) { + self.complete(std::error_code(), std::move(imsg)); + } else { + self.complete(std::error_code(), {imsg.begin(), imsg.end()}); + } + return; + } + + auto n_avail = _socket.available(); + auto buff = _streamBuffer.prepare(n_avail ? n_avail : 1); + + if constexpr (std::derived_from>) { + return asio::async_read(_socket, std::move(buff), asio::transfer_at_least(1), + std::move(self)); + } else if constexpr (std::derived_from>) { + // datagram, so it should be received at once + return _socket.async_receive(std::move(buff), std::move(self)); + } else if constexpr (std::derived_from>) { + // datagram, so it should be received at once + return _socket.async_receive(std::move(buff), *out_flags, std::move(self)); + } else { + static_assert(false, "UNKNOWN ASIO-LIBRARY SOCKET TYPE!!!"); + } + } + + // zero-length message for SEQ_PACK sockets is EOF + if constexpr (std::derived_from>) { + if (!nbytes) { + timer->cancel(); + self.complete(std::error_code(asio::error::misc_errors::eof), std::move(msg)); + return; + } + } + + _streamBuffer.commit(nbytes); + + // if (!nbytes) { + // do_read = true; + // asio::post(std::move(self)); // initiate consequence socket's read operation + // return; + // } + + auto start_ptr = static_cast(_streamBuffer.data().data()); + + auto net_pack = this->search(std::span(start_ptr, _streamBuffer.size())); + if (net_pack.empty()) { + do_read = true; + asio::post(std::move(self)); // initiate consequence socket's read operation + return; + } + + timer->cancel(); // there were no errors in the asynchronous read-operation, so stop timer + + // here one has at least a single message + + std::ranges::copy(this->fromProto(net_pack), std::back_inserter(msg)); + _streamBuffer.consume(net_pack.size()); + + + while (_streamBuffer.size()) { // search for possible additional session protocol packets + start_ptr = static_cast(_streamBuffer.data().data()); + + net_pack = this->search(std::span(start_ptr, _streamBuffer.size())); + + if (!net_pack.empty()) { + _receiveQueue.emplace(); + std::ranges::copy(this->fromProto(net_pack), std::back_inserter(_receiveQueue.back())); + _streamBuffer.consume(net_pack.size()); + } else { + break; // exit and hold remaining bytes in stream buffer + } + } + } + + if (isTimeout(timer, ec)) { + ec = std::make_error_code(std::errc::timed_out); + } else { // an error occured in async_* + timer->cancel(); + } + + if constexpr (std::is_same_v) { + self.complete(ec, std::move(msg)); + } else { + self.complete(ec, {msg.begin(), msg.end()}); + } + + // if constexpr (adc_asio_special_comp_token_c) { + // self.complete(ec, std::move(msg)); + // } else { // may be of non-RMSGT type + // self.complete(ec, {msg.begin(), msg.end()}); + // } + }, + token, _receiveStrand); + } + + /* blocking methods */ + + template + auto connect(const endpoint_t& endpoint, const TimeoutT& timeout = DEFAULT_CONNECT_TIMEOUT) + { + std::future ftr = asyncConnect(endpoint, asio::use_future, timeout); + ftr.get(); + } + + template + auto send(const R& msg, const TimeoutT& timeout = DEFAULT_SEND_TIMEOUT) + { + std::future ftr = asyncSend(msg, asio::use_future, timeout); + ftr.get(); + } + + template + auto receive(const TimeoutT& timeout = DEFAULT_RECEIVE_TIMEOUT) + { + std::future ftr = asyncReceive(asio::use_future, timeout); + return ftr.get(); + } + + // one still may receive messages from queue! + std::error_code close() + { + std::error_code ec; + + _socket.shutdown(_shutdownType, ec); + if (!ec) { + _socket.close(ec); + } + + return ec; + } + + /* additional ASIO-related methods */ + + void clearRcvQueue() + { + // clear receiving messages queue + // NOTE: there is no racing condition here since using asio::strand! + asio::post(_receiveStrand, [this]() { + // + _receiveQueue = {}; + }); + } + + void clearRcvBuff() + { + asio::post(_receiveStrand, [this]() { + // + _streamBuffer.consume(_streamBuffer.size()); + }); + } + + void clearRcvData() + { + asio::post(_receiveStrand, [this]() { + _receiveQueue = {}; + _streamBuffer.consume(_streamBuffer.size()); + }); + } + + void setShutdownType(asio::socket_base::shutdown_type shutdown_type) + { + _shutdownType = shutdown_type; + } + + asio::socket_base::shutdown_type getShutdownType() const + { + return _shutdownType; + } + +protected: + static constexpr netservice_ident_t _ident = + std::derived_from> + ? "STREAM-SOCKET NETWORK SERVICE" + : std::derived_from> + ? "DATAGRAM-SOCKET NETWORK SERVICE" + : std::derived_from> + ? "SEQPACKET-SOCKET NETWORK SERVICE" + : "UNKNOWN"; + + asio::io_context& _ioContext; + asio::io_context::strand _receiveStrand; + + socket_t _socket; + + asio::streambuf _streamBuffer; + + std::queue> _receiveQueue; + + asio::socket_base::shutdown_type _shutdownType = asio::socket_base::shutdown_both; + + + // public: + template + static std::unique_ptr getDeadlineTimer(CancelableT& obj, + const TimeoutT& timeout, + bool arm = true) + { + auto timer = std::make_unique(obj.get_executor()); + + // if (timeout == std::chrono::duration::max()) { + // return timer; // do not arm the timer if MAX duration are given + // } + + if (arm) { + std::chrono::seconds max_d = std::chrono::duration_cast( + std::chrono::steady_clock::time_point::max() - std::chrono::steady_clock::now() - + std::chrono::seconds(1)); + + timer->expires_after(timeout < max_d ? timeout : max_d); // to avoid overflow! + // timer->expires_after(timeout); + + timer->async_wait([&obj](const std::error_code& ec) mutable { + if (!ec) { + obj.cancel(); + } + }); + } + + return timer; + } + + template + static bool isTimeout(const std::unique_ptr& timer, const std::error_code& ec) + { + auto exp_time = timer->expiry(); + return (exp_time < std::chrono::steady_clock::now()) && (ec == asio::error::operation_aborted); + } + + template + static bool isTimeout(const std::shared_ptr& timer, const std::error_code& ec) + { + auto exp_time = timer->expiry(); + return (exp_time < std::chrono::steady_clock::now()) && (ec == asio::error::operation_aborted); + } +}; + +} // namespace adc::impl diff --git a/tests/adc_netservice_test.cpp b/tests/adc_netservice_test.cpp index 6bf00b4..dd53a9c 100644 --- a/tests/adc_netservice_test.cpp +++ b/tests/adc_netservice_test.cpp @@ -3,6 +3,7 @@ #include "../net/adc_netproto.h" #include "../net/asio/adc_netservice_asio.h" +// #include "../net/asio/adc_netsrv_asio.h" template void receive(T srv)