Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -79,32 +79,36 @@ public void checkClientTrusted(final X509Certificate[] certificates, final Strin
if (LOG.isDebugEnabled()) {
printCertificateChain(certificates, s);
}
if (!authStrictness) {
return;
}
if (certificates == null || certificates.length < 1 || certificates[0] == null) {

final X509Certificate primaryClientCertificate = (certificates != null && certificates.length > 0 && certificates[0] != null) ? certificates[0] : null;
String exceptionMsg = "";

if (authStrictness && primaryClientCertificate == null) {
throw new CertificateException("In strict auth mode, certificate(s) are expected from client:" + clientAddress);
} else if (primaryClientCertificate == null) {
LOG.info("No certificate was received from client, but continuing since strict auth mode is disabled");
return;
}
final X509Certificate primaryClientCertificate = certificates[0];

// Revocation check
final BigInteger serialNumber = primaryClientCertificate.getSerialNumber();
if (serialNumber == null || crlDao.findBySerial(serialNumber) != null) {
final String errorMsg = String.format("Client is using revoked certificate of serial=%x, subject=%s from address=%s",
primaryClientCertificate.getSerialNumber(), primaryClientCertificate.getSubjectDN(), clientAddress);
LOG.error(errorMsg);
throw new CertificateException(errorMsg);
exceptionMsg = (Strings.isNullOrEmpty(exceptionMsg)) ? errorMsg : (exceptionMsg + ". " + errorMsg);
}

// Validity check
if (!allowExpiredCertificate) {
try {
primaryClientCertificate.checkValidity();
} catch (final CertificateExpiredException | CertificateNotYetValidException e) {
final String errorMsg = String.format("Client certificate has expired with serial=%x, subject=%s from address=%s",
primaryClientCertificate.getSerialNumber(), primaryClientCertificate.getSubjectDN(), clientAddress);
LOG.error(errorMsg);
throw new CertificateException(errorMsg); }
try {
primaryClientCertificate.checkValidity();
} catch (final CertificateExpiredException | CertificateNotYetValidException e) {
final String errorMsg = String.format("Client certificate has expired with serial=%x, subject=%s from address=%s",
primaryClientCertificate.getSerialNumber(), primaryClientCertificate.getSubjectDN(), clientAddress);
LOG.error(errorMsg);
if (!allowExpiredCertificate) {
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@Slair1 what if allowExpiredCertificate is true ? log the error and silently ignore it ?
should it be appended to exceptionMsg ?

throw new CertificateException(errorMsg);
}
}

// Ownership check
Expand All @@ -122,13 +126,21 @@ public void checkClientTrusted(final X509Certificate[] certificates, final Strin
if (!certMatchesOwnership) {
final String errorMsg = "Certificate ownership verification failed for client: " + clientAddress;
LOG.error(errorMsg);
throw new CertificateException(errorMsg);
exceptionMsg = (Strings.isNullOrEmpty(exceptionMsg)) ? errorMsg : (exceptionMsg + ". " + errorMsg);
}
if (activeCertMap != null && !Strings.isNullOrEmpty(clientAddress)) {
activeCertMap.put(clientAddress, primaryClientCertificate);
if (authStrictness && !Strings.isNullOrEmpty(exceptionMsg)) {
throw new CertificateException(exceptionMsg);
}
if (LOG.isDebugEnabled()) {
LOG.debug("Client/agent connection from ip=" + clientAddress + " has been validated and trusted.");
if (authStrictness) {
LOG.debug("Client/agent connection from ip=" + clientAddress + " has been validated and trusted.");
} else {
LOG.debug("Client/agent connection from ip=" + clientAddress + " accepted without certificate validation.");
}
}

if (primaryClientCertificate != null && activeCertMap != null && !Strings.isNullOrEmpty(clientAddress)) {
activeCertMap.put(clientAddress, primaryClientCertificate);
}
}

Expand All @@ -138,9 +150,6 @@ public void checkServerTrusted(X509Certificate[] x509Certificates, String s) thr

@Override
public X509Certificate[] getAcceptedIssuers() {
if (!authStrictness) {
return null;
}
return new X509Certificate[]{caCertificate};
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -267,9 +267,16 @@ public SSLEngine createSSLEngine(final SSLContext sslContext, final String remot
final boolean allowExpiredCertificate = rootCAAllowExpiredCert.value();

TrustManager[] tms = new TrustManager[]{new RootCACustomTrustManager(remoteAddress, authStrictness, allowExpiredCertificate, certMap, caCertificate, crlDao)};

sslContext.init(kmf.getKeyManagers(), tms, new SecureRandom());
final SSLEngine sslEngine = sslContext.createSSLEngine();
sslEngine.setNeedClientAuth(authStrictness);
// If authStrictness require SSL and validate client cert, otherwise prefer SSL but don't validate client cert
if (authStrictness) {
sslEngine.setNeedClientAuth(true); // Require SSL and client cert validation
} else {
sslEngine.setWantClientAuth(true); // Prefer SSL but don't validate client cert
}

return sslEngine;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,43 @@ public void setUp() throws Exception {
}

@Test
public void testAuthNotStrict() throws Exception {
public void testAuthNotStrictWithInvalidCert() throws Exception {
final RootCACustomTrustManager trustManager = new RootCACustomTrustManager(clientIp, false, true, certMap, caCertificate, crlDao);
trustManager.checkClientTrusted(null, null);
Assert.assertNull(trustManager.getAcceptedIssuers());
}

@Test
public void testAuthNotStrictWithRevokedCert() throws Exception {
Mockito.when(crlDao.findBySerial(Mockito.any(BigInteger.class))).thenReturn(new CrlVO());
final RootCACustomTrustManager trustManager = new RootCACustomTrustManager(clientIp, false, true, certMap, caCertificate, crlDao);
trustManager.checkClientTrusted(new X509Certificate[]{caCertificate}, "RSA");
Assert.assertTrue(certMap.containsKey(clientIp));
Assert.assertEquals(certMap.get(clientIp), caCertificate);
}

@Test
public void testAuthNotStrictWithInvalidCertOwnership() throws Exception {
Mockito.when(crlDao.findBySerial(Mockito.any(BigInteger.class))).thenReturn(null);
final RootCACustomTrustManager trustManager = new RootCACustomTrustManager(clientIp, false, true, certMap, caCertificate, crlDao);
trustManager.checkClientTrusted(new X509Certificate[]{caCertificate}, "RSA");
Assert.assertTrue(certMap.containsKey(clientIp));
Assert.assertEquals(certMap.get(clientIp), caCertificate);
}

@Test(expected = CertificateException.class)
public void testAuthNotStrictWithDenyExpiredCertAndOwnership() throws Exception {
Mockito.when(crlDao.findBySerial(Mockito.any(BigInteger.class))).thenReturn(null);
final RootCACustomTrustManager trustManager = new RootCACustomTrustManager(clientIp, false, false, certMap, caCertificate, crlDao);
trustManager.checkClientTrusted(new X509Certificate[]{expiredClientCertificate}, "RSA");
}

@Test
public void testAuthNotStrictWithAllowExpiredCertAndOwnership() throws Exception {
Mockito.when(crlDao.findBySerial(Mockito.any(BigInteger.class))).thenReturn(null);
final RootCACustomTrustManager trustManager = new RootCACustomTrustManager(clientIp, false, true, certMap, caCertificate, crlDao);
trustManager.checkClientTrusted(new X509Certificate[]{expiredClientCertificate}, "RSA");
Assert.assertTrue(certMap.containsKey(clientIp));
Assert.assertEquals(certMap.get(clientIp), expiredClientCertificate);
}

@Test(expected = CertificateException.class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -133,6 +133,7 @@ public void testCreateSSLEngineWithoutAuthStrictness() throws Exception {
provider.rootCAAuthStrictness = Mockito.mock(ConfigKey.class);
Mockito.when(provider.rootCAAuthStrictness.value()).thenReturn(Boolean.FALSE);
final SSLEngine e = provider.createSSLEngine(SSLUtils.getSSLContext(), "/1.2.3.4:5678", null);
Assert.assertTrue(e.getWantClientAuth());
Assert.assertFalse(e.getNeedClientAuth());
}

Expand All @@ -149,4 +150,4 @@ public void testGetProviderName() throws Exception {
Assert.assertEquals(provider.getProviderName(), "root");
}

}
}