ADC/net/adc_netservice_asio.h
2024-06-17 13:38:01 +03:00

340 lines
13 KiB
C++

#pragma once
/*
ABSTRACT DEVICE COMPONENTS LIBRARY
ASIO-library implementation of network service
*/
#ifdef USE_ASIO_LIBRARY
#include <future>
#include "adc_netservice.h"
#include <asio/basic_datagram_socket.hpp>
#include <asio/basic_seq_packet_socket.hpp>
#include <asio/basic_stream_socket.hpp>
#include <asio/compose.hpp>
#include <asio/ip/tcp.hpp>
#include <asio/ip/udp.hpp>
#include <asio/local/seq_packet_protocol.hpp>
#include <asio/local/stream_protocol.hpp>
#include <asio/read_until.hpp>
#include <asio/steady_timer.hpp>
#include <asio/streambuf.hpp>
#include <asio/use_future.hpp>
#include <asio/write.hpp>
#ifdef USE_OPENSSL_WITH_ASIO
#include <asio/ssl.hpp>
#include <asio/ssl/stream.hpp>
#endif
#include <concepts>
#include "adc_netmsg.h"
namespace adc::traits
{
// still only TCP, UDP and UNIX
template <typename T>
concept adc_asio_inet_proto_c = requires {
requires std::derived_from<T, asio::ip::tcp> || std::derived_from<T, asio::ip::udp> ||
std::derived_from<T, asio::local::seq_packet_protocol> ||
std::derived_from<T, asio::local::stream_protocol>;
};
} // namespace adc::traits
namespace adc::impl
{
template <traits::adc_asio_inet_proto_c InetProtoT, typename HighLevelSocketT = typename InetProtoT::socket>
class AdcNetServiceASIO : public InetProtoT
{
public:
using socket_t = typename InetProtoT::socket;
using endpoint_t = typename InetProtoT::endpoint;
using inet_proto_t = InetProtoT;
using high_level_socket_t = HighLevelSocketT;
typedef std::chrono::steady_clock::duration timeout_t; // nanoseconds resolution
#ifdef USE_OPENSSL_WITH_ASIO
static_assert(std::is_same_v<HighLevelSocketT, socket_t>
? true
: std::is_same_v<HighLevelSocketT, asio::ssl::stream<socket_t>>,
"ONLY BASIC SOCKETS AND SSL::STREAM ARE SUPPORTED!!!");
static constexpr bool IsBasicSocket = std::is_same_v<high_level_socket_t, socket_t>;
#else
static_assert(std::is_same_v<high_level_socket_t, socket_t>,
"HighLevelSocketT AND InetProtoT::socket TYPES MUST BE THE SAME!!!");
#endif
using streambuff_iter_t = asio::buffers_iterator<asio::streambuf::const_buffers_type>;
AdcNetServiceASIO(high_level_socket_t& sock) : _socket(sock) {}
virtual ~AdcNetServiceASIO() = default;
template <traits::adc_time_duration_c TimeoutT, asio::completion_token_for<void(std::error_code)> CompletionTokenT>
auto asyncConnect(const endpoint_t& endpoint, const TimeoutT& timeout, CompletionTokenT&& token)
{
auto timer = getDeadlineTimer(timeout);
enum { starting, cancel_timer };
// wrapper
return asio::async_compose<CompletionTokenT, void(std::error_code)>(
[timer = std::move(timer), state = starting, &endpoint, this](auto& self,
const std::error_code& ec = {}) mutable {
if (!ec) {
switch (state) {
case starting:
state = cancel_timer;
if constexpr (AdcNetServiceASIO::IsBasicSocket) {
return _socket.async_connect(endpoint, std::move(self));
} else {
return _socket.lowest_layer().async_connect(endpoint, std::move(self));
}
break;
case cancel_timer:
timer->cancel();
break;
default:
break;
}
}
self.complete(ec);
},
token, _socket);
}
template <traits::adc_netmessage_c NetMessageT,
traits::adc_time_duration_c TimeoutT,
asio::completion_token_for<void(std::error_code)> CompletionTokenT>
auto asynSend(const NetMessageT& msg, const TimeoutT& timeout, CompletionTokenT&& token)
{
enum { starting, cancel_timer };
// create buffer sequence
std::vector<asio::const_buffer> buff;
std::ranges::for_each(msg.template bytesView<std::vector<std::string_view>>(),
[&buff](const auto& el) { buff.emplace_back(el); });
auto timer = getDeadlineTimer(timeout);
// wrapper
return asio::async_compose<CompletionTokenT, void(std::error_code)>(
[buff = std::move(buff), timer = std::move(timer), state = starting, this](
auto& self, const std::error_code& ec = {}, size_t = 0) mutable {
if (!ec) {
switch (state) {
case starting:
state = cancel_timer;
if constexpr (std::derived_from<
socket_t, asio::basic_stream_socket<typename socket_t::protocol_type>>) {
return asio::async_write(_socket, buff, std::move(self));
} else if constexpr (std::derived_from<socket_t, asio::basic_datagram_socket<
typename socket_t::protocol_type>>) {
return _socket.async_send(buff, std::move(self));
} else if constexpr (std::derived_from<socket_t, asio::basic_seq_packet_socket<
typename socket_t::protocol_type>>) {
return _socket.async_send(buff, std::move(self));
} else {
static_assert(false, "UNKNOWN ASIO-LIBRARY SOCKET TYPE!!!");
}
break;
case cancel_timer:
timer->cancel();
break;
default:
break;
}
}
self.complete(ec);
},
token, _socket);
}
template <traits::adc_netmessage_c NetMessageT, traits::adc_time_duration_c TimeoutT, typename CompletionTokenT>
auto asyncReceive(const TimeoutT& timeout, CompletionTokenT&& token)
{
enum { starting, cancel_timer };
std::unique_ptr<asio::socket_base::message_flags> out_flags;
auto timer = getDeadlineTimer(timeout); // armed timer
return asio::async_compose<CompletionTokenT, void(const std::error_code&, const NetMessageT&)>(
[timer = std::move(timer), out_flags = std::move(out_flags), state = starting, this](
auto& self, const std::error_code& ec = {}, size_t = 0) mutable {
if (!ec) {
switch (state) {
case starting:
state = cancel_timer;
if constexpr (std::derived_from<
socket_t, asio::basic_stream_socket<typename socket_t::protocol_type>>) {
return asio::async_read_until(
_socket, _streamBuffer,
[this](auto begin, auto end) { return this->matchCondition(begin, end); },
std::move(self));
} else if constexpr (std::derived_from<socket_t, asio::basic_datagram_socket<
typename socket_t::protocol_type>>) {
return _socket.receive(_streamBuffer,
std::move(self)); // datagram, so it should be received at once
} else if constexpr (std::derived_from<socket_t, asio::basic_seq_packet_socket<
typename socket_t::protocol_type>>) {
return _socket.receive(_streamBuffer, *out_flags,
std::move(self)); // datagram, so it should be received at once
} else {
static_assert(false, "UNKNOWN ASIO-LIBRARY SOCKET TYPE!!!");
}
break;
case cancel_timer:
timer->cancel();
break;
default:
break;
}
auto begin_it = streambuff_iter_t::begin(_streamBuffer.data());
auto end_it = begin_it + _streamBuffer.data().size();
// check for byte sequence is valid byte sequence and find the limits
// (stream buffer may contain number of bytes more than requred by protocol)
auto res = this->matchCondition(begin_it, end_it);
if (!res.second) {
self.complete(std::make_error_code(std::errc::protocol_error),
NetMessageT()); // return an empty message
} else {
auto nbytes = std::distance(begin_it, res.first);
NetMessageT msg;
auto msg_it = this->fromLowLevel(begin_it, res.first);
msg.setFromBytes(msg_it.first, msg_it.second);
_streamBuffer.consume(nbytes);
self.complete(ec, msg);
}
} else {
self.complete(ec, NetMessageT()); // return an empty message
return;
}
},
token, _socket);
}
template <traits::adc_time_duration_c TimeoutT>
auto connect(const endpoint_t& endpoint, const TimeoutT& timeout)
{
std::future<void> ftr = asyncConnect(endpoint, timeout, asio::use_future);
ftr.get();
}
template <traits::adc_netmessage_c NetMessageT, traits::adc_time_duration_c TimeoutT>
auto send(const NetMessageT& msg, const TimeoutT& timeout)
{
std::future<void> ftr = asyncSend(msg, timeout, asio::use_future);
ftr.get();
}
template <traits::adc_netmessage_c NetMessageT, traits::adc_time_duration_c TimeoutT>
auto receive(const TimeoutT& timeout)
{
std::future<NetMessageT> ftr = asyncReceive(timeout, asio::use_future);
return ftr.get();
}
std::error_code close(asio::socket_base::shutdown_type stype = asio::socket_base::shutdown_both)
{
std::error_code ec;
if constexpr (AdcNetServiceASIO::IsBasicSocket) {
_socket.shutdown(stype, ec);
if (!ec) {
_socket.close(ec);
}
} else {
_socket.shutdown(ec); // shutdown OpenSSL stream
if (!ec) {
_socket.lowest_layer().shutdown(stype, ec);
if (!ec) {
_socket.lowest_layer().close(ec);
}
}
}
return ec;
}
protected:
high_level_socket_t& _socket;
asio::streambuf _streamBuffer;
template <traits::adc_time_duration_c TimeoutT>
std::unique_ptr<asio::steady_timer> getDeadlineTimer(const TimeoutT& timeout, bool arm = true)
{
std::unique_ptr<asio::steady_timer> timer(_socket.get_executor());
if (arm) {
timer->expires_after(std::chrono::duration_cast<timeout_t>(timeout));
timer->async_wait([this](const std::error_code& ec) {
if (!ec) {
_socket.cancel(std::make_error_code(std::errc::timed_out));
}
});
}
return timer;
}
};
typedef AdcNetService<impl::AdcNetServiceASIO<asio::ip::tcp>> AdcNetServiceAsioTcp;
typedef AdcNetService<impl::AdcNetServiceASIO<asio::ip::tcp, asio::ssl::stream<asio::ip::tcp>>> AdcNetServiceAsioTls;
typedef AdcNetService<impl::AdcNetServiceASIO<asio::local::seq_packet_protocol>> AdcNetServiceAsioLocalSeqPack;
typedef AdcNetService<impl::AdcNetServiceASIO<asio::local::stream_protocol>> AdcNetServiceAsioLocalStream;
} // namespace adc::impl
namespace adc::traits
{
template <typename T>
concept adc_netservice_asio_c = requires {
typename T::inet_proto_t;
typename T::high_level_socket_t;
requires std::derived_from<T,
adc::impl::AdcNetServiceASIO<typename T::inet_proto_t, typename T::high_level_socket_t>>;
};
} // namespace adc::traits
#endif