Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 14 additions & 0 deletions AdsLib/AdsLib.h
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@
#include "standalone/AdsLib.h"
#endif

#include "SecureAdsConfig.h"
#include "Sockets.h"

#ifdef BHF_ADS_EXPORT_C
Expand Down Expand Up @@ -158,6 +159,19 @@ namespace ads
*/
long AddLocalRoute(AmsNetId ams, const char *ip);

/**
* Add a Secure ADS route (TLS 1.2, port 8016) to a target system.
* Supports SSC (self-signed) and SCA (shared CA) modes.
* For SSC first-time registration, populate config.username and config.password.
* For SCA, set config.caPath to the CA certificate used to verify the server.
* @param[in] ams AmsNetId of the target system
* @param[in] host hostname or IP of the target (port 8016 used by default)
* @param[in] config TLS certificate paths and optional credentials
* @return [ADS Return Code](https://infosys.beckhoff.com/content/1031/tcadscommon/html/ads_returncodes.htm?id=1666172286265530469)
*/
long AddSecureRoute(AmsNetId ams, const char *host,
const SecureAdsConfig &config);

/**
* Delete ams route that had previously been added with AddLocalRoute().
* @param[in] ams address of the target system
Expand Down
40 changes: 24 additions & 16 deletions AdsLib/AmsConnection.h
Original file line number Diff line number Diff line change
Expand Up @@ -66,32 +66,43 @@ struct AmsResponse {
bool wasWritten;
};

struct AmsConnection {
/**
* Common interface for plain and secure AMS connections.
* AmsRouter stores pointers to this base to support both connection types.
*/
struct AmsConnectionBase {
std::atomic<size_t> refCount{ 0 };
uint32_t ownIp{ 0 };

virtual ~AmsConnectionBase() = default;
virtual bool IsConnectedTo(const struct addrinfo *) const = 0;
virtual long AdsRequest(AmsRequest &, uint32_t timeout) = 0;
virtual SharedDispatcher
CreateNotifyMapping(uint32_t hNotify,
std::shared_ptr<Notification> notification) = 0;
virtual long DeleteNotification(const AmsAddr &, uint32_t hNotify,
uint32_t tmms, uint16_t port) = 0;
};

struct AmsConnection : AmsConnectionBase {
AmsConnection(Router &__router,
const struct addrinfo *destination = nullptr);
~AmsConnection();

SharedDispatcher
CreateNotifyMapping(uint32_t hNotify,
std::shared_ptr<Notification> notification);
std::shared_ptr<Notification> notification) override;
long DeleteNotification(const AmsAddr &amsAddr, uint32_t hNotify,
uint32_t tmms, uint16_t port);
long AdsRequest(AmsRequest &request, uint32_t timeout);

/**
* Confirm if this AmsConnection is connected to one of the target addresses.
* @param[in] targetAddresses pointer to a previously allocated list of
* "struct addrinfo" returned by getaddrinfo(3).
* @return true, this connection can be used to reach one of the targetAddresses.
*/
bool IsConnectedTo(const struct addrinfo *targetAddresses) const;
uint32_t tmms, uint16_t port) override;
long AdsRequest(AmsRequest &request, uint32_t timeout) override;
bool
IsConnectedTo(const struct addrinfo *targetAddresses) const override;

private:
friend struct AmsRouter;
Router &router;
TcpSocket socket;
std::thread receiver;
std::atomic<size_t> refCount;
std::atomic<uint32_t> invokeId;
std::array<AmsResponse, Router::NUM_PORTS_MAX> queue;

Expand Down Expand Up @@ -119,7 +130,4 @@ struct AmsConnection {
std::recursive_mutex dispatcherListMutex;
SharedDispatcher DispatcherListAdd(const VirtualConnection &connection);
SharedDispatcher DispatcherListGet(const VirtualConnection &connection);

public:
const uint32_t ownIp;
};
11 changes: 7 additions & 4 deletions AdsLib/AmsRouter.h
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
#pragma once

#include "AmsConnection.h"
#include "SecureAdsConfig.h"
#include <unordered_set>

