ADC/net/asio/adc_netservice_asio.h
Timur A. Fatkhullin 2cf0b1f94c ...
2024-09-24 21:51:05 +03:00

624 lines
25 KiB
C++

#pragma once
#include <asio/basic_datagram_socket.hpp>
#include <asio/basic_seq_packet_socket.hpp>
#include <asio/basic_stream_socket.hpp>
#include <asio/bind_executor.hpp>
#include <asio/compose.hpp>
#include <asio/deferred.hpp>
#include <asio/ip/tcp.hpp>
#include <asio/ip/udp.hpp>
#include <asio/local/seq_packet_protocol.hpp>
#include <asio/local/stream_protocol.hpp>
#include <asio/read_until.hpp>
#include <asio/steady_timer.hpp>
#include <asio/strand.hpp>
#include <asio/streambuf.hpp>
#include <asio/use_awaitable.hpp>
#include <asio/use_future.hpp>
#include <asio/write.hpp>
#include <functional>
#include <queue>
#ifdef USE_OPENSSL_WITH_ASIO
#include <asio/ssl.hpp>
#include <asio/ssl/stream.hpp>
#endif
#include "../../common/adc_traits.h"
#include "../adc_net_concepts.h"
namespace adc::impl
{
// typedef for ASIO streambuf iterators
using asio_streambuff_iter_t = asio::buffers_iterator<asio::streambuf::const_buffers_type>;
template <typename T>
concept adc_asio_transport_proto_c =
std::derived_from<T, asio::ip::tcp> || std::derived_from<T, asio::ip::udp> ||
std::derived_from<T, asio::local::seq_packet_protocol> || std::derived_from<T, asio::local::stream_protocol>;
template <typename T>
concept adc_asio_stream_transport_proto_c =
std::derived_from<T, asio::ip::tcp> || std::derived_from<T, asio::local::stream_protocol>;
template <typename T>
concept adc_asio_is_future = requires {
// [](std::type_identity<asio::use_future_t<>>) {}(std::type_identity<std::remove_cvref_t<T>>());
[]<typename AllocatorT>(std::type_identity<asio::use_future_t<AllocatorT>>) {
}(std::type_identity<std::remove_cvref_t<T>>{});
};
template <typename T>
concept adc_asio_is_awaitable = requires {
[]<typename ExecutorT>(std::type_identity<asio::use_awaitable_t<ExecutorT>>) {
}(std::type_identity<std::remove_cvref_t<T>>{});
};
template <typename T>
concept adc_asio_special_comp_token =
adc_asio_is_future<T> || adc_asio_is_awaitable<T> || std::same_as<std::remove_cvref_t<T>, asio::deferred_t>;
// template <typename T>
// static constexpr bool adc_is_asio_special_comp_token = std::is_same_v<std::remove_cvref_t<T>, asio::use_future_t<>>
// ||
// std::is_same_v<std::remove_cvref_t<T>, asio::deferred_t> ||
// std::is_same_v<std::remove_cvref_t<T>,
// asio::use_awaitable_t<>>;
struct adc_asio_async_call_ctx_t {
};
// template <typename T>
// concept adc_completion_token_c = traits::adc_is_callable<T> || std::same_as<T, adc_asio_async_call_ctx_t> ||
// std::same_as<T, asio::deferred_t> || adc_asio_is_future<T> ||
// asio::completion_token_for;
// template <typename T, typename SignatureT>
// concept adc_completion_token_c =
// std::same_as<T, adc_asio_async_call_ctx_t> || asio::completion_token_for<T, SignatureT>;
template <typename T, typename SignatureT = void>
concept adc_completion_token_c =
std::same_as<T, adc_asio_async_call_ctx_t> ||
(traits::adc_is_callable<T> &&
std::conditional_t<std::same_as<SignatureT, void>,
std::true_type,
std::bool_constant<asio::completion_token_for<T, SignatureT>>>::value);
template <adc_asio_transport_proto_c TRANSPORT_PROTOT,
interfaces::adc_netsession_proto_c<std::string_view> SESSION_PROTOT,
traits::adc_output_char_range RMSGT = std::vector<char>>
class AdcNetServiceASIOBase : public SESSION_PROTOT
{
public:
typedef std::string_view netservice_ident_t;
using socket_t = typename TRANSPORT_PROTOT::socket;
using endpoint_t = typename TRANSPORT_PROTOT::endpoint;
using acceptor_t =
std::conditional_t<std::derived_from<socket_t, asio::basic_datagram_socket<typename socket_t::protocol_type>>,
std::nullptr_t, // there is no acceptor
typename TRANSPORT_PROTOT::acceptor>;
struct asio_async_ctx_t {
bool use_future = false;
std::function<void(std::error_code)> accept_comp_token;
std::function<void(std::error_code)> connect_comp_token;
std::function<void(std::error_code)> send_comp_token;
std::function<void(std::error_code, RMSGT)> receive_comp_token;
};
class contx_t
{
std::function<void(std::error_code)> _errc_comp_token;
std::function<void(std::error_code, RMSGT)> _errc_msg_comp_token;
public:
contx_t() = default;
contx_t(contx_t&) = default;
contx_t(contx_t&&) = default;
contx_t(const contx_t&) = default;
template <asio::completion_token_for<void(std::error_code)> TokenT>
contx_t(TokenT&& token)
{
_errc_comp_token = std::forward<TokenT>(token);
}
template <asio::completion_token_for<void(std::error_code, RMSGT)> TokenT>
contx_t(TokenT&& token)
{
_errc_msg_comp_token = std::forward<TokenT>(token);
}
template <asio::completion_token_for<void(std::error_code)> TokenT>
contx_t& operator=(TokenT&& token)
{
_errc_comp_token = std::forward<TokenT>(token);
return *this;
}
template <asio::completion_token_for<void(std::error_code, RMSGT)> TokenT>
contx_t& operator=(TokenT&& token)
{
_errc_msg_comp_token = std::forward<TokenT>(token);
return *this;
}
auto operator()(std::error_code ec)
{
return _errc_comp_token(std::move(ec));
}
auto operator()(std::error_code ec, RMSGT msg)
{
return _errc_msg_comp_token(std::move(ec), std::move(msg));
}
template <asio::completion_token_for<void(std::error_code)> TokenT>
operator TokenT() const
{
return _errc_comp_token;
}
template <asio::completion_token_for<void(std::error_code, RMSGT)> TokenT>
operator TokenT() const
{
return _errc_msg_comp_token;
}
};
// to satisfy 'adc_netservice_c' concept
using async_call_ctx_t = adc_asio_async_call_ctx_t;
static constexpr std::chrono::duration DEFAULT_ACCEPT_TIMEOUT = std::chrono::years::max();
static constexpr std::chrono::duration DEFAULT_CONNECT_TIMEOUT = std::chrono::seconds(10);
static constexpr std::chrono::duration DEFAULT_SEND_TIMEOUT = std::chrono::seconds(5);
static constexpr std::chrono::duration DEFAULT_RECEIVE_TIMEOUT = std::chrono::seconds(5);
AdcNetServiceASIOBase(asio::io_context& ctx)
: SESSION_PROTOT(),
_ioContext(ctx),
_receiveStrand(_ioContext),
_receiveQueue(),
_acceptor(_ioContext),
_socket(_ioContext)
{
}
AdcNetServiceASIOBase(socket_t socket)
: SESSION_PROTOT(),
_ioContext(socket.get_executor()),
_receiveStrand(_ioContext),
_receiveQueue(),
_socket(std::move(socket))
{
}
AdcNetServiceASIOBase(const AdcNetServiceASIOBase&) = delete; // no copy constructor!
virtual ~AdcNetServiceASIOBase() {}
constexpr netservice_ident_t ident() const
{
return _ident;
}
/* asynchronuos methods */
template <asio::completion_token_for<void(std::error_code)> TokenT,
traits::adc_time_duration_c TimeoutT = decltype(DEFAULT_ACCEPT_TIMEOUT)>
auto asyncAccept(const endpoint_t& endpoint, TokenT&& token, const TimeoutT& timeout = DEFAULT_ACCEPT_TIMEOUT)
{
// no acceptor for UDP-sockets
if constexpr (std::is_null_pointer_v<acceptor_t>) {
static_assert(false, "INVALID TRANSPORT PROTOCOL TYPE!");
}
// auto acc = acceptor_t(_ioContext);
// auto timer = getDeadlineTimer(acc, timeout);
auto timer = getDeadlineTimer(_acceptor, timeout);
return asio::async_compose<TokenT, void(std::error_code)>(
// [acc = std::move(acc), timer = std::move(timer), start = true, &endpoint, this](
[timer = std::move(timer), start = true, &endpoint, this](auto& self, std::error_code ec = {}) mutable {
if (!ec) {
if (start) {
start = false;
try {
// acc = acceptor_t(_ioContext, endpoint);
if (!_acceptor.is_open() || (_acceptor.local_endpoint() != endpoint)) {
_acceptor = acceptor_t(_ioContext, endpoint);
}
} catch (std::system_error err) {
timer->cancel();
self.complete(err.code());
return;
}
// return acc.async_accept(_socket, std::move(self));
return _acceptor.async_accept(_socket, std::move(self));
}
}
if (isTimeout(timer, ec)) {
ec = std::make_error_code(std::errc::timed_out);
} else { // an error occured in async_connect
timer->cancel();
}
self.complete(ec);
},
token, _ioContext);
}
template <asio::completion_token_for<void(std::error_code)> TokenT,
traits::adc_time_duration_c TimeoutT = decltype(DEFAULT_CONNECT_TIMEOUT)>
auto asyncConnect(const endpoint_t& endpoint, TokenT&& token, const TimeoutT& timeout = DEFAULT_CONNECT_TIMEOUT)
{
auto timer = getDeadlineTimer(_socket, timeout);
return asio::async_compose<TokenT, void(asio::error_code)>(
[start = true, endpoint, timer = std::move(timer), this](auto& self, asio::error_code ec = {}) mutable {
if (!ec) {
if (start) {
start = false;
return _socket.async_connect(endpoint, std::move(self));
}
}
if (isTimeout(timer, ec)) {
ec = std::make_error_code(std::errc::timed_out);
} else { // an error occured in async_connect
timer->cancel();
}
self.complete(ec);
},
token, _socket);
}
template <traits::adc_input_char_range MessageT,
adc_completion_token_c<void(std::error_code ec)> TokenT,
traits::adc_time_duration_c TimeoutT = decltype(DEFAULT_SEND_TIMEOUT)>
auto asyncSend(const MessageT& msg, TokenT&& token, const TimeoutT& timeout = DEFAULT_SEND_TIMEOUT)
{
// create buffer sequence of sending session protocol representation of the input message
std::vector<asio::const_buffer> buff_seq;
std::ranges::for_each(this->toProto(msg), [&buff_seq](const auto& el) { buff_seq.emplace_back(el); });
auto timer = getDeadlineTimer(_socket, timeout);
return asio::async_compose<TokenT, void(asio::error_code)>(
[start = true, buff_seq = std::move(buff_seq), timer = std::move(timer), this](
auto& self, asio::error_code ec = {}) mutable {
if (!ec) {
if (start) {
start = false;
if constexpr (std::derived_from<socket_t,
asio::basic_stream_socket<typename socket_t::protocol_type>>) {
return asio::async_write(_socket, buff_seq, std::move(self));
} else if constexpr (std::derived_from<socket_t, asio::basic_datagram_socket<
typename socket_t::protocol_type>>) {
return _socket.async_send(buff_seq, std::move(self));
} else if constexpr (std::derived_from<socket_t, asio::basic_seq_packet_socket<
typename socket_t::protocol_type>>) {
return _socket.async_send(buff_seq, std::move(self));
} else {
static_assert(false, "UNKNOWN ASIO-LIBRARY SOCKET TYPE!!!");
}
}
}
if (isTimeout(timer, ec)) {
ec = std::make_error_code(std::errc::timed_out);
} else { // an error occured in async_write/async_send
timer->cancel();
}
self.complete(ec);
},
token, _socket);
}
template <adc_completion_token_c TokenT, traits::adc_time_duration_c TimeoutT = decltype(DEFAULT_RECEIVE_TIMEOUT)>
auto asyncReceive(TokenT&& token, const TimeoutT& timeout = DEFAULT_RECEIVE_TIMEOUT)
{
static_assert(!std::is_same_v<TokenT, async_call_ctx_t>, "'async_call_ctx_t'-TYPE MUST NOT BE USED!");
// check completion token signature and deduce message type
if constexpr (!adc_asio_special_comp_token<TokenT>) {
static_assert(traits::adc_func_traits<TokenT>::arity == 2, "INVALID COMPLETION TOKEN SIGNATURE!");
static_assert(std::is_same_v<std::remove_cvref_t<traits::adc_func_arg1_t<TokenT>>, std::error_code>,
"INVALID COMPLETION TOKEN SIGNATURE!");
static_assert(traits::adc_output_char_range<
std::tuple_element_t<1, typename traits::adc_func_traits<TokenT>::args_t>>,
"INVALID COMPLETION TOKEN SIGNATURE!");
}
using msg_t = std::conditional_t<
adc_asio_special_comp_token<TokenT>, RMSGT,
std::remove_cvref_t<std::tuple_element_t<1, typename traits::adc_func_traits<TokenT>::args_t>>>;
// auto s_res = std::make_shared<std::invoke_result_t<decltype(this->template search<RMSGT>), RMSGT>>();
// auto tp = this->search(std::span<const char>());
// auto s_res = std::make_shared<decltype(tp)>();
auto s_res = std::make_shared<std::tuple<asio_streambuff_iter_t, asio_streambuff_iter_t, bool>>();
auto out_flags = std::make_shared<asio::socket_base::message_flags>();
auto timer = getDeadlineTimer(_socket, timeout);
return asio::async_compose<TokenT, void(asio::error_code, msg_t)>(
[s_res, out_flags, start = true, timer = std::move(timer), this](auto& self, asio::error_code ec = {},
size_t = 0) mutable {
RMSGT msg;
if (!ec) {
if (start) {
start = false;
if (_receiveQueue.size()) { // return message from queue
msg = _receiveQueue.front();
_receiveQueue.pop();
if constexpr (std::is_same_v<msg_t, RMSGT>) {
self.complete(std::error_code(), std::move(msg));
} else {
// msg_t user_msg{msg.begin(), msg.end()};
self.complete(std::error_code(), {msg.begin(), msg.end()});
}
return;
}
if constexpr (std::derived_from<socket_t,
asio::basic_stream_socket<typename socket_t::protocol_type>>) {
// adapt to ASIO's MatchCondition
// auto match_func = [s_res, this]<typename IT>(IT begin, IT end) {
// *s_res = this->search(std::span(begin, end));
// // return std::make_pair(std::get<1>(*s_res), std::get<2>(*s_res));
// std::pair<IT, bool> res{std::get<1>(*s_res), std::get<2>(*s_res)};
// return res;
// };
// auto match_func = [s_res, this](asio_streambuff_iter_t begin, asio_streambuff_iter_t end)
// {
// *s_res = this->search(std::span(&*begin, &*end));
// // return std::make_pair(std::get<1>(*s_res), std::get<2>(*s_res));
// auto N = std::distance(std::get<0>(*s_res), std::get<1>(*s_res));
// std::pair<asio_streambuff_iter_t, bool> res{begin + N, std::get<2>(*s_res)};
// return res;
// };
// return asio::async_read_until(_socket, _streamBuffer, std::move(match_func),
// std::move(self));
return asio::async_read_until(
_socket, _streamBuffer,
std::bind(&AdcNetServiceASIOBase::template MatchCondition<decltype(s_res)>, this,
std::placeholders::_1, std::placeholders::_2, s_res),
std::move(self));
} else if constexpr (std::derived_from<socket_t, asio::basic_datagram_socket<
typename socket_t::protocol_type>>) {
// datagram, so it should be received at once
return _socket.receive(_streamBuffer, std::move(self));
} else if constexpr (std::derived_from<socket_t, asio::basic_seq_packet_socket<
typename socket_t::protocol_type>>) {
// datagram, so it should be received at once
return _socket.receive(_streamBuffer, *out_flags, std::move(self));
} else {
static_assert(false, "UNKNOWN ASIO-LIBRARY SOCKET TYPE!!!");
}
}
// here, the iterators were computed in MatchCondition called by asio::async_read_until function!!!
// std::string_view net_pack{std::get<0>(*s_res), std::get<1>(*s_res)};
size_t N = std::distance(std::get<0>(*s_res), std::get<1>(*s_res));
std::span net_pack{&*std::get<0>(*s_res), N};
std::ranges::copy(this->fromProto(net_pack), std::back_inserter(msg));
_streamBuffer.consume(net_pack.size());
while (_streamBuffer.size()) { // search for possible additional session protocol packets
// auto begin_it = (const char*)asio_streambuff_iter_t::begin(_streamBuffer.data());
// auto end_it = (const char*)asio_streambuff_iter_t::end(_streamBuffer.data());
// auto begin_it = asio_streambuff_iter_t::begin(_streamBuffer.data());
// auto end_it = asio_streambuff_iter_t::end(_streamBuffer.data());
// auto begin_it =
// static_cast<std::ranges::iterator_t<std::string_view>>(_streamBuffer.data().data());
// auto end_it = begin_it + _streamBuffer.data().size();
auto begin_it = asio::buffers_begin(_streamBuffer.data());
auto end_it = asio::buffers_end(_streamBuffer.data());
// *s_res = this->search(std::span(begin_it, end_it));
*s_res = this->search(begin_it, end_it);
if (std::get<2>(*s_res)) {
// net_pack = std::string_view{std::get<0>(*s_res), std::get<1>(*s_res)};
N = std::distance(std::get<0>(*s_res), std::get<1>(*s_res));
net_pack = std::span{&*std::get<0>(*s_res), N};
_receiveQueue.emplace();
std::ranges::copy(this->fromProto(net_pack), std::back_inserter(_receiveQueue.back()));
_streamBuffer.consume(net_pack.size());
} else {
break;
}
}
}
if (isTimeout(timer, ec)) {
ec = std::make_error_code(std::errc::timed_out);
} else { // an error occured in async_connect
timer->cancel();
}
if constexpr (std::is_same_v<msg_t, RMSGT>) {
self.complete(ec, std::move(msg));
} else {
// msg_t user_msg{msg.begin(), msg.end()};
self.complete(ec, {msg.begin(), msg.end()});
}
},
token, _socket);
}
/* blocking methods */
template <traits::adc_time_duration_c TimeoutT = decltype(DEFAULT_ACCEPT_TIMEOUT)>
auto accept(const endpoint_t& endpoint, const TimeoutT& timeout = DEFAULT_ACCEPT_TIMEOUT)
{
std::future<void> ftr = asyncAccept(endpoint, asio::use_future, timeout);
ftr.get();
}
template <traits::adc_time_duration_c TimeoutT = decltype(DEFAULT_CONNECT_TIMEOUT)>
auto connect(const endpoint_t& endpoint, const TimeoutT& timeout = DEFAULT_CONNECT_TIMEOUT)
{
std::future<void> ftr = asyncConnect(endpoint, asio::use_future, timeout);
ftr.get();
}
template <traits::adc_input_char_range R, traits::adc_time_duration_c TimeoutT = decltype(DEFAULT_SEND_TIMEOUT)>
auto send(const R& msg, const TimeoutT& timeout = DEFAULT_SEND_TIMEOUT)
{
std::future<void> ftr = asyncSend(msg, timeout, asio::use_future);
ftr.get();
}
template <traits::adc_time_duration_c TimeoutT = decltype(DEFAULT_RECEIVE_TIMEOUT)>
auto receive(const TimeoutT& timeout = DEFAULT_RECEIVE_TIMEOUT)
{
std::future<RMSGT> ftr = asyncReceive(timeout, asio::use_future);
return ftr.get();
}
std::error_code close()
{
std::error_code ec;
_socket.shutdown(_shutdownType, ec);
if (!ec) {
_socket.close(ec);
}
return ec;
}
/* additional ASIO-related methods */
void clear()
{
// clear receiving messages queue
// NOTE: there is no racing condition here since using asio::strand!
asio::post(_receiveStrand, [this]() { _receiveQueue = {}; });
}
void setShutdownType(asio::socket_base::shutdown_type shutdown_type)
{
_shutdownType = shutdown_type;
}
asio::socket_base::shutdown_type getShutdownType() const
{
return _shutdownType;
}
protected:
static constexpr netservice_ident_t _ident =
std::derived_from<socket_t, asio::basic_stream_socket<typename socket_t::protocol_type>>
? "STREAM-SOCKET NETWORK SERVICE"
: std::derived_from<socket_t, asio::basic_datagram_socket<typename socket_t::protocol_type>>
? "DATAGRAM-SOCKET NETWORK SERVICE"
: std::derived_from<socket_t, asio::basic_seq_packet_socket<typename socket_t::protocol_type>>
? "SEQPACKET-SOCKET NETWORK SERVICE"
: "UNKNOWN";
asio::io_context& _ioContext;
asio::io_context::strand _receiveStrand;
socket_t _socket;
acceptor_t _acceptor;
asio::streambuf _streamBuffer;
std::queue<std::vector<char>> _receiveQueue;
asio::socket_base::shutdown_type _shutdownType = asio::socket_base::shutdown_both;
template <typename T>
auto MatchCondition(asio_streambuff_iter_t begin, asio_streambuff_iter_t end, T& s_res)
{
// if (begin == end) {
// *s_res = this->search(std::span<const char>());
// } else {
// *s_res = this->search(std::span(&*begin, &*end));
// }
*s_res = this->search(begin, end);
// return std::make_pair(std::get<1>(*s_res), std::get<2>(*s_res));
std::pair<asio_streambuff_iter_t, bool> res{end, false};
typename std::iterator_traits<asio_streambuff_iter_t>::difference_type N = 0;
if (std::get<2>(*s_res)) {
N = std::distance(std::get<0>(*s_res), std::get<1>(*s_res));
res = std::make_pair(begin + N, true);
}
return res;
};
template <typename CancelableT, traits::adc_time_duration_c TimeoutT>
static std::unique_ptr<asio::steady_timer> getDeadlineTimer(CancelableT& obj,
const TimeoutT& timeout,
bool arm = true)
{
auto timer = std::make_unique<asio::steady_timer>(obj.get_executor());
if (arm) {
timer->expires_after(timeout);
timer->async_wait([&obj](const std::error_code& ec) mutable {
if (!ec) {
obj.cancel();
}
});
}
return timer;
}
template <typename TimerT>
static bool isTimeout(const std::unique_ptr<TimerT>& timer, const std::error_code& ec)
{
auto exp_time = timer->expiry();
return (exp_time < std::chrono::steady_clock::now()) && (ec == asio::error::operation_aborted);
}
};
} // namespace adc::impl