From c6df547acf11a04f9a3b05da2871ceba3d5b535f Mon Sep 17 00:00:00 2001 From: ick3 Date: Fri, 1 May 2026 02:19:05 +0200 Subject: [PATCH] Added OpenSSL to make use of ADS Secure Protocols. Added also a test_connections.sh and a test_connections.cpp to show usage and test usual commands like Read Write. Supported Are Shared CA Certificate PSK SSC Plain is still working Tested on TwinCAT on Linux. --- AdsLib/AdsLib.h | 14 + AdsLib/AmsConnection.h | 40 +- AdsLib/AmsRouter.h | 11 +- AdsLib/CMakeLists.txt | 5 +- AdsLib/SecureAdsConfig.h | 26 ++ AdsLib/standalone/AdsLib.cpp | 17 +- AdsLib/standalone/AdsLib.h | 1 + AdsLib/standalone/AmsConnection.cpp | 3 +- AdsLib/standalone/AmsRouter.cpp | 62 ++- AdsLib/standalone/SecureAmsConnection.cpp | 416 ++++++++++++++++++ AdsLib/standalone/SecureAmsConnection.h | 63 +++ AdsLib/standalone/TlsConnectInfo.h | 35 ++ AdsLib/standalone/TlsSocket.cpp | 340 +++++++++++++++ AdsLib/standalone/TlsSocket.h | 58 +++ certs/.gitkeep | 0 meson.build | 193 +++++---- tools/test_connections.cpp | 494 ++++++++++++++++++++++ tools/test_connections.sh | 227 ++++++++++ 18 files changed, 1886 insertions(+), 119 deletions(-) create mode 100644 AdsLib/SecureAdsConfig.h create mode 100644 AdsLib/standalone/SecureAmsConnection.cpp create mode 100644 AdsLib/standalone/SecureAmsConnection.h create mode 100644 AdsLib/standalone/TlsConnectInfo.h create mode 100644 AdsLib/standalone/TlsSocket.cpp create mode 100644 AdsLib/standalone/TlsSocket.h create mode 100644 certs/.gitkeep create mode 100644 tools/test_connections.cpp create mode 100755 tools/test_connections.sh diff --git a/AdsLib/AdsLib.h b/AdsLib/AdsLib.h index 42d3748a..e912320b 100644 --- a/AdsLib/AdsLib.h +++ b/AdsLib/AdsLib.h @@ -11,6 +11,7 @@ #include "standalone/AdsLib.h" #endif +#include "SecureAdsConfig.h" #include "Sockets.h" #ifdef BHF_ADS_EXPORT_C @@ -158,6 +159,19 @@ namespace ads */ long AddLocalRoute(AmsNetId ams, const char *ip); +/** + * Add a Secure ADS route (TLS 1.2, port 8016) to a target system. + * Supports SSC (self-signed) and SCA (shared CA) modes. + * For SSC first-time registration, populate config.username and config.password. + * For SCA, set config.caPath to the CA certificate used to verify the server. + * @param[in] ams AmsNetId of the target system + * @param[in] host hostname or IP of the target (port 8016 used by default) + * @param[in] config TLS certificate paths and optional credentials + * @return [ADS Return Code](https://infosys.beckhoff.com/content/1031/tcadscommon/html/ads_returncodes.htm?id=1666172286265530469) + */ +long AddSecureRoute(AmsNetId ams, const char *host, + const SecureAdsConfig &config); + /** * Delete ams route that had previously been added with AddLocalRoute(). * @param[in] ams address of the target system diff --git a/AdsLib/AmsConnection.h b/AdsLib/AmsConnection.h index ec0d54c1..35358cb4 100644 --- a/AdsLib/AmsConnection.h +++ b/AdsLib/AmsConnection.h @@ -66,32 +66,43 @@ struct AmsResponse { bool wasWritten; }; -struct AmsConnection { +/** + * Common interface for plain and secure AMS connections. + * AmsRouter stores pointers to this base to support both connection types. + */ +struct AmsConnectionBase { + std::atomic refCount{ 0 }; + uint32_t ownIp{ 0 }; + + virtual ~AmsConnectionBase() = default; + virtual bool IsConnectedTo(const struct addrinfo *) const = 0; + virtual long AdsRequest(AmsRequest &, uint32_t timeout) = 0; + virtual SharedDispatcher + CreateNotifyMapping(uint32_t hNotify, + std::shared_ptr notification) = 0; + virtual long DeleteNotification(const AmsAddr &, uint32_t hNotify, + uint32_t tmms, uint16_t port) = 0; +}; + +struct AmsConnection : AmsConnectionBase { AmsConnection(Router &__router, const struct addrinfo *destination = nullptr); ~AmsConnection(); SharedDispatcher CreateNotifyMapping(uint32_t hNotify, - std::shared_ptr notification); + std::shared_ptr notification) override; long DeleteNotification(const AmsAddr &amsAddr, uint32_t hNotify, - uint32_t tmms, uint16_t port); - long AdsRequest(AmsRequest &request, uint32_t timeout); - - /** - * Confirm if this AmsConnection is connected to one of the target addresses. - * @param[in] targetAddresses pointer to a previously allocated list of - * "struct addrinfo" returned by getaddrinfo(3). - * @return true, this connection can be used to reach one of the targetAddresses. - */ - bool IsConnectedTo(const struct addrinfo *targetAddresses) const; + uint32_t tmms, uint16_t port) override; + long AdsRequest(AmsRequest &request, uint32_t timeout) override; + bool + IsConnectedTo(const struct addrinfo *targetAddresses) const override; private: friend struct AmsRouter; Router &router; TcpSocket socket; std::thread receiver; - std::atomic refCount; std::atomic invokeId; std::array queue; @@ -119,7 +130,4 @@ struct AmsConnection { std::recursive_mutex dispatcherListMutex; SharedDispatcher DispatcherListAdd(const VirtualConnection &connection); SharedDispatcher DispatcherListGet(const VirtualConnection &connection); - - public: - const uint32_t ownIp; }; diff --git a/AdsLib/AmsRouter.h b/AdsLib/AmsRouter.h index 37da047a..98a11af3 100644 --- a/AdsLib/AmsRouter.h +++ b/AdsLib/AmsRouter.h @@ -6,6 +6,7 @@ #pragma once #include "AmsConnection.h" +#include "SecureAdsConfig.h" #include struct AmsRouter : Router { @@ -25,8 +26,10 @@ struct AmsRouter : Router { [[deprecated]] long AddRoute(AmsNetId ams, const IpV4 &ip); long AddRoute(AmsNetId ams, const std::string &host); + long AddSecureRoute(AmsNetId ams, const std::string &host, + const bhf::ads::SecureAdsConfig &config); void DelRoute(const AmsNetId &ams); - AmsConnection *GetConnection(const AmsNetId &pAddr); + AmsConnectionBase *GetConnection(const AmsNetId &pAddr); long AdsRequest(AmsRequest &request); private: @@ -34,13 +37,13 @@ struct AmsRouter : Router { std::recursive_mutex mutex; std::condition_variable_any connection_attempt_events; std::map > connection_attempts; - std::unordered_set > connections; - std::map mapping; + std::unordered_set > connections; + std::map mapping; void AwaitConnectionAttempts(const AmsNetId &ams, std::unique_lock &lock); - void DeleteIfLastConnection(const AmsConnection *conn); + void DeleteIfLastConnection(const AmsConnectionBase *conn); std::array ports; }; diff --git a/AdsLib/CMakeLists.txt b/AdsLib/CMakeLists.txt index 4fbbaab3..c8561afc 100644 --- a/AdsLib/CMakeLists.txt +++ b/AdsLib/CMakeLists.txt @@ -22,6 +22,8 @@ set(SOURCES standalone/AmsPort.cpp standalone/AmsRouter.cpp standalone/NotificationDispatcher.cpp + standalone/SecureAmsConnection.cpp + standalone/TlsSocket.cpp ) add_library(ads ${SOURCES}) @@ -39,4 +41,5 @@ if(WIN32 EQUAL 1) target_link_libraries(ads PUBLIC ws2_32) endif() -target_link_libraries(ads PUBLIC Threads::Threads) +find_package(OpenSSL REQUIRED) +target_link_libraries(ads PUBLIC Threads::Threads OpenSSL::SSL OpenSSL::Crypto) diff --git a/AdsLib/SecureAdsConfig.h b/AdsLib/SecureAdsConfig.h new file mode 100644 index 00000000..67d9fe46 --- /dev/null +++ b/AdsLib/SecureAdsConfig.h @@ -0,0 +1,26 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include + +namespace bhf +{ +namespace ads +{ + +struct SecureAdsConfig { + enum class Mode { SSC, SCA, PSK }; + + Mode mode = Mode::SCA; + std::string certPath; ///< SSC/SCA: PEM path to client certificate + std::string keyPath; ///< SSC/SCA: PEM path to client private key + std::string caPath; ///< SCA: PEM path to CA certificate + std::string username; ///< SSC: username for first-time route registration + std::string + password; ///< SSC: first-time registration password; PSK: password for key derivation + std::string + pskIdentity; ///< PSK: identity string; key = SHA-256(UPPER(identity) + password) +}; + +} +} diff --git a/AdsLib/standalone/AdsLib.cpp b/AdsLib/standalone/AdsLib.cpp index 6f8abae9..32c583e8 100644 --- a/AdsLib/standalone/AdsLib.cpp +++ b/AdsLib/standalone/AdsLib.cpp @@ -5,6 +5,7 @@ #include "AdsLib.h" #include "AmsRouter.h" +#include "Log.h" static AmsRouter &GetRouter() { @@ -37,7 +38,8 @@ long AddLocalRoute(const AmsNetId ams, const char *ip) return GetRouter().AddRoute(ams, ip); } catch (const std::bad_alloc &) { return GLOBALERR_NO_MEMORY; - } catch (const std::runtime_error &) { + } catch (const std::runtime_error &e) { + LOG_ERROR("AddLocalRoute failed: " << e.what()); return GLOBALERR_TARGET_PORT; } } @@ -51,6 +53,19 @@ void SetLocalAddress(const AmsNetId ams) { GetRouter().SetLocalAddress(ams); } + +long AddSecureRoute(const AmsNetId ams, const char *host, + const SecureAdsConfig &config) +{ + try { + return GetRouter().AddSecureRoute(ams, host, config); + } catch (const std::bad_alloc &) { + return GLOBALERR_NO_MEMORY; + } catch (const std::runtime_error &e) { + LOG_ERROR("AddSecureRoute failed: " << e.what()); + return GLOBALERR_TARGET_PORT; + } +} } } diff --git a/AdsLib/standalone/AdsLib.h b/AdsLib/standalone/AdsLib.h index 1f05af9a..d48f25f0 100644 --- a/AdsLib/standalone/AdsLib.h +++ b/AdsLib/standalone/AdsLib.h @@ -5,6 +5,7 @@ #pragma once #include "AdsDef.h" +#include "SecureAdsConfig.h" #ifdef BHF_ADS_EXPORT_C extern "C" { diff --git a/AdsLib/standalone/AmsConnection.cpp b/AdsLib/standalone/AmsConnection.cpp index d09e41e9..29703330 100644 --- a/AdsLib/standalone/AmsConnection.cpp +++ b/AdsLib/standalone/AmsConnection.cpp @@ -70,10 +70,9 @@ AmsConnection::AmsConnection(Router &__router, const struct addrinfo *const destination) : router(__router) , socket(destination) - , refCount(0) , invokeId(0) - , ownIp(socket.Connect()) { + ownIp = socket.Connect(); receiver = std::thread(&AmsConnection::TryRecv, this); } diff --git a/AdsLib/standalone/AmsRouter.cpp b/AdsLib/standalone/AmsRouter.cpp index 97c2b96e..2d638a89 100644 --- a/AdsLib/standalone/AmsRouter.cpp +++ b/AdsLib/standalone/AmsRouter.cpp @@ -4,6 +4,7 @@ */ #include "AmsRouter.h" +#include "SecureAmsConnection.h" #include "Log.h" #include @@ -61,7 +62,7 @@ long AmsRouter::AddRoute(AmsNetId ams, const std::string &host) lock.unlock(); try { - auto new_connection = std::unique_ptr( + auto new_connection = std::unique_ptr( new AmsConnection{ *this, hostAddresses.get() }); lock.lock(); connection_attempts.erase(ams); @@ -96,7 +97,7 @@ void AmsRouter::DelRoute(const AmsNetId &ams) auto route = mapping.find(ams); if (route != mapping.end()) { - AmsConnection *conn = route->second; + AmsConnectionBase *conn = route->second; if (0 == --conn->refCount) { mapping.erase(route); DeleteIfLastConnection(conn); @@ -104,7 +105,7 @@ void AmsRouter::DelRoute(const AmsNetId &ams) } } -void AmsRouter::DeleteIfLastConnection(const AmsConnection *const conn) +void AmsRouter::DeleteIfLastConnection(const AmsConnectionBase *const conn) { if (conn) { for (const auto &r : mapping) { @@ -188,7 +189,60 @@ long AmsRouter::SetTimeout(uint16_t port, uint32_t timeout) return 0; } -AmsConnection *AmsRouter::GetConnection(const AmsNetId &amsDest) +long AmsRouter::AddSecureRoute(AmsNetId ams, const std::string &host, + const bhf::ads::SecureAdsConfig &config) +{ + auto hostAddresses = bhf::ads::GetListOfAddresses(host, "8016"); + + std::unique_lock lock(mutex); + + AwaitConnectionAttempts(ams, lock); + + const auto oldConnection = GetConnection(ams); + if (oldConnection && + !oldConnection->IsConnectedTo(hostAddresses.get())) { + return ROUTERERR_PORTALREADYINUSE; + } + + for (const auto &conn : connections) { + if (conn->IsConnectedTo(hostAddresses.get())) { + conn->refCount++; + mapping[ams] = conn.get(); + return 0; + } + } + + connection_attempts[ams] = {}; + lock.unlock(); + + try { + auto new_connection = std::unique_ptr( + new SecureAmsConnection{ *this, hostAddresses.get(), + config, localAddr }); + lock.lock(); + connection_attempts.erase(ams); + connection_attempt_events.notify_all(); + + auto conn = connections.emplace(std::move(new_connection)); + if (conn.second) { + if (!localAddr) { + localAddr = + AmsNetId{ conn.first->get()->ownIp }; + } + conn.first->get()->refCount++; + mapping[ams] = conn.first->get(); + return !conn.first->get()->ownIp; + } + return -1; + } catch (std::exception &e) { + lock.lock(); + connection_attempts.erase(ams); + connection_attempt_events.notify_all(); + throw; + } +} + +AmsConnectionBase *AmsRouter::GetConnection(const AmsNetId &amsDest) { std::lock_guard lock(mutex); const auto it = mapping.find(amsDest); diff --git a/AdsLib/standalone/SecureAmsConnection.cpp b/AdsLib/standalone/SecureAmsConnection.cpp new file mode 100644 index 00000000..f084e46d --- /dev/null +++ b/AdsLib/standalone/SecureAmsConnection.cpp @@ -0,0 +1,416 @@ +// SPDX-License-Identifier: MIT +#include "SecureAmsConnection.h" +#include "TlsConnectInfo.h" +#include "Log.h" +#include "wrap_endian.h" + +#include +#include +#include +#include + +// ===== Constructor / Destructor ===== + +SecureAmsConnection::SecureAmsConnection( + Router &router, const struct addrinfo *destination, + const bhf::ads::SecureAdsConfig &config, const AmsNetId &localNetId) + : m_Router(router) + , m_Socket(destination, config) +{ + ownIp = m_Socket.Connect(); + Handshake(config, localNetId); + m_Receiver = std::thread(&SecureAmsConnection::TryRecv, this); +} + +SecureAmsConnection::~SecureAmsConnection() +{ + m_Socket.Shutdown(); + m_Receiver.join(); +} + +// ===== Handshake ===== + +void SecureAmsConnection::Handshake(const bhf::ads::SecureAdsConfig &config, + const AmsNetId &localNetId) +{ + char hostname[32] = {}; + gethostname(hostname, sizeof(hostname) - 1); + + TlsConnectInfoBase req = {}; + req.version = TlsConnectInfo::VERSION; + memcpy(&req.netId, &localNetId, sizeof(AmsNetId)); + memcpy(req.hostname, hostname, + std::min(strlen(hostname), sizeof(req.hostname) - 1)); + + const bool isSSC = + (config.mode == bhf::ads::SecureAdsConfig::Mode::SSC); + const bool firstTime = isSSC && !config.username.empty(); + + if (firstTime) { + const auto userLen = static_cast( + std::min(config.username.size(), size_t(255))); + const auto pwdLen = static_cast( + std::min(config.password.size(), size_t(255))); + + req.flags = bhf::ads::htole( + TlsConnectInfo::FLAG_ADD_REMOTE | + TlsConnectInfo::FLAG_SELF_SIGNED | + TlsConnectInfo::FLAG_IP_ADDR | + TlsConnectInfo::FLAG_IGNORE_CN); + req.userLen = userLen; + req.pwdLen = pwdLen; + req.totalLength = + bhf::ads::htole(static_cast( + sizeof(TlsConnectInfoBase) + userLen + pwdLen)); + + m_Socket.write_raw(reinterpret_cast(&req), + sizeof(req)); + m_Socket.write_raw(reinterpret_cast( + config.username.c_str()), + userLen); + m_Socket.write_raw(reinterpret_cast( + config.password.c_str()), + pwdLen); + } else { + req.flags = bhf::ads::htole( + isSSC ? TlsConnectInfo::FLAG_SELF_SIGNED : uint16_t(0)); + req.totalLength = bhf::ads::htole( + static_cast(sizeof(TlsConnectInfoBase))); + + m_Socket.write_raw(reinterpret_cast(&req), + sizeof(req)); + } + + TlsConnectInfoBase resp = {}; + m_Socket.read_raw(reinterpret_cast(&resp), sizeof(resp)); + + const auto respFlags = bhf::ads::letoh(resp.flags); + if (!(respFlags & TlsConnectInfo::FLAG_RESPONSE)) { + throw std::runtime_error( + "SecureADS handshake: server response missing Response flag"); + } + if (resp.error != 0) { + throw std::runtime_error( + "SecureADS handshake error code: " + + std::to_string(static_cast(resp.error))); + } +} + +// ===== Public interface ===== + +bool SecureAmsConnection::IsConnectedTo( + const struct addrinfo *targetAddresses) const +{ + return m_Socket.IsConnectedTo(targetAddresses); +} + +long SecureAmsConnection::AdsRequest(AmsRequest &request, + const uint32_t timeout) +{ + AmsAddr srcAddr; + const auto status = m_Router.GetLocalAddress(request.port, &srcAddr); + if (status) { + return status; + } + request.SetDeadline(timeout); + AmsResponse *response = Write(request, srcAddr); + if (response) { + const auto errorCode = response->Wait(); + response->Release(); + return errorCode; + } + return -1; +} + +SharedDispatcher SecureAmsConnection::CreateNotifyMapping( + uint32_t hNotify, std::shared_ptr notification) +{ + auto dispatcher = DispatcherListAdd(notification->connection); + notification->hNotify(hNotify); + dispatcher->Emplace(hNotify, notification); + return dispatcher; +} + +long SecureAmsConnection::DeleteNotification(const AmsAddr &amsAddr, + uint32_t hNotify, uint32_t tmms, + uint16_t port) +{ + AmsRequest request{ amsAddr, + port, + AoEHeader::DEL_DEVICE_NOTIFICATION, + 0, + nullptr, + nullptr, + sizeof(hNotify) }; + request.frame.prepend(bhf::ads::htole(hNotify)); + return AdsRequest(request, tmms); +} + +// ===== Write — no AmsTcpHeader prepend ===== + +AmsResponse *SecureAmsConnection::Write(AmsRequest &request, + const AmsAddr srcAddr) +{ + const AoEHeader aoeHeader{ request.destAddr.netId, + request.destAddr.port, + srcAddr.netId, + srcAddr.port, + request.cmdId, + static_cast(request.frame.size()), + GetInvokeId() }; + request.frame.prepend(aoeHeader); + + auto response = Reserve(&request, srcAddr.port); + if (!response) { + return nullptr; + } + + response->invokeId.store(aoeHeader.invokeId()); + if (request.frame.size() != m_Socket.write(request.frame)) { + response->Release(); + return nullptr; + } + return response; +} + +// ===== Recv — reads AoEHeader directly, no AmsTcpHeader ===== + +void SecureAmsConnection::Recv() +{ + AoEHeader aoeHeader; + for (;;) { + Receive(aoeHeader); + + if (aoeHeader.cmdId() == AoEHeader::DEVICE_NOTIFICATION) { + ReceiveNotification(aoeHeader); + continue; + } + + auto response = GetPending(aoeHeader.invokeId(), + aoeHeader.targetPort()); + if (!response) { + LOG_WARN("No response pending"); + ReceiveJunk(aoeHeader.length()); + continue; + } + + switch (aoeHeader.cmdId()) { + case AoEHeader::READ_DEVICE_INFO: + case AoEHeader::WRITE: + case AoEHeader::READ_STATE: + case AoEHeader::WRITE_CONTROL: + case AoEHeader::ADD_DEVICE_NOTIFICATION: + case AoEHeader::DEL_DEVICE_NOTIFICATION: + ReceiveFrame(response, + aoeHeader.length(), + aoeHeader.errorCode()); + continue; + + case AoEHeader::READ: + case AoEHeader::READ_WRITE: + ReceiveFrame( + response, aoeHeader.length(), + aoeHeader.errorCode()); + continue; + + default: + LOG_WARN("Unknown AMS command id"); + response->Notify(ADSERR_CLIENT_SYNCRESINVALID); + ReceiveJunk(aoeHeader.length()); + } + } +} + +void SecureAmsConnection::TryRecv() +{ + try { + Recv(); + } catch (const std::runtime_error &e) { + LOG_INFO(e.what()); + } +} + +// ===== Receive helpers ===== + +void SecureAmsConnection::Receive(void *buffer, size_t bytesToRead, + timeval *timeout) const +{ + auto pos = reinterpret_cast(buffer); + while (bytesToRead) { + const size_t n = m_Socket.read(pos, bytesToRead, timeout); + bytesToRead -= n; + pos += n; + } +} + +void SecureAmsConnection::Receive(void *buffer, size_t bytesToRead, + const Timepoint &deadline) const +{ + const auto now = std::chrono::steady_clock::now(); + const auto usec = std::chrono::duration_cast( + deadline - now) + .count(); + if (usec <= 0) { + throw TlsSocket::TimeoutEx("deadline reached already!!!"); + } + timeval timeout{ (long)(usec / 1000000), (int)(usec % 1000000) }; + Receive(buffer, bytesToRead, &timeout); +} + +void SecureAmsConnection::ReceiveJunk(size_t bytesToRead) const +{ + uint8_t buffer[1024]; + while (bytesToRead > sizeof(buffer)) { + Receive(buffer, sizeof(buffer)); + bytesToRead -= sizeof(buffer); + } + Receive(buffer, bytesToRead); +} + +template +void SecureAmsConnection::ReceiveFrame(AmsResponse *const response, + size_t bytesLeft, + uint32_t aoeError) const +{ + AmsRequest *const request = response->request.load(); + const auto responseId = response->invokeId.load(); + T header; + + if (aoeError) { + response->Notify(aoeError); + ReceiveJunk(bytesLeft); + return; + } + if (bytesLeft > sizeof(header) + request->bufferLength) { + LOG_WARN("Frame too long: " + << std::dec << bytesLeft << '>' + << sizeof(header) + request->bufferLength); + response->Notify(ADSERR_DEVICE_INVALIDSIZE); + ReceiveJunk(bytesLeft); + return; + } + + try { + Receive(&header, sizeof(header), request->deadline); + bytesLeft -= sizeof(header); + Receive(request->buffer, bytesLeft, request->deadline); + if (request->bytesRead) { + *(request->bytesRead) = + static_cast::type>( + bytesLeft); + } + response->Notify(header.result()); + } catch (const TlsSocket::TimeoutEx &) { + LOG_WARN("InvokeId " << std::dec << responseId << " timed out"); + response->Notify(ADSERR_CLIENT_SYNCTIMEOUT); + ReceiveJunk(bytesLeft); + } +} + +bool SecureAmsConnection::ReceiveNotification(const AoEHeader &header) +{ + const auto dispatcher = DispatcherListGet( + VirtualConnection{ header.targetPort(), header.sourceAms() }); + if (!dispatcher) { + ReceiveJunk(header.length()); + LOG_WARN("No dispatcher found for notification"); + return false; + } + + auto &ring = dispatcher->ring; + auto bytesLeft = header.length(); + if (bytesLeft + sizeof(bytesLeft) > ring.BytesFree()) { + ReceiveJunk(bytesLeft); + LOG_WARN("port " << std::dec << header.targetPort() + << " receive buffer was full"); + return false; + } + + for (size_t i = 0; i < sizeof(bytesLeft); ++i) { + *ring.write = (bytesLeft >> (8 * i)) & 0xFF; + ring.Write(1); + } + + auto chunk = ring.WriteChunk(); + while (bytesLeft > chunk) { + Receive(ring.write, chunk); + ring.Write(chunk); + bytesLeft -= static_cast(chunk); + chunk = ring.WriteChunk(); + } + Receive(ring.write, bytesLeft); + ring.Write(bytesLeft); + dispatcher->Notify(); + return true; +} + +// ===== Dispatcher helpers ===== + +SharedDispatcher +SecureAmsConnection::DispatcherListAdd(const VirtualConnection &connection) +{ + const auto dispatcher = DispatcherListGet(connection); + if (dispatcher) { + return dispatcher; + } + std::lock_guard lock(m_DispatcherListMutex); + return m_DispatcherList + .emplace(connection, + std::make_shared(std::bind( + &SecureAmsConnection::DeleteNotification, this, + connection.second, std::placeholders::_1, + std::placeholders::_2, connection.first))) + .first->second; +} + +SharedDispatcher +SecureAmsConnection::DispatcherListGet(const VirtualConnection &connection) +{ + std::lock_guard lock(m_DispatcherListMutex); + const auto it = m_DispatcherList.find(connection); + if (it != m_DispatcherList.end()) { + return it->second; + } + return {}; +} + +// ===== Invoke / queue helpers ===== + +uint32_t SecureAmsConnection::GetInvokeId() +{ + uint32_t result; + do { + result = m_InvokeId.fetch_add(1); + } while (!result); + return result; +} + +AmsResponse *SecureAmsConnection::Reserve(AmsRequest *request, + const uint16_t port) +{ + AmsRequest *isFree = nullptr; + if (!m_Queue[port - Router::PORT_BASE].request.compare_exchange_strong( + isFree, request)) { + LOG_WARN("Port: " << port << " already in use as " << isFree); + return nullptr; + } + return &m_Queue[port - Router::PORT_BASE]; +} + +AmsResponse *SecureAmsConnection::GetPending(const uint32_t id, + const uint16_t port) +{ + const uint16_t portIndex = port - Router::PORT_BASE; + if (portIndex >= Router::NUM_PORTS_MAX) { + LOG_WARN("Port 0x" << std::hex << port << " is out of range"); + return nullptr; + } + auto currentId = id; + if (m_Queue[portIndex].invokeId.compare_exchange_strong(currentId, 0)) { + return &m_Queue[portIndex]; + } + LOG_WARN("InvokeId mismatch: waiting for 0x" << std::hex << currentId + << " received 0x" << id); + return nullptr; +} diff --git a/AdsLib/standalone/SecureAmsConnection.h b/AdsLib/standalone/SecureAmsConnection.h new file mode 100644 index 00000000..5e1dce20 --- /dev/null +++ b/AdsLib/standalone/SecureAmsConnection.h @@ -0,0 +1,63 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "AmsConnection.h" +#include "SecureAdsConfig.h" +#include "TlsSocket.h" + +#include +#include +#include +#include + +struct SecureAmsConnection : AmsConnectionBase { + SecureAmsConnection(Router &router, const struct addrinfo *destination, + const bhf::ads::SecureAdsConfig &config, + const AmsNetId &localNetId); + ~SecureAmsConnection(); + + bool + IsConnectedTo(const struct addrinfo *targetAddresses) const override; + long AdsRequest(AmsRequest &request, uint32_t timeout) override; + SharedDispatcher CreateNotifyMapping( + uint32_t hNotify, + std::shared_ptr notification) override; + long DeleteNotification(const AmsAddr &amsAddr, uint32_t hNotify, + uint32_t tmms, uint16_t port) override; + + private: + friend struct AmsRouter; + Router &m_Router; + TlsSocket m_Socket; + std::thread m_Receiver; + std::atomic m_InvokeId{ 0 }; + std::array m_Queue; + + void Handshake(const bhf::ads::SecureAdsConfig &config, + const AmsNetId &localNetId); + + template + void ReceiveFrame(AmsResponse *response, size_t bytesLeft, + uint32_t aoeError) const; + bool ReceiveNotification(const AoEHeader &header); + void ReceiveJunk(size_t bytesToRead) const; + void Receive(void *buffer, size_t bytesToRead, + timeval *timeout = nullptr) const; + void Receive(void *buffer, size_t bytesToRead, + const Timepoint &deadline) const; + template void Receive(T &buffer) const + { + Receive(&buffer, sizeof(T)); + } + AmsResponse *Write(AmsRequest &request, AmsAddr srcAddr); + void Recv(); + void TryRecv(); + uint32_t GetInvokeId(); + AmsResponse *Reserve(AmsRequest *request, uint16_t port); + AmsResponse *GetPending(uint32_t id, uint16_t port); + + std::map m_DispatcherList; + std::recursive_mutex m_DispatcherListMutex; + SharedDispatcher DispatcherListAdd(const VirtualConnection &connection); + SharedDispatcher DispatcherListGet(const VirtualConnection &connection); +}; diff --git a/AdsLib/standalone/TlsConnectInfo.h b/AdsLib/standalone/TlsConnectInfo.h new file mode 100644 index 00000000..eb1bec36 --- /dev/null +++ b/AdsLib/standalone/TlsConnectInfo.h @@ -0,0 +1,35 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "AdsDef.h" +#include + +namespace TlsConnectInfo +{ +static const uint8_t VERSION = 1; +static const uint16_t FLAG_RESPONSE = 0x0001; +static const uint16_t FLAG_AMS_ALLOWED = 0x0002; +static const uint16_t FLAG_SERVER_INFO = 0x0004; +static const uint16_t FLAG_OWN_FILE = 0x0008; +static const uint16_t FLAG_SELF_SIGNED = 0x0010; +static const uint16_t FLAG_IP_ADDR = 0x0020; +static const uint16_t FLAG_IGNORE_CN = 0x0040; +static const uint16_t FLAG_ADD_REMOTE = 0x0080; +} + +#pragma pack(push, 1) +struct TlsConnectInfoBase { + uint16_t totalLength; + uint16_t flags; + uint8_t version; + uint8_t error; + AmsNetId netId; + uint8_t userLen; + uint8_t pwdLen; + uint8_t reserved[18]; + char hostname[32]; +}; +#pragma pack(pop) + +static_assert(sizeof(TlsConnectInfoBase) == 64, + "TlsConnectInfoBase wire layout must be 64 bytes"); diff --git a/AdsLib/standalone/TlsSocket.cpp b/AdsLib/standalone/TlsSocket.cpp new file mode 100644 index 00000000..570dc4ac --- /dev/null +++ b/AdsLib/standalone/TlsSocket.cpp @@ -0,0 +1,340 @@ +// SPDX-License-Identifier: MIT +#include "TlsSocket.h" +#include "Log.h" + +#include +#include +#include +#include +#include + +#include +#include +#include +#include + +#include + +static std::string opensslError() +{ + const unsigned long e = ERR_get_error(); + if (!e) { + return "unknown OpenSSL error"; + } + char buf[256]; + ERR_error_string_n(e, buf, sizeof(buf)); + return std::string(buf); +} + +int TlsSocket::pskExIndex() +{ + static int idx = + SSL_get_ex_new_index(0, nullptr, nullptr, nullptr, nullptr); + return idx; +} + +std::vector TlsSocket::derivePsk(const std::string &identity, + const std::string &password) +{ + std::string upper; + upper.reserve(identity.size()); + for (unsigned char c : identity) { + upper += static_cast(toupper(c)); + } + const std::string data = upper + password; + std::vector digest(SHA256_DIGEST_LENGTH); + SHA256(reinterpret_cast(data.data()), data.size(), + digest.data()); + return digest; +} + +unsigned int TlsSocket::pskClientCallback(SSL *ssl, const char * /*hint*/, + char *identity, unsigned int maxIdLen, + unsigned char *psk, + unsigned int maxPskLen) +{ + auto *sock = + static_cast(SSL_get_ex_data(ssl, pskExIndex())); + if (!sock || sock->m_DerivedPsk.empty()) { + return 0; + } + + const size_t idLen = + std::min(sock->m_PskIdentity.size(), size_t(maxIdLen - 1)); + memcpy(identity, sock->m_PskIdentity.c_str(), idLen); + identity[idLen] = '\0'; + + const size_t keyLen = + std::min(sock->m_DerivedPsk.size(), size_t(maxPskLen)); + memcpy(psk, sock->m_DerivedPsk.data(), keyLen); + return static_cast(keyLen); +} + +static int sscVerifyCallback(int /*preverify_ok*/, X509_STORE_CTX *ctx) +{ + switch (X509_STORE_CTX_get_error(ctx)) { + case X509_V_OK: + case X509_V_ERR_DEPTH_ZERO_SELF_SIGNED_CERT: + case X509_V_ERR_SELF_SIGNED_CERT_IN_CHAIN: + case X509_V_ERR_UNABLE_TO_GET_ISSUER_CERT_LOCALLY: + case X509_V_ERR_UNABLE_TO_VERIFY_LEAF_SIGNATURE: + return 1; + default: + return 0; + } +} + +TlsSocket::TlsSocket(const struct addrinfo *host, + const bhf::ads::SecureAdsConfig &config) +{ + for (const auto *rp = host; rp; rp = rp->ai_next) { + m_Fd = socket(rp->ai_family, SOCK_STREAM, 0); + if (INVALID_SOCKET == m_Fd) { + continue; + } + if (0 == connect(m_Fd, rp->ai_addr, + static_cast(rp->ai_addrlen))) { + memcpy(&m_SockAddress, rp->ai_addr, + std::min(sizeof(m_SockAddress), + rp->ai_addrlen)); + m_AddrLen = static_cast(rp->ai_addrlen); + break; + } + closesocket(m_Fd); + m_Fd = INVALID_SOCKET; + } + + if (INVALID_SOCKET == m_Fd) { + throw std::system_error(WSAGetLastError(), + std::system_category(), + "TlsSocket: TCP connect failed"); + } + + const int nodelay = 0; + setsockopt(m_Fd, IPPROTO_TCP, TCP_NODELAY, + reinterpret_cast(&nodelay), sizeof(nodelay)); + + InitCtx(config); + + m_Ssl = SSL_new(m_Ctx); + if (!m_Ssl) { + throw std::runtime_error("SSL_new failed: " + opensslError()); + } + SSL_set_fd(m_Ssl, static_cast(m_Fd)); + + if (m_IsPsk) { + SSL_set_ex_data(m_Ssl, pskExIndex(), this); + SSL_set_psk_client_callback(m_Ssl, pskClientCallback); + } + + if (SSL_connect(m_Ssl) != 1) { + throw std::runtime_error("TLS handshake failed: " + + opensslError()); + } +} + +void TlsSocket::InitCtx(const bhf::ads::SecureAdsConfig &config) +{ + m_Ctx = SSL_CTX_new(TLS_client_method()); + if (!m_Ctx) { + throw std::runtime_error("SSL_CTX_new failed: " + + opensslError()); + } + + SSL_CTX_set_min_proto_version(m_Ctx, TLS1_2_VERSION); + SSL_CTX_set_max_proto_version(m_Ctx, TLS1_2_VERSION); + + if (config.mode == bhf::ads::SecureAdsConfig::Mode::PSK) { + m_PskIdentity = config.pskIdentity; + m_DerivedPsk = derivePsk(config.pskIdentity, config.password); + m_IsPsk = true; + + if (SSL_CTX_set_cipher_list(m_Ctx, "PSK-AES256-CBC-SHA384:" + "PSK-AES128-CBC-SHA256:" + "PSK-AES256-CBC-SHA:" + "PSK-AES128-CBC-SHA") != 1) { + throw std::runtime_error( + "Failed to set PSK cipher list: " + + opensslError()); + } + SSL_CTX_set_options(m_Ctx, SSL_OP_NO_ENCRYPT_THEN_MAC | + SSL_OP_NO_TICKET); +#ifdef SSL_OP_NO_EXTENDED_MASTER_SECRET + SSL_CTX_set_options(m_Ctx, SSL_OP_NO_EXTENDED_MASTER_SECRET); +#endif + SSL_CTX_set_verify(m_Ctx, SSL_VERIFY_NONE, nullptr); + return; + } + + if (SSL_CTX_use_certificate_file(m_Ctx, config.certPath.c_str(), + SSL_FILETYPE_PEM) != 1) { + throw std::runtime_error("Failed to load cert '" + + config.certPath + + "': " + opensslError()); + } + if (SSL_CTX_use_PrivateKey_file(m_Ctx, config.keyPath.c_str(), + SSL_FILETYPE_PEM) != 1) { + throw std::runtime_error("Failed to load key '" + + config.keyPath + + "': " + opensslError()); + } + if (!SSL_CTX_check_private_key(m_Ctx)) { + throw std::runtime_error("Cert/key mismatch: " + + opensslError()); + } + + if (config.mode == bhf::ads::SecureAdsConfig::Mode::SCA) { + if (!config.caPath.empty() && + SSL_CTX_load_verify_locations(m_Ctx, config.caPath.c_str(), + nullptr) != 1) { + throw std::runtime_error("Failed to load CA '" + + config.caPath + + "': " + opensslError()); + } + SSL_CTX_set_verify(m_Ctx, + SSL_VERIFY_PEER | + SSL_VERIFY_FAIL_IF_NO_PEER_CERT, + nullptr); + } else { + SSL_CTX_set_verify(m_Ctx, SSL_VERIFY_PEER, sscVerifyCallback); + } +} + +TlsSocket::~TlsSocket() +{ + Shutdown(); + if (m_Ssl) { + SSL_free(m_Ssl); + m_Ssl = nullptr; + } + if (m_Ctx) { + SSL_CTX_free(m_Ctx); + m_Ctx = nullptr; + } + if (INVALID_SOCKET != m_Fd) { + closesocket(m_Fd); + m_Fd = INVALID_SOCKET; + } +} + +void TlsSocket::Shutdown() +{ + if (m_Ssl) { + SSL_shutdown(m_Ssl); + } +} + +bool TlsSocket::Select(timeval *timeout) const +{ + fd_set readFds; + FD_ZERO(&readFds); + FD_SET(m_Fd, &readFds); + + const int state = NATIVE_SELECT(static_cast(m_Fd) + 1, &readFds, + nullptr, nullptr, timeout); + if (0 == state) { + LOG_ERROR("TlsSocket select() timeout"); + throw TimeoutEx("TlsSocket select() timeout"); + } + if (state < 0 || !FD_ISSET(m_Fd, &readFds)) { + LOG_ERROR("TlsSocket select() error: " + << std::strerror(WSAGetLastError())); + return false; + } + return true; +} + +size_t TlsSocket::read(uint8_t *buffer, size_t maxBytes, timeval *timeout) const +{ + if (SSL_pending(m_Ssl) == 0) { + if (!Select(timeout)) { + return 0; + } + } + + const int toRead = static_cast( + std::min(maxBytes, std::numeric_limits::max())); + const int n = SSL_read(m_Ssl, reinterpret_cast(buffer), toRead); + if (n > 0) { + return static_cast(n); + } + const int err = SSL_get_error(m_Ssl, n); + if (err == SSL_ERROR_ZERO_RETURN || err == SSL_ERROR_SYSCALL) { + throw std::runtime_error("TLS connection closed by remote"); + } + LOG_ERROR("SSL_read failed: " << opensslError()); + return 0; +} + +size_t TlsSocket::write(const Frame &frame) const +{ + const int len = static_cast(frame.size()); + const int n = SSL_write( + m_Ssl, reinterpret_cast(frame.data()), len); + if (n <= 0) { + LOG_ERROR("SSL_write failed: " << opensslError()); + return 0; + } + return static_cast(n); +} + +void TlsSocket::write_raw(const uint8_t *data, size_t length) const +{ + size_t sent = 0; + while (sent < length) { + const int toSend = static_cast(std::min( + length - sent, std::numeric_limits::max())); + const int n = SSL_write( + m_Ssl, reinterpret_cast(data + sent), + toSend); + if (n <= 0) { + throw std::runtime_error("TLS write_raw failed: " + + opensslError()); + } + sent += static_cast(n); + } +} + +void TlsSocket::read_raw(uint8_t *data, size_t length) const +{ + size_t received = 0; + while (received < length) { + timeval timeout{ 5, 0 }; + const size_t n = + read(data + received, length - received, &timeout); + if (0 == n) { + throw std::runtime_error( + "TLS read_raw: connection closed unexpectedly"); + } + received += n; + } +} + +uint32_t TlsSocket::Connect() const +{ + sockaddr_storage source; + socklen_t len = sizeof(source); + + if (getsockname(m_Fd, reinterpret_cast(&source), &len)) { + LOG_ERROR("TlsSocket: getsockname failed"); + return 0; + } + if (source.ss_family == AF_INET) { + return ntohl(reinterpret_cast(&source) + ->sin_addr.s_addr); + } + return 0xffffffff; +} + +bool TlsSocket::IsConnectedTo(const struct addrinfo *targetAddresses) const +{ + for (const auto *rp = targetAddresses; rp; rp = rp->ai_next) { + if (m_SockAddress.ss_family == rp->ai_family && + !memcmp(&m_SockAddress, rp->ai_addr, + std::min(sizeof(m_SockAddress), + rp->ai_addrlen))) { + return true; + } + } + return false; +} diff --git a/AdsLib/standalone/TlsSocket.h b/AdsLib/standalone/TlsSocket.h new file mode 100644 index 00000000..4a75d576 --- /dev/null +++ b/AdsLib/standalone/TlsSocket.h @@ -0,0 +1,58 @@ +// SPDX-License-Identifier: MIT +#pragma once + +#include "Frame.h" +#include "SecureAdsConfig.h" +#include "wrap_socket.h" + +#include +#include +#include +#include + +struct TlsSocket { + TlsSocket(const struct addrinfo *host, + const bhf::ads::SecureAdsConfig &config); + ~TlsSocket(); + + TlsSocket(const TlsSocket &) = delete; + TlsSocket &operator=(const TlsSocket &) = delete; + + size_t read(uint8_t *buffer, size_t maxBytes, timeval *timeout) const; + size_t write(const Frame &frame) const; + void write_raw(const uint8_t *data, size_t length) const; + void read_raw(uint8_t *data, size_t length) const; + void Shutdown(); + bool IsConnectedTo(const struct addrinfo *targetAddresses) const; + uint32_t Connect() const; + + struct TimeoutEx : std::runtime_error { + TimeoutEx(const char *msg) + : std::runtime_error(msg) + { + } + }; + + private: + SOCKET m_Fd{ INVALID_SOCKET }; + SSL_CTX *m_Ctx{ nullptr }; + SSL *m_Ssl{ nullptr }; + sockaddr_storage m_SockAddress{}; + socklen_t m_AddrLen{ 0 }; + + std::vector m_DerivedPsk; + std::string m_PskIdentity; + bool m_IsPsk{ false }; + + void InitCtx(const bhf::ads::SecureAdsConfig &config); + bool Select(timeval *timeout) const; + + static int pskExIndex(); + static unsigned int pskClientCallback(SSL *ssl, const char *hint, + char *identity, + unsigned int maxIdLen, + unsigned char *psk, + unsigned int maxPskLen); + static std::vector derivePsk(const std::string &identity, + const std::string &password); +}; diff --git a/certs/.gitkeep b/certs/.gitkeep new file mode 100644 index 00000000..e69de29b diff --git a/meson.build b/meson.build index e44f1897..e65edc40 100644 --- a/meson.build +++ b/meson.build @@ -1,11 +1,11 @@ project('AdsLib', 'cpp', - version : '0.1', - default_options : [ - 'buildtype=release', - 'warning_level=3', - 'werror=true', - 'b_pie=true', - ] + version : '0.1', + default_options : [ + 'buildtype=release', + 'warning_level=3', + 'werror=true', + 'b_pie=true', + ] ) # some hardening options @@ -13,33 +13,36 @@ add_project_arguments('-D_FORTIFY_SOURCE=2', language: 'cpp') add_project_arguments(get_option('default_loglevel'), language: 'cpp') common_files = files([ - 'AdsLib/AdsDef.cpp', - 'AdsLib/AdsDevice.cpp', - 'AdsLib/AdsFile.cpp', - 'AdsLib/AdsLib.cpp', - 'AdsLib/MasterDcStatAccess.cpp', - 'AdsLib/ECatAccess.cpp', - 'AdsLib/Frame.cpp', - 'AdsLib/LicenseAccess.cpp', - 'AdsLib/Log.cpp', - 'AdsLib/RTimeAccess.cpp', - 'AdsLib/RegistryAccess.cpp', - 'AdsLib/RouterAccess.cpp', - 'AdsLib/Sockets.cpp', - 'AdsLib/SymbolAccess.cpp', - 'AdsLib/bhf/ParameterList.cpp', -]) + 'AdsLib/AdsDef.cpp', + 'AdsLib/AdsDevice.cpp', + 'AdsLib/AdsFile.cpp', + 'AdsLib/AdsLib.cpp', + 'AdsLib/MasterDcStatAccess.cpp', + 'AdsLib/ECatAccess.cpp', + 'AdsLib/Frame.cpp', + 'AdsLib/LicenseAccess.cpp', + 'AdsLib/Log.cpp', + 'AdsLib/RTimeAccess.cpp', + 'AdsLib/RegistryAccess.cpp', + 'AdsLib/RouterAccess.cpp', + 'AdsLib/Sockets.cpp', + 'AdsLib/SymbolAccess.cpp', + 'AdsLib/bhf/ParameterList.cpp', + ]) router_files = files([ - 'AdsLib/standalone/AdsLib.cpp', - 'AdsLib/standalone/AmsConnection.cpp', - 'AdsLib/standalone/AmsNetId.cpp', - 'AdsLib/standalone/AmsPort.cpp', - 'AdsLib/standalone/AmsRouter.cpp', - 'AdsLib/standalone/NotificationDispatcher.cpp', -]) + 'AdsLib/standalone/AdsLib.cpp', + 'AdsLib/standalone/AmsConnection.cpp', + 'AdsLib/standalone/AmsNetId.cpp', + 'AdsLib/standalone/AmsPort.cpp', + 'AdsLib/standalone/AmsRouter.cpp', + 'AdsLib/standalone/NotificationDispatcher.cpp', + 'AdsLib/standalone/SecureAmsConnection.cpp', + 'AdsLib/standalone/TlsSocket.cpp', + ]) install_headers( + 'AdsLib/SecureAdsConfig.h', 'AdsLib/AdsDef.h', 'AdsLib/AdsDevice.h', 'AdsLib/AdsException.h', @@ -80,30 +83,31 @@ install_headers( ) inc = include_directories([ - 'AdsLib', - 'tools', -]) + 'AdsLib', + 'tools', + ]) libs = [ meson.get_compiler('cpp').find_library('ws2_32', required: false), dependency('threads'), + dependency('openssl'), ] adslib = static_library('AdsLib', - [common_files, router_files], - include_directories: inc, - install: true, + [common_files, router_files], + include_directories : inc, + install : true, ) adslib_so = shared_library('adslib', - [common_files, router_files], - cpp_args: [ - '-DBHF_ADS_EXPORT_C', - '-DBHF_ADS_USE_TWINCAT_ORDER', - ], - include_directories: inc, - install: true, - dependencies: libs, + [common_files, router_files], + cpp_args : [ + '-DBHF_ADS_EXPORT_C', + '-DBHF_ADS_USE_TWINCAT_ORDER', + ], + include_directories : inc, + install : true, + dependencies : libs, ) install_libs = [ adslib, adslib_so ] @@ -114,79 +118,86 @@ adslib_dep = declare_dependency( ) adslibtest = executable('AdsLibTest', - 'AdsLibTest/main.cpp', - include_directories: inc, - dependencies: libs, - link_with: adslib, + 'AdsLibTest/main.cpp', + include_directories : inc, + dependencies : libs, + link_with : adslib, ) adslibooitest = executable('AdsLibOOITest', - 'AdsLibOOITest/main.cpp', - include_directories: inc, - dependencies: libs, - link_with: adslib, + 'AdsLibOOITest/main.cpp', + include_directories : inc, + dependencies : libs, + link_with : adslib, ) adstest_files = files([ - 'AdsTest/main.cpp', - 'AdsTest/RegistryAccessTest.cpp', -]) + 'AdsTest/main.cpp', + 'AdsTest/RegistryAccessTest.cpp', + ]) adstest = executable('AdsTest', - adstest_files, - include_directories: inc, - dependencies: libs, - link_with: adslib, + adstest_files, + include_directories : inc, + dependencies : libs, + link_with : adslib, ) adstool = executable('adstool', - 'AdsTool/main.cpp', - include_directories: inc, - dependencies: libs, - link_with: adslib, + 'AdsTool/main.cpp', + include_directories : inc, + dependencies : libs, + link_with : adslib, +) + +test_connections = executable('test_connections', + 'tools/test_connections.cpp', + include_directories : inc, + dependencies : libs, + link_with : adslib, ) if get_option('tcadsdll_include') != '' tcadslib = static_library('TcAdsLib', - [common_files, 'AdsLib/TwinCAT/AdsLib.cpp'], - cpp_args: '-DUSE_TWINCAT_ROUTER', - include_directories: [ - get_option('tcadsdll_include'), - inc, - ], - install: true, + [common_files, 'AdsLib/TwinCAT/AdsLib.cpp'], + cpp_args : '-DUSE_TWINCAT_ROUTER', + include_directories : [ + get_option('tcadsdll_include'), + inc, + ], + install : true, ) install_libs += [ tcadslib ] tcadsdll_deps = [ - meson.get_compiler('cpp').find_library('TcAdsDll', dirs: get_option('tcadsdll_lib')), + meson.get_compiler('cpp').find_library('TcAdsDll', dirs : get_option('tcadsdll_lib')), libs, ] tcadstool = executable('tcadstool', - 'AdsTool/main.cpp', - cpp_args: '-DUSE_TWINCAT_ROUTER', - include_directories: [ - get_option('tcadsdll_include'), - inc, - ], - dependencies: [ - tcadsdll_deps - ], - link_with: tcadslib, + 'AdsTool/main.cpp', + cpp_args : '-DUSE_TWINCAT_ROUTER', + include_directories : [ + get_option('tcadsdll_include'), + inc, + ], + dependencies : [ + tcadsdll_deps + ], + link_with : tcadslib, ) tcadstest = executable('tcAdsTest', - adstest_files, - cpp_args: '-DUSE_TWINCAT_ROUTER', - include_directories: [ - get_option('tcadsdll_include'), - inc, - ], - dependencies: [ - tcadsdll_deps - ], - link_with: tcadslib, + adstest_files, + cpp_args : '-DUSE_TWINCAT_ROUTER', + include_directories : [ + get_option('tcadsdll_include'), + inc, + ], + dependencies : [ + tcadsdll_deps + ], + link_with : tcadslib, ) endif diff --git a/tools/test_connections.cpp b/tools/test_connections.cpp new file mode 100644 index 00000000..f4dcc1e8 --- /dev/null +++ b/tools/test_connections.cpp @@ -0,0 +1,494 @@ +// SPDX-License-Identifier: MIT +/** + * Connection and data-integrity test for plain, SSC, SCA and PSK ADS. + * Usage: + * test_connections --gw= --localams= + * [--mode=plain|ssc|sca|psk] + * [--cert=] [--key=] [--ca=] + * [--username=] [--password=

] + * [--psk-identity=] + */ +#include + +#include +#include +#include +#include +#include + +// ── TwinCAT primitive sizes ───────────────────────────────────────────────── +using TC_BOOL = uint8_t; // BOOL (1 byte) +using TC_INT = int16_t; // INT (2 bytes) +using TC_DINT = int32_t; // DINT (4 bytes) +using TC_LREAL = double; // LREAL (8 bytes) + +static constexpr uint32_t STR80_LEN = 81; // STRING(80) + NUL +static constexpr uint32_t STR5_LEN = 6; // STRING(5) + NUL + +// ── Argument helper ───────────────────────────────────────────────────────── +static std::string getArg(int argc, char **argv, const char *name, + const char *defVal = "") +{ + const std::string prefix = std::string("--") + name + "="; + for (int i = 2; i < argc; ++i) { + const std::string a(argv[i]); + if (a.substr(0, prefix.size()) == prefix) { + return a.substr(prefix.size()); + } + } + return defVal; +} + +// ── ADS state name ────────────────────────────────────────────────────────── +static const char *adsStateName(uint16_t state) +{ + switch (state) { + case ADSSTATE_INVALID: + return "INVALID"; + case ADSSTATE_IDLE: + return "IDLE"; + case ADSSTATE_RESET: + return "RESET"; + case ADSSTATE_INIT: + return "INIT"; + case ADSSTATE_START: + return "START"; + case ADSSTATE_RUN: + return "RUN"; + case ADSSTATE_STOP: + return "STOP"; + case ADSSTATE_SAVECFG: + return "SAVECFG"; + case ADSSTATE_LOADCFG: + return "LOADCFG"; + case ADSSTATE_POWERFAILURE: + return "POWERFAILURE"; + case ADSSTATE_POWERGOOD: + return "POWERGOOD"; + case ADSSTATE_ERROR: + return "ERROR"; + case ADSSTATE_SHUTDOWN: + return "SHUTDOWN"; + case ADSSTATE_SUSPEND: + return "SUSPEND"; + case ADSSTATE_RESUME: + return "RESUME"; + case ADSSTATE_CONFIG: + return "CONFIG"; + case ADSSTATE_RECONFIG: + return "RECONFIG"; + default: + return "UNKNOWN"; + } +} + +// ── RAII symbol handle ─────────────────────────────────────────────────────── +struct SymbolHandle { + long port; + const AmsAddr *server; + uint32_t handle{ 0 }; + long status{ 0 }; + + SymbolHandle(long p, const AmsAddr *s, const std::string &name) + : port(p) + , server(s) + { + uint32_t bytesRead = 0; + status = AdsSyncReadWriteReqEx2( + port, server, ADSIGRP_SYM_HNDBYNAME, 0, sizeof(handle), + &handle, static_cast(name.size()), + name.c_str(), &bytesRead); + } + + ~SymbolHandle() + { + if (!status) { + AdsSyncWriteReqEx(port, server, ADSIGRP_SYM_RELEASEHND, + 0, sizeof(handle), &handle); + } + } + + bool ok() const + { + return status == 0; + } + + SymbolHandle(const SymbolHandle &) = delete; + SymbolHandle &operator=(const SymbolHandle &) = delete; +}; + +// ── Typed read by symbol name ──────────────────────────────────────────────── +template +static long readVar(long port, const AmsAddr *server, const std::string &name, + T &out) +{ + SymbolHandle sym(port, server, name); + if (!sym.ok()) { + std::cerr << " GetHandle('" << name << "') failed: 0x" + << std::hex << sym.status << std::dec << "\n"; + return sym.status; + } + uint32_t bytesRead = 0; + const long st = AdsSyncReadReqEx2(port, server, ADSIGRP_SYM_VALBYHND, + sym.handle, sizeof(T), &out, + &bytesRead); + if (st) { + std::cerr << " Read('" << name << "') failed: 0x" << std::hex + << st << std::dec << "\n"; + } + return st; +} + +// ── String read by symbol name ─────────────────────────────────────────────── +static long readString(long port, const AmsAddr *server, + const std::string &name, char *buf, uint32_t bufLen) +{ + SymbolHandle sym(port, server, name); + if (!sym.ok()) { + std::cerr << " GetHandle('" << name << "') failed: 0x" + << std::hex << sym.status << std::dec << "\n"; + return sym.status; + } + memset(buf, 0, bufLen); + uint32_t bytesRead = 0; + const long st = AdsSyncReadReqEx2(port, server, ADSIGRP_SYM_VALBYHND, + sym.handle, bufLen, buf, &bytesRead); + if (st) { + std::cerr << " Read('" << name << "') failed: 0x" << std::hex + << st << std::dec << "\n"; + } + buf[bufLen - 1] = '\0'; + return st; +} + +// ── Typed write by symbol name ──────────────────────────────────────────────── +template +static long writeVar(long port, const AmsAddr *server, const std::string &name, + const T &value) +{ + SymbolHandle sym(port, server, name); + if (!sym.ok()) { + std::cerr << " GetHandle('" << name << "') failed: 0x" + << std::hex << sym.status << std::dec << "\n"; + return sym.status; + } + const long st = AdsSyncWriteReqEx(port, server, ADSIGRP_SYM_VALBYHND, + sym.handle, sizeof(T), &value); + if (st) { + std::cerr << " Write('" << name << "') failed: 0x" << std::hex + << st << std::dec << "\n"; + } + return st; +} + +// ── Test: read the six Main variables ──────────────────────────────────────── +static bool testVariableReads(long port, const AmsAddr *server) +{ + std::cout << "\n=== Variable Read Test ===\n"; + bool ok = true; + + // Main.Bool1 [BOOL] + TC_BOOL b1 = 0; + if (readVar(port, server, "Main.Bool1", b1) == 0) { + std::cout << " Main.Bool1 [BOOL] = " + << (b1 ? "TRUE" : "FALSE") << "\n"; + } else { + ok = false; + } + + // Main.count [INT] + TC_INT count = 0; + if (readVar(port, server, "Main.count", count) == 0) { + std::cout << " Main.count [INT] = " << count << "\n"; + } else { + ok = false; + } + + // Main.dint [DINT] + TC_DINT dint = 0; + if (readVar(port, server, "Main.dint1", dint) == 0) { + std::cout << " Main.dint [DINT] = " << dint << "\n"; + } else { + ok = false; + } + + // Main.lreal [LREAL] + TC_LREAL lreal = 0.0; + if (readVar(port, server, "Main.lreal1", lreal) == 0) { + std::cout << " Main.lreal [LREAL] = " << std::fixed + << std::setprecision(6) << lreal << "\n"; + } else { + ok = false; + } + + // Main.str1 [STRING(80)] + char str1[STR80_LEN] = {}; + if (readString(port, server, "Main.str1", str1, STR80_LEN) == 0) { + std::cout << " Main.str1 [STRING(80)] = \"" << str1 + << "\"\n"; + } else { + ok = false; + } + + // Main.str2 [STRING(5)] + char str2[STR5_LEN] = {}; + if (readString(port, server, "Main.str2", str2, STR5_LEN) == 0) { + std::cout << " Main.str2 [STRING(5)] = \"" << str2 + << "\"\n"; + } else { + ok = false; + } + + return ok; +} + +// ── Test: write a sentinel to Main.dint, read back, verify ─────────────────── +static bool testWriteVerify(long port, const AmsAddr *server) +{ + std::cout << "\n=== Write / Verify Round-Trip Test (Main.dint1) ===\n"; + + const TC_DINT sentinel = 0x4242; + + // Save original value + TC_DINT original = 0; + if (readVar(port, server, "Main.dint1", original) != 0) { + return false; + } + std::cout << " Original value = " << original << "\n"; + + // Write sentinel + std::cout << " Writing = " << sentinel << " ... "; + if (writeVar(port, server, "Main.dint1", sentinel) != 0) { + std::cout << "FAIL\n"; + return false; + } + std::cout << "OK\n"; + + // Read back + TC_DINT readback = 0; + std::cout << " Reading back = "; + if (readVar(port, server, "Main.dint1", readback) != 0) { + return false; + } + std::cout << readback << " ... "; + + if (readback != sentinel) { + std::cout << "MISMATCH (expected " << sentinel << ")\n"; + // Restore original before failing + writeVar(port, server, "Main.dint1", original); + return false; + } + std::cout << "MATCH\n"; + + // Restore original value + std::cout << " Restoring = " << original << " ... "; + if (writeVar(port, server, "Main.dint1", original) != 0) { + return false; + } + std::cout << "OK\n"; + + return true; +} + +// ── Test: measure round-trip latency over N reads of Main.dint ─────────────── +static bool testLatency(long port, const AmsAddr *server) +{ + std::cout << "\n=== Latency Test (20 reads of Main.dint1) ===\n"; + + SymbolHandle sym(port, server, "Main.dint1"); + if (!sym.ok()) { + std::cerr << " GetHandle failed: 0x" << std::hex << sym.status + << "\n"; + return false; + } + + using Clock = std::chrono::steady_clock; + using us = std::chrono::microseconds; + + long long minUs = std::numeric_limits::max(); + long long maxUs = 0; + long long sumUs = 0; + const int N = 20; + + for (int i = 0; i < N; ++i) { + TC_DINT val = 0; + uint32_t bytes = 0; + const auto t0 = Clock::now(); + const long st = AdsSyncReadReqEx2(port, server, + ADSIGRP_SYM_VALBYHND, + sym.handle, sizeof(val), &val, + &bytes); + const auto dt = + std::chrono::duration_cast(Clock::now() - t0) + .count(); + + if (st) { + std::cerr << " Read failed on iteration " << i + << ": 0x" << std::hex << st << "\n"; + return false; + } + sumUs += dt; + if (dt < minUs) + minUs = dt; + if (dt > maxUs) + maxUs = dt; + } + + const double avgMs = static_cast(sumUs) / N / 1000.0; + const double minMs = static_cast(minUs) / 1000.0; + const double maxMs = static_cast(maxUs) / 1000.0; + + std::cout << std::fixed << std::setprecision(3); + std::cout << " Avg: " << avgMs << " ms" + << " Min: " << minMs << " ms" + << " Max: " << maxMs << " ms\n"; + return true; +} + +// ── main ───────────────────────────────────────────────────────────────────── +int main(int argc, char **argv) +{ + if (argc < 2) { + std::cerr + << "Usage: test_connections [OPTIONS]\n" + << " --gw= Gateway IP (required)\n" + << " --localams= Local AMS NetId\n" + << " --mode=plain|ssc|sca|psk (default: plain)\n" + << " --cert= Client certificate (PEM, SSC/SCA)\n" + << " --key= Client private key (PEM, SSC/SCA)\n" + << " --ca= CA certificate for SCA (PEM)\n" + << " --username= SSC first-time registration\n" + << " --password=