struct AmsRouter : Router {
Expand All @@ -25,22 +26,24 @@ struct AmsRouter : Router {
[[deprecated]]
long AddRoute(AmsNetId ams, const IpV4 &ip);
long AddRoute(AmsNetId ams, const std::string &host);
long AddSecureRoute(AmsNetId ams, const std::string &host,
const bhf::ads::SecureAdsConfig &config);
void DelRoute(const AmsNetId &ams);
AmsConnection *GetConnection(const AmsNetId &pAddr);
AmsConnectionBase *GetConnection(const AmsNetId &pAddr);
long AdsRequest(AmsRequest &request);

private:
AmsNetId localAddr;
std::recursive_mutex mutex;
std::condition_variable_any connection_attempt_events;
std::map<AmsNetId, std::tuple<> > connection_attempts;
std::unordered_set<std::unique_ptr<AmsConnection> > connections;
std::map<AmsNetId, AmsConnection *> mapping;
std::unordered_set<std::unique_ptr<AmsConnectionBase> > connections;
std::map<AmsNetId, AmsConnectionBase *> mapping;

void
AwaitConnectionAttempts(const AmsNetId &ams,
std::unique_lock<std::recursive_mutex> &lock);
void DeleteIfLastConnection(const AmsConnection *conn);
void DeleteIfLastConnection(const AmsConnectionBase *conn);

std::array<AmsPort, NUM_PORTS_MAX> ports;
};
5 changes: 4 additions & 1 deletion AdsLib/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ set(SOURCES
standalone/AmsPort.cpp
standalone/AmsRouter.cpp
standalone/NotificationDispatcher.cpp
standalone/SecureAmsConnection.cpp
standalone/TlsSocket.cpp
)

add_library(ads ${SOURCES})
Expand All @@ -39,4 +41,5 @@ if(WIN32 EQUAL 1)
target_link_libraries(ads PUBLIC ws2_32)
endif()

target_link_libraries(ads PUBLIC Threads::Threads)
find_package(OpenSSL REQUIRED)
target_link_libraries(ads PUBLIC Threads::Threads OpenSSL::SSL OpenSSL::Crypto)
26 changes: 26 additions & 0 deletions AdsLib/SecureAdsConfig.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,26 @@
// SPDX-License-Identifier: MIT
#pragma once

#include <string>

namespace bhf
{
namespace ads
{

struct SecureAdsConfig {
enum class Mode { SSC, SCA, PSK };

Mode mode = Mode::SCA;
std::string certPath; ///< SSC/SCA: PEM path to client certificate
std::string keyPath; ///< SSC/SCA: PEM path to client private key
std::string caPath; ///< SCA: PEM path to CA certificate
std::string username; ///< SSC: username for first-time route registration
std::string
password; ///< SSC: first-time registration password; PSK: password for key derivation
std::string
pskIdentity; ///< PSK: identity string; key = SHA-256(UPPER(identity) + password)
};

}
}
17 changes: 16 additions & 1 deletion AdsLib/standalone/AdsLib.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@

#include "AdsLib.h"
#include "AmsRouter.h"
#include "Log.h"

static AmsRouter &GetRouter()
{
Expand Down Expand Up @@ -37,7 +38,8 @@ long AddLocalRoute(const AmsNetId ams, const char *ip)
return GetRouter().AddRoute(ams, ip);
} catch (const std::bad_alloc &) {
return GLOBALERR_NO_MEMORY;
} catch (const std::runtime_error &) {
} catch (const std::runtime_error &e) {
LOG_ERROR("AddLocalRoute failed: " << e.what());
return GLOBALERR_TARGET_PORT;
}
}
Expand All @@ -51,6 +53,19 @@ void SetLocalAddress(const AmsNetId ams)
{
GetRouter().SetLocalAddress(ams);
}

long AddSecureRoute(const AmsNetId ams, const char *host,
const SecureAdsConfig &config)
{
try {
return GetRouter().AddSecureRoute(ams, host, config);
} catch (const std::bad_alloc &) {
return GLOBALERR_NO_MEMORY;
} catch (const std::runtime_error &e) {
LOG_ERROR("AddSecureRoute failed: " << e.what());
return GLOBALERR_TARGET_PORT;
}
}
}
}

