#pragma once /* MOUNT CONTROL COMPONENTS LIBRARY */ /* A VERY SIMPLE NETWORK SERVER IMPLEMENTATION */ #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #include #if __has_include() // POSIX #define FORK_EXISTS 1 #include #include #endif #include "mcc_generics.h" #include "mcc_netserver_endpoint.h" #include "mcc_netserver_proto.h" #include "mcc_traits.h" namespace mcc::network { namespace traits { template concept mcc_endpoint_c = std::derived_from || std::derived_from || std::derived_from || std::derived_from; template static constexpr bool is_serial_proto = std::derived_from; template static constexpr bool is_tcp_proto = std::derived_from || std::derived_from; template static constexpr bool is_local_stream_proto = std::derived_from || std::derived_from; template static constexpr bool is_local_seqpack_proto = std::derived_from || std::derived_from; } // namespace traits template class MccNetworkServer : public LoggerT { public: using LoggerT::logDebug; using LoggerT::logError; using LoggerT::logInfo; using LoggerT::logTrace; using LoggerT::logWarn; static constexpr std::chrono::duration DEFAULT_RCV_TIMEOUT = std::chrono::hours(12); static constexpr std::chrono::duration DEFAULT_SND_TIMEOUT = std::chrono::milliseconds(2000); // handle received message user function typedef std::function(std::string_view)> handle_message_func_t; MccNetworkServer(asio::io_context& ctx, const handle_message_func_t& func, LoggerT logger = MccNullLogger{}) : _asioContext(ctx), _handleMessageFunc(func), _stopSignal(ctx), _restartSignal(ctx) { std::stringstream st; st << std::this_thread::get_id(); logInfo(std::format("Create mount server instance (thread ID = {})", st.str())); } ~MccNetworkServer() { std::stringstream st; st << std::this_thread::get_id(); logInfo(std::format("Delete mount server instance (thread ID = {}) ...", st.str())); stopListening(); disconnectClients(); } template asio::awaitable listen(std::derived_from auto endpoint, CtorArgTs&&... ctor_args) { if (!endpoint.isValid()) { logError(std::format("Cannot start listening! Invalid endpoint string representation ('{}')!", endpoint.endpoint())); co_return; } // add root path to endpoint one std::filesystem::path pt("/"); if (endpoint.isLocalSerial()) { pt += endpoint.path(); asio::serial_port s_port(_asioContext); std::error_code ec; if constexpr (sizeof...(CtorArgTs)) { // options setSerialOpts(s_port, std::forward(ctor_args)...); } s_port.open(pt.string(), ec); if (ec) { logError(std::format("Cannot open serial device '{}' (Error = '{}')!", pt.string(), ec.message())); co_return; } // asio::co_spawn(_asioContext, listen(std::move(s_port)), asio::detached); co_await listen(std::move(s_port)); } else if (endpoint.isLocal()) { // create abstract namespace socket endpoint if its path starts from '@' symbol endpoint.makeAbstract('@'); // if (endpoint.path()[0] == '\0') { // abstract namespace // std::string p; // std::ranges::copy(endpoint.path(), std::back_inserter(p)); // p.insert(p.begin() + 1, '/'); // insert after '\0' symbol // pt = p; // } else { // pt += endpoint.path(); // } if (endpoint.isLocalStream()) { co_await listen(asio::local::stream_protocol::endpoint(endpoint.path(pt.string()))); } else if (endpoint.isLocalSeqpacket()) { co_await listen(asio::local::seq_packet_protocol::endpoint(endpoint.path(pt.string()))); } else { co_return; // it must not be!!!! } } else if (endpoint.isTCP()) { // resolve hostname try { asio::ip::tcp::resolver res(_asioContext); auto r_result = co_await res.async_resolve(endpoint.host(), endpoint.portView(), asio::use_awaitable); logInfo(std::format("Resolve hostname <{}> to {} IP-addresses", endpoint.host(), r_result.size())); bool exit_flag = false; asio::ip::tcp::acceptor acc(_asioContext); for (auto const& epn : r_result) { try { // std::stringstream st; // logDebug("Create connection acceptor for endpoint <{}> ...", // epn.address().to_string()); acc = asio::ip::tcp::acceptor(_asioContext, epn); // st << acc.local_endpoint(); exit_flag = true; break; } catch (const std::system_error& err) { logError( std::format("An error occuring while creating connection acceptor (ec = {})", err.what())); continue; } } if (!exit_flag) { logError("Cannot start listening on any resolved endpoints!"); co_return; } _tcpAcceptors.emplace_back(&acc); logInfo( std::format("Start listening at <{}> endpoint ...", acc.local_endpoint().address().to_string())); // start accepting connections for (;;) { auto sock = co_await acc.async_accept(asio::use_awaitable); // start new client session asio::co_spawn(_asioContext, startSession(std::move(sock)), asio::detached); } } catch (const std::system_error& err) { logError( std::format("An error occured while trying to start accepting connections! ec = '{}'", err.what())); } } } template asio::awaitable listen(EpnT endpoint) { using epn_t = std::decay_t; std::error_code ec; if constexpr (traits::is_serial_proto) { // first, check if port is open if (!endpoint.is_open()) { if (ec) { // ?????????? logError("Serial port was not open! Do not start waiting for commands!"); } } else { asio::co_spawn(_asioContext, startSession(std::move(endpoint)), asio::detached); } } else if constexpr (traits::is_tcp_proto || traits::is_local_stream_proto || traits::is_local_seqpack_proto) { try { std::stringstream st; st << endpoint; logDebug(std::format("Create connection acceptor for endpoint <{}> ...", st.str())); auto acc = typename epn_t::protocol_type::acceptor(_asioContext, endpoint); st.str(""); st << acc.local_endpoint(); logInfo(std::format("Start listening at <{}> endpoint ...", st.str())); if constexpr (traits::is_tcp_proto) { _tcpAcceptors.emplace_back(&acc); } else if constexpr (traits::is_local_stream_proto) { _localStreamAcceptors.emplace_back(&acc); } else if constexpr (traits::is_local_seqpack_proto) { _localSeqpackAcceptors.emplace_back(&acc); } else { static_assert(false, "INVALID ENDPOINT!!!"); } // start accepting connections for (;;) { auto sock = co_await acc.async_accept(asio::use_awaitable); // start new client session asio::co_spawn(_asioContext, startSession(std::move(sock)), asio::detached); } } catch (const std::system_error& err) { logError( std::format("An error occured while trying to start accepting connections! ec = '{}'", err.what())); } } else { static_assert(false, "INVALID ENDPOINT!!!"); } co_return; } // close listening on all endpoints void stopListening() { std::error_code ec; logInfo("Close all listening endpoints ..."); auto num = _serialPorts.size() + _tcpAcceptors.size() + _localStreamAcceptors.size() + _localSeqpackAcceptors.size(); if (!num) { logInfo("There are no listening ports/sockets!"); return; } auto close_func = [this](auto& acc_ptrs, std::string_view desc) { size_t N = 0, M = 0; std::error_code ec; if (acc_ptrs.size()) { logInfo(std::format("Close {} acceptors ...", desc)); for (auto& acc : acc_ptrs) { acc->close(ec); if (ec) { logError(std::format("Cannot close {} acceptor! ec = '{}'", desc, ec.message())); } else { ++M; } ++N; } logDebug(std::format("{} from {} {} acceptors were closed!", M, N, desc)); // pointers are invalidated here, so clear its container acc_ptrs.clear(); } }; close_func(_tcpAcceptors, "TCP socket"); close_func(_localStreamAcceptors, "local stream socket"); close_func(_localSeqpackAcceptors, "local seqpack socket"); logInfo("The all server listening endpoints were closed!"); } void disconnectClients() { auto disconn_func = [this](std::ranges::input_range auto& ptrs) { std::error_code ec; for (auto& ptr : ptrs) { // ptr->cancel(ec); // if (ec) { // logWarn("socket_base::cancel: an error occured (ec = {})", ec.message()); // } ptr->shutdown(asio::socket_base::shutdown_both, ec); if (ec) { logWarn(std::format("socket_base::shutdown: an error occured (ec = {})", ec.message())); } ptr->close(ec); if (ec) { logWarn(std::format("socket_base::close: an error occured (ec = {})", ec.message())); } } }; logInfo("Close all client connections ..."); if (_serialPorts.empty() && _localStreamSockets.empty() && _localSeqpackSockets.empty() && _tcpSockets.empty()) { logInfo("There were no active client connections! Skip!"); } if (_serialPorts.size()) { std::lock_guard lock_g(_serialPortsMutex); std::error_code ec; logInfo(std::format("Close serial port clients ({} in total) ...", _serialPorts.size())); for (auto& ptr : _serialPorts) { ptr->cancel(ec); if (ec) { logWarn(std::format("serial_port::cancel: an error occured (ec = {})", ec.message())); } ptr->close(ec); if (ec) { logWarn(std::format("serial_port::close: an error occured (ec = {})", ec.message())); } } } if (_localStreamSockets.size()) { std::lock_guard lock_g(_localStreamSocketsMutex); logInfo( std::format("Close local stream socket-type clients ({} in total) ...", _localStreamSockets.size())); disconn_func(_localStreamSockets); } if (_localSeqpackSockets.size()) { std::lock_guard lock_g(_localSeqpackSocketsMutex); logInfo( std::format("Close local seqpack socket-type clients ({} in total) ...", _localSeqpackSockets.size())); disconn_func(_localSeqpackSockets); } if (_tcpSockets.size()) { std::lock_guard lock_g(_tcpSocketsMutex); logInfo(std::format("Close TCP socket-type clients ({} in total) ...", _tcpSockets.size())); disconn_func(_tcpSockets); } logInfo("Client connection were closed!"); } void daemonize() { #ifdef FORK_EXISTS logInfo("Daemonize the server ..."); _asioContext.notify_fork(asio::execution_context::fork_prepare); auto tmp_path = std::filesystem::temp_directory_path(); if (tmp_path.empty()) { tmp_path = std::filesystem::current_path().root_path(); } if (pid_t pid = fork()) { if (pid > 0) { exit(0); } else { // throw std::system_error(errno, std::generic_category(), "CANNOT FORK 1-STAGE"); logError("CANNOT FORK 1-STAGE! The server was not daemonized!"); return; } } if (setsid() == -1) { // throw std::system_error(errno, std::generic_category(), "CANNOT FORK SETSID"); logError("CANNOT FORK SETSID! The server was not daemonized!"); return; } logInfo(std::format("Try to set the daemon current path to '{}' ...", tmp_path.string())); std::error_code ec{}; std::filesystem::current_path(tmp_path, ec); if (!ec) { logWarn(std::format("Cannot change current path to '{}'! Ignore!", tmp_path.string())); } umask(0); if (pid_t pid = fork()) { if (pid > 0) { exit(0); } else { // throw std::system_error(errno, std::generic_category(), "CANNOT FORK 2-STAGE"); logError("CANNOT FORK 2-STAGE! The server was not daemonized!"); return; } } // stdin, stdout, stderr close(0); close(1); close(2); _asioContext.notify_fork(asio::io_context::fork_child); logInfo("The server was daemonized successfully!"); #else logWarn("Host platform is not POSIX one, so cannot daemonize the server!"); #endif } template , std::ranges::range RRT = std::vector> void setupSignals(const RST& stop_sig_num = {SIGINT, SIGTERM}, const RRT& restart_sig_num = {SIGUSR1}) requires(std::convertible_to, int> && std::convertible_to, int>) { for (const int sig : stop_sig_num) { _stopSignal.add(sig); } _stopSignal.async_wait([this](std::error_code, int signo) { logInfo(std::format("Stop signal was received (signo = {})", signo)); stopListening(); disconnectClients(); _asioContext.stop(); }); for (const int sig : restart_sig_num) { _restartSignal.add(sig); } _restartSignal.async_wait([this](std::error_code, int signo) { logInfo(std::format("Restart signal was received (signo = {})", signo)); restart(); }); } void restart() { disconnectClients(); _restartSignal.async_wait([this](std::error_code, int signo) { logInfo(std::format("Restart signal was received (signo = {})", signo)); restart(); }); } private: asio::io_context& _asioContext; handle_message_func_t _handleMessageFunc; asio::signal_set _stopSignal, _restartSignal; std::set _serialPorts; // std::vector _serialPorts; std::vector _tcpAcceptors; std::vector _localStreamAcceptors; std::vector _localSeqpackAcceptors; std::set _tcpSockets; std::set _localStreamSockets; std::set _localSeqpackSockets; // std::vector _tcpSockets; // std::vector _localStreamSockets; // std::vector _localSeqpackSockets; std::mutex _serialPortsMutex, _tcpSocketsMutex, _localStreamSocketsMutex, _localSeqpackSocketsMutex; // helpers template void setSerialOpts(asio::serial_port& s_port, OptT&& opt, OptTs&&... opts) { std::error_code ec; s_port.set_option(opt, ec); if (ec) { std::string_view opt_name; if constexpr (std::same_as) { opt_name = "baud rate"; } else if constexpr (std::same_as) { opt_name = "parity"; } else if constexpr (std::same_as) { opt_name = "flow control"; } else if constexpr (std::same_as) { opt_name = "stop bits"; } else if constexpr (std::same_as) { opt_name = "char size"; } logError(std::format("Cannot set serial port '{}' option! Just skip!", opt_name)); } if constexpr (sizeof...(OptTs)) { setSerialOpts(s_port, std::forward(opts)...); } } std::vector handleClientCommand(std::string_view command) { std::vector resp{MCC_COMMPROTO_KEYWORD_SERVER_ACK_STR.begin(), MCC_COMMPROTO_KEYWORD_SERVER_ACK_STR.end()}; return resp; } template asio::awaitable startSession(auto socket, const RCVT& rcv_timeout = DEFAULT_RCV_TIMEOUT, const SNDT& snd_timeout = DEFAULT_SND_TIMEOUT) { using namespace asio::experimental::awaitable_operators; using sock_t = std::decay_t; auto look_for_whole_msg = [](auto const& bytes) { auto found = std::ranges::search(bytes, MCC_COMMPROTO_STOP_SEQ); return found.empty() ? std::span(bytes.begin(), bytes.begin()) : std::span(bytes.begin(), found.end()); }; asio::streambuf sbuff; size_t nbytes; std::stringstream st; std::string r_epn; st << std::this_thread::get_id(); std::string thr_id = st.str(); st.str(""); if constexpr (traits::is_serial_proto) { st << "serial port: " << socket.native_handle(); } else { // network sockets st << socket.remote_endpoint(); } r_epn = st.str(); if (r_epn.empty()) { // UNIX domain sockets r_epn = "local"; } logInfo(std::format("Start client session: remote endpoint <{}> (session thread ID = {})", r_epn, thr_id)); try { if constexpr (!traits::is_serial_proto) { logTrace("Set socket option KEEP_ALIVE to TRUE"); socket.set_option(asio::socket_base::keep_alive(true)); } if constexpr (traits::is_serial_proto) { std::lock_guard lock_g(_serialPortsMutex); _serialPorts.insert(&socket); } else if constexpr (traits::is_tcp_proto) { std::lock_guard lock_g(_tcpSocketsMutex); // _tcpSockets.emplace_back(&socket); _tcpSockets.insert(&socket); } else if constexpr (traits::is_local_stream_proto) { std::lock_guard lock_g(_localStreamSocketsMutex); // _localStreamSockets.emplace_back(&socket); _localStreamSockets.insert(&socket); } else if constexpr (traits::is_local_seqpack_proto) { std::lock_guard lock_g(_localSeqpackSocketsMutex); // _localSeqpackSockets.emplace_back(&socket); _localSeqpackSockets.insert(&socket); } else { static_assert(false, "INVALID SOCKET TTYPE!!!"); } // send buffer sequence // initiate the second element by "stop-sequence" symbols std::vector snd_buff_seq{ {}, {MCC_COMMPROTO_STOP_SEQ.data(), MCC_COMMPROTO_STOP_SEQ.size()}}; asio::steady_timer timeout_timer(_asioContext); std::variant op_res; std::error_code ec; bool do_read = true; // main "client request -- server respond" cycle for (;;) { // receive message if (do_read) { logTrace(std::format("Start socket/port reading operation with timeout {} ...", rcv_timeout)); if constexpr (traits::is_serial_proto) { nbytes = 1024; } else { nbytes = socket.available(); } auto buff = sbuff.prepare(nbytes ? nbytes : 1); // timeout_timer.expires_after(std::chrono::seconds(5)); timeout_timer.expires_after(rcv_timeout); if constexpr (traits::is_local_seqpack_proto) { asio::socket_base::message_flags oflags; op_res = co_await ( socket.async_receive(buff, oflags, asio::redirect_error(asio::use_awaitable, ec)) || timeout_timer.async_wait(asio::use_awaitable)); } else { op_res = co_await (asio::async_read(socket, buff, asio::transfer_at_least(1), asio::redirect_error(asio::use_awaitable, ec)) || timeout_timer.async_wait(asio::use_awaitable)); } if (ec) { throw std::system_error(ec); } if (op_res.index()) { throw std::system_error(std::make_error_code(std::errc::timed_out)); } else { nbytes = std::get<0>(op_res); logTrace(std::format("{} bytes were received", nbytes)); if constexpr (traits::is_local_seqpack_proto) { if (!nbytes) { // EOF! throw std::system_error(std::error_code(asio::error::misc_errors::eof)); } } } sbuff.commit(nbytes); } // here, the input stream buffer still contains remaining bytes. try to handle its auto start_ptr = static_cast(sbuff.data().data()); auto msg = look_for_whole_msg(std::span(start_ptr, sbuff.size())); if (msg.empty()) { // still not whole message logTrace(std::format( "It seems a partial command message was received, so waiting for remaining part ...")); do_read = true; continue; } // extract command without stop sequence symbols // std::string comm; // std::ranges::copy(msg | std::views::take(msg.size() - MCC_COMMPROTO_STOP_SEQ.size()), // std::back_inserter(comm)); std::string_view comm{msg.begin(), msg.end() - MCC_COMMPROTO_STOP_SEQ.size()}; logDebug(std::format("A command [{}] was received from client (remote endpoint <{}>, thread ID = {})", comm, r_epn, thr_id)); // auto resp = handleClientCommand(comm); auto resp = _handleMessageFunc(comm); // remove received message from the input stream buffer. NOTE: 'msg' is now invalidated!!! sbuff.consume(msg.size()); do_read = sbuff.size() == 0; logDebug(std::format("Send respond [{}] to client (remote endpoint <{}>, thread ID = {})", std::string_view(resp.begin(), resp.end()), r_epn, thr_id)); // send server respond to client snd_buff_seq[0] = {resp.data(), resp.size()}; timeout_timer.expires_after(snd_timeout); if constexpr (traits::is_local_seqpack_proto) { op_res = co_await (socket.async_send(snd_buff_seq, 0, asio::redirect_error(asio::use_awaitable, ec)) || timeout_timer.async_wait(asio::use_awaitable)); } else { // nbytes = co_await asio::async_write(socket, snd_buff_seq, asio::use_awaitable); op_res = co_await ( asio::async_write(socket, snd_buff_seq, asio::redirect_error(asio::use_awaitable, ec)) || timeout_timer.async_wait(asio::use_awaitable)); } if (ec) { throw std::system_error(ec); } if (op_res.index()) { throw std::system_error(std::make_error_code(std::errc::timed_out)); } else { nbytes = std::get<0>(op_res); logTrace(std::format("{} bytes were sent", nbytes)); } if (nbytes != (resp.size() + MCC_COMMPROTO_STOP_SEQ.size())) { // !!!!!!!!!! } } } catch (const std::system_error& ex) { if (ex.code() == std::error_code(asio::error::misc_errors::eof)) { logInfo(std::format( "It seems client or server closed the connection (remote endpoint <{}>, thread ID = {})", r_epn, thr_id)); } else { logError(std::format("An error '{}' occured in client session (remote endpoint <{}>, thread ID = {})", ex.what(), r_epn, thr_id)); } } catch (const std::exception& ex) { logError( std::format("An unhandled error '{}' occured in client sesssion (remote endpoint <{}>, thread ID = {})", ex.what(), r_epn, thr_id)); } catch (...) { logError(std::format("An unhandled error occured in client sesssion (remote endpoint <{}>, thread ID = {})", r_epn, thr_id)); } // remove pointer as it is invalidated here (at the exit of the method) if constexpr (traits::is_serial_proto) { _serialPorts.erase(&socket); } else if constexpr (traits::is_tcp_proto) { _tcpSockets.erase(&socket); } else if constexpr (traits::is_local_stream_proto) { _localStreamSockets.erase(&socket); } else if constexpr (traits::is_local_seqpack_proto) { _localSeqpackSockets.erase(&socket); } else { static_assert(false, "INVALID SOCKET TTYPE!!!"); } logInfo(std::format("Close client session: remote endpoint <{}> (thread ID = {})", r_epn, thr_id)); } }; } // namespace mcc::network