diff --git a/kms/capi/capi.go b/kms/capi/capi.go index 661a7808..8586ec79 100644 --- a/kms/capi/capi.go +++ b/kms/capi/capi.go @@ -54,6 +54,7 @@ const ( IssuerNameArg = "issuer" KeySpec = "key-spec" // 0, 1, 2; none/NONE, at_keyexchange/AT_KEYEXCHANGE, at_signature/AT_SIGNATURE SkipFindCertificateKey = "skip-find-certificate-key" // skips looking up certificate private key when storing a certificate + DeleteKeyArg = "delete-key" // when "true" on a DeleteCertificate URI, also deletes the CNG key associated with the certificate ) const ( @@ -97,6 +98,7 @@ type uriAttributes struct { description string keySpec string skipFindCertificateKey bool + deleteKey bool pin string } @@ -140,6 +142,7 @@ func parseURI(rawuri string) (*uriAttributes, error) { description: u.Get(DescriptionArg), keySpec: u.Get(KeySpec), skipFindCertificateKey: u.GetBool(SkipFindCertificateKey), + deleteKey: u.GetBool(DeleteKeyArg), pin: u.Pin(), }, nil } @@ -1065,19 +1068,13 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { return nil } - if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { - return fmt.Errorf("failed removing certificate: %w", err) - } - return nil + return deleteCertContextAndMaybeKey(certHandle, u.deleteKey) case len(u.keyID) > 0: certHandle, err = findCertificateBySubjectKeyID(st, u.keyID) if err != nil { return err } - if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { - return fmt.Errorf("failed removing certificate: %w", err) - } - return nil + return deleteCertContextAndMaybeKey(certHandle, u.deleteKey) case u.issuerName != "" && u.serialNumber != nil: // TODO: Replace this search with a CERT_ID + CERT_ISSUER_SERIAL_NUMBER search instead // https://learn.microsoft.com/en-us/windows/win32/api/wincrypt/ns-wincrypt-cert_id @@ -1106,11 +1103,7 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { } if bytes.Equal(x509Cert.SerialNumber.Bytes(), serialBytes) { - if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { - return fmt.Errorf("failed removing certificate: %w", err) - } - - return nil + return deleteCertContextAndMaybeKey(certHandle, u.deleteKey) } prevCert = certHandle } @@ -1131,48 +1124,134 @@ func (k *CAPIKMS) DeleteCertificate(req *apiv1.DeleteCertificateRequest) error { if err != nil { return err } - if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { - return fmt.Errorf("failed removing certificate: %w", err) - } - return nil + return deleteCertContextAndMaybeKey(certHandle, u.deleteKey) default: return fmt.Errorf("%q, %q, or %q and %q is required to find a certificate", HashArg, KeyIDArg, IssuerNameArg, SerialNumberArg) } } -// CleanupCredentials implements [apiv1.CredentialsCleaner]. It finds all -// certificates in the Windows certificate store issued to the subject in req by -// the issuer in req, and deletes any that have already expired. +// deleteCertContextAndMaybeKey removes the certificate referenced by certHandle +// from its store, and when deleteKey is true also removes the CNG private key +// associated with it. The key is removed first so that, if either step fails, +// the certificate remains and a subsequent pass can retry the cleanup; deleting +// the cert first risks orphaning the key once the cert is no longer reachable +// to look up its CRYPT_KEY_PROV_INFO. +// +// certHandle is consumed: CertDeleteCertificateFromStore always frees the +// context (per Microsoft docs); on the bail-out path before the cert delete we +// free it explicitly. +func deleteCertContextAndMaybeKey(certHandle *windows.CertContext, deleteKey bool) error { + if deleteKey { + // best effort: a cert without a CNG-backed key shouldn't fail the call. + if kh, err := cryptFindCertificatePrivateKey(certHandle); err == nil { + if err := nCryptDeleteKey(kh); err != nil { + windows.CertFreeCertificateContext(certHandle) + return fmt.Errorf("failed removing CNG private key: %w", err) + } + } + } + + if err := windows.CertDeleteCertificateFromStore(certHandle); err != nil { + return fmt.Errorf("failed removing certificate: %w", err) + } + return nil +} + +// CleanupCredentials implements [apiv1.CredentialsCleaner]. It scans the Windows +// certificate store identified by req for certificates issued by req.Issuer and +// removes any that have expired, along with their CNG private keys. When +// req.RawSubject is non-empty, only certificates whose DER-encoded Subject +// matches are considered. func (k *CAPIKMS) CleanupCredentials(req *apiv1.CleanupCredentialsRequest) error { - certs, err := k.FindCertificatesByIssuer(&apiv1.LoadCertificateRequest{ - Name: uri.New("capi", url.Values{ - "issuer": []string{req.Issuer}, - "store-location": []string{req.StoreLocation}, - "store": []string{req.Store}, - }).String(), - }, req.RawSubject) + if req == nil { + return errors.New("cleanupCredentialsRequest cannot be nil") + } + if req.Issuer == "" { + return fmt.Errorf("%q is required", IssuerNameArg) + } + + storeLocation := cmp.Or(req.StoreLocation, UserStoreLocation) + storeName := cmp.Or(req.Store, MyStore) + + var certStoreLocation uint32 + switch storeLocation { + case UserStoreLocation: + certStoreLocation = certStoreCurrentUser + case MachineStoreLocation: + certStoreLocation = certStoreLocalMachine + default: + return fmt.Errorf("invalid cert store location %q", storeLocation) + } + + st, err := windows.CertOpenStore( + certStoreProvSystem, + 0, + 0, + certStoreLocation, + uintptr(unsafe.Pointer(wide(storeName))), + ) if err != nil { - return fmt.Errorf("failed loading certificates by issuer %q: %w", req.Issuer, err) + return fmt.Errorf("CertOpenStore for the %q store %q returned: %w", storeLocation, storeName, err) } - var deleteErrors []error now := time.Now() - for _, cert := range certs { - if cert.NotAfter.Before(now) { - deleteURI := uri.New("capi", url.Values{ - "store-location": []string{req.StoreLocation}, - "store": []string{req.Store}, - "issuer": []string{req.Issuer}, - "serial": []string{"0x" + cert.SerialNumber.Text(16)}, - }).String() - - if err := k.DeleteCertificate(&apiv1.DeleteCertificateRequest{Name: deleteURI}); err != nil { - deleteErrors = append(deleteErrors, fmt.Errorf("failed deleting expired certificate (serial %s): %w", cert.SerialNumber.Text(16), err)) - } + var errs []error + + // Restart the enumeration after each deletion to avoid juggling cert-context + // lifetimes across a delete: CertDeleteCertificateFromStore frees the context + // it was given, which would invalidate the "previous" pointer that the next + // CertFindCertificateInStore call expects. + for { + deleted, err := k.cleanupOnePass(st, req.Issuer, req.RawSubject, now) + if err != nil { + errs = append(errs, err) + } + if !deleted { + break } } - return errors.Join(deleteErrors...) + return errors.Join(errs...) +} + +// cleanupOnePass walks the store looking for the first expired certificate +// matching issuer (and rawSubject, if set), deletes it together with its CNG +// key, and returns deleted=true. When no matching expired certificate is found +// it returns deleted=false and the caller stops iterating. Errors encountered +// while inspecting individual certificates are returned with deleted=false so +// the caller can record them and exit (a persistent inspection error would +// otherwise spin the outer loop forever). +func (k *CAPIKMS) cleanupOnePass(st windows.Handle, issuer string, rawSubject []byte, now time.Time) (bool, error) { + var prevCert *windows.CertContext + for { + certHandle, err := findCertificateInStore(st, + encodingX509ASN|encodingPKCS7, + 0, + findIssuerStr, + uintptr(unsafe.Pointer(wide(issuer))), prevCert) + if err != nil { + return false, fmt.Errorf("findCertificateInStore failed: %w", err) + } + if certHandle == nil { + // prevCert was freed by the last findCertificateInStore call. + return false, nil + } + + x509Cert, err := certContextToX509(certHandle) + if err != nil { + windows.CertFreeCertificateContext(certHandle) + return false, fmt.Errorf("could not unmarshal certificate: %w", err) + } + + matchesSubject := len(rawSubject) == 0 || bytes.Equal(x509Cert.RawSubject, rawSubject) + if matchesSubject && x509Cert.NotAfter.Before(now) { + if err := deleteCertContextAndMaybeKey(certHandle, true); err != nil { + return false, fmt.Errorf("failed deleting expired certificate (serial %s): %w", x509Cert.SerialNumber.Text(16), err) + } + return true, nil + } + prevCert = certHandle + } } func (k *CAPIKMS) getKeyFlags(u *uriAttributes) (uint32, error) { @@ -1300,4 +1379,7 @@ func validateIntermediateCertificate(c *x509.Certificate) error { return nil } -var _ apiv1.CertificateManager = (*CAPIKMS)(nil) +var ( + _ apiv1.CertificateManager = (*CAPIKMS)(nil) + _ apiv1.CredentialsCleaner = (*CAPIKMS)(nil) +) diff --git a/kms/platform/kms_windows_test.go b/kms/platform/kms_windows_test.go index a7a8712f..0bd91b65 100644 --- a/kms/platform/kms_windows_test.go +++ b/kms/platform/kms_windows_test.go @@ -503,19 +503,29 @@ func TestKMS_SearchKeys_capi(t *testing.T) { func TestKMS_CleanupCredentials_capi(t *testing.T) { capiKMS := mustCAPIKMS(t) - // Use an expired certificate + // Use an expired certificate. withNoCleanup disables the t.Cleanup hooks + // that would otherwise call DeleteKey / DeleteCertificate on the test's + // behalf; CleanupCredentials is expected to remove both the certificate + // and its CNG private key, so a follow-up DeleteKey would fail with + // "not found". The test instead asserts that both have been removed. chain := mustCreatePlatformCertificate(t, capiKMS, - withNoCleanupCertificate(), + withNoCleanup(), withTemplateModifier(func(c *x509.Certificate) *x509.Certificate { c.NotBefore = time.Now().Add(-time.Minute).Truncate(time.Second) c.NotAfter = time.Now().Add(-time.Second).Truncate(time.Second) return c })) + // Sanity-check the precondition: both the certificate and the CNG key + // exist before CleanupCredentials runs. _, err := capiKMS.LoadCertificate(&apiv1.LoadCertificateRequest{ Name: platformCertName, }) require.NoError(t, err) + _, err = capiKMS.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: platformKeyName, + }) + require.NoError(t, err) type args struct { req *apiv1.CleanupCredentialsRequest @@ -533,7 +543,12 @@ func TestKMS_CleanupCredentials_capi(t *testing.T) { _, loadErr := capiKMS.LoadCertificate(&apiv1.LoadCertificateRequest{ Name: platformCertName, }) - return assert.NoError(t, err) && assert.Error(t, loadErr) + _, getErr := capiKMS.GetPublicKey(&apiv1.GetPublicKeyRequest{ + Name: platformKeyName, + }) + return assert.NoError(t, err) && + assert.Error(t, loadErr, "certificate should be removed") && + assert.Error(t, getErr, "CNG private key should be removed") }}, } for _, tt := range tests { diff --git a/kms/tpmkms/tpmkms.go b/kms/tpmkms/tpmkms.go index ab66f516..13a2e400 100644 --- a/kms/tpmkms/tpmkms.go +++ b/kms/tpmkms/tpmkms.go @@ -1073,6 +1073,10 @@ func (k *TPMKMS) deleteCertificateFromWindowsCertificateStore(req *apiv1.DeleteC uv := url.Values{} uv.Set("store-location", location) uv.Set("store", store) + // Also remove the CNG private key paired with the certificate; TPMKMS-managed + // certs have a 1:1 CNG key with no independent use, so leaving the .PCPKSP + // blob behind would orphan it on disk. + uv.Set("delete-key", "true") switch { case o.serial != "":