diff --git a/include/pulsar/Authentication.h b/include/pulsar/Authentication.h index 34b70cdf..a6a02b8f 100644 --- a/include/pulsar/Authentication.h +++ b/include/pulsar/Authentication.h @@ -515,11 +515,22 @@ typedef std::shared_ptr CachedTokenPtr; * Passed in parameter would be like: * ``` * "type": "client_credentials", + * "tokenEndpointAuthMethod": "client_secret_post", * "issuer_url": "https://accounts.google.com", * "client_id": "d9ZyX97q1ef8Cr81WHVC4hFQ64vSlDK3", * "client_secret": "on1uJ...k6F6R", * "audience": "https://broker.example.com" * ``` + * + * For `tokenEndpointAuthMethod = "tls_client_auth"`: + * ``` + * "type": "client_credentials", + * "tokenEndpointAuthMethod": "tls_client_auth", + * "issuer_url": "https://accounts.google.com", + * "client_id": "d9ZyX97q1ef8Cr81WHVC4hFQ64vSlDK3", + * "tls_cert_file": "/path/to/cert.pem", + * "tls_key_file": "/path/to/key.pem" + * ``` * If passed in as std::string, it should be in Json format. */ class PULSAR_PUBLIC AuthOauth2 : public Authentication { @@ -530,7 +541,14 @@ class PULSAR_PUBLIC AuthOauth2 : public Authentication { /** * Create an AuthOauth2 with a ParamMap * - * The required parameter keys are “issuer_url”, “private_key”, and “audience” + * For `tokenEndpointAuthMethod = "client_secret_post"` (default), the required parameter + * keys are “issuer_url”, “private_key”, and “audience”. + * Optional keys: `scope`, `tls_cert_file`, `tls_key_file`. + * + * For `tokenEndpointAuthMethod = "tls_client_auth"`, the required parameter keys are + * `issuer_url`, `tls_cert_file`, and `tls_key_file`. + * Optional keys: `client_id`, `audience`, `scope`. If `client_id` is omitted, the client + * uses `pulsar-client`. * * @param parameters the key-value to create OAuth 2.0 client credentials * @see http://pulsar.apache.org/docs/en/security-oauth2/#client-credentials diff --git a/lib/auth/AuthOauth2.cc b/lib/auth/AuthOauth2.cc index 9573496b..037a0c56 100644 --- a/lib/auth/AuthOauth2.cc +++ b/lib/auth/AuthOauth2.cc @@ -20,6 +20,7 @@ #include #include +#include #include #include @@ -31,6 +32,36 @@ DECLARE_LOG_OBJECT() namespace pulsar { +const std::string TlsClientAuthFlow::DEFAULT_CLIENT_ID = "pulsar-client"; +namespace { +enum class OAuth2TokenEndpointAuthMethod : std::uint8_t +{ + ClientSecretPost, + TlsClientAuth, + Unknown, +}; + +OAuth2TokenEndpointAuthMethod parseTokenEndpointAuthMethod(const std::string& authMethod) { + if (authMethod == "tls_client_auth") { + return OAuth2TokenEndpointAuthMethod::TlsClientAuth; + } + if (authMethod == "client_secret_post") { + return OAuth2TokenEndpointAuthMethod::ClientSecretPost; + } + return OAuth2TokenEndpointAuthMethod::Unknown; +} + +std::string toFlowName(OAuth2TokenEndpointAuthMethod authMethod) { + switch (authMethod) { + case OAuth2TokenEndpointAuthMethod::TlsClientAuth: + return "TlsClientAuthFlow"; + case OAuth2TokenEndpointAuthMethod::ClientSecretPost: + default: + return "ClientCredentialFlow"; + } +} +} // namespace + // AuthDataOauth2 AuthDataOauth2::AuthDataOauth2(const std::string& accessToken) { accessToken_ = accessToken; } @@ -111,6 +142,8 @@ bool Oauth2CachedToken::isExpired() { return expiresAt_ < Clock::now(); } Oauth2Flow::Oauth2Flow() {} Oauth2Flow::~Oauth2Flow() {} +static std::string buildClientCredentialsBody(CurlWrapper& curl, const ParamMap& params); + KeyFile KeyFile::fromParamMap(ParamMap& params) { const auto it = params.find("private_key"); if (it == params.cend()) { @@ -199,80 +232,186 @@ KeyFile KeyFile::fromBase64(const std::string& encoded) { } } -ClientCredentialFlow::ClientCredentialFlow(ParamMap& params) - : issuerUrl_(params["issuer_url"]), - keyFile_(KeyFile::fromParamMap(params)), - audience_(params["audience"]), - scope_(params["scope"]) {} +static std::string getWellKnownUrl(const std::string& issuerUrl) { + std::string wellKnownUrl = issuerUrl; + if (!wellKnownUrl.empty() && wellKnownUrl.back() == '/') { + wellKnownUrl.pop_back(); + } + wellKnownUrl.append("/.well-known/openid-configuration"); + return wellKnownUrl; +} -std::string ClientCredentialFlow::getTokenEndPoint() const { return tokenEndPoint_; } +static std::unique_ptr createTlsContext(const std::string& tlsTrustCertsFilePath, + const std::string& tlsCertFilePath, + const std::string& tlsKeyFilePath, + OAuth2TokenEndpointAuthMethod authMethod) { + if (tlsTrustCertsFilePath.empty() && tlsCertFilePath.empty() && tlsKeyFilePath.empty()) { + return nullptr; + } -void ClientCredentialFlow::initialize() { - if (issuerUrl_.empty()) { - LOG_ERROR("Failed to initialize ClientCredentialFlow: issuer_url is not set"); - return; + auto tlsContext = std::unique_ptr(new CurlWrapper::TlsContext); + if (!tlsTrustCertsFilePath.empty()) { + tlsContext->trustCertsFilePath = tlsTrustCertsFilePath; } - if (!keyFile_.isValid()) { - return; + if (!tlsCertFilePath.empty() && !tlsKeyFilePath.empty()) { + tlsContext->certPath = tlsCertFilePath; + tlsContext->keyPath = tlsKeyFilePath; + } else if (authMethod == OAuth2TokenEndpointAuthMethod::TlsClientAuth) { + LOG_WARN("Ignore incomplete mTLS settings: both tls_cert_file and tls_key_file are required"); } + return tlsContext; +} - // set URL: well-know endpoint - std::string wellKnownUrl = issuerUrl_; - if (wellKnownUrl.back() == '/') { - wellKnownUrl.pop_back(); +static std::string fetchTokenEndpoint(const std::string& issuerUrl, + const CurlWrapper::TlsContext* tlsContext) { + const auto wellKnownUrl = getWellKnownUrl(issuerUrl); + CurlWrapper curl; + if (!curl.init()) { + LOG_ERROR("Failed to initialize curl"); + return ""; + } + + auto result = curl.get(wellKnownUrl, "Accept: application/json", {}, tlsContext); + if (!result.error.empty()) { + LOG_ERROR("Failed to get the well-known configuration " << issuerUrl << ": " << result.error); + return ""; + } + + const auto res = result.code; + const auto responseCode = result.responseCode; + const auto& responseData = result.responseData; + const auto& errorBuffer = result.serverError; + + switch (res) { + case CURLE_OK: + LOG_DEBUG("Received well-known configuration data " << issuerUrl << " code " << responseCode); + if (responseCode == 200) { + boost::property_tree::ptree root; + std::stringstream stream; + stream << responseData; + try { + boost::property_tree::read_json(stream, root); + return root.get("token_endpoint"); + } catch (boost::property_tree::json_parser_error& e) { + LOG_ERROR("Failed to parse well-known configuration data response: " + << e.what() << "\nInput Json = " << responseData); + return ""; + } + } else { + LOG_ERROR("Response failed for getting the well-known configuration " + << issuerUrl << ". response Code " << responseCode); + } + break; + default: + LOG_ERROR("Response failed for getting the well-known configuration " + << issuerUrl << ". Error Code " << res << ": " << errorBuffer); + break; + } + return ""; +} + +static Oauth2TokenResultPtr fetchOauth2Token(const std::string& issuerUrl, const std::string& tokenEndpoint, + const ParamMap& params, + const CurlWrapper::TlsContext* tlsContext, + OAuth2TokenEndpointAuthMethod authMethod) { + Oauth2TokenResultPtr resultPtr = Oauth2TokenResultPtr(new Oauth2TokenResult()); + if (tokenEndpoint.empty()) { + return resultPtr; } - wellKnownUrl.append("/.well-known/openid-configuration"); CurlWrapper curl; if (!curl.init()) { LOG_ERROR("Failed to initialize curl"); - return; + return resultPtr; } - std::unique_ptr tlsContext; - if (!tlsTrustCertsFilePath_.empty()) { - tlsContext.reset(new CurlWrapper::TlsContext); - tlsContext->trustCertsFilePath = tlsTrustCertsFilePath_; + + auto postData = buildClientCredentialsBody(curl, params); + if (postData.empty()) { + return resultPtr; } + LOG_DEBUG("Generate URL encoded body for " << toFlowName(authMethod) << ": " << postData); - auto result = curl.get(wellKnownUrl, "Accept: application/json", {}, tlsContext.get()); + CurlWrapper::Options options; + options.postFields = std::move(postData); + auto result = + curl.get(tokenEndpoint, "Content-Type: application/x-www-form-urlencoded", options, tlsContext); if (!result.error.empty()) { - LOG_ERROR("Failed to get the well-known configuration " << issuerUrl_ << ": " << result.error); - return; + LOG_ERROR("Failed to get the well-known configuration " << issuerUrl << ": " << result.error); + return resultPtr; } const auto res = result.code; - const auto response_code = result.responseCode; + const auto responseCode = result.responseCode; const auto& responseData = result.responseData; const auto& errorBuffer = result.serverError; switch (res) { case CURLE_OK: - LOG_DEBUG("Received well-known configuration data " << issuerUrl_ << " code " << response_code); - if (response_code == 200) { + LOG_DEBUG("Response received for issuerurl " << issuerUrl << " code " << responseCode); + if (responseCode == 200) { boost::property_tree::ptree root; std::stringstream stream; stream << responseData; try { boost::property_tree::read_json(stream, root); } catch (boost::property_tree::json_parser_error& e) { - LOG_ERROR("Failed to parse well-known configuration data response: " - << e.what() << "\nInput Json = " << responseData); + LOG_ERROR("Failed to parse json of Oauth2 response: " + << e.what() << "\nInput Json = " << responseData + << " passedin: " << options.postFields); break; } - this->tokenEndPoint_ = root.get("token_endpoint"); + resultPtr->setAccessToken(root.get("access_token", "")); + resultPtr->setExpiresIn( + root.get("expires_in", Oauth2TokenResult::undefined_expiration)); + resultPtr->setRefreshToken(root.get("refresh_token", "")); + resultPtr->setIdToken(root.get("id_token", "")); - LOG_DEBUG("Get token endpoint: " << this->tokenEndPoint_); + if (!resultPtr->getAccessToken().empty()) { + LOG_DEBUG("access_token: " << resultPtr->getAccessToken() + << " expires_in: " << resultPtr->getExpiresIn()); + } else { + LOG_ERROR("Response doesn't contain access_token, the response is: " << responseData); + } } else { - LOG_ERROR("Response failed for getting the well-known configuration " - << issuerUrl_ << ". response Code " << response_code); + LOG_ERROR("Response failed for issuerurl " << issuerUrl << ". response Code " << responseCode + << " passedin: " << options.postFields); } break; default: - LOG_ERROR("Response failed for getting the well-known configuration " - << issuerUrl_ << ". Error Code " << res << ": " << errorBuffer); + LOG_ERROR("Response failed for issuerurl " << issuerUrl << ". ErrorCode " << res << ": " + << errorBuffer << " passedin: " << options.postFields); break; } + + return resultPtr; +} + +ClientCredentialFlow::ClientCredentialFlow(ParamMap& params) + : issuerUrl_(params["issuer_url"]), + keyFile_(KeyFile::fromParamMap(params)), + audience_(params["audience"]), + scope_(params["scope"]), + tlsCertFilePath_(params["tls_cert_file"]), + tlsKeyFilePath_(params["tls_key_file"]) {} + +std::string ClientCredentialFlow::getTokenEndPoint() const { return tokenEndPoint_; } + +void ClientCredentialFlow::initialize() { + if (issuerUrl_.empty()) { + LOG_ERROR("Failed to initialize ClientCredentialFlow: issuer_url is not set"); + return; + } + if (!keyFile_.isValid()) { + return; + } + + const auto tlsContext = createTlsContext(tlsTrustCertsFilePath_, tlsCertFilePath_, tlsKeyFilePath_, + OAuth2TokenEndpointAuthMethod::ClientSecretPost); + this->tokenEndPoint_ = fetchTokenEndpoint(issuerUrl_, tlsContext.get()); + if (!this->tokenEndPoint_.empty()) { + LOG_DEBUG("Get token endpoint: " << this->tokenEndPoint_); + } } void ClientCredentialFlow::close() {} @@ -324,85 +463,93 @@ static std::string buildClientCredentialsBody(CurlWrapper& curl, const ParamMap& Oauth2TokenResultPtr ClientCredentialFlow::authenticate() { std::call_once(initializeOnce_, &ClientCredentialFlow::initialize, this); - Oauth2TokenResultPtr resultPtr = Oauth2TokenResultPtr(new Oauth2TokenResult()); - if (tokenEndPoint_.empty()) { - return resultPtr; + const auto params = generateParamMap(); + const auto tlsContext = createTlsContext(tlsTrustCertsFilePath_, tlsCertFilePath_, tlsKeyFilePath_, + OAuth2TokenEndpointAuthMethod::ClientSecretPost); + return fetchOauth2Token(issuerUrl_, tokenEndPoint_, params, tlsContext.get(), + OAuth2TokenEndpointAuthMethod::ClientSecretPost); +} + +TlsClientAuthFlow::TlsClientAuthFlow(ParamMap& params) + : issuerUrl_(params["issuer_url"]), + clientId_(params["client_id"].empty() ? DEFAULT_CLIENT_ID : params["client_id"]), + audience_(params["audience"]), + scope_(params["scope"]), + tlsCertFilePath_(params["tls_cert_file"]), + tlsKeyFilePath_(params["tls_key_file"]) {} + +std::string TlsClientAuthFlow::getTokenEndPoint() const { return tokenEndPoint_; } + +void TlsClientAuthFlow::initialize() { + if (issuerUrl_.empty()) { + LOG_ERROR("Failed to initialize TlsClientAuthFlow: issuer_url is not set"); + return; + } + if (tlsCertFilePath_.empty() || tlsKeyFilePath_.empty()) { + LOG_ERROR("Failed to initialize TlsClientAuthFlow: tls_cert_file or tls_key_file is not set"); + return; } - CurlWrapper curl; - if (!curl.init()) { - LOG_ERROR("Failed to initialize curl"); - return resultPtr; + const auto tlsContext = createTlsContext(tlsTrustCertsFilePath_, tlsCertFilePath_, tlsKeyFilePath_, + OAuth2TokenEndpointAuthMethod::TlsClientAuth); + if (!tlsContext || tlsContext->certPath.empty() || tlsContext->keyPath.empty()) { + LOG_ERROR("Failed to initialize TlsClientAuthFlow: tls_cert_file or tls_key_file is not set"); + return; } - auto postData = buildClientCredentialsBody(curl, generateParamMap()); - if (postData.empty()) { - return resultPtr; + this->tokenEndPoint_ = fetchTokenEndpoint(issuerUrl_, tlsContext.get()); + if (!this->tokenEndPoint_.empty()) { + LOG_DEBUG("Get token endpoint: " << this->tokenEndPoint_); } - LOG_DEBUG("Generate URL encoded body for ClientCredentialFlow: " << postData); +} +void TlsClientAuthFlow::close() {} - CurlWrapper::Options options; - options.postFields = std::move(postData); - std::unique_ptr tlsContext; - if (!tlsTrustCertsFilePath_.empty()) { - tlsContext.reset(new CurlWrapper::TlsContext); - tlsContext->trustCertsFilePath = tlsTrustCertsFilePath_; +ParamMap TlsClientAuthFlow::generateParamMap() const { + ParamMap params; + params.emplace("grant_type", "client_credentials"); + params.emplace("client_id", clientId_); + if (!audience_.empty()) { + params.emplace("audience", audience_); } - auto result = curl.get(tokenEndPoint_, "Content-Type: application/x-www-form-urlencoded", options, - tlsContext.get()); - if (!result.error.empty()) { - LOG_ERROR("Failed to get the well-known configuration " << issuerUrl_ << ": " << result.error); - return resultPtr; + if (!scope_.empty()) { + params.emplace("scope", scope_); } - const auto res = result.code; - const auto response_code = result.responseCode; - const auto& responseData = result.responseData; - const auto& errorBuffer = result.serverError; + return params; +} - switch (res) { - case CURLE_OK: - LOG_DEBUG("Response received for issuerurl " << issuerUrl_ << " code " << response_code); - if (response_code == 200) { - boost::property_tree::ptree root; - std::stringstream stream; - stream << responseData; - try { - boost::property_tree::read_json(stream, root); - } catch (boost::property_tree::json_parser_error& e) { - LOG_ERROR("Failed to parse json of Oauth2 response: " - << e.what() << "\nInput Json = " << responseData << " passedin: " << postData); - break; - } +Oauth2TokenResultPtr TlsClientAuthFlow::authenticate() { + std::call_once(initializeOnce_, &TlsClientAuthFlow::initialize, this); + const auto params = generateParamMap(); + const auto tlsContext = createTlsContext(tlsTrustCertsFilePath_, tlsCertFilePath_, tlsKeyFilePath_, + OAuth2TokenEndpointAuthMethod::TlsClientAuth); + if (!tlsContext || tlsContext->certPath.empty() || tlsContext->keyPath.empty()) { + Oauth2TokenResultPtr resultPtr = Oauth2TokenResultPtr(new Oauth2TokenResult()); + return resultPtr; + } + return fetchOauth2Token(issuerUrl_, tokenEndPoint_, params, tlsContext.get(), + OAuth2TokenEndpointAuthMethod::TlsClientAuth); +} - resultPtr->setAccessToken(root.get("access_token", "")); - resultPtr->setExpiresIn( - root.get("expires_in", Oauth2TokenResult::undefined_expiration)); - resultPtr->setRefreshToken(root.get("refresh_token", "")); - resultPtr->setIdToken(root.get("id_token", "")); +// AuthOauth2 - if (!resultPtr->getAccessToken().empty()) { - LOG_DEBUG("access_token: " << resultPtr->getAccessToken() - << " expires_in: " << resultPtr->getExpiresIn()); - } else { - LOG_ERROR("Response doesn't contain access_token, the response is: " << responseData); - } - } else { - LOG_ERROR("Response failed for issuerurl " << issuerUrl_ << ". response Code " - << response_code << " passedin: " << postData); - } +AuthOauth2::AuthOauth2(ParamMap& params) { + std::string tokenEndpointAuthMethodName = params["tokenEndpointAuthMethod"]; + if (tokenEndpointAuthMethodName.empty()) { + tokenEndpointAuthMethodName = "client_secret_post"; + } + const auto tokenEndpointAuthMethod = parseTokenEndpointAuthMethod(tokenEndpointAuthMethodName); + switch (tokenEndpointAuthMethod) { + case OAuth2TokenEndpointAuthMethod::TlsClientAuth: + flowPtr_ = FlowPtr(new TlsClientAuthFlow(params)); break; - default: - LOG_ERROR("Response failed for issuerurl " << issuerUrl_ << ". ErrorCode " << res << ": " - << errorBuffer << " passedin: " << postData); + case OAuth2TokenEndpointAuthMethod::ClientSecretPost: + flowPtr_ = FlowPtr(new ClientCredentialFlow(params)); break; + case OAuth2TokenEndpointAuthMethod::Unknown: + default: + throw std::invalid_argument("Unknown tokenEndpointAuthMethod: " + tokenEndpointAuthMethodName); } - - return resultPtr; } -// AuthOauth2 - -AuthOauth2::AuthOauth2(ParamMap& params) : flowPtr_(new ClientCredentialFlow(params)) {} - AuthOauth2::~AuthOauth2() {} ParamMap parseJsonAuthParamsString(const std::string& authParamsString) { @@ -436,11 +583,13 @@ const std::string AuthOauth2::getAuthMethodName() const { return "token"; } Result AuthOauth2::getAuthData(AuthenticationDataPtr& authDataContent) { auto initialAuthData = std::dynamic_pointer_cast(authDataContent); if (initialAuthData) { - auto flowPtr = std::dynamic_pointer_cast(flowPtr_); - if (!flowPtr_) { - throw std::invalid_argument("AuthOauth2::flowPtr_ is not a ClientCredentialFlow"); + if (auto clientCredentialFlow = std::dynamic_pointer_cast(flowPtr_)) { + clientCredentialFlow->setTlsTrustCertsFilePath(initialAuthData->tlsTrustCertsFilePath_); + } else if (auto tlsClientAuthFlow = std::dynamic_pointer_cast(flowPtr_)) { + tlsClientAuthFlow->setTlsTrustCertsFilePath(initialAuthData->tlsTrustCertsFilePath_); + } else { + throw std::invalid_argument("AuthOauth2::flowPtr_ is not an OAuth2 flow implementation"); } - flowPtr->setTlsTrustCertsFilePath(initialAuthData->tlsTrustCertsFilePath_); } if (cachedTokenPtr_ == nullptr || cachedTokenPtr_->isExpired()) { diff --git a/lib/auth/AuthOauth2.h b/lib/auth/AuthOauth2.h index 035ad084..b402f37e 100644 --- a/lib/auth/AuthOauth2.h +++ b/lib/auth/AuthOauth2.h @@ -71,6 +71,36 @@ class ClientCredentialFlow : public Oauth2Flow { const KeyFile keyFile_; const std::string audience_; const std::string scope_; + const std::string tlsCertFilePath_; + const std::string tlsKeyFilePath_; + std::string tlsTrustCertsFilePath_; + std::once_flag initializeOnce_; +}; + +class TlsClientAuthFlow : public Oauth2Flow { + public: + static const std::string DEFAULT_CLIENT_ID; + + TlsClientAuthFlow(ParamMap& params); + void initialize(); + Oauth2TokenResultPtr authenticate(); + void close(); + + ParamMap generateParamMap() const; + std::string getTokenEndPoint() const; + + void setTlsTrustCertsFilePath(const std::string& tlsTrustCertsFilePath) { + tlsTrustCertsFilePath_ = tlsTrustCertsFilePath; + } + + private: + std::string tokenEndPoint_; + const std::string issuerUrl_; + const std::string clientId_; + const std::string audience_; + const std::string scope_; + const std::string tlsCertFilePath_; + const std::string tlsKeyFilePath_; std::string tlsTrustCertsFilePath_; std::once_flag initializeOnce_; }; diff --git a/tests/AuthPluginTest.cc b/tests/AuthPluginTest.cc index 6c6b8980..00739649 100644 --- a/tests/AuthPluginTest.cc +++ b/tests/AuthPluginTest.cc @@ -22,11 +22,16 @@ #include #include +#include +#include #include +#include #ifdef USE_ASIO #include +#include #else #include +#include #endif #include @@ -36,6 +41,7 @@ #include "lib/LogUtils.h" #include "lib/Utils.h" #include "lib/auth/AuthOauth2.h" +#include "lib/auth/InitialAuthData.h" DECLARE_LOG_OBJECT() using namespace pulsar; @@ -59,6 +65,8 @@ static const std::string mimServiceUrlTls = "pulsar+ssl://localhost:6653"; static const std::string mimServiceUrlHttps = "https://localhost:8444"; static const std::string mimCaPath = TEST_CONF_DIR "/hn-verification/cacert.pem"; +static const std::string brokerPublicKeyPath = TEST_CONF_DIR "/broker-cert.pem"; +static const std::string brokerPrivateKeyPath = TEST_CONF_DIR "/broker-key.pem"; static void sendCallBackTls(Result r, const MessageId& msgId) { ASSERT_EQ(r, ResultOk); @@ -324,14 +332,12 @@ static std::vector split(const std::string& s, char separator) { return tokens; } -namespace testAthenz { -std::string principalToken; - // ASIO::ip::tcp::iostream could call a virtual function during destruction, so the clang-tidy will fail by // clang-analyzer-optin.cplusplus.VirtualCall. Here we write a simple stream to read lines from socket. +template class SocketStream { public: - SocketStream(ASIO::ip::tcp::socket& socket) : socket_(socket) {} + explicit SocketStream(Stream& stream) : stream_(stream) {} bool getline(std::string& line) { auto pos = buffer_.find('\n', bufferPos_); @@ -343,7 +349,7 @@ class SocketStream { std::array buffer; ASIO_ERROR error; - auto length = socket_.read_some(ASIO::buffer(buffer.data(), buffer.size()), error); + auto length = stream_.read_some(ASIO::buffer(buffer.data(), buffer.size()), error); if (error == ASIO::error::eof) { return false; } else if (error) { @@ -362,12 +368,29 @@ class SocketStream { return true; } + bool readBytes(size_t size, std::string& out) { + while (buffer_.size() - bufferPos_ < size) { + std::array buffer; + ASIO_ERROR error; + auto length = stream_.read_some(ASIO::buffer(buffer.data(), buffer.size()), error); + if (error == ASIO::error::eof) return false; + if (error) return false; + buffer_.append(buffer.data(), length); + } + out.assign(buffer_.data() + bufferPos_, size); + bufferPos_ += size; + return true; + } + private: - ASIO::ip::tcp::socket& socket_; + Stream& stream_; std::string buffer_; size_t bufferPos_{0}; }; +namespace testAthenz { +std::string principalToken; + void mockZTS(Latch& latch, int port) { LOG_INFO("-- MockZTS started"); ASIO::io_context io; @@ -380,7 +403,7 @@ void mockZTS(Latch& latch, int port) { LOG_INFO("-- MockZTS got connection"); std::string headerLine; - SocketStream stream(socket); + SocketStream stream(socket); while (stream.getline(headerLine)) { if (headerLine.empty()) { continue; @@ -518,6 +541,84 @@ TEST(AuthPluginTest, testAuthFactoryAthenz) { } } +namespace testOauth2Tls { +class MockOauth2Server { + public: + MockOauth2Server(const std::string& responseBody, const std::string& responseContentType, int listenPort, + bool requireClientCert = true) + : responseBody_(responseBody), + responseContentType_(responseContentType), + acceptor_(io_, ASIO::ip::tcp::endpoint(ASIO::ip::tcp::v4(), static_cast(listenPort))), + sslCtx_(ASIO::ssl::context::sslv23) { + sslCtx_.set_options(ASIO::ssl::context::default_workarounds | ASIO::ssl::context::no_sslv2 | + ASIO::ssl::context::no_sslv3); + sslCtx_.use_certificate_chain_file(brokerPublicKeyPath); + sslCtx_.use_private_key_file(brokerPrivateKeyPath, ASIO::ssl::context::pem); + sslCtx_.load_verify_file(caPath); + sslCtx_.set_verify_mode(requireClientCert + ? (ASIO::ssl::verify_peer | ASIO::ssl::verify_fail_if_no_peer_cert) + : ASIO::ssl::verify_none); + } + + const std::string& request() const { return request_; } + + bool mockServe() { + ASIO::ip::tcp::socket socket(io_); + acceptor_.accept(socket); + ASIO_ERROR error; + ASIO::ssl::stream sslStream(socket, sslCtx_); + sslStream.handshake(ASIO::ssl::stream_base::server, error); + if (error) return false; + if (!readRequest(sslStream)) return false; + + const std::string response = "HTTP/1.1 200 OK\nContent-Type: " + responseContentType_ + + "\nContent-Length: " + std::to_string(responseBody_.size()) + + "\nConnection: close\n\n" + responseBody_; + ASIO::write(sslStream, ASIO::buffer(response.data(), response.size()), error); + if (error) return false; + return true; + } + + private: + bool readRequest(ASIO::ssl::stream& sslStream) { + SocketStream> stream(sslStream); + request_.clear(); + int contentLength = 0; + const std::string prefix = "Content-Length:"; + std::string headerLine; + while (stream.getline(headerLine)) { + if (headerLine.empty()) { + continue; + } + request_.append(headerLine).append("\n"); + if (headerLine.rfind(prefix, 0) == 0) { + contentLength = std::stoi(headerLine.substr(prefix.size())); + } + if (headerLine == "\r") { + break; + } + } + if (headerLine != "\r") return false; + + if (contentLength > 0) { + std::string body; + if (!stream.readBytes(static_cast(contentLength), body)) return false; + request_ += body; + } + return true; + } + + const std::string responseBody_; + const std::string responseContentType_; + std::string request_; + + ASIO::io_context io_; + ASIO::ip::tcp::acceptor acceptor_; + ASIO::ssl::context sslCtx_; +}; + +} // namespace testOauth2Tls + TEST(AuthPluginTest, testOauth2) { // test success get token from oauth2 server. pulsar::AuthenticationDataPtr data; @@ -584,11 +685,15 @@ TEST(AuthPluginTest, testOauth2RequestBody) { params["client_id"] = "Xd23RHsUnvUlP7wchjNYOaIfazgeHd9x"; params["client_secret"] = "rT7ps7WY8uhdVuBTKWZkttwLdQotmdEliaM5rLfmgNibvqziZ-g07ZH52N_poGAb"; params["audience"] = "https://dev-kt-aa9ne.us.auth0.com/api/v2/"; + params["tls_cert_file"] = "/path/to/cert.pem"; + params["tls_key_file"] = "/path/to/key.pem"; auto createExpectedResult = [&] { auto paramsCopy = params; paramsCopy.emplace("grant_type", "client_credentials"); paramsCopy.erase("issuer_url"); + paramsCopy.erase("tls_cert_file"); + paramsCopy.erase("tls_key_file"); return paramsCopy; }; @@ -668,6 +773,205 @@ TEST(AuthPluginTest, testOauth2Failure) { client5.close(); } +TEST(AuthPluginTest, testOauth2TlsClientAuth) { + const int tokenServerPort = 58081; + const int wellKnownServerPort = 58082; + const std::string tokenBody = R"({"access_token":"mockToken","expires_in":3600,"token_type":"Bearer"})"; + std::unique_ptr tokenServer; + try { + tokenServer = + std::make_unique(tokenBody, "application/json", tokenServerPort); + } catch (const std::exception& e) { + FAIL() << "Failed to bind local mock token server: " << e.what(); + } + + std::promise tokenPromise; + auto tokenFuture = tokenPromise.get_future(); + std::thread tokenThread( + [&tokenServer, &tokenPromise]() { tokenPromise.set_value(tokenServer->mockServe()); }); + + std::ostringstream wellKnownBody; + wellKnownBody << R"({"token_endpoint":"https://localhost:)" << tokenServerPort << R"(/oauth/token"})"; + std::unique_ptr wellKnownServer; + try { + wellKnownServer = std::make_unique( + wellKnownBody.str(), "application/json", wellKnownServerPort, false); + } catch (const std::exception& e) { + tokenThread.join(); + FAIL() << "Failed to bind local mock well-known server: " << e.what(); + } + + std::promise wellKnownPromise; + auto wellKnownFuture = wellKnownPromise.get_future(); + std::thread wellKnownThread([&wellKnownServer, &wellKnownPromise]() { + wellKnownPromise.set_value(wellKnownServer->mockServe()); + }); + + ParamMap params; + params["tokenEndpointAuthMethod"] = "tls_client_auth"; + params["issuer_url"] = "https://localhost:" + std::to_string(wellKnownServerPort); + params["client_id"] = "test-client"; + params["tls_cert_file"] = clientPublicKeyPath; + params["tls_key_file"] = clientPrivateKeyPath; + + AuthenticationDataPtr data = + std::static_pointer_cast(std::make_shared(caPath)); + AuthenticationPtr auth = AuthOauth2::create(params); + ASSERT_EQ(auth->getAuthData(data), ResultOk); + ASSERT_TRUE(data->hasDataFromCommand()); + ASSERT_EQ(data->getCommandData(), "mockToken"); + + ASSERT_TRUE(wellKnownFuture.get()); + ASSERT_TRUE(tokenFuture.get()); + ASSERT_NE(wellKnownServer->request().find("GET /.well-known/openid-configuration "), std::string::npos); + ASSERT_NE(tokenServer->request().find("POST /oauth/token "), std::string::npos); + ASSERT_NE(tokenServer->request().find("grant_type=client_credentials"), std::string::npos); + wellKnownThread.join(); + tokenThread.join(); +} + +TEST(AuthPluginTest, testOauth2TlsClientAuthWrongCert) { + const int tokenServerPort = 58083; + const int wellKnownServerPort = 58084; + const std::string tokenBody = R"({"access_token":"mockToken","expires_in":3600,"token_type":"Bearer"})"; + + std::unique_ptr tokenServer; + try { + tokenServer = + std::make_unique(tokenBody, "application/json", tokenServerPort); + } catch (const std::exception& e) { + FAIL() << "Failed to bind local mock token server: " << e.what(); + } + + std::promise tokenPromise; + auto tokenFuture = tokenPromise.get_future(); + std::thread tokenThread( + [&tokenServer, &tokenPromise]() { tokenPromise.set_value(tokenServer->mockServe()); }); + + std::ostringstream wellKnownBody; + wellKnownBody << R"({"token_endpoint":"https://localhost:)" << tokenServerPort << R"(/oauth/token"})"; + std::unique_ptr wellKnownServer; + try { + wellKnownServer = std::make_unique( + wellKnownBody.str(), "application/json", wellKnownServerPort, false); + } catch (const std::exception& e) { + tokenThread.join(); + FAIL() << "Failed to bind local mock well-known server: " << e.what(); + } + + std::promise wellKnownPromise; + auto wellKnownFuture = wellKnownPromise.get_future(); + std::thread wellKnownThread([&wellKnownServer, &wellKnownPromise]() { + wellKnownPromise.set_value(wellKnownServer->mockServe()); + }); + + ParamMap params; + params["tokenEndpointAuthMethod"] = "tls_client_auth"; + params["issuer_url"] = "https://localhost:" + std::to_string(wellKnownServerPort); + params["client_id"] = "test-client"; + // set wrong cert and key + params["tls_cert_file"] = TEST_CONF_DIR "/hn-verification/broker-cert.pem"; + params["tls_key_file"] = TEST_CONF_DIR "/hn-verification/broker-key.pem"; + + AuthenticationDataPtr data = + std::static_pointer_cast(std::make_shared(caPath)); + AuthenticationPtr auth = AuthOauth2::create(params); + ASSERT_EQ(auth->getAuthData(data), ResultAuthenticationError); + + ASSERT_TRUE(wellKnownFuture.get()); + ASSERT_FALSE(tokenFuture.get()); + ASSERT_NE(wellKnownServer->request().find("GET /.well-known/openid-configuration "), std::string::npos); + wellKnownThread.join(); + tokenThread.join(); +} + +TEST(AuthPluginTest, testOauth2TlsClientAuthRequestBody) { + ParamMap params; + params["tokenEndpointAuthMethod"] = "tls_client_auth"; + params["issuer_url"] = "https://dev-kt-aa9ne.us.auth0.com"; + params["client_id"] = "Xd23RHsUnvUlP7wchjNYOaIfazgeHd9x"; + params["audience"] = "https://dev-kt-aa9ne.us.auth0.com/api/v2/"; + params["tls_cert_file"] = "/path/to/cert.pem"; + params["tls_key_file"] = "/path/to/key.pem"; + + auto createExpectedResult = [&] { + auto paramsCopy = params; + paramsCopy.emplace("grant_type", "client_credentials"); + paramsCopy.erase("tokenEndpointAuthMethod"); + paramsCopy.erase("issuer_url"); + paramsCopy.erase("tls_cert_file"); + paramsCopy.erase("tls_key_file"); + return paramsCopy; + }; + + const auto expectedResult1 = createExpectedResult(); + TlsClientAuthFlow flow1(params); + ASSERT_EQ(flow1.generateParamMap(), expectedResult1); + + params["scope"] = "test-scope"; + const auto expectedResult2 = createExpectedResult(); + TlsClientAuthFlow flow2(params); + ASSERT_EQ(flow2.generateParamMap(), expectedResult2); + + params.erase("client_id"); + auto expectedResult3 = expectedResult2; + expectedResult3["client_id"] = TlsClientAuthFlow::DEFAULT_CLIENT_ID; + TlsClientAuthFlow flow3(params); + ASSERT_EQ(flow3.generateParamMap(), expectedResult3); + + params.erase("audience"); + auto expectedResult4 = expectedResult3; + expectedResult4.erase("audience"); + TlsClientAuthFlow flow4(params); + ASSERT_EQ(flow4.generateParamMap(), expectedResult4); +} + +TEST(AuthPluginTest, testOauth2TlsClientAuthFailure) { + ParamMap params; + auto getAuthDataResult = [&]() -> Result { + AuthenticationDataPtr data = + std::static_pointer_cast(std::make_shared(caPath)); + AuthenticationPtr auth = AuthOauth2::create(params); + return auth->getAuthData(data); + }; + + params["tokenEndpointAuthMethod"] = "tls_client_auth"; + params["tls_cert_file"] = clientPublicKeyPath; + params["tls_key_file"] = clientPrivateKeyPath; + + // No issuer_url + params.erase("issuer_url"); + ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError); + + // Invalid issuer_url + params["issuer_url"] = "hello"; + ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError); + + // No cert and key + params["issuer_url"] = "https://localhost:58086"; + params.erase("tls_cert_file"); + params.erase("tls_key_file"); + ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError); + + // Invalid cert and key + params["tls_cert_file"] = TEST_CONF_DIR "/not-exist-cert.pem"; + params["tls_key_file"] = TEST_CONF_DIR "/not-exist-key.pem"; + ASSERT_EQ(getAuthDataResult(), ResultAuthenticationError); +} + +TEST(AuthPluginTest, testOauth2UnknownTokenEndpointAuthMethod) { + std::string params = R"({ + "type": "client_credentials", + "tokenEndpointAuthMethod": "client_secret_get", + "issuer_url": "https://dev-kt-aa9ne.us.auth0.com", + "client_id": "Xd23RHsUnvUlP7wchjNYOaIfazgeHd9x", + "client_secret": "rT7ps7WY8uhdVuBTKWZkttwLdQotmdEliaM5rLfmgNibvqziZ-g07ZH52N_poGAb", + "audience": "https://dev-kt-aa9ne.us.auth0.com/api/v2/"})"; + + LOG_INFO("PARAMS: " << params); + ASSERT_THROW(AuthOauth2::create(params), std::invalid_argument); +} + TEST(AuthPluginTest, testInvalidPlugin) { Client client("pulsar://localhost:6650", ClientConfiguration{}.setAuth(AuthFactory::create("invalid"))); Producer producer;