This commit is contained in:
Timur A. Fatkhullin
2024-06-15 21:23:57 +03:00
parent daf4e1eab9
commit 9a2baa702d
5 changed files with 182 additions and 116 deletions

View File

@@ -10,15 +10,19 @@
#include <future>
#include "adc_netservice.h"
#ifdef USE_ASIO_LIBRARY
#include <asio/awaitable.hpp>
// #include <asio/awaitable.hpp>
#include <asio/basic_datagram_socket.hpp>
#include <asio/basic_seq_packet_socket.hpp>
#include <asio/basic_stream_socket.hpp>
#include <asio/compose.hpp>
#include <asio/experimental/awaitable_operators.hpp>
// #include <asio/experimental/awaitable_operators.hpp>
#include <asio/ip/tcp.hpp>
#include <asio/read_until.hpp>
#include <asio/ssl.hpp>
#include <asio/ssl/stream.hpp>
#include <asio/steady_timer.hpp>
#include <asio/streambuf.hpp>
#include <asio/use_future.hpp>
@@ -26,11 +30,14 @@
#include <concepts>
#include "adc_netmsg.h"
namespace adc::impl
{
template <typename NetMessageT, typename InetProtoT>
template <typename InetProtoT>
class AdcNetServiceASIO : public InetProtoT
{
public:
@@ -63,7 +70,12 @@ public:
switch (state) {
case starting:
state = cancel_timer;
return _socket.async_connect(endpoint, std::move(self));
if constexpr (std::derived_from<socket_t,
asio::ssl::stream<typename socket_t::lowest_layer_type>>) {
return _socket.lowest_layer().async_connect(endpoint, std::move(self));
} else {
return _socket.async_connect(endpoint, std::move(self));
}
break;
case cancel_timer:
timer->cancel();
@@ -78,7 +90,9 @@ public:
token, _socket);
}
template <typename TimeoutT, asio::completion_token_for<void(std::error_code)> CompletionTokenT>
template <traits::adc_netmessage_c NetMessageT,
typename TimeoutT,
asio::completion_token_for<void(std::error_code)> CompletionTokenT>
auto asynSend(const NetMessageT& msg, const TimeoutT& timeout, CompletionTokenT&& token)
{
enum { starting, cancel_timer };
@@ -93,10 +107,11 @@ public:
// wrapper
return asio::async_compose<CompletionTokenT, void(std::error_code)>(
[buff = std::move(buff), timer = std::move(timer), state = starting, this](
auto& self, const std::error_code& ec = {}, size_t sz = 0) mutable {
auto& self, const std::error_code& ec = {}, size_t = 0) mutable {
if (!ec) {
switch (state) {
case starting:
state = cancel_timer;
if constexpr (std::derived_from<
socket_t, asio::basic_stream_socket<typename socket_t::protocol_type>>) {
return asio::async_write(_socket, buff, std::move(self));
@@ -124,7 +139,7 @@ public:
}
template <typename TimeoutT, typename CompletionTokenT>
template <traits::adc_netmessage_c NetMessageT, typename TimeoutT, typename CompletionTokenT>
auto asyncReceive(const TimeoutT& timeout, CompletionTokenT&& token)
{
enum { starting, cancel_timer };
@@ -135,15 +150,16 @@ public:
return asio::async_compose<CompletionTokenT, void(const std::error_code&, const NetMessageT&)>(
[timer = std::move(timer), out_flags = std::move(out_flags), state = starting, this](
auto& self, const std::error_code& ec = {}, size_t sz = 0) mutable {
auto& self, const std::error_code& ec = {}, size_t = 0) mutable {
if (!ec) {
switch (state) {
case starting:
state = cancel_timer;
if constexpr (std::derived_from<
socket_t, asio::basic_stream_socket<typename socket_t::protocol_type>>) {
return asio::async_read_until(
_socket, _streamBuffer,
[this](auto begin, auto end) { this->matchCondition(begin, end); },
[this](auto begin, auto end) { return this->matchCondition(begin, end); },
std::move(self));
} else if constexpr (std::derived_from<socket_t, asio::basic_datagram_socket<
typename socket_t::protocol_type>>) {
@@ -200,20 +216,43 @@ public:
ftr.get();
}
template <typename TimeoutT>
template <traits::adc_netmessage_c NetMessageT, typename TimeoutT>
auto send(const NetMessageT& msg, const TimeoutT& timeout)
{
std::future<void> ftr = asyncSend(msg, timeout, asio::use_future);
ftr.get();
}
template <typename TimeoutT>
template <traits::adc_netmessage_c NetMessageT, typename TimeoutT>
auto receive(const TimeoutT& timeout)
{
std::future<NetMessageT> ftr = asyncReceive(timeout, asio::use_future);
return ftr.get();
}
std::error_code close(asio::socket_base::shutdown_type stype = asio::socket_base::shutdown_both)
{
std::error_code ec;
if constexpr (std::derived_from<socket_t, asio::ssl::stream<typename socket_t::lowest_layer_type>>) {
_socket.shutdown(ec); // shutdown OpenSSL stream
if (!ec) {
_socket.lowest_layer().shutdown(stype, ec);
if (!ec) {
_socket.lowest_layer().close(ec);
}
}
} else {
_socket.shutdown(stype, ec);
if (!ec) {
_socket.close(ec);
}
}
return ec;
}
protected:
socket_t& _socket;
@@ -242,4 +281,11 @@ protected:
} // namespace adc::impl
namespace adc
{
typedef AdcNetService<impl::AdcNetServiceASIO<asio::ip::tcp>> AdcNetServiceAsioTcp;
} // namespace adc
#endif