From 19ee8b4840bfc406f74b08259e96d59efa7c3fbf Mon Sep 17 00:00:00 2001 From: Aviad Lichtenstadt Date: Sun, 15 Feb 2026 13:36:06 +0000 Subject: [PATCH] feat: add OAuth 2.0 client assertion JWT creation Implements RFC 7523 client assertion JWT creation for OAuth client authentication. This allows customers to use JWT bearer tokens for client authentication when migrating from Node.js to Java services. Changes: - Add ClientAssertionRequest model with clientId, tokenEndpoint, privateKey, algorithm, and expirationSeconds - Add createClientAssertion method to JwtService interface - Implement client assertion JWT creation using jjwt library with proper claims (iss, sub, aud, exp, iat, jti) - Add comprehensive unit tests covering RS256, ES256, validation, and error cases - Update README with usage examples and documentation The generated JWT follows RFC 7523 specifications: - iss and sub set to client_id - aud set to token endpoint URL - exp defaults to 5 minutes (customizable) - jti includes unique UUID to prevent replay attacks - Supports RS256, ES256, and other signing algorithms Resolves: KHealth customer request for Java SDK parity with Node.js --- README.md | 48 ++++ .../jwt/request/ClientAssertionRequest.java | 50 ++++ .../java/com/descope/sdk/mgmt/JwtService.java | 11 + .../descope/sdk/mgmt/impl/JwtServiceImpl.java | 56 ++++ .../JwtServiceImplClientAssertionTest.java | 267 ++++++++++++++++++ 5 files changed, 432 insertions(+) create mode 100644 src/main/java/com/descope/model/jwt/request/ClientAssertionRequest.java create mode 100644 src/test/java/com/descope/sdk/mgmt/impl/JwtServiceImplClientAssertionTest.java diff --git a/README.md b/README.md index e597f720..a772bd64 100644 --- a/README.md +++ b/README.md @@ -1247,6 +1247,54 @@ mgmtSignUpUser.setCustomClaims(new HashMap() {{ AuthenticationInfo res = jwtService.signUpOrIn("Dummy", mgmtSignUpUser); ``` +#### OAuth 2.0 Client Assertion JWT + +You can create client assertion JWTs for OAuth 2.0 client authentication per [RFC 7523](https://datatracker.ietf.org/doc/html/rfc7523). This is useful when authenticating to OAuth token endpoints that support JWT bearer client assertions. + +```java +import com.descope.model.jwt.request.ClientAssertionRequest; +import java.security.KeyStore; +import java.security.interfaces.RSAPrivateKey; + +JwtService jwtService = descopeClient.getManagementServices().getJwtService(); + +// Load your private key (example using keystore) +KeyStore keyStore = KeyStore.getInstance("PKCS12"); +try (InputStream is = new FileInputStream("/path/to/keystore.p12")) { + keyStore.load(is, "keystore-password".toCharArray()); +} +RSAPrivateKey privateKey = (RSAPrivateKey) keyStore.getKey("key-alias", "key-password".toCharArray()); + +// Create the client assertion JWT +ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("your-client-id") + .tokenEndpoint("https://auth.example.com/oauth/token") + .privateKey(privateKey) + .algorithm("RS256") // Optional, defaults to RS256. Also supports ES256, etc. + .expirationSeconds(300) // Optional, defaults to 300 (5 minutes) + .build(); + +try { + String clientAssertion = jwtService.createClientAssertion(request); + + // Use the client assertion in your OAuth token request + // POST to token endpoint with: + // grant_type=client_credentials + // client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer + // client_assertion= +} catch (DescopeException de) { + // Handle the error +} +``` + +The generated JWT contains the following claims per RFC 7523: +- `iss` (issuer): Your client ID +- `sub` (subject): Your client ID +- `aud` (audience): The token endpoint URL +- `exp` (expiration): Current time + expiration seconds +- `iat` (issued at): Current time +- `jti` (JWT ID): Unique identifier to prevent replay attacks + ### Audit You can perform an audit search for either specific values or full-text across the fields. Audit search is limited to the last 30 days. diff --git a/src/main/java/com/descope/model/jwt/request/ClientAssertionRequest.java b/src/main/java/com/descope/model/jwt/request/ClientAssertionRequest.java new file mode 100644 index 00000000..824dd071 --- /dev/null +++ b/src/main/java/com/descope/model/jwt/request/ClientAssertionRequest.java @@ -0,0 +1,50 @@ +package com.descope.model.jwt.request; + +import java.security.PrivateKey; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Request object for creating OAuth 2.0 client assertion JWT. + * + *

This is used for client authentication using JWT bearer tokens as per RFC 7523. + * The generated JWT can be used with OAuth token endpoints that support + * client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer + */ +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class ClientAssertionRequest { + /** + * The client ID (issuer and subject of the JWT). + */ + private String clientId; + + /** + * The token endpoint URL (audience of the JWT). + */ + private String tokenEndpoint; + + /** + * The private key used to sign the JWT. + * Typically an RSA or ECDSA private key. + */ + private PrivateKey privateKey; + + /** + * The signing algorithm to use (e.g., "RS256", "ES256"). + * Defaults to "RS256" if not specified. + */ + @Builder.Default + private String algorithm = "RS256"; + + /** + * JWT expiration time in seconds. + * Defaults to 300 seconds (5 minutes) if not specified. + */ + @Builder.Default + private long expirationSeconds = 300; +} diff --git a/src/main/java/com/descope/sdk/mgmt/JwtService.java b/src/main/java/com/descope/sdk/mgmt/JwtService.java index 7ee1814b..d513a910 100644 --- a/src/main/java/com/descope/sdk/mgmt/JwtService.java +++ b/src/main/java/com/descope/sdk/mgmt/JwtService.java @@ -5,6 +5,7 @@ import com.descope.model.jwt.MgmtSignUpUser; import com.descope.model.jwt.Token; import com.descope.model.jwt.request.AnonymousUserRequest; +import com.descope.model.jwt.request.ClientAssertionRequest; import com.descope.model.magiclink.LoginOptions; import java.util.Map; @@ -31,4 +32,14 @@ AuthenticationInfo signUpOrIn(String loginId, MgmtSignUpUser signUpUserDetails) AuthenticationInfo signIn(String loginId, LoginOptions loginOptions) throws DescopeException; AuthenticationInfo anonymous(AnonymousUserRequest request) throws DescopeException; + + /** + * Create an OAuth 2.0 client assertion JWT for client authentication. + * This JWT can be used with OAuth token endpoints that support RFC 7523. + * + * @param request - ClientAssertionRequest containing clientId, tokenEndpoint, privateKey, and signing algorithm + * @return - The signed JWT string that can be used as client_assertion parameter + * @throws DescopeException if JWT creation fails + */ + String createClientAssertion(ClientAssertionRequest request) throws DescopeException; } diff --git a/src/main/java/com/descope/sdk/mgmt/impl/JwtServiceImpl.java b/src/main/java/com/descope/sdk/mgmt/impl/JwtServiceImpl.java index 738cb1fe..098bb9c5 100644 --- a/src/main/java/com/descope/sdk/mgmt/impl/JwtServiceImpl.java +++ b/src/main/java/com/descope/sdk/mgmt/impl/JwtServiceImpl.java @@ -14,6 +14,7 @@ import com.descope.model.jwt.MgmtSignUpUser; import com.descope.model.jwt.Token; import com.descope.model.jwt.request.AnonymousUserRequest; +import com.descope.model.jwt.request.ClientAssertionRequest; import com.descope.model.jwt.request.ManagementSignInRequest; import com.descope.model.jwt.request.ManagementSignUpRequest; import com.descope.model.jwt.request.UpdateJwtRequest; @@ -22,8 +23,13 @@ import com.descope.model.magiclink.LoginOptions; import com.descope.proxy.ApiProxy; import com.descope.sdk.mgmt.JwtService; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.SignatureAlgorithm; import java.net.URI; +import java.security.PrivateKey; +import java.util.Date; import java.util.Map; +import java.util.UUID; import org.apache.commons.lang3.StringUtils; class JwtServiceImpl extends ManagementsBase implements JwtService { @@ -137,6 +143,56 @@ private AuthenticationInfo validateAndCreateAuthInfo(JWTResponse jwtResponse) th return new AuthenticationInfo(sessionToken, refreshToken, jwtResponse.getUser(), jwtResponse.getFirstSeen()); } + @Override + public String createClientAssertion(ClientAssertionRequest request) throws DescopeException { + if (request == null) { + throw ServerCommonException.invalidArgument("ClientAssertionRequest"); + } + if (StringUtils.isBlank(request.getClientId())) { + throw ServerCommonException.invalidArgument("clientId"); + } + if (StringUtils.isBlank(request.getTokenEndpoint())) { + throw ServerCommonException.invalidArgument("tokenEndpoint"); + } + if (request.getPrivateKey() == null) { + throw ServerCommonException.invalidArgument("privateKey"); + } + + try { + long nowMillis = System.currentTimeMillis(); + Date now = new Date(nowMillis); + Date expiration = new Date(nowMillis + (request.getExpirationSeconds() * 1000)); + + SignatureAlgorithm algorithm = getSignatureAlgorithm(request.getAlgorithm()); + PrivateKey privateKey = request.getPrivateKey(); + + return Jwts.builder() + .setIssuer(request.getClientId()) + .setSubject(request.getClientId()) + .setAudience(request.getTokenEndpoint()) + .setIssuedAt(now) + .setExpiration(expiration) + .setId(UUID.randomUUID().toString()) + .signWith(algorithm, privateKey) + .compact(); + } catch (Exception e) { + String errorMessage = "Failed to create client assertion JWT: " + e.getMessage(); + throw ServerCommonException.parseResponseError(errorMessage, null, e); + } + } + + private SignatureAlgorithm getSignatureAlgorithm(String algorithm) { + if (StringUtils.isBlank(algorithm)) { + return SignatureAlgorithm.RS256; + } + + try { + return SignatureAlgorithm.forName(algorithm); + } catch (IllegalArgumentException e) { + throw ServerCommonException.invalidArgument("algorithm - unsupported: " + algorithm); + } + } + private URI composeUpdateJwtUri() { return getUri(UPDATE_JWT_LINK); } diff --git a/src/test/java/com/descope/sdk/mgmt/impl/JwtServiceImplClientAssertionTest.java b/src/test/java/com/descope/sdk/mgmt/impl/JwtServiceImplClientAssertionTest.java new file mode 100644 index 00000000..24356f5a --- /dev/null +++ b/src/test/java/com/descope/sdk/mgmt/impl/JwtServiceImplClientAssertionTest.java @@ -0,0 +1,267 @@ +package com.descope.sdk.mgmt.impl; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.descope.exception.DescopeException; +import com.descope.exception.ServerCommonException; +import com.descope.model.client.Client; +import com.descope.model.jwt.request.ClientAssertionRequest; +import com.descope.model.mgmt.ManagementServices; +import com.descope.sdk.TestUtils; +import com.descope.sdk.mgmt.JwtService; +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.Jwts; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.interfaces.ECPrivateKey; +import java.security.interfaces.RSAPrivateKey; +import java.security.interfaces.RSAPublicKey; +import java.util.Date; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class JwtServiceImplClientAssertionTest { + + private JwtService jwtService; + private RSAPrivateKey rsaPrivateKey; + private RSAPublicKey rsaPublicKey; + private ECPrivateKey ecPrivateKey; + + @BeforeEach + void setUp() throws Exception { + Client client = TestUtils.getClient(); + ManagementServices mgmtServices = ManagementServiceBuilder.buildServices(client); + this.jwtService = mgmtServices.getJwtService(); + + KeyPairGenerator rsaGenerator = KeyPairGenerator.getInstance("RSA"); + rsaGenerator.initialize(2048); + KeyPair rsaKeyPair = rsaGenerator.generateKeyPair(); + this.rsaPrivateKey = (RSAPrivateKey) rsaKeyPair.getPrivate(); + this.rsaPublicKey = (RSAPublicKey) rsaKeyPair.getPublic(); + + KeyPairGenerator ecGenerator = KeyPairGenerator.getInstance("EC"); + ecGenerator.initialize(256); + KeyPair ecKeyPair = ecGenerator.generateKeyPair(); + this.ecPrivateKey = (ECPrivateKey) ecKeyPair.getPrivate(); + } + + @Test + void testCreateClientAssertionWithRS256() throws Exception { + String clientId = "test-client-id"; + String tokenEndpoint = "https://auth.example.com/oauth/token"; + + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId(clientId) + .tokenEndpoint(tokenEndpoint) + .privateKey(rsaPrivateKey) + .algorithm("RS256") + .expirationSeconds(300) + .build(); + + String jwt = jwtService.createClientAssertion(request); + + assertNotNull(jwt); + assertTrue(jwt.split("\\.").length == 3); + + Claims claims = Jwts.parser() + .setSigningKey(rsaPublicKey) + .build() + .parseSignedClaims(jwt) + .getPayload(); + + assertEquals(clientId, claims.getIssuer()); + assertEquals(clientId, claims.getSubject()); + assertEquals(tokenEndpoint, claims.getAudience().iterator().next()); + assertNotNull(claims.getId()); + assertNotNull(claims.getIssuedAt()); + assertNotNull(claims.getExpiration()); + + long expirationDiff = claims.getExpiration().getTime() - claims.getIssuedAt().getTime(); + assertEquals(300000, expirationDiff, 1000); + } + + @Test + void testCreateClientAssertionWithES256() throws Exception { + String clientId = "test-client-id"; + String tokenEndpoint = "https://auth.example.com/oauth/token"; + + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId(clientId) + .tokenEndpoint(tokenEndpoint) + .privateKey(ecPrivateKey) + .algorithm("ES256") + .expirationSeconds(300) + .build(); + + String jwt = jwtService.createClientAssertion(request); + + assertNotNull(jwt); + assertTrue(jwt.split("\\.").length == 3); + } + + @Test + void testCreateClientAssertionWithDefaultAlgorithm() throws Exception { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(rsaPrivateKey) + .build(); + + String jwt = jwtService.createClientAssertion(request); + + assertNotNull(jwt); + + Claims claims = Jwts.parser() + .setSigningKey(rsaPublicKey) + .build() + .parseSignedClaims(jwt) + .getPayload(); + + assertNotNull(claims); + } + + @Test + void testCreateClientAssertionWithCustomExpiration() throws Exception { + long customExpiration = 600; + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(rsaPrivateKey) + .expirationSeconds(customExpiration) + .build(); + + String jwt = jwtService.createClientAssertion(request); + + Claims claims = Jwts.parser() + .setSigningKey(rsaPublicKey) + .build() + .parseSignedClaims(jwt) + .getPayload(); + + long expirationDiff = claims.getExpiration().getTime() - claims.getIssuedAt().getTime(); + assertEquals(600000, expirationDiff, 1000); + } + + @Test + void testCreateClientAssertionExpirationNotInPast() throws Exception { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(rsaPrivateKey) + .build(); + + String jwt = jwtService.createClientAssertion(request); + + Claims claims = Jwts.parser() + .setSigningKey(rsaPublicKey) + .build() + .parseSignedClaims(jwt) + .getPayload(); + + assertTrue(claims.getExpiration().after(new Date())); + } + + @Test + void testCreateClientAssertionUniqueJti() throws Exception { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(rsaPrivateKey) + .build(); + + String jwt1 = jwtService.createClientAssertion(request); + String jwt2 = jwtService.createClientAssertion(request); + + Claims claims1 = Jwts.parser() + .setSigningKey(rsaPublicKey) + .build() + .parseSignedClaims(jwt1) + .getPayload(); + + Claims claims2 = Jwts.parser() + .setSigningKey(rsaPublicKey) + .build() + .parseSignedClaims(jwt2) + .getPayload(); + + assertNotNull(claims1.getId()); + assertNotNull(claims2.getId()); + assertTrue(!claims1.getId().equals(claims2.getId())); + } + + @Test + void testCreateClientAssertionNullRequest() { + ServerCommonException thrown = assertThrows( + ServerCommonException.class, + () -> jwtService.createClientAssertion(null)); + + assertEquals("The ClientAssertionRequest argument is invalid", thrown.getMessage()); + } + + @Test + void testCreateClientAssertionEmptyClientId() { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(rsaPrivateKey) + .build(); + + ServerCommonException thrown = assertThrows( + ServerCommonException.class, + () -> jwtService.createClientAssertion(request)); + + assertEquals("The clientId argument is invalid", thrown.getMessage()); + } + + @Test + void testCreateClientAssertionEmptyTokenEndpoint() { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("") + .privateKey(rsaPrivateKey) + .build(); + + ServerCommonException thrown = assertThrows( + ServerCommonException.class, + () -> jwtService.createClientAssertion(request)); + + assertEquals("The tokenEndpoint argument is invalid", thrown.getMessage()); + } + + @Test + void testCreateClientAssertionNullPrivateKey() { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(null) + .build(); + + ServerCommonException thrown = assertThrows( + ServerCommonException.class, + () -> jwtService.createClientAssertion(request)); + + assertEquals("The privateKey argument is invalid", thrown.getMessage()); + } + + @Test + void testCreateClientAssertionUnsupportedAlgorithm() { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(rsaPrivateKey) + .algorithm("INVALID_ALGORITHM") + .build(); + + ServerCommonException thrown = assertThrows( + ServerCommonException.class, + () -> jwtService.createClientAssertion(request)); + + assertNotNull(thrown); + String message = thrown.getMessage(); + assertTrue(message != null && message.contains("algorithm"), + "Expected error message to contain 'algorithm', but got: " + message); + } +}