diff --git a/net/adc_netserver.h b/net/adc_netserver.h index 4bb055a..eeb88a8 100644 --- a/net/adc_netserver.h +++ b/net/adc_netserver.h @@ -260,7 +260,7 @@ public: protected: - // started sessions waek pointers + // started sessions weak pointers template static std::unordered_map> _serverSessions; std::vector> _stopSessionFunc; diff --git a/net/adc_netservice.h b/net/adc_netservice.h index 558c4b3..c90492e 100644 --- a/net/adc_netservice.h +++ b/net/adc_netservice.h @@ -28,12 +28,11 @@ concept adc_time_duration_c = requires { } // namespace traits -template -class AdcNetService -{ -protected: - ImplT _impl; +/* The class incapsulates network operations */ +template +class AdcNetService : public ImplT +{ public: using impl_t = ImplT; @@ -42,7 +41,7 @@ public: using typename ImplT::timeout_t; template - AdcNetService(ImplCtorArgTs&&... ctor_args) : _impl(std::forward(ctor_args)...) + AdcNetService(ImplCtorArgTs&&... ctor_args) : impl_t(std::forward(ctor_args)...) { } @@ -52,53 +51,60 @@ public: /* asynchronuos operations */ - // open connection + // start accepting client connections (server side) + template + auto asyncAccept(const endpoint_t& end_point, const timeout_t& timeout, ArgTs&&... args) + { + return impl_t::asyncAccept(end_point, timeout, std::forward(args)...); + } + + // open connection (client side) template auto asyncConnect(const endpoint_t& end_point, const timeout_t& timeout, ArgTs&&... args) { - return _impl.asyncConnect(end_point, timeout, std::forward(args)...); + return impl_t::asyncConnect(end_point, timeout, std::forward(args)...); } template auto asyncSend(const NetMessageT& msg, const timeout_t& timeout, ArgTs&&... args) { - return _impl.asyncSend(msg, timeout, std::forward(args)...); + return impl_t::asyncSend(msg, timeout, std::forward(args)...); } template auto asyncReceive(const timeout_t& timeout, ArgTs&&... args) { - return _impl.asyncReceive(timeout, std::forward(args)...); + return impl_t::asyncReceive(timeout, std::forward(args)...); } - /* blocking operations */ + /* blocking operations (throw exceptions if there is an error) */ template auto connect(const endpoint_t& endpoint, const timeout_t& timeout, ArgTs&&... args) { - return _impl.connect(endpoint, timeout, std::forward(args)...); + return impl_t::connect(endpoint, timeout, std::forward(args)...); } template auto send(const NetMessageT& msg, const timeout_t& timeout, ArgTs&&... args) { - return _impl.send(msg, timeout, std::forward(args)...); + return impl_t::send(msg, timeout, std::forward(args)...); } template NetMessageT receive(const timeout_t& timeout, ArgTs&&... args) { - return _impl.receive(timeout, std::forward(args)...); + return impl_t::receive(timeout, std::forward(args)...); } template auto close(ArgTs&&... args) { - return _impl.close(std::forward(args)...); + return impl_t::close(std::forward(args)...); } }; diff --git a/net/adc_netservice_asio.h b/net/adc_netservice_asio.h index 1e5527a..08d501d 100644 --- a/net/adc_netservice_asio.h +++ b/net/adc_netservice_asio.h @@ -46,11 +46,13 @@ namespace adc::traits // still only TCP, UDP and UNIX template -concept adc_asio_inet_proto_c = requires { - requires std::derived_from || std::derived_from || - std::derived_from || - std::derived_from; -}; +concept adc_asio_inet_proto_c = + std::derived_from || std::derived_from || + std::derived_from || std::derived_from; + +template +concept adc_asio_inet_stream_proto_c = + std::derived_from || std::derived_from; } // namespace adc::traits @@ -58,46 +60,43 @@ namespace adc::impl { -template +template class AdcNetServiceASIO : public InetProtoT { public: using socket_t = typename InetProtoT::socket; using endpoint_t = typename InetProtoT::endpoint; - using inet_proto_t = InetProtoT; - using high_level_socket_t = HighLevelSocketT; + using acceptor_t = + std::conditional_t>, + std::nullptr_t, + typename InetProtoT::acceptor>; + using inet_proto_t = InetProtoT; typedef std::chrono::steady_clock::duration timeout_t; // nanoseconds resolution -#ifdef USE_OPENSSL_WITH_ASIO - - static_assert(std::is_same_v - ? true - : std::is_same_v>, - "ONLY BASIC SOCKETS AND SSL::STREAM ARE SUPPORTED!!!"); - - static constexpr bool IsBasicSocket = std::is_same_v; -#else - static_assert(std::is_same_v, - "HighLevelSocketT AND InetProtoT::socket TYPES MUST BE THE SAME!!!"); -#endif - using streambuff_iter_t = asio::buffers_iterator; - AdcNetServiceASIO(high_level_socket_t& sock) : _socket(sock) {} + AdcNetServiceASIO(asio::io_context& io_context) : _ioContext(io_context), _socket(io_context), _acceptor(io_context) + { + } virtual ~AdcNetServiceASIO() = default; template CompletionTokenT> - auto asyncConnect(const endpoint_t& endpoint, const TimeoutT& timeout, CompletionTokenT&& token) + auto asyncAccept(const endpoint_t& endpoint, const TimeoutT& timeout, CompletionTokenT&& token) { + // no acceptor for UDP-sockets + if constexpr (!std::is_null_pointer_v) { + return; + } + auto timer = getDeadlineTimer(timeout); - enum { starting, cancel_timer }; + enum { starting, finishing }; // wrapper return asio::async_compose( @@ -106,21 +105,57 @@ public: if (!ec) { switch (state) { case starting: - state = cancel_timer; - if constexpr (AdcNetServiceASIO::IsBasicSocket) { - return _socket.async_connect(endpoint, std::move(self)); - } else { - return _socket.lowest_layer().async_connect(endpoint, std::move(self)); + state = finishing; + + try { + _acceptor = acceptor_t(_ioContext, endpoint); + } catch (std::system_error err) { + timer->cancel(); + self.complete(err.code()); + return; } + + return _acceptor.async_accept(_socket, std::move(self)); break; - case cancel_timer: - timer->cancel(); + case finishing: break; default: break; } } + timer->cancel(); + self.complete(ec); + }, + token, _socket); + } + + + template CompletionTokenT> + auto asyncConnect(const endpoint_t& endpoint, const TimeoutT& timeout, CompletionTokenT&& token) + { + auto timer = getDeadlineTimer(timeout); + + enum { starting, finishing }; + + // wrapper + return asio::async_compose( + [timer = std::move(timer), state = starting, &endpoint, this](auto& self, + const std::error_code& ec = {}) mutable { + if (!ec) { + switch (state) { + case starting: + state = finishing; + return _socket.async_connect(endpoint, std::move(self)); + break; + case finishing: + break; + default: + break; + } + } + + timer->cancel(); self.complete(ec); }, token, _socket); @@ -129,9 +164,9 @@ public: template CompletionTokenT> - auto asynSend(const NetMessageT& msg, const TimeoutT& timeout, CompletionTokenT&& token) + auto asyncSend(const NetMessageT& msg, const TimeoutT& timeout, CompletionTokenT&& token) { - enum { starting, cancel_timer }; + enum { starting, finishing }; // create buffer sequence std::vector buff; @@ -147,7 +182,7 @@ public: if (!ec) { switch (state) { case starting: - state = cancel_timer; + state = finishing; if constexpr (std::derived_from< socket_t, asio::basic_stream_socket>) { return asio::async_write(_socket, buff, std::move(self)); @@ -161,7 +196,7 @@ public: static_assert(false, "UNKNOWN ASIO-LIBRARY SOCKET TYPE!!!"); } break; - case cancel_timer: + case finishing: timer->cancel(); break; default: @@ -178,7 +213,7 @@ public: template auto asyncReceive(const TimeoutT& timeout, CompletionTokenT&& token) { - enum { starting, cancel_timer }; + enum { starting, finishing }; std::unique_ptr out_flags; @@ -190,7 +225,7 @@ public: if (!ec) { switch (state) { case starting: - state = cancel_timer; + state = finishing; if constexpr (std::derived_from< socket_t, asio::basic_stream_socket>) { return asio::async_read_until( @@ -209,7 +244,7 @@ public: static_assert(false, "UNKNOWN ASIO-LIBRARY SOCKET TYPE!!!"); } break; - case cancel_timer: + case finishing: timer->cancel(); break; default: @@ -245,6 +280,8 @@ public: token, _socket); } + + template auto connect(const endpoint_t& endpoint, const TimeoutT& timeout) { @@ -271,26 +308,20 @@ public: { std::error_code ec; - if constexpr (AdcNetServiceASIO::IsBasicSocket) { - _socket.shutdown(stype, ec); - if (!ec) { - _socket.close(ec); - } - } else { - _socket.shutdown(ec); // shutdown OpenSSL stream - if (!ec) { - _socket.lowest_layer().shutdown(stype, ec); - if (!ec) { - _socket.lowest_layer().close(ec); - } - } + _socket.shutdown(stype, ec); + if (!ec) { + _socket.close(ec); } return ec; } protected: - high_level_socket_t& _socket; + asio::io_context& _ioContext; + + socket_t _socket; + + acceptor_t _acceptor; asio::streambuf _streamBuffer; @@ -314,8 +345,371 @@ protected: }; +#ifdef USE_OPENSSL_WITH_ASIO +template +class AdcNetServiceAsioTls : public InetProtoT +{ +public: + using socket_t = typename InetProtoT::socket; + using stream_t = asio::ssl::stream; + using endpoint_t = typename InetProtoT::endpoint; + + using acceptor_t = typename InetProtoT::acceptor; + + using inet_proto_t = InetProtoT; + + typedef std::chrono::steady_clock::duration timeout_t; // nanoseconds resolution + + // TLS certificate attributes comparison function: + // 'serial' - as returned by OpenSSL BN_bn2hex + // 'fingerprint' - as returned by OpenSSL X509_digest + // 'depth' - depth in chain + // the function must return 0 - if comparison failed; otherwise - something != 0 + typedef std::function& fingerprint, int depth)> + cert_comp_func_t; + + using streambuff_iter_t = asio::buffers_iterator; + + AdcNetServiceAsioTls(asio::io_context& io_context, + asio::ssl::context&& tls_context = asio::ssl::context(asio::ssl::context::tlsv13), + asio::ssl::verify_mode verify_mode = asio::ssl::verify_peer) + : _ioContext(io_context), + _tlsStream(socket_t(io_context), io_context), + _tlsContext(std::move(tls_context)), + _acceptor(io_context), + _tlsPeerVerifyMode(verify_mode), + _tlsCertFingerprintDigest("sha256"), + _tlsCertCompFunc(nullptr) + { + } + + template CompletionTokenT> + auto asyncAccept(const endpoint_t& endpoint, const TimeoutT& timeout, CompletionTokenT&& token) + { + auto sock = std::make_unique(_ioContext); + + auto timer = getDeadlineTimer(timeout); + + enum { starting, handshaking, finishing }; + + // wrapper + return asio::async_compose( + [timer = std::move(timer), sock = std::move(sock), state = starting, &endpoint, this]( + auto& self, std::error_code ec = {}) mutable { + if (!ec) { + switch (state) { + case starting: + state = handshaking; + + try { + _acceptor = acceptor_t(_ioContext, endpoint); + return _acceptor.async_accept(*sock, std::move(self)); + } catch (std::system_error err) { + ec = err.code(); + } + break; + case handshaking: + state = finishing; + _tlsStream = stream_t(std::move(sock), _tlsContext); + _tlsStream->set_verify_mode(_tlsPeerVerifyMode, ec); + if (!ec) { + _tlsStream->set_verify_callback( + [this](bool preverified, asio::ssl::verify_context& ctx) { + return verifyCertificate(preverified ? 1 : 0, ctx.native_handle()); + }); + + return _tlsStream->async_handshake(asio::ssl::stream_base::server, std::move(self)); + } + break; + case finishing: + break; + default: + break; + } + } + + timer->cancel(); + self.complete(ec); + }, + token, _ioContext); + } + + + template CompletionTokenT> + auto asyncConnect(const endpoint_t& endpoint, const TimeoutT& timeout, CompletionTokenT&& token) + { + auto sock = std::make_unique(_ioContext); + + auto timer = getDeadlineTimer(timeout); + + enum { starting, handshaking, finishing }; + + // wrapper + return asio::async_compose( + [timer = std::move(timer), sock = std::move(sock), state = starting, &endpoint, this]( + auto& self, std::error_code ec = {}) mutable { + if (!ec) { + switch (state) { + case starting: + state = handshaking; + return sock->async_connect(endpoint, std::move(self)); + break; + case handshaking: + state = finishing; + _tlsStream = stream_t(std::move(*sock), _tlsContext); + return _tlsStream.async_handshake(asio::ssl::stream_base::client, std::move(self)); + break; + case finishing: + break; + default: + break; + } + } + + timer->cancel(); + self.complete(ec); + }, + token, _ioContext); + } + + + + template CompletionTokenT> + auto asyncSend(const NetMessageT& msg, const TimeoutT& timeout, CompletionTokenT&& token) + { + enum { starting, finishing }; + + // create buffer sequence + std::vector buff; + std::ranges::for_each(msg.template bytesView>(), + [&buff](const auto& el) { buff.emplace_back(el); }); + + auto timer = getDeadlineTimer(timeout); + + // wrapper + return asio::async_compose( + [buff = std::move(buff), timer = std::move(timer), state = starting, this]( + auto& self, const std::error_code& ec = {}, size_t = 0) mutable { + if (!ec) { + switch (state) { + case starting: + state = finishing; + return asio::async_write(_tlsStream, buff, std::move(self)); + break; + case finishing: + timer->cancel(); + break; + default: + break; + } + } + + self.complete(ec); + }, + token, _ioContext); + } + + + template + auto asyncReceive(const TimeoutT& timeout, CompletionTokenT&& token) + { + enum { starting, finishing }; + + std::unique_ptr out_flags; + + auto timer = getDeadlineTimer(timeout); // armed timer + + return asio::async_compose( + [timer = std::move(timer), out_flags = std::move(out_flags), state = starting, this]( + auto& self, const std::error_code& ec = {}, size_t = 0) mutable { + if (!ec) { + switch (state) { + case starting: + state = finishing; + return asio::async_read_until( + _tlsStream, _streamBuffer, + [this](auto begin, auto end) { return this->matchCondition(begin, end); }, + std::move(self)); + break; + case finishing: + timer->cancel(); + break; + default: + break; + } + + auto begin_it = streambuff_iter_t::begin(_streamBuffer.data()); + auto end_it = begin_it + _streamBuffer.data().size(); + + // check for byte sequence is valid byte sequence and find the limits + // (stream buffer may contain number of bytes more than requred by protocol) + auto res = this->matchCondition(begin_it, end_it); + + if (!res.second) { + self.complete(std::make_error_code(std::errc::protocol_error), + NetMessageT()); // return an empty message + } else { + auto nbytes = std::distance(begin_it, res.first); + NetMessageT msg; + + auto msg_it = this->fromLowLevel(begin_it, res.first); + msg.setFromBytes(msg_it.first, msg_it.second); + + _streamBuffer.consume(nbytes); + + self.complete(ec, msg); + } + } else { + self.complete(ec, NetMessageT()); // return an empty message + return; + } + }, + token, _ioContext); + } + + + template + auto connect(const endpoint_t& endpoint, const TimeoutT& timeout) + { + std::future ftr = asyncConnect(endpoint, timeout, asio::use_future); + ftr.get(); + } + + template + auto send(const NetMessageT& msg, const TimeoutT& timeout) + { + std::future ftr = asyncSend(msg, timeout, asio::use_future); + ftr.get(); + } + + template + auto receive(const TimeoutT& timeout) + { + std::future ftr = asyncReceive(timeout, asio::use_future); + return ftr.get(); + } + + std::error_code close(asio::socket_base::shutdown_type stype = asio::socket_base::shutdown_both) + { + std::error_code ec; + + _tlsStream.shutdown(ec); // shutdown OpenSSL stream + if (!ec) { + _tlsStream.lowest_layer().shutdown(stype, ec); + if (!ec) { + _tlsStream.lowest_layer().close(ec); + } + } + + return ec; + } + + // special TLS-related methods + void setPeerVerifyMode(asio::ssl::verify_mode mode = asio::ssl::verify_peer) + { + // restart TLS server?!! + _tlsPeerVerifyMode = mode; + } + + asio::ssl::verify_mode getPeerVerifyMode() const + { + return _tlsPeerVerifyMode; + } + + AdcNetServiceAsioTls& setTLSCertFingerprintDigest(const std::string& digest = "sha256") + { + // check for validness?!! + _tlsCertFingerprintDigest = digest; + } + + std::string getTLSCertFingerprintDigest() const + { + return _tlsCertFingerprintDigest; + } + + template &, int> FuncT> + AdcNetServiceAsioTls& setTLSCertCompFunc(FuncT&& func) + { + _tlsCertCompFunc = static_cast(std::forward(func)); + } + + +protected: + asio::io_context& _ioContext; + + acceptor_t _acceptor; + stream_t _tlsStream; + asio::ssl::context _tlsContext; + asio::ssl::verify_mode _tlsPeerVerifyMode; + std::string _tlsCertFingerprintDigest; + cert_comp_func_t _tlsCertCompFunc; + + asio::streambuf _streamBuffer; + + + // reference implementation + virtual bool verifyCertificate(int preverified_ok, X509_STORE_CTX* store) + { + if (preverified_ok == 0) { + int err = X509_STORE_CTX_get_error(store); + auto err_str = X509_verify_cert_error_string(err); + // log_error("TLS certificate verification error: {}", err_str); + + return preverified_ok; + } + + char subject_name[256]; + + int depth; + ASN1_INTEGER* serial; + BIGNUM* bnser; + + X509* cert = X509_STORE_CTX_get_current_cert(store); + + if (cert != NULL) { + depth = X509_STORE_CTX_get_error_depth(store); + X509_NAME_oneline(X509_get_subject_name(cert), subject_name, 256); + serial = X509_get_serialNumber(cert); // IT IS INTERNAL POINTER SO IT MUST NOT BE FREED UP!!! + bnser = ASN1_INTEGER_to_BN(serial, NULL); + auto serial_hex = BN_bn2hex(bnser); + + // log_debug("Received TLS certificate: SUBJECT = {}, SERIAL = {}, DEPTH = {}", subject_name, serial_hex, + // depth); + + // if no compare function then do not compute fingerprint + if (_tlsCertCompFunc) { + // compute certificate fingerprint + unsigned char digest_buff[EVP_MAX_MD_SIZE]; + const EVP_MD* digest = EVP_get_digestbyname(_tlsCertFingerprintDigest.c_str()); + unsigned int N; + + if (X509_digest(cert, digest, digest_buff, &N)) { + preverified_ok = _tlsCertCompFunc(std::string(serial_hex), + std::vector(digest_buff, digest_buff + N), depth); + + } else { + // log_error("Cannot compute client certificate fingerprint! Cannot verify the certificate!"); + preverified_ok = 0; + } + } + + BN_free(bnser); + OPENSSL_free(serial_hex); + + } else { + // log_error("OpenSSL error: cannot get current certificate"); + preverified_ok = 0; + } + + return preverified_ok; + } +}; +#endif + typedef AdcNetService> AdcNetServiceAsioTcp; -typedef AdcNetService>> AdcNetServiceAsioTls; +typedef AdcNetService> AdcNetServiceAsioUdp; typedef AdcNetService> AdcNetServiceAsioLocalSeqPack; typedef AdcNetService> AdcNetServiceAsioLocalStream; @@ -328,10 +722,13 @@ namespace adc::traits template concept adc_netservice_asio_c = requires { typename T::inet_proto_t; - typename T::high_level_socket_t; - requires std::derived_from>; +#ifdef USE_OPENSSL_WITH_ASIO + requires std::derived_from> || + std::derived_from>; +#else + requires std::derived_from>; +#endif }; } // namespace adc::traits