SSC password / PSK password\n" + << " --psk-identity= PSK identity string\n"; + return 1; + } + + const AmsNetId targetNetId{ argv[1] }; + const AmsAddr server{ targetNetId, AMSPORT_R0_PLC_TC3 }; + + const auto gw = getArg(argc, argv, "gw"); + const auto localAms = getArg(argc, argv, "localams"); + const auto mode = getArg(argc, argv, "mode", "plain"); + const auto certPath = getArg(argc, argv, "cert"); + const auto keyPath = getArg(argc, argv, "key"); + const auto caPath = getArg(argc, argv, "ca"); + const auto username = getArg(argc, argv, "username"); + const auto password = getArg(argc, argv, "password"); + const auto pskIdentity = getArg(argc, argv, "psk-identity"); + + if (gw.empty()) { + std::cerr << "Error: --gw is required\n"; + return 1; + } + + if (!localAms.empty()) { + bhf::ads::SetLocalAddress(AmsNetId{ localAms }); + } + + long routeResult = -1; + + if (mode == "plain") { + routeResult = bhf::ads::AddLocalRoute(targetNetId, gw.c_str()); + } else if (mode == "ssc" || mode == "sca") { + if (certPath.empty() || keyPath.empty()) { + std::cerr << "Error: --cert and --key required for " + << mode << " mode\n"; + return 1; + } + bhf::ads::SecureAdsConfig cfg; + cfg.certPath = certPath; + cfg.keyPath = keyPath; + if (mode == "ssc") { + cfg.mode = bhf::ads::SecureAdsConfig::Mode::SSC; + cfg.username = username; + cfg.password = password; + } else { + cfg.mode = bhf::ads::SecureAdsConfig::Mode::SCA; + cfg.caPath = caPath; + } + routeResult = + bhf::ads::AddSecureRoute(targetNetId, gw.c_str(), cfg); + } else if (mode == "psk") { + if (pskIdentity.empty() || password.empty()) { + std::cerr + << "Error: --psk-identity and --password required for psk mode\n"; + return 1; + } + bhf::ads::SecureAdsConfig cfg; + cfg.mode = bhf::ads::SecureAdsConfig::Mode::PSK; + cfg.pskIdentity = pskIdentity; + cfg.password = password; + routeResult = + bhf::ads::AddSecureRoute(targetNetId, gw.c_str(), cfg); + } else { + std::cerr << "Error: unknown mode '" << mode << "'\n"; + return 1; + } + + if (routeResult) { + std::cerr << "AddRoute failed: 0x" << std::hex << routeResult + << "\n"; + return 1; + } + + const long port = AdsPortOpenEx(); + if (!port) { + std::cerr << "AdsPortOpenEx failed\n"; + bhf::ads::DelLocalRoute(targetNetId); + return 1; + } + + int rc = 0; + + // ── Device info ─────────────────────────────────────────────────────── + std::cout << "\n=== Device Info ===\n"; + char devName[17] = {}; + AdsVersion version{ 0, 0, 0 }; + long status = + AdsSyncReadDeviceInfoReqEx(port, &server, devName, &version); + if (status) { + std::cerr << " ReadDeviceInfo failed: 0x" << std::hex << status + << "\n"; + rc = 1; + } else { + std::cout << " Device: " << devName << " v" + << static_cast(version.version) << "." + << static_cast(version.revision) << "." + << version.build << "\n"; + } + + // ── ADS state ───────────────────────────────────────────────────────── + uint16_t adsState = 0, devState = 0; + status = AdsSyncReadStateReqEx(port, &server, &adsState, &devState); + if (status) { + std::cerr << " ReadState failed: 0x" << std::hex << status + << "\n"; + rc = 1; + } else { + std::cout << " ADS State: " << adsStateName(adsState) << " (" + << std::dec << adsState << ")" + << " Device State: " << devState << "\n"; + } + + // ── Variable reads ──────────────────────────────────────────────────── + if (!testVariableReads(port, &server)) { + rc = 1; + } + + // ── Write / verify round-trip ───────────────────────────────────────── + if (!testWriteVerify(port, &server)) { + rc = 1; + } + + // ── Latency ─────────────────────────────────────────────────────────── + if (!testLatency(port, &server)) { + rc = 1; + } + + std::cout << "\n=== Result: " << (rc == 0 ? "PASS" : "FAIL") + << " ===\n"; + + AdsPortCloseEx(port); + bhf::ads::DelLocalRoute(targetNetId); + return rc; +} \ No newline at end of file diff --git a/tools/test_connections.sh b/tools/test_connections.sh new file mode 100755 index 00000000..4ef5830b --- /dev/null +++ b/tools/test_connections.sh @@ -0,0 +1,227 @@ +#!/usr/bin/env bash +# Interactive connection test for plain, SSC, SCA, and BSK ADS. +set -euo pipefail + +SCRIPT_DIR="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)" +REPO_DIR="$(cd "$SCRIPT_DIR/.." && pwd)" +BUILD_DIR="${BUILD_DIR:-$REPO_DIR/build}" + +# Network / credentials +LOCAL_IP="192.168.1.220" +LOCAL_AMS="192.168.1.220.1.1" +PLC_IP="192.168.1.180" +PLC_AMS="39.178.175.124.1.1" +ADS_USER="Administrator" +ADS_PASS="" + +# Certificate paths +CERTS_DIR="$REPO_DIR/certs" +SSC_KEY="$CERTS_DIR/ssc_client.key" +SSC_CERT="$CERTS_DIR/ssc_client.crt" +SCA_KEY="$CERTS_DIR/sca_client.key" +SCA_CERT="$CERTS_DIR/sca_client.crt" +SCA_CA="/home/ick3/certsonPlc/rootCA.pem" +SCA_CA_KEY="/home/ick3/certsonPlc/rootCA.key" + +ADSTOOL="$BUILD_DIR/adstool" +TEST_BIN="$BUILD_DIR/test_connections" + +PASS=0 +FAIL=0 + +# --------------------------------------------------------------------------- +# Helpers +# --------------------------------------------------------------------------- +green() { printf '\033[0;32m%s\033[0m\n' "$*"; } +red() { printf '\033[0;31m%s\033[0m\n' "$*"; } +yellow() { printf '\033[0;33m%s\033[0m\n' "$*"; } +bold() { printf '\033[1m%s\033[0m\n' "$*"; } + +run_test() { + local label="$1"; shift + bold "--- $label ---" + if "$@"; then + green "PASS: $label" + PASS=$((PASS + 1)) + else + red "FAIL: $label (exit $?)" + FAIL=$((FAIL + 1)) + fi + echo +} + +print_summary() { + bold "=== Results ===" + green "PASSED: $PASS" + if [[ "$FAIL" -gt 0 ]]; then + red "FAILED: $FAIL" + else + echo "FAILED: 0" + fi +} + +# --------------------------------------------------------------------------- +# Build +# --------------------------------------------------------------------------- +do_build() { + bold "=== Building ===" + if [ ! -d "$BUILD_DIR" ]; then + meson setup "$BUILD_DIR" "$REPO_DIR" + fi + ninja -C "$BUILD_DIR" adstool test_connections + echo +} + +# --------------------------------------------------------------------------- +# Certificate helpers +# --------------------------------------------------------------------------- +ensure_ssc_cert() { + mkdir -p "$CERTS_DIR" + if [ ! -f "$SSC_KEY" ] || [ ! -f "$SSC_CERT" ]; then + bold "Generating self-signed SSC client cert (CN=cachyos-x8664)..." + openssl req -x509 -newkey rsa:2048 -nodes \ + -keyout "$SSC_KEY" -out "$SSC_CERT" \ + -days 3650 -subj "/CN=cachyos-x8664" 2>/dev/null + echo "Created $SSC_CERT" + fi +} + +ensure_sca_cert() { + mkdir -p "$CERTS_DIR" + if [ ! -f "$SCA_KEY" ] || [ ! -f "$SCA_CERT" ]; then + if [ ! -f "$SCA_CA" ] || [ ! -f "$SCA_CA_KEY" ]; then + yellow "WARNING: CA files not found at $SCA_CA_KEY / $SCA_CA" + return 1 + fi + bold "Generating SCA client cert signed by rootCA..." + openssl genrsa -out "$SCA_KEY" 2048 2>/dev/null + openssl req -new -key "$SCA_KEY" \ + -out "$CERTS_DIR/sca_client.csr" \ + -subj "/CN=cachyos-x8664" 2>/dev/null + openssl x509 -req \ + -in "$CERTS_DIR/sca_client.csr" \ + -CA "$SCA_CA" -CAkey "$SCA_CA_KEY" -CAcreateserial \ + -out "$SCA_CERT" -days 3650 2>/dev/null + rm -f "$CERTS_DIR/sca_client.csr" + echo "Created $SCA_CERT" + fi + [ -f "$SCA_CA" ] || { yellow "WARNING: CA cert $SCA_CA not found"; return 1; } +} + +# --------------------------------------------------------------------------- +# Individual test runners +# --------------------------------------------------------------------------- +test_plain() { + bold "=== Test: Plain ADS ===" + echo "Registering route on PLC via adstool..." + "$ADSTOOL" "$PLC_IP" addroute \ + --netid="$LOCAL_AMS" --addr="$LOCAL_IP" \ + --password="$ADS_PASS" || true + echo + run_test "plain ADS" \ + "$TEST_BIN" "$PLC_AMS" \ + --gw="$PLC_IP" \ + --localams="$LOCAL_AMS" +} + +test_ssc() { + ensure_ssc_cert + bold "=== Test: SSC first-time registration ===" + run_test "SSC (first-time, with credentials)" \ + "$TEST_BIN" "$PLC_AMS" \ + --gw="$PLC_IP" --localams="$LOCAL_AMS" \ + --mode=ssc --cert="$SSC_CERT" --key="$SSC_KEY" \ + --username="$ADS_USER" --password="$ADS_PASS" + + bold "=== Test: SSC established (no credentials) ===" + run_test "SSC (established)" \ + "$TEST_BIN" "$PLC_AMS" \ + --gw="$PLC_IP" --localams="$LOCAL_AMS" \ + --mode=ssc --cert="$SSC_CERT" --key="$SSC_KEY" +} + +test_sca() { + if ! ensure_sca_cert; then + yellow "SKIP: SCA test (CA files missing)" + return + fi + bold "=== Test: SCA ===" + run_test "SCA" \ + "$TEST_BIN" "$PLC_AMS" \ + --gw="$PLC_IP" --localams="$LOCAL_AMS" \ + --mode=sca --cert="$SCA_CERT" --key="$SCA_KEY" \ + --ca="$SCA_CA" +} + +test_psk() { + bold "=== Test: PSK (Pre-Shared Key) ===" + printf "PSK identity: " + read -r PSK_IDENTITY + printf "PSK password: " + read -rs PSK_PASSWORD + echo + if [[ -z "$PSK_IDENTITY" || -z "$PSK_PASSWORD" ]]; then + yellow "SKIP: PSK test (identity or password empty)" + return + fi + run_test "PSK" \ + "$TEST_BIN" "$PLC_AMS" \ + --gw="$PLC_IP" --localams="$LOCAL_AMS" \ + --mode=psk \ + --psk-identity="$PSK_IDENTITY" \ + --password="$PSK_PASSWORD" +} + +test_bsk_Prefilled() { + bold "=== Test: PSK (Pre-Shared Key) ===" + run_test "PSK" \ + "$TEST_BIN" "$PLC_AMS" \ + --gw="$PLC_IP" --localams="$LOCAL_AMS" \ + --mode=psk \ + --psk-identity="MY_IDENTITY" \ + --password="MySecret" +} + +# --------------------------------------------------------------------------- +# Interactive menu +# --------------------------------------------------------------------------- +echo +bold "================================================" +bold " ADS Connection Test Suite" +bold "================================================" +echo +echo " 1) All tests( Depends on Route Configuration PSK or SCA will fail because only one can be active at time)" +echo " 2) SSC (Self-Signed Certificate)" +echo " 3) SCA (Shared CA Certificate)" +echo " 4) Plain ADS" +echo " 5) PSK (Pre-Shared Key)" +echo " 6) PSK Prefilled" +echo +printf "Select test [1-6]: " +read -r CHOICE +echo + +do_build + +case "$CHOICE" in + 1) test_psk # Does not work If not called first. After a SSC no + test_plain + test_ssc + test_sca + + ;; + 2) test_ssc ;; + 3) test_sca ;; + 4) test_plain ;; + 5) test_psk ;; + 6) test_bsk_Prefilled;; # Only Runable if the Prefilled Values are used for the PSK! + + *) + red "Invalid selection: $CHOICE" + exit 1 + ;; +esac + +echo +print_summary +[[ "$FAIL" -gt 0 ]] && exit 1 || exit 0 \ No newline at end of file