diff --git a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java index f34e336553b..66591cda153 100644 --- a/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java +++ b/netty/src/test/java/io/grpc/netty/AdvancedTlsTest.java @@ -436,16 +436,13 @@ public void onFileReloadingTrustManagerBadInitialContentTest() throws Exception } @Test - public void keyManagerAliasesTest() { + public void keyManagerAliasesTest() throws Exception { AdvancedTlsX509KeyManager km = new AdvancedTlsX509KeyManager(); - assertArrayEquals( - new String[] {"default"}, km.getClientAliases("", null)); - assertEquals( - "default", km.chooseClientAlias(new String[] {"default"}, null, null)); - assertArrayEquals( - new String[] {"default"}, km.getServerAliases("", null)); - assertEquals( - "default", km.chooseServerAlias("default", null, null)); + km.updateIdentityCredentials(serverCert0, serverKey0); + assertArrayEquals(new String[] {"key-1"}, km.getClientAliases("", null)); + assertEquals("key-1", km.chooseClientAlias(new String[] {"key-1"}, null, null)); + assertArrayEquals(new String[] {"key-1"}, km.getServerAliases("", null)); + assertEquals("key-1", km.chooseServerAlias("key-1", null, null)); } @Test diff --git a/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java b/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java index 1f807cd405d..eea664f2ad4 100644 --- a/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java +++ b/util/src/main/java/io/grpc/util/AdvancedTlsX509KeyManager.java @@ -32,6 +32,7 @@ import java.util.concurrent.ScheduledExecutorService; import java.util.concurrent.ScheduledFuture; import java.util.concurrent.TimeUnit; +import java.util.concurrent.atomic.AtomicInteger; import java.util.logging.Level; import java.util.logging.Logger; import javax.net.ssl.SSLEngine; @@ -40,59 +41,86 @@ /** * AdvancedTlsX509KeyManager is an {@code X509ExtendedKeyManager} that allows users to configure * advanced TLS features, such as private key and certificate chain reloading. + * + *

The alias increments on every credential load (e.g. {@code "key-1"}, {@code "key-2"}, ...), + * so the same alias always maps to the same key material. The previous alias is retained for one + * rotation to allow in-progress handshakes to complete, ensuring alias-to-key-material consistency + * across credential reloads. */ public final class AdvancedTlsX509KeyManager extends X509ExtendedKeyManager { private static final Logger log = Logger.getLogger(AdvancedTlsX509KeyManager.class.getName()); // Minimum allowed period for refreshing files with credential information. - private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1 ; - // The credential information to be sent to peers to prove our identity. - private volatile KeyInfo keyInfo; + private static final int MINIMUM_REFRESH_PERIOD_IN_MINUTES = 1; + // Prefix for the key material alias; revision counter appended on each credential load. + static final String ALIAS_PREFIX = "key-"; + + private final AtomicInteger revision = new AtomicInteger(0); + // Snapshot of current and previous KeyInfo; previous is retained for in-progress handshakes + // after one rotation. + private volatile KeyInfoSnapshot snapshot = new KeyInfoSnapshot(null, null); + + public AdvancedTlsX509KeyManager() {} + + private String alias() { + KeyInfo curr = this.snapshot.current; + return curr != null ? curr.alias : null; + } @Override public PrivateKey getPrivateKey(String alias) { - if (alias.equals("default")) { - return this.keyInfo.key; + KeyInfoSnapshot snap = this.snapshot; + if (snap.current != null && snap.current.alias.equals(alias)) { + return snap.current.key; + } + if (snap.previous != null && snap.previous.alias.equals(alias)) { + return snap.previous.key; } return null; } @Override public X509Certificate[] getCertificateChain(String alias) { - if (alias.equals("default")) { - return Arrays.copyOf(this.keyInfo.certs, this.keyInfo.certs.length); + KeyInfoSnapshot snap = this.snapshot; + if (snap.current != null && snap.current.alias.equals(alias)) { + return Arrays.copyOf(snap.current.certs, snap.current.certs.length); + } + if (snap.previous != null && snap.previous.alias.equals(alias)) { + return Arrays.copyOf(snap.previous.certs, snap.previous.certs.length); } return null; } @Override public String[] getClientAliases(String keyType, Principal[] issuers) { - return new String[] {"default"}; + String alias = alias(); + return alias != null ? new String[] {alias} : null; } @Override public String chooseClientAlias(String[] keyType, Principal[] issuers, Socket socket) { - return "default"; + return alias(); } @Override public String chooseEngineClientAlias(String[] keyType, Principal[] issuers, SSLEngine engine) { - return "default"; + return alias(); } @Override public String[] getServerAliases(String keyType, Principal[] issuers) { - return new String[] {"default"}; + String alias = alias(); + return alias != null ? new String[] {alias} : null; } @Override public String chooseServerAlias(String keyType, Principal[] issuers, Socket socket) { - return "default"; + return alias(); } @Override public String chooseEngineServerAlias(String keyType, Principal[] issuers, SSLEngine engine) { - return "default"; + return alias(); } /** @@ -116,7 +144,9 @@ public void updateIdentityCredentials(PrivateKey key, X509Certificate[] certs) { * @param key the private key that is going to be used */ public void updateIdentityCredentials(X509Certificate[] certs, PrivateKey key) { - this.keyInfo = new KeyInfo(checkNotNull(certs, "certs"), checkNotNull(key, "key")); + KeyInfo newInfo = new KeyInfo(checkNotNull(certs, "certs"), checkNotNull(key, "key"), + ALIAS_PREFIX + revision.incrementAndGet()); + this.snapshot = new KeyInfoSnapshot(newInfo, this.snapshot.current); } /** @@ -218,10 +248,22 @@ private static class KeyInfo { // The private key and the cert chain we will use to send to peers to prove our identity. final X509Certificate[] certs; final PrivateKey key; + final String alias; - public KeyInfo(X509Certificate[] certs, PrivateKey key) { + public KeyInfo(X509Certificate[] certs, PrivateKey key, String alias) { this.certs = certs; this.key = key; + this.alias = alias; + } + } + + private static class KeyInfoSnapshot { + final KeyInfo current; + final KeyInfo previous; + + KeyInfoSnapshot(KeyInfo current, KeyInfo previous) { + this.current = current; + this.previous = previous; } } @@ -309,4 +351,3 @@ public interface Closeable extends java.io.Closeable { void close(); } } - diff --git a/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java b/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java index f96c85e4f4f..b8431d4f991 100644 --- a/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java +++ b/util/src/test/java/io/grpc/util/AdvancedTlsX509KeyManagerTest.java @@ -18,6 +18,7 @@ import static org.junit.Assert.assertArrayEquals; import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertThrows; import static org.junit.Assert.assertTrue; import static org.junit.Assert.fail; @@ -48,7 +49,6 @@ public class AdvancedTlsX509KeyManagerTest { private static final String SERVER_0_PEM_FILE = "server0.pem"; private static final String CLIENT_0_KEY_FILE = "client.key"; private static final String CLIENT_0_PEM_FILE = "client.pem"; - private static final String ALIAS = "default"; private ScheduledExecutorService executor; @@ -79,22 +79,62 @@ public void setUp() throws Exception { public void updateTrustCredentials_replacesIssuers() throws Exception { // Overall happy path checking of public API. AdvancedTlsX509KeyManager serverKeyManager = new AdvancedTlsX509KeyManager(); + serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0); - assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS)); - assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS)); + String alias1 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(AdvancedTlsX509KeyManager.ALIAS_PREFIX + "1", alias1); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias1)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias1)); serverKeyManager.updateIdentityCredentials(clientCert0File, clientKey0File); - assertEquals(clientKey0, serverKeyManager.getPrivateKey(ALIAS)); - assertArrayEquals(clientCert0, serverKeyManager.getCertificateChain(ALIAS)); - - serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File,1, + String alias2 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(AdvancedTlsX509KeyManager.ALIAS_PREFIX + "2", alias2); + assertEquals(clientKey0, serverKeyManager.getPrivateKey(alias2)); + assertArrayEquals(clientCert0, serverKeyManager.getCertificateChain(alias2)); + // Previous alias still resolves — retained to allow in-progress handshakes to complete. + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias1)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias1)); + + serverKeyManager.updateIdentityCredentials(serverCert0File, serverKey0File, 1, TimeUnit.MINUTES, executor); - assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS)); - assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS)); + String alias3 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias3)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias3)); + // alias1 is now two rotations back — no longer retained. + assertNull(serverKeyManager.getPrivateKey(alias1)); serverKeyManager.updateIdentityCredentials(serverCert0, serverKey0); - assertEquals(serverKey0, serverKeyManager.getPrivateKey(ALIAS)); - assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(ALIAS)); + String alias4 = serverKeyManager.chooseEngineServerAlias(null, null, null); + assertEquals(serverKey0, serverKeyManager.getPrivateKey(alias4)); + assertArrayEquals(serverCert0, serverKeyManager.getCertificateChain(alias4)); + } + + @Test + public void allAliasMethods_returnNullBeforeCredentialsLoaded() { + AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); + + assertNull(keyManager.chooseClientAlias(null, null, null)); + assertNull(keyManager.chooseServerAlias(null, null, null)); + assertNull(keyManager.chooseEngineClientAlias(null, null, null)); + assertNull(keyManager.chooseEngineServerAlias(null, null, null)); + assertNull(keyManager.getClientAliases(null, null)); + assertNull(keyManager.getServerAliases(null, null)); + assertNull(keyManager.getPrivateKey("key-1")); + assertNull(keyManager.getCertificateChain("key-1")); + } + + @Test + public void allAliasMethods_agreeAfterCredentialLoad() throws Exception { + AdvancedTlsX509KeyManager keyManager = new AdvancedTlsX509KeyManager(); + keyManager.updateIdentityCredentials(serverCert0, serverKey0); + + String expectedAlias = AdvancedTlsX509KeyManager.ALIAS_PREFIX + "1"; + assertEquals(expectedAlias, keyManager.chooseClientAlias(null, null, null)); + assertEquals(expectedAlias, keyManager.chooseServerAlias(null, null, null)); + assertEquals(expectedAlias, keyManager.chooseEngineClientAlias(null, null, null)); + assertEquals(expectedAlias, keyManager.chooseEngineServerAlias(null, null, null)); + assertArrayEquals(new String[]{expectedAlias}, keyManager.getClientAliases(null, null)); + assertArrayEquals(new String[]{expectedAlias}, keyManager.getServerAliases(null, null)); } @Test