From 4e3e3ec60e1c7fb7a2be66fa8034e2f38f7afe21 Mon Sep 17 00:00:00 2001 From: "Timur A. Fatkhullin" Date: Tue, 29 Oct 2024 01:21:24 +0300 Subject: [PATCH] ... --- net/adc_device_netmsg.h | 2 +- net/adc_device_netserver.h | 2 +- net/adc_endpoint.h | 75 +++++++-------------------- net/adc_netserver.h | 15 +++--- net/asio/adc_device_netserver_asio.h | 15 ++++-- net/asio/adc_netservice_asio.h | 76 +++++++++++++++++++++++----- tests/adc_asio_netserver_test.cpp | 11 ++-- 7 files changed, 109 insertions(+), 87 deletions(-) diff --git a/net/adc_device_netmsg.h b/net/adc_device_netmsg.h index cf2b64b..7f3fddb 100644 --- a/net/adc_device_netmsg.h +++ b/net/adc_device_netmsg.h @@ -105,7 +105,7 @@ public: } else { T res; - auto bs = _byteSequence | std::views::drop_while([](const auto& ch) { return ch = ' '; }); + auto bs = _byteSequence | std::views::drop_while([](const auto& ch) { return ch == ' '; }); auto found = std::ranges::search(bs, keyValueDelimiter); if (!found.empty()) { diff --git a/net/adc_device_netserver.h b/net/adc_device_netserver.h index a49715c..9a61a93 100644 --- a/net/adc_device_netserver.h +++ b/net/adc_device_netserver.h @@ -299,7 +299,7 @@ public: auto dev_name = get_elem(0); bool found = false; for (auto& [ptr, dev_wr] : _serverPtr->_devices) { - if (dev_wr.ident() == dev_name) { + if (std::ranges::equal(dev_wr.ident(), dev_name)) { _bindDevice = dev_wr; found = true; break; diff --git a/net/adc_endpoint.h b/net/adc_endpoint.h index e1f96c8..0856bdc 100644 --- a/net/adc_endpoint.h +++ b/net/adc_endpoint.h @@ -40,8 +40,6 @@ namespace adc class AdcEndpointParser { - typedef std::span host_part_t; - public: static constexpr std::string_view protoHostDelim = "://"; static constexpr std::string_view hostPortDelim = ":"; @@ -117,19 +115,16 @@ public: _proto = validProtoMarks[idx]; - // _host = std::string_view{found.end(), _endpoint.end()}; - _host = host_part_t{found.end(), _endpoint.end()}; + _host = std::string_view{found.end(), _endpoint.end()}; auto f1 = std::ranges::search(_host, portPathDelim); std::string_view port_sv; if (f1.empty() && isLocal()) { // no path, but it is mandatory for 'local'! return _isValid; } else { - // _host = std::string_view(_host.begin(), f1.begin()); - _host = host_part_t{found.end(), _endpoint.end()}; + _host = std::string_view(_host.begin(), f1.begin()); - // _path = std::string_view(f1.end(), &*_endpoint.end()); - _path = std::string_view(&*f1.end(), &*_endpoint.end()); + _path = std::string_view(f1.end(), &*_endpoint.end()); f1 = std::ranges::search(_host, hostPortDelim); if (f1.empty() && !isLocal()) { // no port, but it is mandatory for non-local! @@ -138,8 +133,7 @@ public: port_sv = std::string_view(f1.end(), _host.end()); if (port_sv.size()) { - // _host = std::string_view(_host.begin(), f1.begin()); - _host = host_part_t{found.end(), _endpoint.end()}; + _host = std::string_view(_host.begin(), f1.begin()); if (!isLocal()) { // convert port string to int @@ -165,7 +159,7 @@ public: } return ok; })) { - // _host = validLocalProtoTypes[idx]; + _host = validLocalProtoTypes[idx]; } else { return _isValid; } @@ -229,20 +223,17 @@ public: bool isLocalStream() const { - // return host() == localProtoTypeStream; - return utils::AdcCharRangeCompare(host(), localProtoTypeStream, true); + return host() == localProtoTypeStream; } bool isLocalDatagram() const { - // return host() == localProtoTypeDatagram; - return utils::AdcCharRangeCompare(host(), localProtoTypeDatagram, true); + return host() == localProtoTypeDatagram; } bool isLocalSeqpacket() const { - // return host() == localProtoTypeSeqpacket; - return utils::AdcCharRangeCompare(host(), localProtoTypeSeqpacket, true); + return host() == localProtoTypeSeqpacket; } @@ -273,9 +264,7 @@ public: protected: std::string _endpoint; - // std::string_view _proto, _host, _path; - std::string_view _proto, _path; - host_part_t _host; + std::string_view _proto, _host, _path; int _port; bool _isValid; @@ -309,54 +298,28 @@ protected: // return res; // } - // auto part = _proto; - - // switch (what) { - // case PROTO_PART: - // part = _proto; - // break; - // case HOST_PART: - // part = _host; - // break; - // case PATH_PART: - // part = _path; - // break; - // default: - // break; - // } - - // if constexpr (std::ranges::view) { - // return R(part.begin(), part.end()); - // } else { - // std::ranges::copy(part, std::back_inserter(res)); - // } + auto part = _proto; switch (what) { case PROTO_PART: - if constexpr (std::ranges::view) { - res = R(_proto.begin(), _proto.size()); - } else { - std::ranges::copy(_proto, std::back_inserter(res)); - } + part = _proto; break; case HOST_PART: - if constexpr (std::ranges::view) { - res = R(_host.begin(), _host.end()); - } else { - std::ranges::copy(_host, std::back_inserter(res)); - } + part = _host; break; case PATH_PART: - if constexpr (std::ranges::view) { - res = R(_path.begin(), _path.end()); - } else { - std::ranges::copy(_path, std::back_inserter(res)); - } + part = _path; break; default: break; } + if constexpr (std::ranges::view) { + return {part.begin(), part.end()}; + } else { + std::ranges::copy(part, std::back_inserter(res)); + } + return res; } }; diff --git a/net/adc_netserver.h b/net/adc_netserver.h index 2596419..93092b4 100644 --- a/net/adc_netserver.h +++ b/net/adc_netserver.h @@ -11,6 +11,7 @@ ABSTRACT DEVICE COMPONENTS LIBRARY #include #include #include +#include #if __has_include() // POSIX #define FORK_EXISTS 1 @@ -138,7 +139,7 @@ protected: // started sessions weak pointers template - static std::unordered_map>> _serverSessions; + static std::unordered_map>> _serverSessions; std::vector> _stopSessionFunc; std::vector> _moveCtorFunc; @@ -147,11 +148,11 @@ protected: { auto res = _serverSessions[this].emplace(sess_ptr); if (res.second) { - sess_ptr.start(); + sess_ptr->start(); _stopSessionFunc.emplace_back([res, this]() { - if (!res.first.expired()) { // session is still existing - auto sess = res.first.lock(); + if (!res.first->expired()) { // session is still existing + auto sess = res.first->lock(); sess->stop(); _serverSessions[this].erase(res.first); return true; @@ -308,9 +309,9 @@ public: // only once per SessionT if (_isListening[this].size() == 1) { - _moveCtorFunc = [this](const AdcGenericNetServer* new_instance) { + _moveCtorFunc.emplace_back([this](const AdcGenericNetServer* new_instance) { _isListening[new_instance] = std::move(_isListening[this]); - }; + }); } }; @@ -358,7 +359,7 @@ protected: startSession(sess); _isListening[this][id] = true; - doAccept(acceptor, id, sess_ctx); + doAccept(acceptor, id, sess_ctx); } else { _isListening[this][id] = false; } diff --git a/net/asio/adc_device_netserver_asio.h b/net/asio/adc_device_netserver_asio.h index ed7f980..431bd3e 100644 --- a/net/asio/adc_device_netserver_asio.h +++ b/net/asio/adc_device_netserver_asio.h @@ -25,7 +25,13 @@ public: template EptT> +#ifdef USE_OPENSSL_WITH_ASIO + void start(const EptT& endpoint, + asio::ssl::context tls_context = asio::ssl::context(asio::ssl::context::tlsv13_server), + asio::ssl::verify_mode tls_verify_mode = asio::ssl::context_base::verify_peer) +#else void start(const EptT& endpoint) +#endif { if (!endpoint.isValid()) { return; @@ -38,10 +44,11 @@ public: asio::ip::tcp::endpoint ept(asio::ip::make_address(endpoint.host()), endpoint.port()); if (endpoint.isTCP()) { using srv_t = AdcNetServiceASIO; - AdcDeviceNetServer::start>("TCP", this, _ioContext, ept); + AdcGenericNetServer::start>("TCP", this, _ioContext, ept); + } else { using srv_t = AdcNetServiceASIOTLS; - AdcDeviceNetServer::start>("TLS", this, _ioContext, ept); + AdcGenericNetServer::start>("TLS", this, _ioContext, ept, tls_context, tls_verify_mode); } #else if (endpoint.isTCP()) { @@ -53,7 +60,7 @@ public: if (endpoint.isLocalStream()) { asio::local::stream_protocol::endpoint ept(endpoint.template path()); using srv_t = AdcNetServiceASIO; - AdcDeviceNetServer::start>("LOCAL STREAM", this, _ioContext, ept); + AdcGenericNetServer::start>("LOCAL STREAM", this, _ioContext, ept); // } else if (endpoint.isLocalDatagram()) { // asio::local::datagram_protocol::endpoint ept(endpoint.template path()); // using srv_t = AdcNetServiceASIO; @@ -61,7 +68,7 @@ public: } else if (endpoint.isLocalSeqpacket()) { asio::local::seq_packet_protocol::endpoint ept(endpoint.template path()); using srv_t = AdcNetServiceASIO; - AdcDeviceNetServer::start>("LOCAL SEQPACK", this, _ioContext, ept); + AdcGenericNetServer::start>("LOCAL SEQPACK", this, _ioContext, ept); } } else { throw std::system_error(std::make_error_code(std::errc::protocol_not_supported)); diff --git a/net/asio/adc_netservice_asio.h b/net/asio/adc_netservice_asio.h index a5c2cb0..e081675 100644 --- a/net/asio/adc_netservice_asio.h +++ b/net/asio/adc_netservice_asio.h @@ -183,7 +183,9 @@ public: auto timer = netservice_t::getDeadlineTimer(_acceptor, timeout); - auto srv = std::make_unique(_ioContext); + // auto srv = std::make_unique(_ioContext); + auto srv = netservice_t::isTLS ? std::make_unique(_ioContext, srv->_tlsContext) + : std::make_unique(_ioContext); return asio::async_compose( [timer = std::move(timer), srv = std::move(srv), state = sock_accept, this]( @@ -254,6 +256,11 @@ public: return accept(timeout); } + void close(std::error_code& ec) + { + _acceptor.close(ec); + } + private: asio::io_context& _ioContext; srv_acceptor_t _acceptor; @@ -273,14 +280,14 @@ public: } - AdcBaseNetServiceASIO(socket_t socket) - : SESSION_PROTOT(), - _ioContext(static_cast(socket.get_executor().context())), - _socket(std::move(socket)), - _receiveStrand(_ioContext), - _receiveQueue() - { - } + // AdcBaseNetServiceASIO(socket_t socket) + // : SESSION_PROTOT(), + // _ioContext(static_cast(socket.get_executor().context())), + // _socket(std::move(socket)), + // _receiveStrand(_ioContext), + // _receiveQueue() + // { + // } #ifdef USE_OPENSSL_WITH_ASIO AdcBaseNetServiceASIO(asio::io_context& ctx, @@ -300,6 +307,7 @@ public: AdcBaseNetServiceASIO(AdcBaseNetServiceASIO&& other) + requires(!isTLS) : _ioContext(other._ioContext), _receiveStrand(std::move(other._receiveStrand)), _socket(std::move(other._socket)), @@ -312,6 +320,23 @@ public: _streamBuffer.commit(bytes); } +#ifdef USE_OPENSSL_WITH_ASIO + AdcBaseNetServiceASIO(AdcBaseNetServiceASIO&& other) + requires isTLS + : _ioContext(other._ioContext), + _receiveStrand(std::move(other._receiveStrand)), + _socket(std::move(other._socket)), + _sessSocket(std::move(other._sessSocket)), + _streamBuffer(), + _receiveQueue(std::move(other._receiveQueue)), + _tlsContext(std::move(other._tlsContext)), + _tlsPeerVerifyMode(std::move(other._tlsPeerVerifyMode)) + + { + auto bytes = asio::buffer_copy(_streamBuffer.prepare(other._streamBuffer.size()), other._streamBuffer.data()); + _streamBuffer.commit(bytes); + } +#endif AdcBaseNetServiceASIO(const AdcBaseNetServiceASIO&) = delete; // no copy constructor! @@ -321,6 +346,7 @@ public: AdcBaseNetServiceASIO& operator=(const AdcBaseNetServiceASIO&) = delete; AdcBaseNetServiceASIO& operator=(AdcBaseNetServiceASIO&& other) + requires(!isTLS) { _ioContext = other._ioContext; _receiveStrand = std::move(other._receiveStrand); @@ -335,7 +361,29 @@ public: return *this; }; +#ifdef USE_OPENSSL_WITH_ASIO + AdcBaseNetServiceASIO& operator=(AdcBaseNetServiceASIO&& other) + requires isTLS + { + _ioContext = other._ioContext; + _receiveStrand = std::move(other._receiveStrand); + _receiveQueue = std::move(other._receiveQueue); + _socket = std::move(other._socket); + _sessSocket = std::move(other._sessSocket); + _tlsContext = std::move(other._tlsContext); + _tlsPeerVerifyMode = std::move(other._tlsPeerVerifyMode); + + + _streamBuffer.consume(_streamBuffer.size()); + + auto bytes = asio::buffer_copy(_streamBuffer.prepare(other._streamBuffer.size()), other._streamBuffer.data()); + _streamBuffer.commit(bytes); + + return *this; + }; + +#endif constexpr netservice_ident_t ident() const { return _ident; @@ -424,7 +472,7 @@ public: return _socket.async_send(buff_seq, std::move(self)); } else if constexpr (std::derived_from>) { - return _socket.async_send(buff_seq, std::move(self)); + return _socket.async_send(buff_seq, 0, std::move(self)); } else { static_assert(false, "UNKNOWN ASIO-LIBRARY SOCKET TYPE!!!"); } @@ -772,8 +820,12 @@ protected: std::chrono::steady_clock::time_point::max() - std::chrono::steady_clock::now() - std::chrono::seconds(1)); - timer->expires_after(timeout < max_d ? timeout : max_d); // to avoid overflow! - // timer->expires_after(timeout); + if (timeout < max_d) { + timer->expires_after(timeout); + } else { + timer->expires_after(max_d); // to avoid overflow! + } + timer->async_wait([&obj](const std::error_code& ec) mutable { if (!ec) { diff --git a/tests/adc_asio_netserver_test.cpp b/tests/adc_asio_netserver_test.cpp index c0b65dd..beff25a 100644 --- a/tests/adc_asio_netserver_test.cpp +++ b/tests/adc_asio_netserver_test.cpp @@ -15,7 +15,7 @@ int main(int argc, char* argv[]) options.add_options()("h,help", "Print usage"); options.add_options()( "endpoints", "endpoints server will be listening for", - cxxopts::value>()->default_value("local://stream/tmp/ADC_ASIO_TEST_SERVER")); + cxxopts::value>()->default_value("local://stream/@ADC_ASIO_TEST_SERVER")); options.parse_positional({"endpoints"}); @@ -37,11 +37,10 @@ int main(int argc, char* argv[]) adc::AdcEndpointParser epn(ep); if (epn.isValid()) { if (epn.isLocalSeqpacket() || epn.isLocalStream()) { - if (opt_result["abstract"].as()) { - auto s = epn.path>(); - if (s[0] == '@') { // replace '@' to '\0' (use of UNIX abstract namespace) - s[0] = '\0'; - } + if (epn.path()[0] == '@') { // replace '@' to '\0' (use of UNIX abstract namespace) + auto it = std::ranges::find(ep, '@'); + *it = '\0'; + epn = adc::AdcEndpointParser(ep); } }