Expand Down
1 change: 1 addition & 0 deletions AdsLib/standalone/AdsLib.h
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#pragma once

#include "AdsDef.h"
#include "SecureAdsConfig.h"

#ifdef BHF_ADS_EXPORT_C
extern "C" {
Expand Down
3 changes: 1 addition & 2 deletions AdsLib/standalone/AmsConnection.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -70,10 +70,9 @@ AmsConnection::AmsConnection(Router &__router,
const struct addrinfo *const destination)
: router(__router)
, socket(destination)
, refCount(0)
, invokeId(0)
, ownIp(socket.Connect())
{
ownIp = socket.Connect();
receiver = std::thread(&AmsConnection::TryRecv, this);
}

Expand Down
62 changes: 58 additions & 4 deletions AdsLib/standalone/AmsRouter.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@
*/

#include "AmsRouter.h"
#include "SecureAmsConnection.h"
#include "Log.h"

#include <algorithm>
Expand Down Expand Up @@ -61,7 +62,7 @@ long AmsRouter::AddRoute(AmsNetId ams, const std::string &host)
lock.unlock();

try {
auto new_connection = std::unique_ptr<AmsConnection>(
auto new_connection = std::unique_ptr<AmsConnectionBase>(
new AmsConnection{ *this, hostAddresses.get() });
lock.lock();
connection_attempts.erase(ams);
Expand Down Expand Up @@ -96,15 +97,15 @@ void AmsRouter::DelRoute(const AmsNetId &ams)

auto route = mapping.find(ams);
if (route != mapping.end()) {
AmsConnection *conn = route->second;
AmsConnectionBase *conn = route->second;
if (0 == --conn->refCount) {
mapping.erase(route);
DeleteIfLastConnection(conn);
}
}
}

void AmsRouter::DeleteIfLastConnection(const AmsConnection *const conn)
void AmsRouter::DeleteIfLastConnection(const AmsConnectionBase *const conn)
{
if (conn) {
for (const auto &r : mapping) {
Expand Down Expand Up @@ -188,7 +189,60 @@ long AmsRouter::SetTimeout(uint16_t port, uint32_t timeout)
return 0;
}

AmsConnection *AmsRouter::GetConnection(const AmsNetId &amsDest)
long AmsRouter::AddSecureRoute(AmsNetId ams, const std::string &host,
const bhf::ads::SecureAdsConfig &config)
{
auto hostAddresses = bhf::ads::GetListOfAddresses(host, "8016");

std::unique_lock<std::recursive_mutex> lock(mutex);

AwaitConnectionAttempts(ams, lock);

const auto oldConnection = GetConnection(ams);
if (oldConnection &&
!oldConnection->IsConnectedTo(hostAddresses.get())) {
return ROUTERERR_PORTALREADYINUSE;
}

for (const auto &conn : connections) {
if (conn->IsConnectedTo(hostAddresses.get())) {
conn->refCount++;
mapping[ams] = conn.get();
return 0;
}
}

connection_attempts[ams] = {};
lock.unlock();

try {
auto new_connection = std::unique_ptr<AmsConnectionBase>(
new SecureAmsConnection{ *this, hostAddresses.get(),
config, localAddr });
lock.lock();
connection_attempts.erase(ams);
connection_attempt_events.notify_all();

auto conn = connections.emplace(std::move(new_connection));
if (conn.second) {
if (!localAddr) {
localAddr =
AmsNetId{ conn.first->get()->ownIp };
}
conn.first->get()->refCount++;
mapping[ams] = conn.first->get();
return !conn.first->get()->ownIp;
}
return -1;
} catch (std::exception &e) {
lock.lock();
connection_attempts.erase(ams);
connection_attempt_events.notify_all();
throw;
}
}

AmsConnectionBase *AmsRouter::GetConnection(const AmsNetId &amsDest)
{
std::lock_guard<std::recursive_mutex> lock(mutex);
const auto it = mapping.find(amsDest);
Expand Down
Loading