diff --git a/README.md b/README.md index e597f720..a772bd64 100644 --- a/README.md +++ b/README.md @@ -1247,6 +1247,54 @@ mgmtSignUpUser.setCustomClaims(new HashMap() {{ AuthenticationInfo res = jwtService.signUpOrIn("Dummy", mgmtSignUpUser); ``` +#### OAuth 2.0 Client Assertion JWT + +You can create client assertion JWTs for OAuth 2.0 client authentication per [RFC 7523](https://datatracker.ietf.org/doc/html/rfc7523). This is useful when authenticating to OAuth token endpoints that support JWT bearer client assertions. + +```java +import com.descope.model.jwt.request.ClientAssertionRequest; +import java.security.KeyStore; +import java.security.interfaces.RSAPrivateKey; + +JwtService jwtService = descopeClient.getManagementServices().getJwtService(); + +// Load your private key (example using keystore) +KeyStore keyStore = KeyStore.getInstance("PKCS12"); +try (InputStream is = new FileInputStream("/path/to/keystore.p12")) { + keyStore.load(is, "keystore-password".toCharArray()); +} +RSAPrivateKey privateKey = (RSAPrivateKey) keyStore.getKey("key-alias", "key-password".toCharArray()); + +// Create the client assertion JWT +ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("your-client-id") + .tokenEndpoint("https://auth.example.com/oauth/token") + .privateKey(privateKey) + .algorithm("RS256") // Optional, defaults to RS256. Also supports ES256, etc. + .expirationSeconds(300) // Optional, defaults to 300 (5 minutes) + .build(); + +try { + String clientAssertion = jwtService.createClientAssertion(request); + + // Use the client assertion in your OAuth token request + // POST to token endpoint with: + // grant_type=client_credentials + // client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer + // client_assertion= +} catch (DescopeException de) { + // Handle the error +} +``` + +The generated JWT contains the following claims per RFC 7523: +- `iss` (issuer): Your client ID +- `sub` (subject): Your client ID +- `aud` (audience): The token endpoint URL +- `exp` (expiration): Current time + expiration seconds +- `iat` (issued at): Current time +- `jti` (JWT ID): Unique identifier to prevent replay attacks + ### Audit You can perform an audit search for either specific values or full-text across the fields. Audit search is limited to the last 30 days. diff --git a/src/main/java/com/descope/model/jwt/request/ClientAssertionRequest.java b/src/main/java/com/descope/model/jwt/request/ClientAssertionRequest.java new file mode 100644 index 00000000..824dd071 --- /dev/null +++ b/src/main/java/com/descope/model/jwt/request/ClientAssertionRequest.java @@ -0,0 +1,50 @@ +package com.descope.model.jwt.request; + +import java.security.PrivateKey; +import lombok.AllArgsConstructor; +import lombok.Builder; +import lombok.Data; +import lombok.NoArgsConstructor; + +/** + * Request object for creating OAuth 2.0 client assertion JWT. + * + *

This is used for client authentication using JWT bearer tokens as per RFC 7523. + * The generated JWT can be used with OAuth token endpoints that support + * client_assertion_type=urn:ietf:params:oauth:client-assertion-type:jwt-bearer + */ +@Data +@Builder +@NoArgsConstructor +@AllArgsConstructor +public class ClientAssertionRequest { + /** + * The client ID (issuer and subject of the JWT). + */ + private String clientId; + + /** + * The token endpoint URL (audience of the JWT). + */ + private String tokenEndpoint; + + /** + * The private key used to sign the JWT. + * Typically an RSA or ECDSA private key. + */ + private PrivateKey privateKey; + + /** + * The signing algorithm to use (e.g., "RS256", "ES256"). + * Defaults to "RS256" if not specified. + */ + @Builder.Default + private String algorithm = "RS256"; + + /** + * JWT expiration time in seconds. + * Defaults to 300 seconds (5 minutes) if not specified. + */ + @Builder.Default + private long expirationSeconds = 300; +} diff --git a/src/main/java/com/descope/sdk/mgmt/JwtService.java b/src/main/java/com/descope/sdk/mgmt/JwtService.java index 7ee1814b..d513a910 100644 --- a/src/main/java/com/descope/sdk/mgmt/JwtService.java +++ b/src/main/java/com/descope/sdk/mgmt/JwtService.java @@ -5,6 +5,7 @@ import com.descope.model.jwt.MgmtSignUpUser; import com.descope.model.jwt.Token; import com.descope.model.jwt.request.AnonymousUserRequest; +import com.descope.model.jwt.request.ClientAssertionRequest; import com.descope.model.magiclink.LoginOptions; import java.util.Map; @@ -31,4 +32,14 @@ AuthenticationInfo signUpOrIn(String loginId, MgmtSignUpUser signUpUserDetails) AuthenticationInfo signIn(String loginId, LoginOptions loginOptions) throws DescopeException; AuthenticationInfo anonymous(AnonymousUserRequest request) throws DescopeException; + + /** + * Create an OAuth 2.0 client assertion JWT for client authentication. + * This JWT can be used with OAuth token endpoints that support RFC 7523. + * + * @param request - ClientAssertionRequest containing clientId, tokenEndpoint, privateKey, and signing algorithm + * @return - The signed JWT string that can be used as client_assertion parameter + * @throws DescopeException if JWT creation fails + */ + String createClientAssertion(ClientAssertionRequest request) throws DescopeException; } diff --git a/src/main/java/com/descope/sdk/mgmt/impl/JwtServiceImpl.java b/src/main/java/com/descope/sdk/mgmt/impl/JwtServiceImpl.java index 738cb1fe..098bb9c5 100644 --- a/src/main/java/com/descope/sdk/mgmt/impl/JwtServiceImpl.java +++ b/src/main/java/com/descope/sdk/mgmt/impl/JwtServiceImpl.java @@ -14,6 +14,7 @@ import com.descope.model.jwt.MgmtSignUpUser; import com.descope.model.jwt.Token; import com.descope.model.jwt.request.AnonymousUserRequest; +import com.descope.model.jwt.request.ClientAssertionRequest; import com.descope.model.jwt.request.ManagementSignInRequest; import com.descope.model.jwt.request.ManagementSignUpRequest; import com.descope.model.jwt.request.UpdateJwtRequest; @@ -22,8 +23,13 @@ import com.descope.model.magiclink.LoginOptions; import com.descope.proxy.ApiProxy; import com.descope.sdk.mgmt.JwtService; +import io.jsonwebtoken.Jwts; +import io.jsonwebtoken.SignatureAlgorithm; import java.net.URI; +import java.security.PrivateKey; +import java.util.Date; import java.util.Map; +import java.util.UUID; import org.apache.commons.lang3.StringUtils; class JwtServiceImpl extends ManagementsBase implements JwtService { @@ -137,6 +143,56 @@ private AuthenticationInfo validateAndCreateAuthInfo(JWTResponse jwtResponse) th return new AuthenticationInfo(sessionToken, refreshToken, jwtResponse.getUser(), jwtResponse.getFirstSeen()); } + @Override + public String createClientAssertion(ClientAssertionRequest request) throws DescopeException { + if (request == null) { + throw ServerCommonException.invalidArgument("ClientAssertionRequest"); + } + if (StringUtils.isBlank(request.getClientId())) { + throw ServerCommonException.invalidArgument("clientId"); + } + if (StringUtils.isBlank(request.getTokenEndpoint())) { + throw ServerCommonException.invalidArgument("tokenEndpoint"); + } + if (request.getPrivateKey() == null) { + throw ServerCommonException.invalidArgument("privateKey"); + } + + try { + long nowMillis = System.currentTimeMillis(); + Date now = new Date(nowMillis); + Date expiration = new Date(nowMillis + (request.getExpirationSeconds() * 1000)); + + SignatureAlgorithm algorithm = getSignatureAlgorithm(request.getAlgorithm()); + PrivateKey privateKey = request.getPrivateKey(); + + return Jwts.builder() + .setIssuer(request.getClientId()) + .setSubject(request.getClientId()) + .setAudience(request.getTokenEndpoint()) + .setIssuedAt(now) + .setExpiration(expiration) + .setId(UUID.randomUUID().toString()) + .signWith(algorithm, privateKey) + .compact(); + } catch (Exception e) { + String errorMessage = "Failed to create client assertion JWT: " + e.getMessage(); + throw ServerCommonException.parseResponseError(errorMessage, null, e); + } + } + + private SignatureAlgorithm getSignatureAlgorithm(String algorithm) { + if (StringUtils.isBlank(algorithm)) { + return SignatureAlgorithm.RS256; + } + + try { + return SignatureAlgorithm.forName(algorithm); + } catch (IllegalArgumentException e) { + throw ServerCommonException.invalidArgument("algorithm - unsupported: " + algorithm); + } + } + private URI composeUpdateJwtUri() { return getUri(UPDATE_JWT_LINK); } diff --git a/src/test/java/com/descope/sdk/mgmt/impl/JwtServiceImplClientAssertionTest.java b/src/test/java/com/descope/sdk/mgmt/impl/JwtServiceImplClientAssertionTest.java new file mode 100644 index 00000000..24356f5a --- /dev/null +++ b/src/test/java/com/descope/sdk/mgmt/impl/JwtServiceImplClientAssertionTest.java @@ -0,0 +1,267 @@ +package com.descope.sdk.mgmt.impl; + +import static org.junit.jupiter.api.Assertions.assertEquals; +import static org.junit.jupiter.api.Assertions.assertNotNull; +import static org.junit.jupiter.api.Assertions.assertThrows; +import static org.junit.jupiter.api.Assertions.assertTrue; + +import com.descope.exception.DescopeException; +import com.descope.exception.ServerCommonException; +import com.descope.model.client.Client; +import com.descope.model.jwt.request.ClientAssertionRequest; +import com.descope.model.mgmt.ManagementServices; +import com.descope.sdk.TestUtils; +import com.descope.sdk.mgmt.JwtService; +import io.jsonwebtoken.Claims; +import io.jsonwebtoken.Jwts; +import java.security.KeyPair; +import java.security.KeyPairGenerator; +import java.security.interfaces.ECPrivateKey; +import java.security.interfaces.RSAPrivateKey; +import java.security.interfaces.RSAPublicKey; +import java.util.Date; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +public class JwtServiceImplClientAssertionTest { + + private JwtService jwtService; + private RSAPrivateKey rsaPrivateKey; + private RSAPublicKey rsaPublicKey; + private ECPrivateKey ecPrivateKey; + + @BeforeEach + void setUp() throws Exception { + Client client = TestUtils.getClient(); + ManagementServices mgmtServices = ManagementServiceBuilder.buildServices(client); + this.jwtService = mgmtServices.getJwtService(); + + KeyPairGenerator rsaGenerator = KeyPairGenerator.getInstance("RSA"); + rsaGenerator.initialize(2048); + KeyPair rsaKeyPair = rsaGenerator.generateKeyPair(); + this.rsaPrivateKey = (RSAPrivateKey) rsaKeyPair.getPrivate(); + this.rsaPublicKey = (RSAPublicKey) rsaKeyPair.getPublic(); + + KeyPairGenerator ecGenerator = KeyPairGenerator.getInstance("EC"); + ecGenerator.initialize(256); + KeyPair ecKeyPair = ecGenerator.generateKeyPair(); + this.ecPrivateKey = (ECPrivateKey) ecKeyPair.getPrivate(); + } + + @Test + void testCreateClientAssertionWithRS256() throws Exception { + String clientId = "test-client-id"; + String tokenEndpoint = "https://auth.example.com/oauth/token"; + + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId(clientId) + .tokenEndpoint(tokenEndpoint) + .privateKey(rsaPrivateKey) + .algorithm("RS256") + .expirationSeconds(300) + .build(); + + String jwt = jwtService.createClientAssertion(request); + + assertNotNull(jwt); + assertTrue(jwt.split("\\.").length == 3); + + Claims claims = Jwts.parser() + .setSigningKey(rsaPublicKey) + .build() + .parseSignedClaims(jwt) + .getPayload(); + + assertEquals(clientId, claims.getIssuer()); + assertEquals(clientId, claims.getSubject()); + assertEquals(tokenEndpoint, claims.getAudience().iterator().next()); + assertNotNull(claims.getId()); + assertNotNull(claims.getIssuedAt()); + assertNotNull(claims.getExpiration()); + + long expirationDiff = claims.getExpiration().getTime() - claims.getIssuedAt().getTime(); + assertEquals(300000, expirationDiff, 1000); + } + + @Test + void testCreateClientAssertionWithES256() throws Exception { + String clientId = "test-client-id"; + String tokenEndpoint = "https://auth.example.com/oauth/token"; + + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId(clientId) + .tokenEndpoint(tokenEndpoint) + .privateKey(ecPrivateKey) + .algorithm("ES256") + .expirationSeconds(300) + .build(); + + String jwt = jwtService.createClientAssertion(request); + + assertNotNull(jwt); + assertTrue(jwt.split("\\.").length == 3); + } + + @Test + void testCreateClientAssertionWithDefaultAlgorithm() throws Exception { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(rsaPrivateKey) + .build(); + + String jwt = jwtService.createClientAssertion(request); + + assertNotNull(jwt); + + Claims claims = Jwts.parser() + .setSigningKey(rsaPublicKey) + .build() + .parseSignedClaims(jwt) + .getPayload(); + + assertNotNull(claims); + } + + @Test + void testCreateClientAssertionWithCustomExpiration() throws Exception { + long customExpiration = 600; + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(rsaPrivateKey) + .expirationSeconds(customExpiration) + .build(); + + String jwt = jwtService.createClientAssertion(request); + + Claims claims = Jwts.parser() + .setSigningKey(rsaPublicKey) + .build() + .parseSignedClaims(jwt) + .getPayload(); + + long expirationDiff = claims.getExpiration().getTime() - claims.getIssuedAt().getTime(); + assertEquals(600000, expirationDiff, 1000); + } + + @Test + void testCreateClientAssertionExpirationNotInPast() throws Exception { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(rsaPrivateKey) + .build(); + + String jwt = jwtService.createClientAssertion(request); + + Claims claims = Jwts.parser() + .setSigningKey(rsaPublicKey) + .build() + .parseSignedClaims(jwt) + .getPayload(); + + assertTrue(claims.getExpiration().after(new Date())); + } + + @Test + void testCreateClientAssertionUniqueJti() throws Exception { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(rsaPrivateKey) + .build(); + + String jwt1 = jwtService.createClientAssertion(request); + String jwt2 = jwtService.createClientAssertion(request); + + Claims claims1 = Jwts.parser() + .setSigningKey(rsaPublicKey) + .build() + .parseSignedClaims(jwt1) + .getPayload(); + + Claims claims2 = Jwts.parser() + .setSigningKey(rsaPublicKey) + .build() + .parseSignedClaims(jwt2) + .getPayload(); + + assertNotNull(claims1.getId()); + assertNotNull(claims2.getId()); + assertTrue(!claims1.getId().equals(claims2.getId())); + } + + @Test + void testCreateClientAssertionNullRequest() { + ServerCommonException thrown = assertThrows( + ServerCommonException.class, + () -> jwtService.createClientAssertion(null)); + + assertEquals("The ClientAssertionRequest argument is invalid", thrown.getMessage()); + } + + @Test + void testCreateClientAssertionEmptyClientId() { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(rsaPrivateKey) + .build(); + + ServerCommonException thrown = assertThrows( + ServerCommonException.class, + () -> jwtService.createClientAssertion(request)); + + assertEquals("The clientId argument is invalid", thrown.getMessage()); + } + + @Test + void testCreateClientAssertionEmptyTokenEndpoint() { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("") + .privateKey(rsaPrivateKey) + .build(); + + ServerCommonException thrown = assertThrows( + ServerCommonException.class, + () -> jwtService.createClientAssertion(request)); + + assertEquals("The tokenEndpoint argument is invalid", thrown.getMessage()); + } + + @Test + void testCreateClientAssertionNullPrivateKey() { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(null) + .build(); + + ServerCommonException thrown = assertThrows( + ServerCommonException.class, + () -> jwtService.createClientAssertion(request)); + + assertEquals("The privateKey argument is invalid", thrown.getMessage()); + } + + @Test + void testCreateClientAssertionUnsupportedAlgorithm() { + ClientAssertionRequest request = ClientAssertionRequest.builder() + .clientId("test-client") + .tokenEndpoint("https://auth.example.com/token") + .privateKey(rsaPrivateKey) + .algorithm("INVALID_ALGORITHM") + .build(); + + ServerCommonException thrown = assertThrows( + ServerCommonException.class, + () -> jwtService.createClientAssertion(request)); + + assertNotNull(thrown); + String message = thrown.getMessage(); + assertTrue(message != null && message.contains("algorithm"), + "Expected error message to contain 'algorithm', but got: " + message); + } +}