diff --git a/.github/workflows/checks.yaml b/.github/workflows/checks.yaml index 051123ae..9b98607b 100644 --- a/.github/workflows/checks.yaml +++ b/.github/workflows/checks.yaml @@ -113,7 +113,7 @@ jobs: - name: Check out and start up platform with deps/containers id: run-platform - uses: opentdf/platform/test/start-up-with-containers@main + uses: opentdf/platform/test/start-up-with-containers@fix/start-additional-kas-workflow with: platform-ref: main @@ -240,7 +240,7 @@ jobs: working-directory: cmdline - name: Start additional kas - uses: opentdf/platform/test/start-additional-kas@main + uses: opentdf/platform/test/start-additional-kas@fix/start-additional-kas-workflow with: kas-port: 8282 kas-name: beta diff --git a/cmdline/src/main/java/io/opentdf/platform/Command.java b/cmdline/src/main/java/io/opentdf/platform/Command.java index c031c4a4..b0a1bfee 100644 --- a/cmdline/src/main/java/io/opentdf/platform/Command.java +++ b/cmdline/src/main/java/io/opentdf/platform/Command.java @@ -1,6 +1,16 @@ package io.opentdf.platform; import com.google.gson.Gson; +import com.google.gson.JsonDeserializationContext; +import com.google.gson.JsonDeserializer; +import com.google.gson.JsonElement; +import com.google.gson.JsonObject; +import com.google.gson.JsonParseException; +import com.nimbusds.jose.jwk.JWK; +import com.google.gson.GsonBuilder; +import com.google.gson.reflect.TypeToken; + +import java.text.ParseException; import com.google.gson.JsonSyntaxException; import com.nimbusds.jose.JOSEException; import io.opentdf.platform.sdk.AssertionConfig; @@ -57,8 +67,38 @@ class Versions { + "\",\"tdfSpecVersion\":\"" + Versions.TDF_SPEC + "\"}") class Command { - @Option(names = { "-V", "--version" }, versionHelp = true, description = "display version info") - boolean versionInfoRequested; + private static class AssertionKeyDeserializer implements JsonDeserializer { + @Override + public AssertionConfig.AssertionKey deserialize(JsonElement json, java.lang.reflect.Type typeOfT, JsonDeserializationContext context) throws JsonParseException { + JsonObject jsonObject = json.getAsJsonObject(); + AssertionConfig.AssertionKey assertionKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.NotDefined, null); + + if (jsonObject.has("alg")) { + assertionKey.alg = context.deserialize(jsonObject.get("alg"), AssertionConfig.AssertionKeyAlg.class); + } + if (jsonObject.has("key")) { + assertionKey.key = context.deserialize(jsonObject.get("key"), Object.class); + } + if (jsonObject.has("jwk")) { + try { + assertionKey.jwk = JWK.parse(jsonObject.get("jwk").toString()); + } catch (ParseException e) { + throw new JsonParseException("Failed to parse jwk", e); + } + } + if (jsonObject.has("x5c")) { + assertionKey.x5c = context.deserialize(jsonObject.get("x5c"), new TypeToken>() {}.getType()); + } + + return assertionKey; + } + } + + private Gson buildGson() { + return new GsonBuilder() + .registerTypeAdapter(AssertionConfig.AssertionKey.class, new AssertionKeyDeserializer()) + .create(); + } private static final String PRIVATE_KEY_HEADER = "-----BEGIN PRIVATE KEY-----"; private static final String PRIVATE_KEY_FOOTER = "-----END PRIVATE KEY-----"; @@ -177,7 +217,7 @@ void encrypt( if (assertion.isPresent()) { var assertionConfig = assertion.get(); - Gson gson = new Gson(); + Gson gson = buildGson(); AssertionConfig[] assertionConfigs; try { @@ -252,7 +292,7 @@ void decrypt(@Option(names = { "-f", "--file" }, required = true) Path tdfPath, try (var stdout = new BufferedOutputStream(System.out)) { if (assertionVerification.isPresent()) { var assertionVerificationInput = assertionVerification.get(); - Gson gson = new Gson(); + Gson gson = buildGson(); AssertionVerificationKeys assertionVerificationKeys; try { diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/AssertionConfig.java b/sdk/src/main/java/io/opentdf/platform/sdk/AssertionConfig.java index e36c304c..a1e83a8d 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/AssertionConfig.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/AssertionConfig.java @@ -2,11 +2,14 @@ import com.google.gson.Gson; import com.google.gson.annotations.SerializedName; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.util.Base64; import java.net.InetAddress; import java.net.UnknownHostException; import java.time.OffsetDateTime; import java.time.format.DateTimeFormatter; +import java.util.List; import java.util.Objects; /** @@ -88,12 +91,24 @@ public String toString() { static public class AssertionKey { public Object key; public AssertionKeyAlg alg = AssertionKeyAlg.NotDefined; + public transient JWK jwk; + public transient List x5c; public AssertionKey(AssertionKeyAlg alg, Object key) { this.alg = alg; this.key = key; } + public AssertionKey withJwk(JWK jwk) { + this.jwk = jwk; + return this; + } + + public AssertionKey withX5c(List x5c) { + this.x5c = x5c; + return this; + } + public boolean isDefined() { return alg != AssertionKeyAlg.NotDefined; } diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/CryptoUtils.java b/sdk/src/main/java/io/opentdf/platform/sdk/CryptoUtils.java index e7f8c5c3..158e47be 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/CryptoUtils.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/CryptoUtils.java @@ -1,8 +1,11 @@ package io.opentdf.platform.sdk; +import com.nimbusds.jose.jwk.RSAKey; + import javax.crypto.Mac; import javax.crypto.spec.SecretKeySpec; import java.security.*; +import java.security.interfaces.RSAPublicKey; import java.security.spec.ECGenParameterSpec; import java.util.Base64; @@ -58,6 +61,15 @@ public static String getPublicKeyPEM(PublicKey publicKey) { "\r\n-----END PUBLIC KEY-----"; } + public static String getPublicKeyJWK(PublicKey publicKey) { + if (publicKey instanceof RSAPublicKey) { + RSAKey jwk = new RSAKey.Builder((RSAPublicKey) publicKey).build(); + return jwk.toString(); + } else { + throw new IllegalArgumentException("Unsupported public key algorithm: " + publicKey.getAlgorithm()); + } + } + public static String getPrivateKeyPEM(PrivateKey privateKey) { return "-----BEGIN PRIVATE KEY-----\r\n" + Base64.getMimeEncoder().encodeToString(privateKey.getEncoded()) + diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java b/sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java index 9cd94aa1..20fbb13e 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/Manifest.java @@ -20,6 +20,9 @@ import com.nimbusds.jose.crypto.MACVerifier; import com.nimbusds.jose.crypto.RSASSASigner; import com.nimbusds.jose.crypto.RSASSAVerifier; +import com.nimbusds.jose.crypto.factories.DefaultJWSVerifierFactory; +import com.nimbusds.jose.jwk.JWK; +import com.nimbusds.jose.util.X509CertUtils; import com.nimbusds.jwt.JWTClaimsSet; import com.nimbusds.jwt.SignedJWT; import io.opentdf.platform.sdk.SDK.AssertionException; @@ -33,6 +36,7 @@ import java.security.NoSuchAlgorithmException; import java.security.PrivateKey; import java.security.interfaces.RSAPublicKey; +import java.security.cert.X509Certificate; import java.text.ParseException; import java.util.ArrayList; import java.util.Base64; @@ -400,7 +404,7 @@ public void sign(final HashValues hashValues, final AssertionConfig.AssertionKey // returns the hash and the signature. It returns an error if the verification // fails. public Assertion.HashValues verify(AssertionConfig.AssertionKey assertionKey) - throws ParseException, JOSEException { + throws ParseException, JOSEException, java.security.cert.CertificateException { if (binding == null) { throw new AssertionException("Binding is null in assertion", this.id); } @@ -409,7 +413,37 @@ public Assertion.HashValues verify(AssertionConfig.AssertionKey assertionKey) binding = null; // Clear the binding after use SignedJWT signedJWT = SignedJWT.parse(signatureString); - JWSVerifier verifier = createVerifier(assertionKey); + JWSHeader header = signedJWT.getHeader(); + JWSVerifier verifier = null; + + // Check for JWK in header + if (header.getJWK() != null) { + try { + verifier = createVerifier(header.getJWK()); + } catch (JOSEException e) { + throw new SDKException("Invalid JWK in JWT header", e); + } + } + + // Check for X.509 certificate chain in header + if (verifier == null && header.getX509CertChain() != null && !header.getX509CertChain().isEmpty()) { + try { + X509Certificate cert = X509CertUtils.parse(header.getX509CertChain().get(0).decode()); + if (cert.getPublicKey() instanceof RSAPublicKey) { + verifier = createVerifier((RSAPublicKey) cert.getPublicKey()); + } else { + throw new SDKException("Unsupported public key type in X.509 certificate"); + } + } catch (IllegalArgumentException e) { + throw new SDKException("Invalid Base64 in X.509 certificate in JWT header", e); + } + } + + + if (verifier == null) { + verifier = createVerifier(assertionKey); + } + if (!signedJWT.verify(verifier)) { throw new SDKException("Unable to verify assertion signature"); @@ -424,19 +458,27 @@ public Assertion.HashValues verify(AssertionConfig.AssertionKey assertionKey) private SignedJWT createSignedJWT(final JWTClaimsSet claims, final AssertionConfig.AssertionKey assertionKey) throws SDKException { - final JWSHeader jwsHeader; + final JWSHeader.Builder headerBuilder; switch (assertionKey.alg) { case RS256: - jwsHeader = new JWSHeader.Builder(JWSAlgorithm.RS256).build(); + headerBuilder = new JWSHeader.Builder(JWSAlgorithm.RS256); break; case HS256: - jwsHeader = new JWSHeader.Builder(JWSAlgorithm.HS256).build(); + headerBuilder = new JWSHeader.Builder(JWSAlgorithm.HS256); break; default: throw new SDKException("Unknown assertion key algorithm, error signing assertion"); } - return new SignedJWT(jwsHeader, claims); + if (assertionKey.jwk != null) { + headerBuilder.jwk(assertionKey.jwk); + } + + if (assertionKey.x5c != null) { + headerBuilder.x509CertChain(assertionKey.x5c); + } + + return new SignedJWT(headerBuilder.build(), claims); } private JWSSigner createSigner(final AssertionConfig.AssertionKey assertionKey) @@ -460,13 +502,30 @@ private JWSSigner createSigner(final AssertionConfig.AssertionKey assertionKey) private JWSVerifier createVerifier(AssertionConfig.AssertionKey assertionKey) throws JOSEException { switch (assertionKey.alg) { case RS256: - return new RSASSAVerifier((RSAPublicKey) assertionKey.key); + if (assertionKey.key instanceof JWK) { + return createVerifier((JWK) assertionKey.key); + } else if (assertionKey.key instanceof RSAPublicKey) { + return createVerifier((RSAPublicKey) assertionKey.key); + } else { + throw new SDKException("Expected JWK or RSAPublicKey for RS256 algorithm"); + } case HS256: return new MACVerifier((byte[]) assertionKey.key); default: throw new SDKException("Unknown verify key, unable to verify assertion signature"); } } + + private JWSVerifier createVerifier(JWK jwk) throws JOSEException { + if (jwk instanceof com.nimbusds.jose.jwk.RSAKey) { + return new RSASSAVerifier(jwk.toRSAKey()); + } + throw new JOSEException("Unsupported JWK type: " + jwk.getKeyType() + ". Only RSA keys are supported."); + } + + private JWSVerifier createVerifier(RSAPublicKey publicKey) { + return new RSASSAVerifier(publicKey); + } } public static class AssertionValueAdapter implements JsonDeserializer { diff --git a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java index 2ae08f5b..ee8a495c 100644 --- a/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java +++ b/sdk/src/main/java/io/opentdf/platform/sdk/TDF.java @@ -695,7 +695,7 @@ Reader loadTDF(SeekableByteChannel tdf, Config.TDFReaderConfig tdfReaderConfig) Manifest.Assertion.HashValues hashValues; try { hashValues = assertion.verify(assertionKey); - } catch (ParseException | JOSEException e) { + } catch (ParseException | JOSEException | java.security.cert.CertificateException e) { throw new SDKException("error validating assertion hash", e); } var hashOfAssertionAsHex = assertion.hash(); diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java b/sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java index 3c58de4d..ca20a7ac 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/TDFTest.java @@ -4,6 +4,7 @@ import com.connectrpc.UnaryBlockingCall; import com.google.gson.reflect.TypeToken; import com.nimbusds.jose.JOSEException; +import com.nimbusds.jose.jwk.JWK; import com.google.gson.Gson; import io.opentdf.platform.policy.KeyAccessServer; import io.opentdf.platform.policy.kasregistry.KeyAccessServerRegistryServiceClient; @@ -24,6 +25,7 @@ import java.nio.charset.StandardCharsets; import java.security.KeyPair; import java.security.SecureRandom; +import java.security.cert.X509Certificate; import java.util.ArrayList; import java.util.Base64; import java.util.Collections; @@ -44,788 +46,979 @@ import static org.assertj.core.api.Assertions.assertThat; public class TDFTest { - protected static KeyAccessServerRegistryServiceClient kasRegistryService; - protected static String platformUrl = "http://localhost:8080"; + protected static KeyAccessServerRegistryServiceClient kasRegistryService; + protected static String platformUrl = "http://localhost:8080"; - protected static SDK.KAS kas = new SDK.KAS() { - @Override - public void close() { - } - - @Override - public Config.KASInfo getPublicKey(Config.KASInfo kasInfo) { - // handle platform url - int index; - // if the kasinfo url contains the platform url, remove it - if (kasInfo.URL.startsWith(platformUrl)) { - index = Integer.parseInt(kasInfo.URL - .replaceFirst("^" + Pattern.quote(platformUrl) + "/kas", "")); - } else { - index = Integer.parseInt(kasInfo.URL.replaceFirst("^https://example.com/kas", "")); - } - var kiCopy = new Config.KASInfo(); - kiCopy.KID = "r1"; - kiCopy.PublicKey = CryptoUtils.getPublicKeyPEM(keypairs.get(index).getPublic()); - kiCopy.URL = kasInfo.URL; - return kiCopy; - } - - @Override - public byte[] unwrap(Manifest.KeyAccess keyAccess, String policy, KeyType sessionKeyType) { - - try { - int index; - // if the keyAccess.url contains the platform url, remove it - if (keyAccess.url.startsWith(platformUrl)) { - index = Integer.parseInt(keyAccess.url - .replaceFirst("^" + Pattern.quote(platformUrl) + "/kas", "")); - } else { - index = Integer.parseInt( - keyAccess.url.replaceFirst("^https://example.com/kas", "")); - } - var bytes = Base64.getDecoder().decode(keyAccess.wrappedKey); - if (sessionKeyType.isEc()) { - var kasPrivateKey = CryptoUtils - .getPrivateKeyPEM(keypairs.get(index).getPrivate()); - var privateKey = ECKeyPair.privateKeyFromPem(kasPrivateKey); - var clientEphemeralPublicKey = keyAccess.ephemeralPublicKey; - var publicKey = ECKeyPair.publicKeyFromPem(clientEphemeralPublicKey); - byte[] symKey = ECKeyPair.computeECDHKey(publicKey, privateKey); - - var sessionKey = ECKeyPair.calculateHKDF(GLOBAL_KEY_SALT, symKey); - - AesGcm gcm = new AesGcm(sessionKey); - AesGcm.Encrypted encrypted = new AesGcm.Encrypted(bytes); - return gcm.decrypt(encrypted); - } else { - var decryptor = new AsymDecryption(keypairs.get(index).getPrivate()); - return decryptor.decrypt(bytes); - } - } catch (Exception e) { - throw new RuntimeException(e); - } - } - - @Override - public KASKeyCache getKeyCache() { - return new KASKeyCache(); - } - }; - - private static ArrayList keypairs = new ArrayList<>(); - - @BeforeAll - static void setupKeyPairsAndMocks() { - for (int i = 0; i < 2 + new Random().nextInt(5); i++) { - if (i % 2 == 0) { - keypairs.add(CryptoUtils.generateRSAKeypair()); - } else { - keypairs.add(CryptoUtils.generateECKeypair(KeyType.EC256Key.getECCurve().getCurveName())); - } - } - - kasRegistryService = mock(KeyAccessServerRegistryServiceClient.class); - List kasRegEntries = new ArrayList<>(); - for (Config.KASInfo kasInfo : getRSAKASInfos()) { - kasRegEntries.add(KeyAccessServer.newBuilder() - .setUri(kasInfo.URL).build()); - } - for (Config.KASInfo kasInfo : getECKASInfos()) { - kasRegEntries.add(KeyAccessServer.newBuilder() - .setUri(kasInfo.URL).build()); - } - ListKeyAccessServersResponse mockResponse = ListKeyAccessServersResponse.newBuilder() - .addAllKeyAccessServers(kasRegEntries) - .build(); - - // Stub the listKeyAccessServers method - when(kasRegistryService.listKeyAccessServersBlocking(any(ListKeyAccessServersRequest.class), any())) - .thenReturn(new UnaryBlockingCall<>() { - @Override - public ResponseMessage execute() { - return new ResponseMessage.Success<>(mockResponse, - Collections.emptyMap(), - Collections.emptyMap()); - } - - @Override - public void cancel() { - // this never happens in tests - } - }); + protected static SDK.KAS kas = new SDK.KAS() { + @Override + public void close() { } - @Test - void testSimpleTDFEncryptAndDecrypt() throws Exception { - - class TDFConfigPair { - public final Config.TDFConfig tdfConfig; - public final Config.TDFReaderConfig tdfReaderConfig; + @Override + public Config.KASInfo getPublicKey(Config.KASInfo kasInfo) { + // handle platform url + int index; + // if the kasinfo url contains the platform url, remove it + if (kasInfo.URL.startsWith(platformUrl)) { + index = Integer.parseInt(kasInfo.URL + .replaceFirst("^" + Pattern.quote(platformUrl) + "/kas", "")); + } else { + index = Integer.parseInt(kasInfo.URL.replaceFirst("^https://example.com/kas", "")); + } + var kiCopy = new Config.KASInfo(); + kiCopy.KID = "r1"; + kiCopy.PublicKey = CryptoUtils.getPublicKeyPEM(keypairs.get(index).getPublic()); + kiCopy.URL = kasInfo.URL; + return kiCopy; + } - public TDFConfigPair(Config.TDFConfig tdfConfig, Config.TDFReaderConfig tdfReaderConfig) { - this.tdfConfig = tdfConfig; - this.tdfReaderConfig = tdfReaderConfig; - } + @Override + public byte[] unwrap(Manifest.KeyAccess keyAccess, String policy, KeyType sessionKeyType) { + + try { + int index; + // if the keyAccess.url contains the platform url, remove it + if (keyAccess.url.startsWith(platformUrl)) { + index = Integer.parseInt(keyAccess.url + .replaceFirst("^" + Pattern.quote(platformUrl) + "/kas", "")); + } else { + index = Integer.parseInt( + keyAccess.url.replaceFirst("^https://example.com/kas", "")); } - - SecureRandom secureRandom = new SecureRandom(); - byte[] key = new byte[32]; - secureRandom.nextBytes(key); - - var assertion1 = new AssertionConfig(); - assertion1.id = "assertion1"; - assertion1.type = AssertionConfig.Type.BaseAssertion; - assertion1.scope = AssertionConfig.Scope.TrustedDataObj; - assertion1.appliesToState = AssertionConfig.AppliesToState.Unencrypted; - assertion1.statement = new AssertionConfig.Statement(); - assertion1.statement.format = "base64binary"; - assertion1.statement.schema = "text"; - assertion1.statement.value = "ICAgIDxlZGoOkVkaD4="; - assertion1.signingKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256, key); - - var assertionVerificationKeys = new Config.AssertionVerificationKeys(); - assertionVerificationKeys.defaultKey = new AssertionConfig.AssertionKey( - AssertionConfig.AssertionKeyAlg.HS256, - key); - - List tdfConfigPairs = List.of( - new TDFConfigPair( - Config.newTDFConfig(Config.withAutoconfigure(false), - Config.withKasInformation(getRSAKASInfos()), - Config.withMetaData("here is some metadata"), - Config.withDataAttributes( - "https://example.org/attr/a/value/b", - "https://example.org/attr/c/value/d"), - Config.withAssertionConfig(assertion1)), - Config.newTDFReaderConfig(Config.withAssertionVerificationKeys( - assertionVerificationKeys))), - new TDFConfigPair( - Config.newTDFConfig(Config.withAutoconfigure(false), - Config.withKasInformation(getECKASInfos()), - Config.withMetaData("here is some metadata"), - Config.WithWrappingKeyAlg(KeyType.EC256Key), - Config.withDataAttributes( - "https://example.org/attr/a/value/b", - "https://example.org/attr/c/value/d"), - Config.withAssertionConfig(assertion1)), - Config.newTDFReaderConfig( - Config.withAssertionVerificationKeys( - assertionVerificationKeys), - Config.WithSessionKeyType(KeyType.EC256Key)))); - - for (TDFConfigPair configPair : tdfConfigPairs) { - String plainText = "this is extremely sensitive stuff!!!"; - InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); - ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); - - TDF tdf = new TDF(new FakeServicesBuilder().setKas(kas) - .setKeyAccessServerRegistryService(kasRegistryService).build()); - var manifest = tdf.createTDF(plainTextInputStream, tdfOutputStream, configPair.tdfConfig) - .getManifest(); - - assertThat(manifest.assertions).asList().hasSize(1); - var assertion = manifest.assertions.get(0); - assertThat(assertion.appliesToState).isEqualTo("unencrypted"); - assertThat(assertion.type).isEqualTo("other"); - assertThat(assertion.statement.value).isEqualTo("ICAgIDxlZGoOkVkaD4="); - assertThat(assertion.statement.schema).isEqualTo("text"); - assertThat(assertion.statement.format).isEqualTo("base64binary"); - - assertThat(manifest.payload.isEncrypted).isTrue(); - var size = manifest.encryptionInformation.integrityInformation.segments.stream() - .map(s -> s.segmentSize) - .reduce(0L, Long::sum); - assertThat(size).isEqualTo(plainText.getBytes().length); - - var unwrappedData = new ByteArrayOutputStream(); - var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), - configPair.tdfReaderConfig, platformUrl); - assertThat(reader.getManifest().payload.mimeType).isEqualTo("application/octet-stream"); - - reader.readPayload(unwrappedData); - - assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) - .withFailMessage("extracted data does not match") - .isEqualTo(plainText); - assertThat(reader.getMetadata()).isEqualTo("here is some metadata"); - - var policyObject = reader.readPolicyObject(); - assertThat(policyObject).isNotNull(); - assertThat(policyObject.body.dataAttributes.stream().map(a -> a.attribute) - .collect(Collectors.toList())) - .asList() - .containsExactlyInAnyOrder("https://example.org/attr/a/value/b", - "https://example.org/attr/c/value/d"); + var bytes = Base64.getDecoder().decode(keyAccess.wrappedKey); + if (sessionKeyType.isEc()) { + var kasPrivateKey = CryptoUtils + .getPrivateKeyPEM(keypairs.get(index).getPrivate()); + var privateKey = ECKeyPair.privateKeyFromPem(kasPrivateKey); + var clientEphemeralPublicKey = keyAccess.ephemeralPublicKey; + var publicKey = ECKeyPair.publicKeyFromPem(clientEphemeralPublicKey); + byte[] symKey = ECKeyPair.computeECDHKey(publicKey, privateKey); + + var sessionKey = ECKeyPair.calculateHKDF(GLOBAL_KEY_SALT, symKey); + + AesGcm gcm = new AesGcm(sessionKey); + AesGcm.Encrypted encrypted = new AesGcm.Encrypted(bytes); + return gcm.decrypt(encrypted); + } else { + var decryptor = new AsymDecryption(keypairs.get(index).getPrivate()); + return decryptor.decrypt(bytes); } + } catch (Exception e) { + throw new RuntimeException(e); + } } - @Test - void testSimpleTDFWithAssertionWithRS256() throws Exception { - String assertion1Id = "assertion1"; - var keypair = CryptoUtils.generateRSAKeypair(); - var assertionConfig = new AssertionConfig(); - assertionConfig.id = assertion1Id; - assertionConfig.type = AssertionConfig.Type.BaseAssertion; - assertionConfig.scope = AssertionConfig.Scope.TrustedDataObj; - assertionConfig.appliesToState = AssertionConfig.AppliesToState.Unencrypted; - assertionConfig.statement = new AssertionConfig.Statement(); - assertionConfig.statement.format = "base64binary"; - assertionConfig.statement.schema = "text"; - assertionConfig.statement.value = "ICAgIDxlZGoOkVkaD4="; - assertionConfig.signingKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, - keypair.getPrivate()); - - var rsaKasInfo = new Config.KASInfo(); - rsaKasInfo.URL = "https://example.com/kas" + 0; - - Config.TDFConfig config = Config.newTDFConfig( - Config.withAutoconfigure(false), - Config.withKasInformation(rsaKasInfo), - Config.withSystemMetadataAssertion(), - Config.withAssertionConfig(assertionConfig)); - - String plainText = "this is extremely sensitive stuff!!!"; - InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); - ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); - - TDF tdf = new TDF( - new FakeServicesBuilder().setKas(kas) - .setKeyAccessServerRegistryService(kasRegistryService).build()); - tdf.createTDF(plainTextInputStream, tdfOutputStream, config); - - var assertionVerificationKeys = new Config.AssertionVerificationKeys(); - assertionVerificationKeys.keys.put(assertion1Id, - new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, - keypair.getPublic())); - - var unwrappedData = new ByteArrayOutputStream(); - Config.TDFReaderConfig readerConfig = Config.newTDFReaderConfig( - Config.withAssertionVerificationKeys(assertionVerificationKeys)); - var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), readerConfig, - platformUrl); - reader.readPayload(unwrappedData); - - assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) - .withFailMessage("extracted data does not match") - .isEqualTo(plainText); + @Override + public KASKeyCache getKeyCache() { + return new KASKeyCache(); } - - @Test - void testWithAssertionVerificationDisabled() throws Exception { - String assertion1Id = "assertion1"; - var keypair = CryptoUtils.generateRSAKeypair(); - var assertionConfig = new AssertionConfig(); - assertionConfig.id = assertion1Id; - assertionConfig.type = AssertionConfig.Type.BaseAssertion; - assertionConfig.scope = AssertionConfig.Scope.TrustedDataObj; - assertionConfig.appliesToState = AssertionConfig.AppliesToState.Unencrypted; - assertionConfig.statement = new AssertionConfig.Statement(); - assertionConfig.statement.format = "base64binary"; - assertionConfig.statement.schema = "text"; - assertionConfig.statement.value = "ICAgIDxlZGoOkVkaD4="; - assertionConfig.signingKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, - keypair.getPrivate()); - - Config.TDFConfig config = Config.newTDFConfig( - Config.withAutoconfigure(false), - Config.withKasInformation(getRSAKASInfos()), - Config.withAssertionConfig(assertionConfig)); - - String plainText = "this is extremely sensitive stuff!!!"; - InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); - ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); - - TDF tdf = new TDF( - new FakeServicesBuilder().setKas(kas) - .setKeyAccessServerRegistryService(kasRegistryService).build()); - tdf.createTDF(plainTextInputStream, tdfOutputStream, config); - - var assertionVerificationKeys = new Config.AssertionVerificationKeys(); - assertionVerificationKeys.keys.put(assertion1Id, - new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, - keypair.getPublic())); - - var unwrappedData = new ByteArrayOutputStream(); - var dataToUnwrap = new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()); - var emptyConfig = Config.newTDFReaderConfig(); - var thrown = assertThrows(SDKException.class, () -> { - tdf.loadTDF(dataToUnwrap, emptyConfig, platformUrl); - }); - assertThat(thrown.getCause()).isInstanceOf(JOSEException.class); - - // try with assertion verification disabled and not passing the assertion - // verification keys - Config.TDFReaderConfig readerConfig = Config.newTDFReaderConfig( - Config.withDisableAssertionVerification(true)); - var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), readerConfig, - platformUrl); - reader.readPayload(unwrappedData); - - assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) - .withFailMessage("extracted data does not match") - .isEqualTo(plainText); + }; + + private static ArrayList keypairs = new ArrayList<>(); + + @BeforeAll + static void setupKeyPairsAndMocks() { + for (int i = 0; i < 2 + new Random().nextInt(5); i++) { + if (i % 2 == 0) { + keypairs.add(CryptoUtils.generateRSAKeypair()); + } else { + keypairs.add(CryptoUtils.generateECKeypair(KeyType.EC256Key.getECCurve().getCurveName())); + } } - @Test - void testSimpleTDFWithAssertionWithHS256() throws Exception { - String assertion1Id = "assertion1"; - var assertionConfig1 = new AssertionConfig(); - assertionConfig1.id = assertion1Id; - assertionConfig1.type = AssertionConfig.Type.BaseAssertion; - assertionConfig1.scope = AssertionConfig.Scope.TrustedDataObj; - assertionConfig1.appliesToState = AssertionConfig.AppliesToState.Unencrypted; - assertionConfig1.statement = new AssertionConfig.Statement(); - assertionConfig1.statement.format = "base64binary"; - assertionConfig1.statement.schema = "text"; - assertionConfig1.statement.value = "ICAgIDxlZGoOkVkaD4="; - - String assertion2Id = "assertion2"; - var assertionConfig2 = new AssertionConfig(); - assertionConfig2.id = assertion2Id; - assertionConfig2.type = AssertionConfig.Type.HandlingAssertion; - assertionConfig2.scope = AssertionConfig.Scope.TrustedDataObj; - assertionConfig2.appliesToState = AssertionConfig.AppliesToState.Unencrypted; - assertionConfig2.statement = new AssertionConfig.Statement(); - assertionConfig2.statement.format = "json"; - assertionConfig2.statement.schema = "urn:nato:stanag:5636:A:1:elements:json"; - assertionConfig2.statement.value = "{\"uuid\":\"f74efb60-4a9a-11ef-a6f1-8ee1a61c148a\",\"body\":{\"dataAttributes\":null,\"dissem\":null}}"; - - var rsaKasInfo = new Config.KASInfo(); - rsaKasInfo.URL = "https://example.com/kas" + 0; - - Config.TDFConfig config = Config.newTDFConfig( - Config.withAutoconfigure(false), - Config.withKasInformation(rsaKasInfo), - Config.withAssertionConfig(assertionConfig1, assertionConfig2)); - - String plainText = "this is extremely sensitive stuff!!!"; - InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); - ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); - - TDF tdf = new TDF( - new FakeServicesBuilder().setKas(kas) - .setKeyAccessServerRegistryService(kasRegistryService).build()); - tdf.createTDF(plainTextInputStream, tdfOutputStream, config); - - var unwrappedData = new ByteArrayOutputStream(); - var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), - Config.newTDFReaderConfig(), platformUrl); - reader.readPayload(unwrappedData); - - assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) - .withFailMessage("extracted data does not match") - .isEqualTo(plainText); - - var manifest = reader.getManifest(); - var assertions = manifest.assertions; - assertThat(assertions.size()).isEqualTo(2); - for (var assertion : assertions) { - if (assertion.id.equals(assertion1Id)) { - assertThat(assertion.statement.format).isEqualTo("base64binary"); - assertThat(assertion.statement.schema).isEqualTo("text"); - assertThat(assertion.statement.value).isEqualTo("ICAgIDxlZGoOkVkaD4="); - assertThat(assertion.type).isEqualTo(AssertionConfig.Type.BaseAssertion.toString()); - } else if (assertion.id.equals(assertion2Id)) { - assertThat(assertion.statement.format).isEqualTo("json"); - assertThat(assertion.statement.schema) - .isEqualTo("urn:nato:stanag:5636:A:1:elements:json"); - assertThat(assertion.statement.value).isEqualTo( - "{\"uuid\":\"f74efb60-4a9a-11ef-a6f1-8ee1a61c148a\",\"body\":{\"dataAttributes\":null,\"dissem\":null}}"); - assertThat(assertion.type).isEqualTo(AssertionConfig.Type.HandlingAssertion.toString()); - } else { - throw new RuntimeException("unexpected assertion id: " + assertion.id); - } - } + kasRegistryService = mock(KeyAccessServerRegistryServiceClient.class); + List kasRegEntries = new ArrayList<>(); + for (Config.KASInfo kasInfo : getRSAKASInfos()) { + kasRegEntries.add(KeyAccessServer.newBuilder() + .setUri(kasInfo.URL).build()); } - - @Test - void testSimpleTDFWithAssertionWithHS256Failure() throws Exception { - // var keypair = CryptoUtils.generateRSAKeypair(); - SecureRandom secureRandom = new SecureRandom(); - byte[] key = new byte[32]; - secureRandom.nextBytes(key); - - String assertion1Id = "assertion1"; - var assertionConfig1 = new AssertionConfig(); - assertionConfig1.id = assertion1Id; - assertionConfig1.type = AssertionConfig.Type.BaseAssertion; - assertionConfig1.scope = AssertionConfig.Scope.TrustedDataObj; - assertionConfig1.appliesToState = AssertionConfig.AppliesToState.Unencrypted; - assertionConfig1.statement = new AssertionConfig.Statement(); - assertionConfig1.statement.format = "base64binary"; - assertionConfig1.statement.schema = "text"; - assertionConfig1.statement.value = "ICAgIDxlZGoOkVkaD4="; - assertionConfig1.signingKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256, - key); - - var rsaKasInfo = new Config.KASInfo(); - rsaKasInfo.URL = "https://example.com/kas" + 0; - - Config.TDFConfig config = Config.newTDFConfig( - Config.withAutoconfigure(false), - Config.withKasInformation(rsaKasInfo), - Config.withAssertionConfig(assertionConfig1)); - - String plainText = "this is extremely sensitive stuff!!!"; - InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); - ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); - - TDF tdf = new TDF( - new FakeServicesBuilder().setKas(kas) - .setKeyAccessServerRegistryService(kasRegistryService).build()); - tdf.createTDF(plainTextInputStream, tdfOutputStream, config); - - byte[] notkey = new byte[32]; - secureRandom.nextBytes(notkey); - var assertionVerificationKeys = new Config.AssertionVerificationKeys(); - assertionVerificationKeys.defaultKey = new AssertionConfig.AssertionKey( - AssertionConfig.AssertionKeyAlg.HS256, - notkey); - Config.TDFReaderConfig readerConfig = Config.newTDFReaderConfig( - Config.withAssertionVerificationKeys(assertionVerificationKeys)); - - try { - tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), readerConfig, - platformUrl); - throw new RuntimeException("assertion verify key error thrown"); - - } catch (SDKException e) { - assertThat(e).hasMessageContaining("verify"); - } + for (Config.KASInfo kasInfo : getECKASInfos()) { + kasRegEntries.add(KeyAccessServer.newBuilder() + .setUri(kasInfo.URL).build()); } + ListKeyAccessServersResponse mockResponse = ListKeyAccessServersResponse.newBuilder() + .addAllKeyAccessServers(kasRegEntries) + .build(); + + // Stub the listKeyAccessServers method + when(kasRegistryService.listKeyAccessServersBlocking(any(ListKeyAccessServersRequest.class), any())) + .thenReturn(new UnaryBlockingCall<>() { + @Override + public ResponseMessage execute() { + return new ResponseMessage.Success<>(mockResponse, + Collections.emptyMap(), + Collections.emptyMap()); + } + + @Override + public void cancel() { + // this never happens in tests + } + }); + } - @Test - public void testCreatingTDFWithMultipleSegments() throws Exception { - var random = new Random(); + @Test + void testSimpleTDFEncryptAndDecrypt() throws Exception { - Config.TDFConfig config = Config.newTDFConfig( - Config.withAutoconfigure(false), - Config.withKasInformation(getRSAKASInfos()), - Config.withSegmentSize(Config.MIN_SEGMENT_SIZE)); - - // data should be large enough to have multiple complete and a partial segment - var data = new byte[(int) (Config.MIN_SEGMENT_SIZE * 2.8)]; - random.nextBytes(data); - var plainTextInputStream = new ByteArrayInputStream(data); - var tdfOutputStream = new ByteArrayOutputStream(); - var tdf = new TDF( - new FakeServicesBuilder().setKas(kas) - .setKeyAccessServerRegistryService(kasRegistryService).build()); - tdf.createTDF(plainTextInputStream, tdfOutputStream, config); - var unwrappedData = new ByteArrayOutputStream(); - var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), platformUrl); - reader.readPayload(unwrappedData); - - assertThat(unwrappedData.toByteArray()) - .withFailMessage("extracted data does not match") - .containsExactly(data); + class TDFConfigPair { + public final Config.TDFConfig tdfConfig; + public final Config.TDFReaderConfig tdfReaderConfig; + public TDFConfigPair(Config.TDFConfig tdfConfig, Config.TDFReaderConfig tdfReaderConfig) { + this.tdfConfig = tdfConfig; + this.tdfReaderConfig = tdfReaderConfig; + } } - @Test - public void testCreatingTooLargeTDF() { - var random = new Random(); - var maxSize = random.nextInt(1024); - var numReturned = new AtomicInteger(0); - - // return 1 more byte than the maximum size - var is = new InputStream() { - @Override - public int read() { - if (numReturned.get() > maxSize) { - return -1; - } - numReturned.incrementAndGet(); - return 1; - } - - @Override - public int read(byte[] b, int off, int len) { - var numToReturn = Math.min(len, maxSize - numReturned.get() + 1); - numReturned.addAndGet(numToReturn); - return numToReturn; - } - }; - - var os = new OutputStream() { - @Override - public void write(int b) { - } - - @Override - public void write(byte[] b, int off, int len) { - } - }; - - var tdf = new TDF(maxSize, new FakeServicesBuilder().setKas(kas).build()); - var tdfConfig = Config.newTDFConfig( - Config.withAutoconfigure(false), + SecureRandom secureRandom = new SecureRandom(); + byte[] key = new byte[32]; + secureRandom.nextBytes(key); + + var assertion1 = new AssertionConfig(); + assertion1.id = "assertion1"; + assertion1.type = AssertionConfig.Type.BaseAssertion; + assertion1.scope = AssertionConfig.Scope.TrustedDataObj; + assertion1.appliesToState = AssertionConfig.AppliesToState.Unencrypted; + assertion1.statement = new AssertionConfig.Statement(); + assertion1.statement.format = "base64binary"; + assertion1.statement.schema = "text"; + assertion1.statement.value = "ICAgIDxlZGoOkVkaD4="; + assertion1.signingKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256, key); + + var assertionVerificationKeys = new Config.AssertionVerificationKeys(); + assertionVerificationKeys.defaultKey = new AssertionConfig.AssertionKey( + AssertionConfig.AssertionKeyAlg.HS256, + key); + + List tdfConfigPairs = List.of( + new TDFConfigPair( + Config.newTDFConfig(Config.withAutoconfigure(false), Config.withKasInformation(getRSAKASInfos()), - Config.withSegmentSize(Config.MIN_SEGMENT_SIZE)); - assertThrows(SDK.DataSizeNotSupported.class, - () -> tdf.createTDF(is, os, tdfConfig), - "didn't throw an exception when we created TDF that was too large"); - assertThat(numReturned.get()) - .withFailMessage("test returned the wrong number of bytes") - .isEqualTo(maxSize + 1); + Config.withMetaData("here is some metadata"), + Config.withDataAttributes( + "https://example.org/attr/a/value/b", + "https://example.org/attr/c/value/d"), + Config.withAssertionConfig(assertion1)), + Config.newTDFReaderConfig(Config.withAssertionVerificationKeys( + assertionVerificationKeys))), + new TDFConfigPair( + Config.newTDFConfig(Config.withAutoconfigure(false), + Config.withKasInformation(getECKASInfos()), + Config.withMetaData("here is some metadata"), + Config.WithWrappingKeyAlg(KeyType.EC256Key), + Config.withDataAttributes( + "https://example.org/attr/a/value/b", + "https://example.org/attr/c/value/d"), + Config.withAssertionConfig(assertion1)), + Config.newTDFReaderConfig( + Config.withAssertionVerificationKeys( + assertionVerificationKeys), + Config.WithSessionKeyType(KeyType.EC256Key)))); + + for (TDFConfigPair configPair : tdfConfigPairs) { + String plainText = "this is extremely sensitive stuff!!!"; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF(new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + var manifest = tdf.createTDF(plainTextInputStream, tdfOutputStream, configPair.tdfConfig) + .getManifest(); + + assertThat(manifest.assertions).asList().hasSize(1); + var assertion = manifest.assertions.get(0); + assertThat(assertion.appliesToState).isEqualTo("unencrypted"); + assertThat(assertion.type).isEqualTo("other"); + assertThat(assertion.statement.value).isEqualTo("ICAgIDxlZGoOkVkaD4="); + assertThat(assertion.statement.schema).isEqualTo("text"); + assertThat(assertion.statement.format).isEqualTo("base64binary"); + + assertThat(manifest.payload.isEncrypted).isTrue(); + var size = manifest.encryptionInformation.integrityInformation.segments.stream() + .map(s -> s.segmentSize) + .reduce(0L, Long::sum); + assertThat(size).isEqualTo(plainText.getBytes().length); + + var unwrappedData = new ByteArrayOutputStream(); + var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), + configPair.tdfReaderConfig, platformUrl); + assertThat(reader.getManifest().payload.mimeType).isEqualTo("application/octet-stream"); + + reader.readPayload(unwrappedData); + + assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) + .withFailMessage("extracted data does not match") + .isEqualTo(plainText); + assertThat(reader.getMetadata()).isEqualTo("here is some metadata"); + + var policyObject = reader.readPolicyObject(); + assertThat(policyObject).isNotNull(); + assertThat(policyObject.body.dataAttributes.stream().map(a -> a.attribute) + .collect(Collectors.toList())) + .asList() + .containsExactlyInAnyOrder("https://example.org/attr/a/value/b", + "https://example.org/attr/c/value/d"); } - - @Test - public void testCreateTDFWithMimeType() throws Exception { - final String mimeType = "application/pdf"; - - Config.TDFConfig config = Config.newTDFConfig( - Config.withAutoconfigure(false), - Config.withKasInformation(getRSAKASInfos()), - Config.withMimeType(mimeType)); - - String plainText = "this is extremely sensitive stuff!!!"; - InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); - ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); - - TDF tdf = new TDF( - new FakeServicesBuilder().setKas(kas) - .setKeyAccessServerRegistryService(kasRegistryService).build()); - tdf.createTDF(plainTextInputStream, tdfOutputStream, config); - - var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), platformUrl); - assertThat(reader.getManifest().payload.mimeType).isEqualTo(mimeType); - } - - @Test - void legacyTDFRoundTrips() throws IOException { - final String mimeType = "application/pdf"; - var assertionConfig1 = new AssertionConfig(); - assertionConfig1.id = "assertion1"; - assertionConfig1.type = AssertionConfig.Type.BaseAssertion; - assertionConfig1.scope = AssertionConfig.Scope.TrustedDataObj; - assertionConfig1.appliesToState = AssertionConfig.AppliesToState.Unencrypted; - assertionConfig1.statement = new AssertionConfig.Statement(); - assertionConfig1.statement.format = "base64binary"; - assertionConfig1.statement.schema = "text"; - assertionConfig1.statement.value = "ICAgIDxlZGoOkVkaD4="; - - Config.TDFConfig config = Config.newTDFConfig( - Config.withAutoconfigure(false), - Config.withKasInformation(getRSAKASInfos()), - Config.withTargetMode("4.2.1"), - Config.withAssertionConfig(assertionConfig1), - Config.withMimeType(mimeType)); - - byte[] data = new byte[129]; - new Random().nextBytes(data); - InputStream plainTextInputStream = new ByteArrayInputStream(data); - ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); - - TDF tdf = new TDF( - new FakeServicesBuilder().setKas(kas) - .setKeyAccessServerRegistryService(kasRegistryService).build()); - tdf.createTDF(plainTextInputStream, tdfOutputStream, config); - - var dataOutputStream = new ByteArrayOutputStream(); - - var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), platformUrl); - var integrityInformation = reader.getManifest().encryptionInformation.integrityInformation; - assertThat(reader.getManifest().tdfVersion).isNull(); - var decodedSignature = Base64.getDecoder().decode(integrityInformation.rootSignature.signature); - for (var b : decodedSignature) { - assertThat(isHexChar(b)) - .withFailMessage("non-hex byte in signature: " + b) - .isTrue(); - } - for (var s : integrityInformation.segments) { - var decodedSegmentSignature = Base64.getDecoder().decode(s.hash); - for (var b : decodedSegmentSignature) { - assertThat(isHexChar(b)) - .withFailMessage("non-hex byte in segment signature: " + b) - .isTrue(); - } - } - reader.readPayload(dataOutputStream); - assertThat(reader.getManifest().payload.mimeType).isEqualTo(mimeType); - assertArrayEquals(data, dataOutputStream.toByteArray(), "extracted data does not match"); - var manifest = reader.getManifest(); - var assertions = manifest.assertions; - assertThat(assertions.size()).isEqualTo(1); - var assertion = assertions.get(0); - assertThat(assertion.id).isEqualTo("assertion1"); + } + + @Test + void testSimpleTDFWithAssertionWithRS256() throws Exception { + String assertion1Id = "assertion1"; + var keypair = CryptoUtils.generateRSAKeypair(); + var assertionConfig = new AssertionConfig(); + assertionConfig.id = assertion1Id; + assertionConfig.type = AssertionConfig.Type.BaseAssertion; + assertionConfig.scope = AssertionConfig.Scope.TrustedDataObj; + assertionConfig.appliesToState = AssertionConfig.AppliesToState.Unencrypted; + assertionConfig.statement = new AssertionConfig.Statement(); + assertionConfig.statement.format = "base64binary"; + assertionConfig.statement.schema = "text"; + assertionConfig.statement.value = "ICAgIDxlZGoOkVkaD4="; + assertionConfig.signingKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, + keypair.getPrivate()); + + var rsaKasInfo = new Config.KASInfo(); + rsaKasInfo.URL = "https://example.com/kas" + 0; + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(rsaKasInfo), + Config.withSystemMetadataAssertion(), + Config.withAssertionConfig(assertionConfig)); + + String plainText = "this is extremely sensitive stuff!!!"; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF( + new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + + var assertionVerificationKeys = new Config.AssertionVerificationKeys(); + assertionVerificationKeys.keys.put(assertion1Id, + new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, + keypair.getPublic())); + + var unwrappedData = new ByteArrayOutputStream(); + Config.TDFReaderConfig readerConfig = Config.newTDFReaderConfig( + Config.withAssertionVerificationKeys(assertionVerificationKeys)); + var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), readerConfig, + platformUrl); + reader.readPayload(unwrappedData); + + assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) + .withFailMessage("extracted data does not match") + .isEqualTo(plainText); + } + + @Test + void testSimpleTDFWithAssertionWithJWK() throws Exception { + var keypair = CryptoUtils.generateRSAKeypair(); + var assertionConfig = new AssertionConfig(); + assertionConfig.type = AssertionConfig.Type.BaseAssertion; + assertionConfig.scope = AssertionConfig.Scope.TrustedDataObj; + assertionConfig.appliesToState = AssertionConfig.AppliesToState.Unencrypted; + assertionConfig.statement = new AssertionConfig.Statement(); + assertionConfig.statement.format = "base64binary"; + assertionConfig.statement.schema = "text"; + assertionConfig.statement.value = "ICAgIDxlZGoOkVkaD4="; + + JWK jwk = JWK.parse(CryptoUtils.getPublicKeyJWK(keypair.getPublic())); + assertionConfig.signingKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, + keypair.getPrivate()); + + var rsaKasInfo = new Config.KASInfo(); + rsaKasInfo.URL = "https://example.com/kas" + 0; + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(rsaKasInfo), + Config.withAssertionConfig(assertionConfig)); + + String plainText = "this is extremely sensitive stuff!!!"; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF( + new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + + var assertionVerificationKeys = new Config.AssertionVerificationKeys(); + assertionVerificationKeys.defaultKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, jwk); + + var unwrappedData = new ByteArrayOutputStream(); + Config.TDFReaderConfig readerConfig = Config.newTDFReaderConfig( + Config.withAssertionVerificationKeys(assertionVerificationKeys)); + var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), readerConfig, + platformUrl); + reader.readPayload(unwrappedData); + + assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) + .withFailMessage("extracted data does not match") + .isEqualTo(plainText); + } + + @Test + void testSimpleTDFWithAssertionWithX5C() throws Exception { + var keypair = CryptoUtils.generateRSAKeypair(); + var assertionConfig = new AssertionConfig(); + assertionConfig.type = AssertionConfig.Type.BaseAssertion; + assertionConfig.scope = AssertionConfig.Scope.TrustedDataObj; + assertionConfig.appliesToState = AssertionConfig.AppliesToState.Unencrypted; + assertionConfig.statement = new AssertionConfig.Statement(); + assertionConfig.statement.format = "base64binary"; + assertionConfig.statement.schema = "text"; + assertionConfig.statement.value = "ICAgIDxlZGoOkVkaD4="; + + X509Certificate cert = TestUtil.createTestCertificate(keypair.getPublic(), keypair.getPrivate()); + assertionConfig.signingKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, + keypair.getPrivate()); + + var rsaKasInfo = new Config.KASInfo(); + rsaKasInfo.URL = "https://example.com/kas" + 0; + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(rsaKasInfo), + Config.withAssertionConfig(assertionConfig)); + + String plainText = "this is extremely sensitive stuff!!!"; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF( + new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + + var assertionVerificationKeys = new Config.AssertionVerificationKeys(); + assertionVerificationKeys.defaultKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, cert.getPublicKey()); + + var unwrappedData = new ByteArrayOutputStream(); + Config.TDFReaderConfig readerConfig = Config.newTDFReaderConfig( + Config.withAssertionVerificationKeys(assertionVerificationKeys)); + var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), readerConfig, + platformUrl); + reader.readPayload(unwrappedData); + + assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) + .withFailMessage("extracted data does not match") + .isEqualTo(plainText); + } + + @Test + void testVerifyAssertionWithJwkHeaderInJWT() throws Exception { + var keypair = CryptoUtils.generateRSAKeypair(); + + var assertionConfig = new AssertionConfig(); + assertionConfig.type = AssertionConfig.Type.BaseAssertion; + assertionConfig.scope = AssertionConfig.Scope.TrustedDataObj; + assertionConfig.appliesToState = AssertionConfig.AppliesToState.Unencrypted; + assertionConfig.statement = new AssertionConfig.Statement(); + assertionConfig.statement.format = "base64binary"; + assertionConfig.statement.schema = "text"; + assertionConfig.statement.value = "ICAgIDxlZGoOkVkaD4="; + + JWK jwk = JWK.parse(CryptoUtils.getPublicKeyJWK(keypair.getPublic())); + + assertionConfig.signingKey = new AssertionConfig.AssertionKey( + AssertionConfig.AssertionKeyAlg.RS256, + keypair.getPrivate() + ).withJwk(jwk.toPublicJWK()); + + var rsaKasInfo = new Config.KASInfo(); + rsaKasInfo.URL = "https://example.com/kas" + 0; + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(rsaKasInfo), + Config.withAssertionConfig(assertionConfig)); + + String plainText = "this is extremely sensitive stuff!!!"; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF( + new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + + var unwrappedData = new ByteArrayOutputStream(); + var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), + Config.newTDFReaderConfig(), platformUrl); + reader.readPayload(unwrappedData); + + assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) + .withFailMessage("extracted data does not match") + .isEqualTo(plainText); + } + + @Test + void testVerifyAssertionWithX5cHeaderInJWT() throws Exception { + var keypair = CryptoUtils.generateRSAKeypair(); + var assertionConfig = new AssertionConfig(); + assertionConfig.type = AssertionConfig.Type.BaseAssertion; + assertionConfig.scope = AssertionConfig.Scope.TrustedDataObj; + assertionConfig.appliesToState = AssertionConfig.AppliesToState.Unencrypted; + assertionConfig.statement = new AssertionConfig.Statement(); + assertionConfig.statement.format = "base64binary"; + assertionConfig.statement.schema = "text"; + assertionConfig.statement.value = "ICAgIDxlZGoOkVkaD4="; + + X509Certificate cert = TestUtil.createTestCertificate(keypair.getPublic(), keypair.getPrivate()); + List x5c = new ArrayList<>(); + x5c.add(com.nimbusds.jose.util.Base64.encode(cert.getEncoded())); + + assertionConfig.signingKey = new AssertionConfig.AssertionKey( + AssertionConfig.AssertionKeyAlg.RS256, + keypair.getPrivate() + ).withX5c(x5c); + + var rsaKasInfo = new Config.KASInfo(); + rsaKasInfo.URL = "https://example.com/kas" + 0; + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(rsaKasInfo), + Config.withAssertionConfig(assertionConfig)); + + String plainText = "this is extremely sensitive stuff!!!"; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF( + new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + + var unwrappedData = new ByteArrayOutputStream(); + var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), + Config.newTDFReaderConfig(), platformUrl); + reader.readPayload(unwrappedData); + + assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) + .withFailMessage("extracted data does not match") + .isEqualTo(plainText); + } + + @Test + void testWithAssertionVerificationDisabled() throws Exception { + String assertion1Id = "assertion1"; + var keypair = CryptoUtils.generateRSAKeypair(); + var assertionConfig = new AssertionConfig(); + assertionConfig.id = assertion1Id; + assertionConfig.type = AssertionConfig.Type.BaseAssertion; + assertionConfig.scope = AssertionConfig.Scope.TrustedDataObj; + assertionConfig.appliesToState = AssertionConfig.AppliesToState.Unencrypted; + assertionConfig.statement = new AssertionConfig.Statement(); + assertionConfig.statement.format = "base64binary"; + assertionConfig.statement.schema = "text"; + assertionConfig.statement.value = "ICAgIDxlZGoOkVkaD4="; + assertionConfig.signingKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, + keypair.getPrivate()); + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(getRSAKASInfos()), + Config.withAssertionConfig(assertionConfig)); + + String plainText = "this is extremely sensitive stuff!!!"; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF( + new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + + var assertionVerificationKeys = new Config.AssertionVerificationKeys(); + assertionVerificationKeys.keys.put(assertion1Id, + new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.RS256, + keypair.getPublic())); + + var unwrappedData = new ByteArrayOutputStream(); + var dataToUnwrap = new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()); + var emptyConfig = Config.newTDFReaderConfig(); + var thrown = assertThrows(SDKException.class, () -> { + tdf.loadTDF(dataToUnwrap, emptyConfig, platformUrl); + }); + assertThat(thrown.getCause()).isInstanceOf(JOSEException.class); + + // try with assertion verification disabled and not passing the assertion + // verification keys + Config.TDFReaderConfig readerConfig = Config.newTDFReaderConfig( + Config.withDisableAssertionVerification(true)); + var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), readerConfig, + platformUrl); + reader.readPayload(unwrappedData); + + assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) + .withFailMessage("extracted data does not match") + .isEqualTo(plainText); + } + + @Test + void testSimpleTDFWithAssertionWithHS256() throws Exception { + String assertion1Id = "assertion1"; + var assertionConfig1 = new AssertionConfig(); + assertionConfig1.id = assertion1Id; + assertionConfig1.type = AssertionConfig.Type.BaseAssertion; + assertionConfig1.scope = AssertionConfig.Scope.TrustedDataObj; + assertionConfig1.appliesToState = AssertionConfig.AppliesToState.Unencrypted; + assertionConfig1.statement = new AssertionConfig.Statement(); + assertionConfig1.statement.format = "base64binary"; + assertionConfig1.statement.schema = "text"; + assertionConfig1.statement.value = "ICAgIDxlZGoOkVkaD4="; + + String assertion2Id = "assertion2"; + var assertionConfig2 = new AssertionConfig(); + assertionConfig2.id = assertion2Id; + assertionConfig2.type = AssertionConfig.Type.HandlingAssertion; + assertionConfig2.scope = AssertionConfig.Scope.TrustedDataObj; + assertionConfig2.appliesToState = AssertionConfig.AppliesToState.Unencrypted; + assertionConfig2.statement = new AssertionConfig.Statement(); + assertionConfig2.statement.format = "json"; + assertionConfig2.statement.schema = "urn:nato:stanag:5636:A:1:elements:json"; + assertionConfig2.statement.value = "{\"uuid\":\"f74efb60-4a9a-11ef-a6f1-8ee1a61c148a\",\"body\":{\"dataAttributes\":null,\"dissem\":null}}"; + + var rsaKasInfo = new Config.KASInfo(); + rsaKasInfo.URL = "https://example.com/kas" + 0; + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(rsaKasInfo), + Config.withAssertionConfig(assertionConfig1, assertionConfig2)); + + String plainText = "this is extremely sensitive stuff!!!"; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF( + new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + + var unwrappedData = new ByteArrayOutputStream(); + var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), + Config.newTDFReaderConfig(), platformUrl); + reader.readPayload(unwrappedData); + + assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) + .withFailMessage("extracted data does not match") + .isEqualTo(plainText); + + var manifest = reader.getManifest(); + var assertions = manifest.assertions; + assertThat(assertions.size()).isEqualTo(2); + for (var assertion : assertions) { + if (assertion.id.equals(assertion1Id)) { assertThat(assertion.statement.format).isEqualTo("base64binary"); assertThat(assertion.statement.schema).isEqualTo("text"); assertThat(assertion.statement.value).isEqualTo("ICAgIDxlZGoOkVkaD4="); assertThat(assertion.type).isEqualTo(AssertionConfig.Type.BaseAssertion.toString()); + } else if (assertion.id.equals(assertion2Id)) { + assertThat(assertion.statement.format).isEqualTo("json"); + assertThat(assertion.statement.schema) + .isEqualTo("urn:nato:stanag:5636:A:1:elements:json"); + assertThat(assertion.statement.value).isEqualTo( + "{\"uuid\":\"f74efb60-4a9a-11ef-a6f1-8ee1a61c148a\",\"body\":{\"dataAttributes\":null,\"dissem\":null}}"); + assertThat(assertion.type).isEqualTo(AssertionConfig.Type.HandlingAssertion.toString()); + } else { + throw new RuntimeException("unexpected assertion id: " + assertion.id); + } } - - @Test - void testSystemMetadataAssertion() throws Exception { - Config.TDFConfig tdfConfig = Config.newTDFConfig( - Config.withAutoconfigure(false), - Config.withKasInformation(getRSAKASInfos()), - Config.withSystemMetadataAssertion() // Enable system metadata assertion - ); - - String plainText = "Test data for system metadata assertion."; - InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes(StandardCharsets.UTF_8)); - ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); - - TDF tdf = new TDF( - new FakeServicesBuilder().setKas(kas) - .setKeyAccessServerRegistryService(kasRegistryService).build()); - var createdManifest = tdf.createTDF(plainTextInputStream, tdfOutputStream, tdfConfig).getManifest(); - - // Verify the created manifest directly - assertThat(createdManifest.assertions).isNotNull(); - assertThat(createdManifest.assertions.size()).isEqualTo(1); - Manifest.Assertion sysAssertion = createdManifest.assertions.get(0); - assertThat(sysAssertion.id).isEqualTo("system-metadata"); - assertThat(sysAssertion.type).isEqualTo(AssertionConfig.Type.BaseAssertion.toString()); - assertThat(sysAssertion.scope).isEqualTo(AssertionConfig.Scope.Payload.toString()); - assertThat(sysAssertion.appliesToState) - .isEqualTo(AssertionConfig.AppliesToState.Unencrypted.toString()); - assertThat(sysAssertion.statement.format).isEqualTo("json"); - assertThat(sysAssertion.statement.schema).isEqualTo("system-metadata-v1"); - - // Deserialize and check the metadata JSON - Gson gson = new Gson(); - java.lang.reflect.Type mapType = new TypeToken>() { - }.getType(); - Map metadataMap = gson.fromJson(sysAssertion.statement.value, mapType); - - assertThat(metadataMap).containsKey("tdf_spec_version"); - assertThat(metadataMap.get("tdf_spec_version")).isEqualTo(TDF.TDF_SPEC_VERSION); - assertThat(metadataMap).containsKey("creation_date"); - assertThat(metadataMap.get("creation_date")).isNotBlank(); - assertThat(metadataMap).containsKey("operating_system"); - assertThat(metadataMap.get("operating_system")).isEqualTo(System.getProperty("os.name")); - assertThat(metadataMap).containsKey("sdk_version"); - assertThat(metadataMap.get("sdk_version")).isEqualTo("Java-" + Version.SDK); - assertThat(metadataMap).containsKey("java_version"); - assertThat(metadataMap.get("java_version")).isEqualTo(System.getProperty("java.version")); - assertThat(metadataMap).containsKey("architecture"); - assertThat(metadataMap.get("architecture")).isEqualTo(System.getProperty("os.arch")); + } + + @Test + void testSimpleTDFWithAssertionWithHS256Failure() throws Exception { + // var keypair = CryptoUtils.generateRSAKeypair(); + SecureRandom secureRandom = new SecureRandom(); + byte[] key = new byte[32]; + secureRandom.nextBytes(key); + + String assertion1Id = "assertion1"; + var assertionConfig1 = new AssertionConfig(); + assertionConfig1.id = assertion1Id; + assertionConfig1.type = AssertionConfig.Type.BaseAssertion; + assertionConfig1.scope = AssertionConfig.Scope.TrustedDataObj; + assertionConfig1.appliesToState = AssertionConfig.AppliesToState.Unencrypted; + assertionConfig1.statement = new AssertionConfig.Statement(); + assertionConfig1.statement.format = "base64binary"; + assertionConfig1.statement.schema = "text"; + assertionConfig1.statement.value = "ICAgIDxlZGoOkVkaD4="; + assertionConfig1.signingKey = new AssertionConfig.AssertionKey(AssertionConfig.AssertionKeyAlg.HS256, + key); + + var rsaKasInfo = new Config.KASInfo(); + rsaKasInfo.URL = "https://example.com/kas" + 0; + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(rsaKasInfo), + Config.withAssertionConfig(assertionConfig1)); + + String plainText = "this is extremely sensitive stuff!!!"; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF( + new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + + byte[] notkey = new byte[32]; + secureRandom.nextBytes(notkey); + var assertionVerificationKeys = new Config.AssertionVerificationKeys(); + assertionVerificationKeys.defaultKey = new AssertionConfig.AssertionKey( + AssertionConfig.AssertionKeyAlg.HS256, + notkey); + Config.TDFReaderConfig readerConfig = Config.newTDFReaderConfig( + Config.withAssertionVerificationKeys(assertionVerificationKeys)); + + try { + tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), readerConfig, + platformUrl); + throw new RuntimeException("assertion verify key error thrown"); + + } catch (SDKException e) { + assertThat(e).hasMessageContaining("verify"); } - - @Test - void testKasAllowlist() throws Exception { - - KeyAccessServerRegistryServiceClient kasRegistryServiceNoUrl = mock( - KeyAccessServerRegistryServiceClient.class); - List kasRegEntries = new ArrayList<>(); - kasRegEntries.add(KeyAccessServer.newBuilder() - .setUri("http://example.com/kas0").build()); - - ListKeyAccessServersResponse mockResponse = ListKeyAccessServersResponse.newBuilder() - .addAllKeyAccessServers(kasRegEntries) - .build(); - - // Stub the listKeyAccessServers method - when(kasRegistryServiceNoUrl.listKeyAccessServersBlocking(any(ListKeyAccessServersRequest.class), - any())) - .thenReturn(new UnaryBlockingCall<>() { - @Override - public ResponseMessage execute() { - return new ResponseMessage.Success<>(mockResponse, - Collections.emptyMap(), - Collections.emptyMap()); - } - - @Override - public void cancel() { - // we never do this during tests - } - }); - - var rsaKasInfo = new Config.KASInfo(); - rsaKasInfo.URL = "https://example.com/kas" + Integer.toString(0); - - Config.TDFConfig config = Config.newTDFConfig( - Config.withAutoconfigure(false), - Config.withKasInformation(rsaKasInfo)); - - String plainText = "this is extremely sensitive stuff!!!"; - InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); - ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); - - TDF tdf = new TDF(new FakeServicesBuilder().setKas(kas) - .setKeyAccessServerRegistryService(kasRegistryServiceNoUrl).build()); - tdf.createTDF(plainTextInputStream, tdfOutputStream, config); - - var unwrappedData = new ByteArrayOutputStream(); - - // should throw error because the kas url is not in the allowlist - try { - tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), - Config.newTDFReaderConfig(), - platformUrl); - throw new RuntimeException("expected allowlist error to be thrown"); - } catch (Exception e) { - assertThat(e).hasMessageContaining("KasAllowlist"); + } + + @Test + public void testCreatingTDFWithMultipleSegments() throws Exception { + var random = new Random(); + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(getRSAKASInfos()), + Config.withSegmentSize(Config.MIN_SEGMENT_SIZE)); + + // data should be large enough to have multiple complete and a partial segment + var data = new byte[(int) (Config.MIN_SEGMENT_SIZE * 2.8)]; + random.nextBytes(data); + var plainTextInputStream = new ByteArrayInputStream(data); + var tdfOutputStream = new ByteArrayOutputStream(); + var tdf = new TDF( + new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + var unwrappedData = new ByteArrayOutputStream(); + var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), platformUrl); + reader.readPayload(unwrappedData); + + assertThat(unwrappedData.toByteArray()) + .withFailMessage("extracted data does not match") + .containsExactly(data); + + } + + @Test + public void testCreatingTooLargeTDF() { + var random = new Random(); + var maxSize = random.nextInt(1024); + var numReturned = new AtomicInteger(0); + + // return 1 more byte than the maximum size + var is = new InputStream() { + @Override + public int read() { + if (numReturned.get() > maxSize) { + return -1; } + numReturned.incrementAndGet(); + return 1; + } + + @Override + public int read(byte[] b, int off, int len) { + var numToReturn = Math.min(len, maxSize - numReturned.get() + 1); + numReturned.addAndGet(numToReturn); + return numToReturn; + } + }; - // with custom allowlist should succeed - Config.TDFReaderConfig readerConfig = Config.newTDFReaderConfig( - Config.WithKasAllowlist("https://example.com")); - tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), readerConfig, platformUrl); - - // with ignore allowlist should succeed - readerConfig = Config.newTDFReaderConfig( - Config.WithIgnoreKasAllowlist(true)); - Reader reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), - readerConfig, - platformUrl); - reader.readPayload(unwrappedData); - - assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) - .withFailMessage("extracted data does not match") - .isEqualTo(plainText); - - // use the platform url as kas url, should succeed - var platformKasInfo = new Config.KASInfo(); - platformKasInfo.URL = platformUrl + "/kas" + Integer.toString(0); - config = Config.newTDFConfig( - Config.withAutoconfigure(false), - Config.withKasInformation(platformKasInfo)); - plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); - tdfOutputStream = new ByteArrayOutputStream(); - tdf = new TDF(new FakeServicesBuilder().setKas(kas) - .setKeyAccessServerRegistryService(kasRegistryServiceNoUrl) - .build()); - tdf.createTDF(plainTextInputStream, tdfOutputStream, config); - - unwrappedData = new ByteArrayOutputStream(); - reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), - Config.newTDFReaderConfig(), platformUrl); - reader.readPayload(unwrappedData); - - assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) - .withFailMessage("extracted data does not match") - .isEqualTo(plainText); - } + var os = new OutputStream() { + @Override + public void write(int b) { + } - @Nonnull - private static Config.KASInfo[] getKASInfos(Predicate filter) { - var kasInfos = new ArrayList(); - for (int i = 0; i < keypairs.size(); i++) { - if (filter.test(i)) { - var kasInfo = new Config.KASInfo(); - kasInfo.URL = "https://example.com/kas" + Integer.toString(i); - kasInfo.PublicKey = null; - kasInfos.add(kasInfo); - } - } - return kasInfos.toArray(Config.KASInfo[]::new); - } + @Override + public void write(byte[] b, int off, int len) { + } + }; - @Nonnull - private static Config.KASInfo[] getRSAKASInfos() { - return getKASInfos(i -> i % 2 == 0); + var tdf = new TDF(maxSize, new FakeServicesBuilder().setKas(kas).build()); + var tdfConfig = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(getRSAKASInfos()), + Config.withSegmentSize(Config.MIN_SEGMENT_SIZE)); + assertThrows(SDK.DataSizeNotSupported.class, + () -> tdf.createTDF(is, os, tdfConfig), + "didn't throw an exception when we created TDF that was too large"); + assertThat(numReturned.get()) + .withFailMessage("test returned the wrong number of bytes") + .isEqualTo(maxSize + 1); + } + + @Test + public void testCreateTDFWithMimeType() throws Exception { + final String mimeType = "application/pdf"; + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(getRSAKASInfos()), + Config.withMimeType(mimeType)); + + String plainText = "this is extremely sensitive stuff!!!"; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF( + new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + + var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), platformUrl); + assertThat(reader.getManifest().payload.mimeType).isEqualTo(mimeType); + } + + @Test + void legacyTDFRoundTrips() throws IOException { + final String mimeType = "application/pdf"; + var assertionConfig1 = new AssertionConfig(); + assertionConfig1.id = "assertion1"; + assertionConfig1.type = AssertionConfig.Type.BaseAssertion; + assertionConfig1.scope = AssertionConfig.Scope.TrustedDataObj; + assertionConfig1.appliesToState = AssertionConfig.AppliesToState.Unencrypted; + assertionConfig1.statement = new AssertionConfig.Statement(); + assertionConfig1.statement.format = "base64binary"; + assertionConfig1.statement.schema = "text"; + assertionConfig1.statement.value = "ICAgIDxlZGoOkVkaD4="; + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(getRSAKASInfos()), + Config.withTargetMode("4.2.1"), + Config.withAssertionConfig(assertionConfig1), + Config.withMimeType(mimeType)); + + byte[] data = new byte[129]; + new Random().nextBytes(data); + InputStream plainTextInputStream = new ByteArrayInputStream(data); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF( + new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + + var dataOutputStream = new ByteArrayOutputStream(); + + var reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), platformUrl); + var integrityInformation = reader.getManifest().encryptionInformation.integrityInformation; + assertThat(reader.getManifest().tdfVersion).isNull(); + var decodedSignature = Base64.getDecoder().decode(integrityInformation.rootSignature.signature); + for (var b : decodedSignature) { + assertThat(isHexChar(b)) + .withFailMessage("non-hex byte in signature: " + b) + .isTrue(); + } + for (var s : integrityInformation.segments) { + var decodedSegmentSignature = Base64.getDecoder().decode(s.hash); + for (var b : decodedSegmentSignature) { + assertThat(isHexChar(b)) + .withFailMessage("non-hex byte in segment signature: " + b) + .isTrue(); + } } + reader.readPayload(dataOutputStream); + assertThat(reader.getManifest().payload.mimeType).isEqualTo(mimeType); + assertArrayEquals(data, dataOutputStream.toByteArray(), "extracted data does not match"); + var manifest = reader.getManifest(); + var assertions = manifest.assertions; + assertThat(assertions.size()).isEqualTo(1); + var assertion = assertions.get(0); + assertThat(assertion.id).isEqualTo("assertion1"); + assertThat(assertion.statement.format).isEqualTo("base64binary"); + assertThat(assertion.statement.schema).isEqualTo("text"); + assertThat(assertion.statement.value).isEqualTo("ICAgIDxlZGoOkVkaD4="); + assertThat(assertion.type).isEqualTo(AssertionConfig.Type.BaseAssertion.toString()); + } + + @Test + void testSystemMetadataAssertion() throws Exception { + Config.TDFConfig tdfConfig = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(getRSAKASInfos()), + Config.withSystemMetadataAssertion() // Enable system metadata assertion + ); + + String plainText = "Test data for system metadata assertion."; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes(StandardCharsets.UTF_8)); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF( + new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryService).build()); + var createdManifest = tdf.createTDF(plainTextInputStream, tdfOutputStream, tdfConfig).getManifest(); + + // Verify the created manifest directly + assertThat(createdManifest.assertions).isNotNull(); + assertThat(createdManifest.assertions.size()).isEqualTo(1); + Manifest.Assertion sysAssertion = createdManifest.assertions.get(0); + assertThat(sysAssertion.id).isEqualTo("system-metadata"); + assertThat(sysAssertion.type).isEqualTo(AssertionConfig.Type.BaseAssertion.toString()); + assertThat(sysAssertion.scope).isEqualTo(AssertionConfig.Scope.Payload.toString()); + assertThat(sysAssertion.appliesToState) + .isEqualTo(AssertionConfig.AppliesToState.Unencrypted.toString()); + assertThat(sysAssertion.statement.format).isEqualTo("json"); + assertThat(sysAssertion.statement.schema).isEqualTo("system-metadata-v1"); + + // Deserialize and check the metadata JSON + Gson gson = new Gson(); + java.lang.reflect.Type mapType = new TypeToken>() { + }.getType(); + Map metadataMap = gson.fromJson(sysAssertion.statement.value, mapType); + + assertThat(metadataMap).containsKey("tdf_spec_version"); + assertThat(metadataMap.get("tdf_spec_version")).isEqualTo(TDF.TDF_SPEC_VERSION); + assertThat(metadataMap).containsKey("creation_date"); + assertThat(metadataMap.get("creation_date")).isNotBlank(); + assertThat(metadataMap).containsKey("operating_system"); + assertThat(metadataMap.get("operating_system")).isEqualTo(System.getProperty("os.name")); + assertThat(metadataMap).containsKey("sdk_version"); + assertThat(metadataMap.get("sdk_version")).isEqualTo("Java-" + Version.SDK); + assertThat(metadataMap).containsKey("java_version"); + assertThat(metadataMap.get("java_version")).isEqualTo(System.getProperty("java.version")); + assertThat(metadataMap).containsKey("architecture"); + assertThat(metadataMap.get("architecture")).isEqualTo(System.getProperty("os.arch")); + } + + @Test + void testKasAllowlist() throws Exception { + + KeyAccessServerRegistryServiceClient kasRegistryServiceNoUrl = mock( + KeyAccessServerRegistryServiceClient.class); + List kasRegEntries = new ArrayList<>(); + kasRegEntries.add(KeyAccessServer.newBuilder() + .setUri("http://example.com/kas0").build()); + + ListKeyAccessServersResponse mockResponse = ListKeyAccessServersResponse.newBuilder() + .addAllKeyAccessServers(kasRegEntries) + .build(); + + // Stub the listKeyAccessServers method + when(kasRegistryServiceNoUrl.listKeyAccessServersBlocking(any(ListKeyAccessServersRequest.class), + any())) + .thenReturn(new UnaryBlockingCall<>() { + @Override + public ResponseMessage execute() { + return new ResponseMessage.Success<>(mockResponse, + Collections.emptyMap(), + Collections.emptyMap()); + } + + @Override + public void cancel() { + // we never do this during tests + } + }); + + var rsaKasInfo = new Config.KASInfo(); + rsaKasInfo.URL = "https://example.com/kas" + Integer.toString(0); + + Config.TDFConfig config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(rsaKasInfo)); + + String plainText = "this is extremely sensitive stuff!!!"; + InputStream plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + ByteArrayOutputStream tdfOutputStream = new ByteArrayOutputStream(); + + TDF tdf = new TDF(new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryServiceNoUrl).build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + + var unwrappedData = new ByteArrayOutputStream(); - @Nonnull - private static Config.KASInfo[] getECKASInfos() { - return getKASInfos(i -> i % 2 != 0); + // should throw error because the kas url is not in the allowlist + try { + tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), + Config.newTDFReaderConfig(), + platformUrl); + throw new RuntimeException("expected allowlist error to be thrown"); + } catch (Exception e) { + assertThat(e).hasMessageContaining("KasAllowlist"); } - private static boolean isHexChar(byte b) { - return (b >= 'a' && b <= 'f') || (b >= '0' && b <= '9'); + // with custom allowlist should succeed + Config.TDFReaderConfig readerConfig = Config.newTDFReaderConfig( + Config.WithKasAllowlist("https://example.com")); + tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), readerConfig, platformUrl); + + // with ignore allowlist should succeed + readerConfig = Config.newTDFReaderConfig( + Config.WithIgnoreKasAllowlist(true)); + Reader reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), + readerConfig, + platformUrl); + reader.readPayload(unwrappedData); + + assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) + .withFailMessage("extracted data does not match") + .isEqualTo(plainText); + + // use the platform url as kas url, should succeed + var platformKasInfo = new Config.KASInfo(); + platformKasInfo.URL = platformUrl + "/kas" + Integer.toString(0); + config = Config.newTDFConfig( + Config.withAutoconfigure(false), + Config.withKasInformation(platformKasInfo)); + plainTextInputStream = new ByteArrayInputStream(plainText.getBytes()); + tdfOutputStream = new ByteArrayOutputStream(); + tdf = new TDF(new FakeServicesBuilder().setKas(kas) + .setKeyAccessServerRegistryService(kasRegistryServiceNoUrl) + .build()); + tdf.createTDF(plainTextInputStream, tdfOutputStream, config); + + unwrappedData = new ByteArrayOutputStream(); + reader = tdf.loadTDF(new SeekableInMemoryByteChannel(tdfOutputStream.toByteArray()), + Config.newTDFReaderConfig(), platformUrl); + reader.readPayload(unwrappedData); + + assertThat(unwrappedData.toString(StandardCharsets.UTF_8)) + .withFailMessage("extracted data does not match") + .isEqualTo(plainText); + } + + @Nonnull + private static Config.KASInfo[] getKASInfos(Predicate filter) { + var kasInfos = new ArrayList(); + for (int i = 0; i < keypairs.size(); i++) { + if (filter.test(i)) { + var kasInfo = new Config.KASInfo(); + kasInfo.URL = "https://example.com/kas" + Integer.toString(i); + kasInfo.PublicKey = null; + kasInfos.add(kasInfo); + } } -} + return kasInfos.toArray(Config.KASInfo[]::new); + } + + @Nonnull + private static Config.KASInfo[] getRSAKASInfos() { + return getKASInfos(i -> i % 2 == 0); + } + + @Nonnull + private static Config.KASInfo[] getECKASInfos() { + return getKASInfos(i -> i % 2 != 0); + } + + private static boolean isHexChar(byte b) { + return (b >= 'a' && b <= 'f') || (b >= '0' && b <= '9'); + } +} \ No newline at end of file diff --git a/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java b/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java index 53076007..13cbedbe 100644 --- a/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java +++ b/sdk/src/test/java/io/opentdf/platform/sdk/TestUtil.java @@ -5,6 +5,18 @@ import java.util.Collections; + +import org.bouncycastle.asn1.x500.X500Name; +import org.bouncycastle.cert.jcajce.JcaX509CertificateConverter; +import org.bouncycastle.cert.jcajce.JcaX509v3CertificateBuilder; +import org.bouncycastle.operator.jcajce.JcaContentSignerBuilder; + +import java.math.BigInteger; +import java.security.cert.X509Certificate; +import java.security.PrivateKey; +import java.security.PublicKey; +import java.util.Date; + public class TestUtil { static UnaryBlockingCall successfulUnaryCall(T result) { return new UnaryBlockingCall() { @@ -19,4 +31,18 @@ public void cancel() { } }; } + + public static X509Certificate createTestCertificate(PublicKey publicKey, PrivateKey privateKey) throws Exception { + X500Name owner = new X500Name("CN=Test"); + JcaX509v3CertificateBuilder builder = new JcaX509v3CertificateBuilder( + owner, + BigInteger.ONE, + new Date(System.currentTimeMillis() - 1000L * 60 * 60 * 24), + new Date(System.currentTimeMillis() + (1000L * 60 * 60 * 24 * 365)), + owner, + publicKey + ); + + return new JcaX509CertificateConverter().getCertificate(builder.build(new JcaContentSignerBuilder("SHA256WithRSA").build(privateKey))); + } }