From 28e8481dc81cf2b52f5aab17db24d95bc329e089 Mon Sep 17 00:00:00 2001 From: Tony Arcieri Date: Fri, 30 Jan 2026 09:12:36 -0700 Subject: [PATCH] Implement new `kem::Kem` trait Companion PR to RustCrypto/traits#2243 This implements a trait which describes a whole KEM type family, with a similar shape to the former `dhkem::DhKem` and `ml_kem::KemCore` traits (both of which have been removed and replaced with `kem::Kem`). As part of this, the `*Params` types in `ml_kem` have been merged with the former type aliases of the `ml_kem::Kem` type (which have also been removed), and now `MlKem512`, `MlKem768`, and `MlKem1024` are the one true ZSTs for describing ML-KEM parameters. --- Cargo.lock | 20 +++--- dhkem/Cargo.toml | 4 +- dhkem/src/ecdh_kem.rs | 49 +++++++------ dhkem/src/lib.rs | 24 ++++--- dhkem/src/x25519_kem.rs | 39 +++++----- dhkem/tests/hpke_p256_test.rs | 4 +- dhkem/tests/tests.rs | 18 ++--- ml-kem/Cargo.toml | 4 +- ml-kem/benches/mlkem.rs | 24 +++---- ml-kem/src/kem.rs | 132 +++++++++++----------------------- ml-kem/src/lib.rs | 121 +++++++++++++++++-------------- ml-kem/src/param.rs | 23 +++--- ml-kem/src/pkcs8.rs | 16 ++--- ml-kem/src/pke.rs | 30 ++++---- ml-kem/src/traits.rs | 42 +++++------ ml-kem/tests/encap-decap.rs | 20 +++--- ml-kem/tests/key-gen.rs | 13 ++-- ml-kem/tests/pkcs8.rs | 15 ++-- ml-kem/tests/wycheproof.rs | 14 ++-- x-wing/Cargo.toml | 2 +- x-wing/src/lib.rs | 96 +++++++++++-------------- 21 files changed, 337 insertions(+), 373 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index bbc16e9..3440075 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -384,9 +384,9 @@ dependencies = [ [[package]] name = "digest" -version = "0.11.0-rc.8" +version = "0.11.0-rc.9" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "2fc1408b7a9f59a7b933faff3e9e7fc15a05a524effd3b3d1601156944c8077f" +checksum = "bff8de092798697546237a3a701e4174fe021579faec9b854379af9bf1e31962" dependencies = [ "block-buffer", "const-oid", @@ -624,18 +624,18 @@ checksum = "e712f64ec3850b98572bffac52e2c6f282b29fe6c5fa6d42334b30be438d95c1" [[package]] name = "hkdf" -version = "0.13.0-rc.3" +version = "0.13.0-rc.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "cfbb4225acf2b5cc4e12d384672cd6d1f0cb980ff5859ffcf144db25b593a24d" +checksum = "c1493605868fc7d216afa78a26956d56f5c0a12dbdb8ee4fe9e0b70a28ec7d57" dependencies = [ "hmac", ] [[package]] name = "hmac" -version = "0.13.0-rc.3" +version = "0.13.0-rc.4" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "f1c597ac7d6cc8143e30e83ef70915e7f883b18d8bec2e2b2bce47f5bbb06d57" +checksum = "d9956e202a691c5c86c60303a421f66f93f44b29433407b7c43cf2bebadc750e" dependencies = [ "digest", ] @@ -724,9 +724,9 @@ dependencies = [ [[package]] name = "kem" -version = "0.3.0-rc.2" +version = "0.3.0-rc.3" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1d13ca544bfe26ab1f199a9eef34be41bf06827d0372b35f61846edb9af8d18d" +checksum = "5eb982d00ac39162293481bac7f737b667d4ed1661bf057529466bf031351d43" dependencies = [ "crypto-common", "rand_core", @@ -919,9 +919,9 @@ checksum = "8b870d8c151b6f2fb93e84a13146138f05d02ed11c7e7c54f8826aaaf7c9f184" [[package]] name = "pkcs8" -version = "0.11.0-rc.8" +version = "0.11.0-rc.10" source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "77089aec8290d0b7bb01b671b091095cf1937670725af4fd73d47249f03b12c0" +checksum = "b226d2cc389763951db8869584fd800cbbe2962bf454e2edeb5172b31ee99774" dependencies = [ "der", "spki", diff --git a/dhkem/Cargo.toml b/dhkem/Cargo.toml index 58cf4b2..e1b9af3 100644 --- a/dhkem/Cargo.toml +++ b/dhkem/Cargo.toml @@ -14,7 +14,7 @@ keywords = ["crypto", "ecdh", "ecc"] readme = "README.md" [dependencies] -kem = "0.3.0-rc.2" +kem = "0.3.0-rc.3" rand_core = "0.10.0-rc-6" # optional dependencies @@ -29,7 +29,7 @@ zeroize = { version = "1.8.1", optional = true, default-features = false } [dev-dependencies] getrandom = { version = "0.4.0-rc.1", features = ["sys_rng"] } hex-literal = "1" -hkdf = "0.13.0-rc.3" +hkdf = "0.13.0-rc.4" sha2 = "0.11.0-rc.4" [features] diff --git a/dhkem/src/ecdh_kem.rs b/dhkem/src/ecdh_kem.rs index 2158a30..92e5e3d 100644 --- a/dhkem/src/ecdh_kem.rs +++ b/dhkem/src/ecdh_kem.rs @@ -10,7 +10,7 @@ use elliptic_curve::{ }, }; use kem::{ - Ciphertext, Encapsulate, Generate, InvalidKey, KemParams, KeyExport, KeySizeUser, SharedSecret, + Ciphertext, Encapsulate, Generate, InvalidKey, Kem, KeyExport, KeySizeUser, SharedKey, TryDecapsulate, TryKeyInit, }; use rand_core::{CryptoRng, TryCryptoRng}; @@ -29,15 +29,20 @@ pub type EcdhEncapsulationKey = EncapsulationKey>; /// traits from the `elliptic-curve` crate. /// /// Implements a KEM interface that internally uses ECDH. +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] pub struct EcdhKem(PhantomData); -impl KemParams for EcdhEncapsulationKey +impl Kem for EcdhKem where C: CurveArithmetic, FieldBytesSize: ModulusSize, + EcdhDecapsulationKey: TryDecapsulate + Generate, + EcdhEncapsulationKey: Encapsulate + Clone, { + type DecapsulationKey = EcdhDecapsulationKey; + type EncapsulationKey = EcdhEncapsulationKey; type CiphertextSize = UncompressedPointSize; - type SharedSecretSize = FieldBytesSize; + type SharedKeySize = FieldBytesSize; } /// From [RFC9810 §7.1.1]: `SerializePublicKey` and `DeserializePublicKey`: @@ -97,13 +102,26 @@ where } } -impl Encapsulate for EcdhEncapsulationKey +impl Generate for EcdhDecapsulationKey +where + C: CurveArithmetic, + FieldBytesSize: ModulusSize, +{ + fn try_generate_from_rng(rng: &mut R) -> Result { + Ok(EphemeralSecret::try_generate_from_rng(rng)?.into()) + } +} + +impl Encapsulate> for EcdhEncapsulationKey where C: CurveArithmetic, FieldBytesSize: ModulusSize, AffinePoint: FromEncodedPoint + ToEncodedPoint, { - fn encapsulate_with_rng(&self, rng: &mut R) -> (Ciphertext, SharedSecret) + fn encapsulate_with_rng( + &self, + rng: &mut R, + ) -> (Ciphertext>, SharedKey>) where R: CryptoRng + ?Sized, { @@ -111,25 +129,12 @@ where let sk = EphemeralSecret::generate_from_rng(rng); let ss = sk.diffie_hellman(&self.0); - // TODO(tarcieri): sk.public_key().to_uncompressed_point() - let mut pk = UncompressedPoint::::default(); - pk.copy_from_slice(sk.public_key().to_encoded_point(false).as_bytes()); - + let pk = sk.public_key().to_uncompressed_point(); (pk, ss.raw_secret_bytes().clone()) } } -impl Generate for EcdhDecapsulationKey -where - C: CurveArithmetic, - FieldBytesSize: ModulusSize, -{ - fn try_generate_from_rng(rng: &mut R) -> Result { - Ok(EphemeralSecret::try_generate_from_rng(rng)?.into()) - } -} - -impl TryDecapsulate for EcdhDecapsulationKey +impl TryDecapsulate> for EcdhDecapsulationKey where C: CurveArithmetic, FieldBytesSize: ModulusSize, @@ -139,8 +144,8 @@ where fn try_decapsulate( &self, - encapsulated_key: &Ciphertext, - ) -> Result, Error> { + encapsulated_key: &Ciphertext>, + ) -> Result>, Error> { let encapsulated_key = PublicKey::::from_sec1_bytes(encapsulated_key)?; let shared_secret = self.dk.diffie_hellman(&encapsulated_key); Ok(shared_secret.raw_secret_bytes().clone()) diff --git a/dhkem/src/lib.rs b/dhkem/src/lib.rs index 762ac0a..33d9460 100644 --- a/dhkem/src/lib.rs +++ b/dhkem/src/lib.rs @@ -30,7 +30,7 @@ //! [RFC9180]: https://datatracker.ietf.org/doc/html/rfc9180#name-dh-based-kem-dhkem //! [TLS KEM combiner]: https://datatracker.ietf.org/doc/html/draft-ietf-tls-hybrid-design-10 -pub use kem::{self, Decapsulator, Encapsulate, Generate, KemParams, TryDecapsulate}; +pub use kem::{self, Encapsulate, Generate, Kem, TryDecapsulate}; #[cfg(feature = "ecdh")] mod ecdh_kem; @@ -60,13 +60,8 @@ pub struct DecapsulationKey { ek: EncapsulationKey, } -impl Decapsulator for DecapsulationKey -where - EncapsulationKey: Encapsulate + Clone, -{ - type Encapsulator = EncapsulationKey; - - fn encapsulator(&self) -> &EncapsulationKey { +impl AsRef> for DecapsulationKey { + fn as_ref(&self) -> &EncapsulationKey { &self.ek } } @@ -74,7 +69,6 @@ where impl From for DecapsulationKey where EK: for<'a> From<&'a DK>, - EncapsulationKey: KemParams, { fn from(dk: DK) -> Self { let ek = EncapsulationKey(EK::from(&dk)); @@ -140,6 +134,9 @@ impl Zeroize for DecapsulationKey { #[cfg(feature = "zeroize")] impl ZeroizeOnDrop for DecapsulationKey {} +/// NIST P-256 DHKEM. +#[cfg(feature = "p256")] +pub type NistP256Kem = EcdhKem; /// NIST P-256 ECDH Decapsulation Key. #[cfg(feature = "p256")] pub type NistP256DecapsulationKey = EcdhDecapsulationKey; @@ -147,6 +144,9 @@ pub type NistP256DecapsulationKey = EcdhDecapsulationKey; #[cfg(feature = "p256")] pub type NistP256EncapsulationKey = EcdhEncapsulationKey; +/// NIST P-256 DHKEM. +#[cfg(feature = "p384")] +pub type NistP384Kem = EcdhKem; /// NIST P-384 ECDH Decapsulation Key. #[cfg(feature = "p384")] pub type NistP384DecapsulationKey = EcdhDecapsulationKey; @@ -154,6 +154,9 @@ pub type NistP384DecapsulationKey = EcdhDecapsulationKey; #[cfg(feature = "p384")] pub type NistP384EncapsulationKey = EcdhEncapsulationKey; +/// NIST P-521 DHKEM. +#[cfg(feature = "p521")] +pub type NistP521Kem = EcdhKem; /// NIST P-521 ECDH Decapsulation Key. #[cfg(feature = "p521")] pub type NistP521DecapsulationKey = EcdhDecapsulationKey; @@ -161,6 +164,9 @@ pub type NistP521DecapsulationKey = EcdhDecapsulationKey; #[cfg(feature = "p521")] pub type NistP521EncapsulationKey = EcdhEncapsulationKey; +/// secp256k1 DHKEM. +#[cfg(feature = "p521")] +pub type Secp256k1Kem = EcdhKem; /// secp256k1 ECDH Decapsulation Key. #[cfg(feature = "k256")] pub type Secp256k1DecapsulationKey = EcdhDecapsulationKey; diff --git a/dhkem/src/x25519_kem.rs b/dhkem/src/x25519_kem.rs index f8ba32f..bc3ab05 100644 --- a/dhkem/src/x25519_kem.rs +++ b/dhkem/src/x25519_kem.rs @@ -1,7 +1,7 @@ use crate::{DecapsulationKey, EncapsulationKey}; use kem::{ - Decapsulate, Encapsulate, Generate, InvalidKey, KemParams, Key, KeyExport, KeySizeUser, - TryKeyInit, common::array::Array, consts::U32, + Decapsulate, Encapsulate, Generate, InvalidKey, Kem, Key, KeyExport, KeySizeUser, TryKeyInit, + common::array::Array, consts::U32, }; use rand_core::{CryptoRng, TryCryptoRng, UnwrapErr}; use x25519::{PublicKey, ReusableSecret}; @@ -20,16 +20,19 @@ pub type X25519EncapsulationKey = EncapsulationKey; type Ciphertext = Array; /// X25519 shared secrets are also compressed Montgomery x/u-coordinates. -type SharedSecret = Array; +type SharedKey = Array; /// X22519 Diffie-Hellman KEM adapter. /// /// Implements a KEM interface that internally uses X25519 ECDH. +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] pub struct X25519Kem; -impl KemParams for EncapsulationKey { +impl Kem for X25519Kem { + type DecapsulationKey = X25519DecapsulationKey; + type EncapsulationKey = X25519EncapsulationKey; type CiphertextSize = U32; - type SharedSecretSize = U32; + type SharedKeySize = U32; } /// From [RFC9810 §7.1.1]: `SerializePublicKey` and `DeserializePublicKey`: @@ -69,8 +72,17 @@ impl KeyExport for X25519EncapsulationKey { } } -impl Encapsulate for X25519EncapsulationKey { - fn encapsulate_with_rng(&self, rng: &mut R) -> (Ciphertext, SharedSecret) +impl Generate for X25519DecapsulationKey { + fn try_generate_from_rng(rng: &mut R) -> Result { + // TODO(tarcieri): don't panic! Fallible `ReusableSecret` generation? + Ok(Self::from(ReusableSecret::random_from_rng(&mut UnwrapErr( + rng, + )))) + } +} + +impl Encapsulate for X25519EncapsulationKey { + fn encapsulate_with_rng(&self, rng: &mut R) -> (Ciphertext, SharedKey) where R: CryptoRng + ?Sized, { @@ -82,17 +94,8 @@ impl Encapsulate for X25519EncapsulationKey { } } -impl Generate for X25519DecapsulationKey { - fn try_generate_from_rng(rng: &mut R) -> Result { - // TODO(tarcieri): don't panic! Fallible `ReusableSecret` generation? - Ok(Self::from(ReusableSecret::random_from_rng(&mut UnwrapErr( - rng, - )))) - } -} - -impl Decapsulate for X25519DecapsulationKey { - fn decapsulate(&self, encapsulated_key: &Ciphertext) -> SharedSecret { +impl Decapsulate for X25519DecapsulationKey { + fn decapsulate(&self, encapsulated_key: &Ciphertext) -> SharedKey { let public_key = PublicKey::from(encapsulated_key.0); self.dk.diffie_hellman(&public_key).to_bytes().into() } diff --git a/dhkem/tests/hpke_p256_test.rs b/dhkem/tests/hpke_p256_test.rs index d85a3c4..3663ae7 100644 --- a/dhkem/tests/hpke_p256_test.rs +++ b/dhkem/tests/hpke_p256_test.rs @@ -5,7 +5,7 @@ use dhkem::NistP256DecapsulationKey; use elliptic_curve::Generate; use hex_literal::hex; use hkdf::Hkdf; -use kem::{Decapsulator, Encapsulate, KeyExport, TryDecapsulate}; +use kem::{Encapsulate, KeyExport, TryDecapsulate}; use rand_core::{TryCryptoRng, TryRng}; use sha2::Sha256; @@ -83,7 +83,7 @@ fn test_dhkem_p256_hkdf_sha256() { "f3ce7fdae57e1a310d87f1ebbde6f328be0a99cdbcadf4d6589cf29de4b8ffd2" ))) .unwrap(); - let pkr = skr.encapsulator(); + let pkr = skr.as_ref(); assert_eq!(&pkr.to_bytes(), &example_pkr); let (pke, ss1) = pkr.encapsulate_with_rng(&mut ConstantRng(&hex!( diff --git a/dhkem/tests/tests.rs b/dhkem/tests/tests.rs index afebbee..d7f1887 100644 --- a/dhkem/tests/tests.rs +++ b/dhkem/tests/tests.rs @@ -6,11 +6,11 @@ feature = "x25519" ))] -use kem::{Decapsulator, Encapsulate, Generate, TryDecapsulate}; +use kem::{Encapsulate, Generate, Kem, TryDecapsulate}; -fn test_kem() { - let dk = DK::generate(); - let ek = dk.encapsulator(); +fn test_kem() { + let dk = K::DecapsulationKey::generate(); + let ek = dk.as_ref().clone(); let (ek, ss1) = ek.encapsulate(); let ss2 = dk.try_decapsulate(&ek).unwrap(); assert_eq!(ss1.as_slice(), ss2.as_slice()); @@ -19,29 +19,29 @@ fn test_kem() { #[cfg(feature = "x25519")] #[test] fn test_x25519() { - test_kem::(); + test_kem::(); } #[cfg(feature = "k256")] #[test] fn test_k256() { - test_kem::(); + test_kem::(); } #[cfg(feature = "p256")] #[test] fn test_p256() { - test_kem::(); + test_kem::(); } #[cfg(feature = "p384")] #[test] fn test_p384() { - test_kem::(); + test_kem::(); } #[cfg(feature = "p521")] #[test] fn test_p521() { - test_kem::(); + test_kem::(); } diff --git a/ml-kem/Cargo.toml b/ml-kem/Cargo.toml index 8ae90e5..911bcfa 100644 --- a/ml-kem/Cargo.toml +++ b/ml-kem/Cargo.toml @@ -27,14 +27,14 @@ zeroize = ["module-lattice/zeroize", "dep:zeroize"] [dependencies] array = { version = "0.4.4", package = "hybrid-array", features = ["extra-sizes", "subtle"] } module-lattice = { version = "0.1.0-rc.0", features = ["subtle"] } -kem = "0.3.0-rc.2" +kem = "0.3.0-rc.3" rand_core = "0.10.0-rc-6" sha3 = { version = "0.11.0-rc.3", default-features = false } subtle = { version = "2", default-features = false } # optional dependencies const-oid = { version = "0.10.1", optional = true, default-features = false, features = ["db"] } -pkcs8 = { version = "0.11.0-rc.8", optional = true, default-features = false } +pkcs8 = { version = "0.11.0-rc.10", optional = true, default-features = false } zeroize = { version = "1.8.1", optional = true, default-features = false } [dev-dependencies] diff --git a/ml-kem/benches/mlkem.rs b/ml-kem/benches/mlkem.rs index 51ac9af..e7934f8 100644 --- a/ml-kem/benches/mlkem.rs +++ b/ml-kem/benches/mlkem.rs @@ -1,4 +1,5 @@ -use ::kem::{Decapsulate, Decapsulator, Encapsulate, Generate}; +use ::kem::{Decapsulate, Encapsulate, Kem, KeyExport, KeyInit, TryKeyInit}; +use core::hint::black_box; use criterion::{Criterion, criterion_group, criterion_main}; use getrandom::SysRng; use ml_kem::*; @@ -10,25 +11,25 @@ fn criterion_benchmark(c: &mut Criterion) { // Key generation c.bench_function("keygen", |b| { b.iter(|| { - let dk = ml_kem_768::DecapsulationKey::generate_from_rng(&mut rng); - let _dk_bytes = dk.to_encoded_bytes(); - let _ek_bytes = dk.encapsulator().to_encoded_bytes(); + let (dk, ek) = MlKem768::generate_keypair_from_rng(&mut rng); + let _dk_bytes = black_box(dk.to_seed().unwrap()); + let _ek_bytes = black_box(ek.to_bytes()); }) }); - let dk = ml_kem_768::DecapsulationKey::generate_from_rng(&mut rng); - let dk_bytes = dk.to_encoded_bytes(); - let ek_bytes = dk.encapsulator().to_encoded_bytes(); - let ek = ml_kem_768::EncapsulationKey::from_encoded_bytes(&ek_bytes).unwrap(); + let (dk, ek) = MlKem768::generate_keypair_from_rng(&mut rng); + let dk_bytes = dk.to_seed().unwrap(); + let ek_bytes = ek.to_bytes(); + let ek = ::EncapsulationKey::new(&ek_bytes).unwrap(); // Encapsulation c.bench_function("encapsulate", |b| { b.iter(|| ek.encapsulate_with_rng(&mut rng)) }); - let (ct, _ss) = ek.encapsulate_with_rng(&mut rng); + let (ct, _sk) = ek.encapsulate_with_rng(&mut rng); // Decapsulation - let dk = ::DecapsulationKey::from_encoded_bytes(&dk_bytes).unwrap(); + let dk = ::DecapsulationKey::new(&dk_bytes); c.bench_function("decapsulate", |b| { b.iter(|| { @@ -39,8 +40,7 @@ fn criterion_benchmark(c: &mut Criterion) { // Round trip c.bench_function("round_trip", |b| { b.iter(|| { - let dk = ml_kem_768::DecapsulationKey::generate_from_rng(&mut rng); - let ek = dk.encapsulator(); + let (dk, ek) = MlKem768::generate_keypair_from_rng(&mut rng); let (ct, _sk) = ek.encapsulate_with_rng(&mut rng); dk.decapsulate(&ct); }) diff --git a/ml-kem/src/kem.rs b/ml-kem/src/kem.rs index 6065f02..2441ae8 100644 --- a/ml-kem/src/kem.rs +++ b/ml-kem/src/kem.rs @@ -2,31 +2,24 @@ // Re-export traits from the `kem` crate pub use ::kem::{ - Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, Key, KeyExport, KeyInit, + Ciphertext, Decapsulate, Encapsulate, Generate, InvalidKey, Kem, Key, KeyExport, KeyInit, KeySizeUser, TryKeyInit, }; -use sha3::Digest; use crate::{ - B32, Encoded, EncodedSizeUser, KemCore, Seed, + B32, Encoded, EncodedSizeUser, Seed, SharedKey, crypto::{G, H, J}, - param::{ - DecapsulationKeySize, EncapsulationKeySize, EncodedCiphertext, ExpandedDecapsulationKey, - KemParams, - }, + param::{DecapsulationKeySize, EncapsulationKeySize, ExpandedDecapsulationKey, KemParams}, pke::{DecryptionKey, EncryptionKey}, }; -use array::typenum::{U32, U64}; -use core::marker::PhantomData; +use array::sizes::{U32, U64}; use rand_core::{CryptoRng, TryCryptoRng, TryRng}; +use sha3::Digest; use subtle::{ConditionallySelectable, ConstantTimeEq}; #[cfg(feature = "zeroize")] use zeroize::{Zeroize, ZeroizeOnDrop}; -/// A shared key resulting from an ML-KEM transaction -pub(crate) type SharedKey = B32; - /// A `DecapsulationKey` provides the ability to generate a new key pair, and decapsulate an /// encapsulated shared key. #[derive(Clone, Debug)] @@ -75,11 +68,11 @@ where } } -impl

Decapsulate for DecapsulationKey

+impl

Decapsulate

for DecapsulationKey

where - P: KemParams, + P: Kem, SharedKeySize = U32> + KemParams, { - fn decapsulate(&self, encapsulated_key: &EncodedCiphertext

) -> SharedKey { + fn decapsulate(&self, encapsulated_key: &Ciphertext

) -> SharedKey { let mp = self.dk_pke.decrypt(encapsulated_key); let (Kp, rp) = G(&[&mp, &self.ek.h]); let Kbar = J(&[self.z.as_slice(), encapsulated_key.as_ref()]); @@ -88,13 +81,11 @@ where } } -impl

Decapsulator for DecapsulationKey

+impl

AsRef> for DecapsulationKey

where P: KemParams, { - type Encapsulator = EncapsulationKey

; - - fn encapsulator(&self) -> &EncapsulationKey

{ + fn as_ref(&self) -> &EncapsulationKey

{ &self.ek } } @@ -223,7 +214,7 @@ where #[allow(clippy::similar_names)] // allow dk_pke, ek_pke, following the spec pub(crate) fn generate_deterministic(d: B32, z: B32) -> Self { let (dk_pke, ek_pke) = DecryptionKey::generate(&d); - let ek = EncapsulationKey::new(ek_pke); + let ek = EncapsulationKey::from_encryption_key(ek_pke); let d = Some(d); Self { dk_pke, ek, d, z } } @@ -242,9 +233,9 @@ where impl

EncapsulationKey

where - P: KemParams, + P: Kem + KemParams, { - pub(crate) fn new(ek_pke: EncryptionKey

) -> Self { + pub(crate) fn from_encryption_key(ek_pke: EncryptionKey

) -> Self { let h = H(ek_pke.to_bytes()); Self { ek_pke, h } } @@ -255,18 +246,18 @@ where /// Do NOT use this function unless you know what you're doing. If you fail to use all uniform /// random bytes even once, you can have catastrophic security failure. #[cfg_attr(not(feature = "hazmat"), doc(hidden))] - pub fn encapsulate_deterministic(&self, m: &B32) -> (EncodedCiphertext

, SharedKey) { + pub fn encapsulate_deterministic(&self, m: &B32) -> (Ciphertext

, SharedKey) { let (K, r) = G(&[m, &self.h]); let c = self.ek_pke.encrypt(m, &r); (c, K) } } -impl

Encapsulate for EncapsulationKey

+impl

Encapsulate

for EncapsulationKey

where - P: KemParams, + P: Kem + KemParams, { - fn encapsulate_with_rng(&self, rng: &mut R) -> (EncodedCiphertext

, SharedKey) + fn encapsulate_with_rng(&self, rng: &mut R) -> (Ciphertext

, SharedKey) where R: CryptoRng + ?Sized, { @@ -282,7 +273,7 @@ where type EncodedSize = EncapsulationKeySize

; fn from_encoded_bytes(enc: &Encoded) -> Result { - Ok(Self::new(EncryptionKey::from_bytes(enc)?)) + Ok(Self::from_encryption_key(EncryptionKey::from_bytes(enc)?)) } fn to_encoded_bytes(&self) -> Encoded { @@ -290,14 +281,6 @@ where } } -impl

::kem::KemParams for EncapsulationKey

-where - P: KemParams, -{ - type CiphertextSize = P::CiphertextSize; - type SharedSecretSize = U32; -} - impl

KeyExport for EncapsulationKey

where P: KemParams, @@ -320,7 +303,7 @@ where { fn new(encapsulation_key: &Key) -> Result { EncryptionKey::from_bytes(encapsulation_key) - .map(Self::new) + .map(Self::from_encryption_key) .map_err(|_| InvalidKey) } } @@ -336,69 +319,34 @@ where } } -/// An implementation of overall ML-KEM functionality. Generic over parameter sets, but then ties -/// together all of the other related types and sizes. -#[derive(Clone)] -pub struct Kem

-where - P: KemParams, -{ - _phantom: PhantomData

, -} - -impl

KemCore for Kem

-where - P: KemParams, -{ - type SharedKeySize = U32; - type CiphertextSize = P::CiphertextSize; - type DecapsulationKey = DecapsulationKey

; - type EncapsulationKey = EncapsulationKey

; - - /// Generate a new (decapsulation, encapsulation) key pair - fn generate( - rng: &mut R, - ) -> (Self::DecapsulationKey, Self::EncapsulationKey) { - let dk = Self::DecapsulationKey::generate_from_rng(rng); - let ek = dk.encapsulation_key().clone(); - (dk, ek) - } - - fn from_seed(seed: Seed) -> (Self::DecapsulationKey, Self::EncapsulationKey) { - let dk = Self::DecapsulationKey::from_seed(seed); - let ek = dk.encapsulation_key().clone(); - (dk, ek) - } -} - #[cfg(test)] mod test { use super::*; - use crate::{MlKem512Params, MlKem768Params, MlKem1024Params}; - use ::kem::{Decapsulate, Encapsulate, Generate}; + use crate::{MlKem512, MlKem768, MlKem1024}; + use ::kem::{Encapsulate, Generate, TryDecapsulate}; use array::typenum::Unsigned; use getrandom::SysRng; use rand_core::UnwrapErr; fn round_trip_test

() where - P: KemParams, + P: Kem, { let mut rng = UnwrapErr(SysRng); - let dk = DecapsulationKey::

::generate_from_rng(&mut rng); - let ek = dk.encapsulation_key(); + let dk = P::DecapsulationKey::generate_from_rng(&mut rng); + let ek = dk.as_ref().clone(); let (ct, k_send) = ek.encapsulate_with_rng(&mut rng); - let k_recv = dk.decapsulate(&ct); + let k_recv = dk.try_decapsulate(&ct).unwrap(); assert_eq!(k_send, k_recv); } #[test] fn round_trip() { - round_trip_test::(); - round_trip_test::(); - round_trip_test::(); + round_trip_test::(); + round_trip_test::(); + round_trip_test::(); } fn expanded_key_test

() @@ -420,9 +368,9 @@ mod test { #[test] fn expanded_key() { - expanded_key_test::(); - expanded_key_test::(); - expanded_key_test::(); + expanded_key_test::(); + expanded_key_test::(); + expanded_key_test::(); } fn invalid_hash_expanded_key_test

() @@ -444,9 +392,9 @@ mod test { #[test] fn invalid_hash_expanded_key() { - invalid_hash_expanded_key_test::(); - invalid_hash_expanded_key_test::(); - invalid_hash_expanded_key_test::(); + invalid_hash_expanded_key_test::(); + invalid_hash_expanded_key_test::(); + invalid_hash_expanded_key_test::(); } fn seed_test

() @@ -464,9 +412,9 @@ mod test { #[test] fn seed() { - seed_test::(); - seed_test::(); - seed_test::(); + seed_test::(); + seed_test::(); + seed_test::(); } fn key_inequality_test

() @@ -489,8 +437,8 @@ mod test { #[test] fn key_inequality() { - key_inequality_test::(); - key_inequality_test::(); - key_inequality_test::(); + key_inequality_test::(); + key_inequality_test::(); + key_inequality_test::(); } } diff --git a/ml-kem/src/lib.rs b/ml-kem/src/lib.rs index 3478046..a91ebe3 100644 --- a/ml-kem/src/lib.rs +++ b/ml-kem/src/lib.rs @@ -25,13 +25,12 @@ //! // NOTE: requires the `getrandom` feature is enabled //! //! use ml_kem::{ -//! ml_kem_768::DecapsulationKey, -//! kem::{Decapsulate, Decapsulator, Encapsulate, Generate, KeyInit} +//! MlKem768, +//! kem::{Decapsulate, Encapsulate, Kem} //! }; //! //! // Generate a decapsulation/encapsulation keypair -//! let dk = DecapsulationKey::generate(); -//! let ek = dk.encapsulator(); +//! let (dk, ek) = MlKem768::generate_keypair(); //! //! // Encapsulate a shared key to the holder of the decapsulation key, receive the shared //! // secret `k_send` and the encapsulated form `ct`. @@ -71,10 +70,11 @@ pub mod pkcs8; /// Trait definitions mod traits; +pub use ::kem::{Ciphertext, Kem}; pub use array; -pub use ml_kem_512::MlKem512Params; -pub use ml_kem_768::MlKem768Params; -pub use ml_kem_1024::MlKem1024Params; +pub use ml_kem_512::MlKem512; +pub use ml_kem_768::MlKem768; +pub use ml_kem_1024::MlKem1024; pub use module_lattice::encoding::ArraySize; pub use param::{ExpandedDecapsulationKey, ParameterSet}; pub use traits::*; @@ -96,14 +96,15 @@ pub type Seed = Array; /// cipher with a 128-bit key. pub mod ml_kem_512 { use super::{Debug, ParameterSet, U2, U3, U4, U10, kem}; - use crate::param; + use crate::param::{self, EncodedUSize, EncodedVSize}; + use array::{sizes::U32, typenum::Sum}; /// `MlKem512` is the parameter set for security category 1, corresponding to key search on a /// block cipher with a 128-bit key. - #[derive(Default, Clone, Debug, PartialEq)] - pub struct MlKem512Params; + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] + pub struct MlKem512; - impl ParameterSet for MlKem512Params { + impl ParameterSet for MlKem512 { type K = U2; type Eta1 = U3; type Eta2 = U2; @@ -111,35 +112,44 @@ pub mod ml_kem_512 { type Dv = U4; } + impl kem::Kem for MlKem512 { + type DecapsulationKey = DecapsulationKey; + type EncapsulationKey = EncapsulationKey; + type CiphertextSize = Sum, EncodedVSize>; + type SharedKeySize = U32; + } + /// An ML-KEM-512 `DecapsulationKey` which provides the ability to generate a new key pair, and /// decapsulate an encapsulated shared key. - pub type DecapsulationKey = kem::DecapsulationKey; + pub type DecapsulationKey = kem::DecapsulationKey; /// An ML-KEM-512 `EncapsulationKey` provides the ability to encapsulate a shared key so that it /// can only be decapsulated by the holder of the corresponding decapsulation key. - pub type EncapsulationKey = kem::EncapsulationKey; + pub type EncapsulationKey = kem::EncapsulationKey; /// Encoded ML-KEM-512 ciphertexts. - pub type EncodedCiphertext = param::EncodedCiphertext; + pub type Ciphertext = kem::Ciphertext; /// Legacy expanded decapsulation keys. Prefer seeds instead. #[doc(hidden)] #[deprecated(since = "0.3.0", note = "use `Seed` instead")] - pub type ExpandedDecapsulationKey = param::ExpandedDecapsulationKey; + pub type ExpandedDecapsulationKey = param::ExpandedDecapsulationKey; } /// ML-KEM-768 is the parameter set for security category 3, corresponding to key search on a block /// cipher with a 192-bit key. pub mod ml_kem_768 { use super::{Debug, ParameterSet, U2, U3, U4, U10, kem}; - use crate::param; + use crate::param::{self, EncodedUSize, EncodedVSize}; + use array::sizes::U32; + use array::typenum::Sum; /// `MlKem768` is the parameter set for security category 3, corresponding to key search on a /// block cipher with a 192-bit key. - #[derive(Default, Clone, Debug, PartialEq)] - pub struct MlKem768Params; + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] + pub struct MlKem768; - impl ParameterSet for MlKem768Params { + impl ParameterSet for MlKem768 { type K = U3; type Eta1 = U2; type Eta2 = U2; @@ -147,34 +157,43 @@ pub mod ml_kem_768 { type Dv = U4; } + impl kem::Kem for MlKem768 { + type DecapsulationKey = DecapsulationKey; + type EncapsulationKey = EncapsulationKey; + type CiphertextSize = Sum, EncodedVSize>; + type SharedKeySize = U32; + } + /// An ML-KEM-768 `DecapsulationKey` which provides the ability to generate a new key pair, and /// decapsulate an encapsulated shared key. - pub type DecapsulationKey = kem::DecapsulationKey; + pub type DecapsulationKey = kem::DecapsulationKey; /// An ML-KEM-768 `EncapsulationKey` provides the ability to encapsulate a shared key so that it /// can only be decapsulated by the holder of the corresponding decapsulation key. - pub type EncapsulationKey = kem::EncapsulationKey; + pub type EncapsulationKey = kem::EncapsulationKey; /// Encoded ML-KEM-512 ciphertexts. - pub type EncodedCiphertext = param::EncodedCiphertext; + pub type Ciphertext = kem::Ciphertext; /// Legacy expanded decapsulation keys. Prefer seeds instead. #[doc(hidden)] #[deprecated(since = "0.3.0", note = "use `Seed` instead")] - pub type ExpandedDecapsulationKey = param::ExpandedDecapsulationKey; + pub type ExpandedDecapsulationKey = param::ExpandedDecapsulationKey; } /// ML-KEM-1024 is the parameter set for security category 5, corresponding to key search on a block /// cipher with a 256-bit key. pub mod ml_kem_1024 { use super::{Debug, ParameterSet, U2, U4, U5, U11, kem, param}; + use crate::param::{EncodedUSize, EncodedVSize}; + use array::{sizes::U32, typenum::Sum}; /// `MlKem1024` is the parameter set for security category 5, corresponding to key search on a /// block cipher with a 256-bit key. - #[derive(Default, Clone, Debug, PartialEq)] - pub struct MlKem1024Params; + #[derive(Clone, Copy, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] + pub struct MlKem1024; - impl ParameterSet for MlKem1024Params { + impl ParameterSet for MlKem1024 { type K = U4; type Eta1 = U2; type Eta2 = U2; @@ -182,21 +201,28 @@ pub mod ml_kem_1024 { type Dv = U5; } + impl kem::Kem for MlKem1024 { + type DecapsulationKey = DecapsulationKey; + type EncapsulationKey = EncapsulationKey; + type CiphertextSize = Sum, EncodedVSize>; + type SharedKeySize = U32; + } + /// An ML-KEM-1024 `DecapsulationKey` which provides the ability to generate a new key pair, and /// decapsulate an encapsulated shared key. - pub type DecapsulationKey = kem::DecapsulationKey; + pub type DecapsulationKey = kem::DecapsulationKey; /// An ML-KEM-1024 `EncapsulationKey` provides the ability to encapsulate a shared key so that /// it can only be decapsulated by the holder of the corresponding decapsulation key. - pub type EncapsulationKey = kem::EncapsulationKey; + pub type EncapsulationKey = kem::EncapsulationKey; /// Encoded ML-KEM-512 ciphertexts. - pub type EncodedCiphertext = param::EncodedCiphertext; + pub type Ciphertext = kem::Ciphertext; /// Legacy expanded decapsulation keys. Prefer seeds instead. #[doc(hidden)] #[deprecated(since = "0.3.0", note = "use `Seed` instead")] - pub type ExpandedDecapsulationKey = param::ExpandedDecapsulationKey; + pub type ExpandedDecapsulationKey = param::ExpandedDecapsulationKey; } /// An ML-KEM-512 `DecapsulationKey` which provides the ability to generate a new key pair, and @@ -223,45 +249,30 @@ pub type DecapsulationKey1024 = ml_kem_1024::DecapsulationKey; /// can only be decapsulated by the holder of the corresponding decapsulation key. pub type EncapsulationKey1024 = ml_kem_1024::EncapsulationKey; -/// A shared key produced by the KEM `K` -pub type SharedKey = Array::SharedKeySize>; - -/// A ciphertext produced by the KEM `K` -pub type Ciphertext = Array::CiphertextSize>; - -/// ML-KEM with the parameter set for security category 1, corresponding to key search on a block -/// cipher with a 128-bit key. -pub type MlKem512 = kem::Kem; - -/// ML-KEM with the parameter set for security category 3, corresponding to key search on a block -/// cipher with a 192-bit key. -pub type MlKem768 = kem::Kem; - -/// ML-KEM with the parameter set for security category 5, corresponding to key search on a block -/// cipher with a 256-bit key. -pub type MlKem1024 = kem::Kem; +/// Shared key established by using ML-KEM, returned from both encapsulation and decapsulation. +pub type SharedKey = Array; #[cfg(test)] #[cfg(feature = "getrandom")] mod test { use super::*; - use ::kem::{Decapsulate, Encapsulate, Generate}; + use ::kem::{Encapsulate, Generate, TryDecapsulate}; fn round_trip_test() where - K: Decapsulate + Generate, + K: Kem, { - let dk = K::generate(); - let ek = dk.encapsulator(); + let dk = K::DecapsulationKey::generate(); + let ek = dk.as_ref().clone(); let (ct, k_send) = ek.encapsulate(); - let k_recv = dk.decapsulate(&ct); + let k_recv = dk.try_decapsulate(&ct).unwrap(); assert_eq!(k_send, k_recv); } #[test] fn round_trip() { - round_trip_test::(); - round_trip_test::(); - round_trip_test::(); + round_trip_test::(); + round_trip_test::(); + round_trip_test::(); } } diff --git a/ml-kem/src/param.rs b/ml-kem/src/param.rs index 3b4bc4a..7dad2f4 100644 --- a/ml-kem/src/param.rs +++ b/ml-kem/src/param.rs @@ -16,7 +16,7 @@ pub(crate) use module_lattice::encoding::{ }; use crate::{ - B32, + B32, Ciphertext, Kem, algebra::{BaseField, Elem, NttVector}, }; use array::{ @@ -108,23 +108,23 @@ pub trait ParameterSet: Default + Clone + Debug + PartialEq { type Dv: EncodingSize; } -type EncodedUSize

= EncodedVectorSize<

::Du,

::K>; -type EncodedVSize

= EncodedPolynomialSize<

::Dv>; +pub(crate) type EncodedUSize

= + EncodedVectorSize<

::Du,

::K>; +pub(crate) type EncodedVSize

= EncodedPolynomialSize<

::Dv>; type EncodedU

= Array>; type EncodedV

= Array>; /// Derived parameter relevant to K-PKE -pub trait PkeParams: ParameterSet { +pub trait PkeParams: Kem + ParameterSet { type NttVectorSize: ArraySize; type EncryptionKeySize: ArraySize; - type CiphertextSize: ArraySize; fn encode_u12(p: &NttVector) -> EncodedNttVector; fn decode_u12(v: &EncodedNttVector) -> NttVector; - fn concat_ct(u: EncodedU, v: EncodedV) -> EncodedCiphertext; - fn split_ct(ct: &EncodedCiphertext) -> (&EncodedU, &EncodedV); + fn concat_ct(u: EncodedU, v: EncodedV) -> Ciphertext; + fn split_ct(ct: &Ciphertext) -> (&EncodedU, &EncodedV); fn concat_ek(t_hat: EncodedNttVector, rho: B32) -> EncodedEncryptionKey; fn split_ek(ek: &EncodedEncryptionKey) -> (&EncodedNttVector, &B32); @@ -133,11 +133,11 @@ pub trait PkeParams: ParameterSet { pub type EncodedNttVector

= Array::NttVectorSize>; pub type EncodedDecryptionKey

= Array::NttVectorSize>; pub type EncodedEncryptionKey

= Array::EncryptionKeySize>; -pub type EncodedCiphertext

= Array::CiphertextSize>; impl

PkeParams for P where - P: ParameterSet, + P: Kem, EncodedVSize

>, SharedKeySize = U32> + + ParameterSet, U384: Mul, Prod: ArraySize + Add + Div + Rem, EncodedUSize

: Add>, @@ -149,7 +149,6 @@ where { type NttVectorSize = EncodedVectorSize; type EncryptionKeySize = Sum; - type CiphertextSize = Sum, EncodedVSize

>; fn encode_u12(p: &NttVector) -> EncodedNttVector { Encode::::encode(p) @@ -159,11 +158,11 @@ where Encode::::decode(v) } - fn concat_ct(u: EncodedU, v: EncodedV) -> EncodedCiphertext { + fn concat_ct(u: EncodedU, v: EncodedV) -> Ciphertext { u.concat(v) } - fn split_ct(ct: &EncodedCiphertext) -> (&EncodedU, &EncodedV) { + fn split_ct(ct: &Ciphertext) -> (&EncodedU, &EncodedV) { ct.split_ref() } diff --git a/ml-kem/src/pkcs8.rs b/ml-kem/src/pkcs8.rs index 8f2ee26..2b4a787 100644 --- a/ml-kem/src/pkcs8.rs +++ b/ml-kem/src/pkcs8.rs @@ -16,7 +16,7 @@ pub use const_oid::AssociatedOid; pub use ::pkcs8::{EncodePrivateKey, EncodePublicKey}; use crate::{ - MlKem512Params, MlKem768Params, MlKem1024Params, + MlKem512, MlKem768, MlKem1024, kem::{DecapsulationKey, EncapsulationKey}, param::{EncapsulationKeySize, KemParams}, pke::EncryptionKey, @@ -42,19 +42,19 @@ const SEED_TAG_NUMBER: TagNumber = TagNumber(0); /// ML-KEM seed serialized as ASN.1. type SeedString<'a> = ContextSpecific<&'a OctetStringRef>; -impl AssociatedOid for MlKem512Params { +impl AssociatedOid for MlKem512 { const OID: ::pkcs8::ObjectIdentifier = const_oid::db::fips203::ID_ALG_ML_KEM_512; } -impl AssociatedOid for MlKem768Params { +impl AssociatedOid for MlKem768 { const OID: ::pkcs8::ObjectIdentifier = const_oid::db::fips203::ID_ALG_ML_KEM_768; } -impl AssociatedOid for MlKem1024Params { +impl AssociatedOid for MlKem1024 { const OID: ::pkcs8::ObjectIdentifier = const_oid::db::fips203::ID_ALG_ML_KEM_1024; } -impl AssociatedAlgorithmIdentifier for MlKem512Params { +impl AssociatedAlgorithmIdentifier for MlKem512 { type Params = ::pkcs8::der::AnyRef<'static>; const ALGORITHM_IDENTIFIER: spki::AlgorithmIdentifier = @@ -64,7 +64,7 @@ impl AssociatedAlgorithmIdentifier for MlKem512Params { }; } -impl AssociatedAlgorithmIdentifier for MlKem768Params { +impl AssociatedAlgorithmIdentifier for MlKem768 { type Params = ::pkcs8::der::AnyRef<'static>; const ALGORITHM_IDENTIFIER: spki::AlgorithmIdentifier = @@ -74,7 +74,7 @@ impl AssociatedAlgorithmIdentifier for MlKem768Params { }; } -impl AssociatedAlgorithmIdentifier for MlKem1024Params { +impl AssociatedAlgorithmIdentifier for MlKem1024 { type Params = ::pkcs8::der::AnyRef<'static>; const ALGORITHM_IDENTIFIER: spki::AlgorithmIdentifier = @@ -139,7 +139,7 @@ where None => return Err(spki::Error::KeyMalformed), }; - Ok(Self::new(enc_key)) + Ok(Self::from_encryption_key(enc_key)) } } diff --git a/ml-kem/src/pke.rs b/ml-kem/src/pke.rs index 57e769b..d67ff3e 100644 --- a/ml-kem/src/pke.rs +++ b/ml-kem/src/pke.rs @@ -5,9 +5,9 @@ use crate::algebra::{ }; use crate::compress::Compress; use crate::crypto::{G, PRF}; -use crate::param::{EncodedCiphertext, EncodedDecryptionKey, EncodedEncryptionKey, PkeParams}; +use crate::param::{EncodedDecryptionKey, EncodedEncryptionKey, PkeParams}; use array::typenum::{U1, Unsigned}; -use kem::InvalidKey; +use kem::{Ciphertext, InvalidKey}; use module_lattice::encoding::Encode; use subtle::{Choice, ConstantTimeEq}; @@ -85,7 +85,7 @@ where /// Decrypt ciphertext to obtain the encrypted value, according to the K-PKE.Decrypt procedure. // Algorithm 14. kK-PKE.Decrypt(dk_PKE, c) - pub fn decrypt(&self, ciphertext: &EncodedCiphertext

) -> B32 { + pub fn decrypt(&self, ciphertext: &Ciphertext

) -> B32 { let (c1, c2) = P::split_ct(ciphertext); let mut u: Vector = Encode::::decode(c1); @@ -129,7 +129,7 @@ where { /// Encrypt the specified message for the holder of the corresponding decryption key, using the /// provided randomness, according the `K-PKE.Encrypt` procedure. - pub fn encrypt(&self, message: &B32, randomness: &B32) -> EncodedCiphertext

{ + pub fn encrypt(&self, message: &B32, randomness: &B32) -> Ciphertext

{ let r = sample_poly_vec_cbd::(randomness, 0); let e1 = sample_poly_vec_cbd::(randomness, P::K::U8); @@ -208,7 +208,7 @@ where #[cfg(test)] mod test { use super::*; - use crate::{MlKem512Params, MlKem768Params, MlKem1024Params}; + use crate::{MlKem512, MlKem768, MlKem1024}; use ::kem::Generate; use getrandom::{SysRng, rand_core::UnwrapErr}; @@ -229,9 +229,9 @@ mod test { #[test] fn round_trip() { - round_trip_test::(); - round_trip_test::(); - round_trip_test::(); + round_trip_test::(); + round_trip_test::(); + round_trip_test::(); } fn codec_test

() @@ -253,9 +253,9 @@ mod test { #[test] fn codec() { - codec_test::(); - codec_test::(); - codec_test::(); + codec_test::(); + codec_test::(); + codec_test::(); } #[test] @@ -263,7 +263,7 @@ mod test { // Create an invalid key: all bytes set to 0xFF // When decoded as 12-bit coefficients, this produces values of 0xFFF = 4095 > 3329 let invalid_key = [0xFF; 1184]; - assert!(EncryptionKey::::from_bytes(&invalid_key.into()).is_err()); + assert!(EncryptionKey::::from_bytes(&invalid_key.into()).is_err()); } fn key_inequality_test

() @@ -283,8 +283,8 @@ mod test { #[test] fn key_inequality() { - key_inequality_test::(); - key_inequality_test::(); - key_inequality_test::(); + key_inequality_test::(); + key_inequality_test::(); + key_inequality_test::(); } } diff --git a/ml-kem/src/traits.rs b/ml-kem/src/traits.rs index ed77199..5be5bb7 100644 --- a/ml-kem/src/traits.rs +++ b/ml-kem/src/traits.rs @@ -1,10 +1,8 @@ //! Trait definitions use crate::{ArraySize, Seed}; -use array::Array; -use core::fmt::Debug; -use kem::{Decapsulate, Encapsulate, InvalidKey}; -use rand_core::CryptoRng; +use array::{Array, sizes::U64}; +use kem::{InvalidKey, Kem, KeyInit, KeySizeUser}; /// An object that knows what size it is pub trait EncodedSizeUser: Sized { @@ -24,26 +22,20 @@ pub trait EncodedSizeUser: Sized { /// A byte array encoding a value the indicated size pub type Encoded = Array::EncodedSize>; -/// A generic interface to a Key Encapsulation Method -pub trait KemCore: Clone { - /// The size of a shared key generated by this KEM - type SharedKeySize: ArraySize; - - /// The size of a ciphertext encapsulating a shared key - type CiphertextSize: ArraySize; - - /// A decapsulation key for this KEM - type DecapsulationKey: Decapsulate + EncodedSizeUser + Debug + PartialEq; - - /// An encapsulation key for this KEM - type EncapsulationKey: Encapsulate + EncodedSizeUser + Clone + Debug + Eq + PartialEq; - - /// Generate a new (decapsulation, encapsulation) key pair. - fn generate( - rng: &mut R, - ) -> (Self::DecapsulationKey, Self::EncapsulationKey); +/// Initialize a KEM from a seed. +pub trait FromSeed: Kem { + /// Using the provided [`Seed`] value, create a KEM keypair. + fn from_seed(seed: &Seed) -> (Self::DecapsulationKey, Self::EncapsulationKey); +} - /// Generate a new (decapsulation, encapsulation) key pair deterministically from the given - /// uniformly random seed value. - fn from_seed(seed: Seed) -> (Self::DecapsulationKey, Self::EncapsulationKey); +impl FromSeed for K +where + K: Kem, + K::DecapsulationKey: KeyInit + KeySizeUser, +{ + fn from_seed(seed: &Seed) -> (K::DecapsulationKey, K::EncapsulationKey) { + let dk = K::DecapsulationKey::new(seed); + let ek = dk.as_ref().clone(); + (dk, ek) + } } diff --git a/ml-kem/tests/encap-decap.rs b/ml-kem/tests/encap-decap.rs index 153dc69..0ec9dcb 100644 --- a/ml-kem/tests/encap-decap.rs +++ b/ml-kem/tests/encap-decap.rs @@ -62,8 +62,8 @@ fn verify_encap_group(tg: &acvp::EncapTestGroup) { fn verify_encap(tc: &acvp::EncapTestCase) where - K: KemCore, - K::EncapsulationKey: EncapsulateDeterministic, + K: Kem, + K::EncapsulationKey: EncapsulateDeterministic + EncodedSizeUser, { let m = Array::try_from(tc.m.as_slice()).unwrap(); let ek_bytes = Encoded::::try_from(tc.ek.as_slice()).unwrap(); @@ -78,16 +78,20 @@ where fn verify_decap_group(tg: &acvp::DecapTestGroup) { for tc in tg.tests.iter() { match tg.parameter_set { - acvp::ParameterSet::MlKem512 => verify_decap::(tc, &tg.dk), - acvp::ParameterSet::MlKem768 => verify_decap::(tc, &tg.dk), - acvp::ParameterSet::MlKem1024 => verify_decap::(tc, &tg.dk), + acvp::ParameterSet::MlKem512 => verify_decap::(tc, &tg.dk), + acvp::ParameterSet::MlKem768 => verify_decap::(tc, &tg.dk), + acvp::ParameterSet::MlKem1024 => verify_decap::(tc, &tg.dk), } } } -fn verify_decap(tc: &acvp::DecapTestCase, dk_slice: &[u8]) { - let dk_bytes = Encoded::::try_from(dk_slice).unwrap(); - let dk = K::from_encoded_bytes(&dk_bytes).unwrap(); +fn verify_decap(tc: &acvp::DecapTestCase, dk_slice: &[u8]) +where + K: Kem, + K::DecapsulationKey: Decapsulate + EncodedSizeUser, +{ + let dk_bytes = Encoded::::try_from(dk_slice).unwrap(); + let dk = K::DecapsulationKey::from_encoded_bytes(&dk_bytes).unwrap(); let c = ::kem::Ciphertext::::try_from(tc.c.as_slice()).unwrap(); let k = dk.decapsulate(&c); diff --git a/ml-kem/tests/key-gen.rs b/ml-kem/tests/key-gen.rs index 9e28c14..2ffbc91 100644 --- a/ml-kem/tests/key-gen.rs +++ b/ml-kem/tests/key-gen.rs @@ -1,6 +1,6 @@ -use ml_kem::*; - use array::ArrayN; +use core::fmt::Debug; +use ml_kem::*; use std::{fs::read_to_string, path::PathBuf}; #[test] @@ -25,14 +25,19 @@ fn acvp_key_gen() { } } -fn verify(tc: &acvp::TestCase) { +fn verify(tc: &acvp::TestCase) +where + K: Kem + FromSeed, + K::DecapsulationKey: EncodedSizeUser + Debug + PartialEq, + K::EncapsulationKey: EncodedSizeUser, +{ // Import test data into the relevant array structures let d = ArrayN::::try_from(tc.d.as_slice()).unwrap(); let z = ArrayN::::try_from(tc.z.as_slice()).unwrap(); let dk_bytes = Encoded::::try_from(tc.dk.as_slice()).unwrap(); let ek_bytes = Encoded::::try_from(tc.ek.as_slice()).unwrap(); - let (dk, ek) = K::from_seed(d.concat(z)); + let (dk, ek) = K::from_seed(&d.concat(z)); // Verify correctness via serialization assert_eq!(dk.to_encoded_bytes().as_slice(), tc.dk.as_slice()); diff --git a/ml-kem/tests/pkcs8.rs b/ml-kem/tests/pkcs8.rs index 3c3d3ba..8095b18 100644 --- a/ml-kem/tests/pkcs8.rs +++ b/ml-kem/tests/pkcs8.rs @@ -2,8 +2,9 @@ #![cfg(all(feature = "pkcs8", feature = "alloc"))] +use core::fmt::Debug; use getrandom::SysRng; -use ml_kem::{EncodedSizeUser, KemCore, MlKem512, MlKem768, MlKem1024, Seed}; +use ml_kem::{Kem, MlKem512, MlKem768, MlKem1024, Seed, kem::KeyExport}; use pkcs8::{ DecodePrivateKey, DecodePublicKey, EncodePrivateKey, EncodePublicKey, PrivateKeyInfoRef, SubjectPublicKeyInfoRef, @@ -19,12 +20,12 @@ type SeedString<'a> = ContextSpecific<&'a OctetStringRef>; fn der_serialization_and_deserialization(expected_encaps_len: u32) where - K: KemCore, + K: Kem, K::EncapsulationKey: EncodePublicKey + DecodePublicKey, - K::DecapsulationKey: EncodePrivateKey + DecodePrivateKey + From + PartialEq, + K::DecapsulationKey: EncodePrivateKey + DecodePrivateKey + Debug + From + PartialEq, { let mut rng = UnwrapErr(SysRng); - let (decaps_key, encaps_key) = K::generate(&mut rng); + let (decaps_key, encaps_key) = K::generate_keypair_from_rng(&mut rng); // TEST: (de)serialize encapsulation key into DER document { @@ -38,7 +39,7 @@ where // verify that original encapsulation key corresponds to deserialized encapsulation key let pub_key = parsed.decode_msg::().unwrap(); assert_eq!( - encaps_key.to_encoded_bytes().as_slice(), + encaps_key.to_bytes().as_slice(), pub_key.subject_public_key.as_bytes().unwrap() ); } @@ -102,7 +103,7 @@ fn pkcs8_serialize_and_deserialize_round_trip() { #[cfg(feature = "pem")] fn compare_with_reference_keys(variant: usize, ref_pub_key_pem: &str, ref_priv_key_pem: &str) where - K: KemCore, + K: Kem, K::EncapsulationKey: EncodePublicKey, K::DecapsulationKey: EncodePrivateKey, { @@ -143,7 +144,7 @@ where let seed: [u8; SEED_LEN] = core::array::from_fn(|i| u8::try_from(i).unwrap()); let mut rng = SeedBasedRng { seed, index: 0 }; - let (decaps_key, encaps_key) = K::generate(&mut rng); + let (decaps_key, encaps_key) = K::generate_keypair_from_rng(&mut rng); let gen_pub_key_pem = encaps_key .to_public_key_pem(pkcs8::LineEnding::LF) diff --git a/ml-kem/tests/wycheproof.rs b/ml-kem/tests/wycheproof.rs index d390f53..0e48191 100644 --- a/ml-kem/tests/wycheproof.rs +++ b/ml-kem/tests/wycheproof.rs @@ -2,7 +2,7 @@ use array::{Array, ArraySize}; use ml_kem::{ - EncodedSizeUser, KemCore, MlKem512, MlKem768, MlKem1024, + EncodedSizeUser, FromSeed, MlKem512, MlKem768, MlKem1024, kem::{Decapsulate, KeyExport, TryKeyInit}, }; use serde::Deserialize; @@ -106,11 +106,11 @@ macro_rules! mlkem_test { } }; - let (dk, ek) = $kem::from_seed(test_seed); + let (dk, ek) = $kem::from_seed(&test_seed); assert_eq!(test.ek.as_slice(), ek.to_bytes().as_slice()); - use ml_kem::$kem_module::EncodedCiphertext; - let test_c: EncodedCiphertext = match decode_optional_hex(&test.c, "c") { + use ml_kem::$kem_module::Ciphertext; + let test_c: Ciphertext = match decode_optional_hex(&test.c, "c") { Some(dk) => dk, None => { assert_eq!(test.result, ExpectedResult::Invalid); @@ -148,7 +148,7 @@ macro_rules! mlkem_keygen_seed_test { let test_seed = decode_expected_hex(&test.seed, "seed"); let test_dk = decode_expected_hex(&test.dk, "dk"); - let (dk, ek) = $kem::from_seed(test_seed); + let (dk, ek) = $kem::from_seed(&test_seed); assert_eq!(test_dk, dk.to_encoded_bytes()); assert_eq!(test.ek.as_slice(), ek.to_bytes().as_slice()); } @@ -218,7 +218,7 @@ macro_rules! mlkem_decaps_test { #[allow(deprecated)] use ml_kem::$kem_module::{ - DecapsulationKey, EncodedCiphertext, ExpandedDecapsulationKey, + Ciphertext, DecapsulationKey, ExpandedDecapsulationKey, }; #[allow(deprecated)] @@ -242,7 +242,7 @@ macro_rules! mlkem_decaps_test { } let dk = dk_result.unwrap(); - let test_c: EncodedCiphertext = match decode_optional_hex(&test.c, "c") { + let test_c: Ciphertext = match decode_optional_hex(&test.c, "c") { Some(dk) => dk, None => { assert_eq!(test.result, ExpectedResult::Invalid); diff --git a/x-wing/Cargo.toml b/x-wing/Cargo.toml index 0f444f1..26f55e6 100644 --- a/x-wing/Cargo.toml +++ b/x-wing/Cargo.toml @@ -19,7 +19,7 @@ zeroize = ["dep:zeroize", "ml-kem/zeroize", "x25519-dalek/zeroize"] hazmat = [] [dependencies] -kem = "0.3.0-rc.2" +kem = "0.3.0-rc.3" ml-kem = { version = "=0.3.0-pre.5", default-features = false, features = ["hazmat"] } rand_core = { version = "0.10.0-rc-6", default-features = false } sha3 = { version = "0.11.0-rc.4", default-features = false } diff --git a/x-wing/src/lib.rs b/x-wing/src/lib.rs index 7970760..aa8bee1 100644 --- a/x-wing/src/lib.rs +++ b/x-wing/src/lib.rs @@ -17,21 +17,24 @@ #![cfg_attr(feature = "getrandom", doc = "```")] #![cfg_attr(not(feature = "getrandom"), doc = "```ignore")] //! // NOTE: requires the `getrandom` feature is enabled -//! use kem::{Decapsulate, Encapsulate}; +//! use x_wing::{ +//! XWingKem, +//! kem::{Decapsulate, Encapsulate, Kem} +//! }; //! -//! let (sk, pk) = x_wing::generate_key_pair(); -//! let (ct, ss_sender) = pk.encapsulate(); -//! let ss_receiver = sk.decapsulate(&ct); -//! assert_eq!(ss_sender, ss_receiver); +//! let (sk, pk) = XWingKem::generate_keypair(); +//! let (ct, sk_sender) = pk.encapsulate(); +//! let sk_receiver = sk.decapsulate(&ct); +//! assert_eq!(sk_sender, sk_receiver); //! ``` pub use kem::{ - self, Decapsulate, Decapsulator, Encapsulate, Generate, InvalidKey, KemParams, Key, KeyExport, - KeyInit, KeySizeUser, TryKeyInit, + self, Decapsulate, Encapsulate, Generate, InvalidKey, Kem, Key, KeyExport, KeyInit, + KeySizeUser, TryKeyInit, }; use ml_kem::{ - EncodedSizeUser, KemCore, MlKem768, MlKem768Params, + EncodedSizeUser, FromSeed, MlKem768, array::{ Array, ArrayN, AsArrayRef, sizes::{U32, U1120, U1184, U1216}, @@ -47,8 +50,8 @@ use x25519_dalek::{PublicKey, StaticSecret}; #[cfg(feature = "zeroize")] use zeroize::{Zeroize, ZeroizeOnDrop}; -type MlKem768DecapsulationKey = ml_kem::kem::DecapsulationKey; -type MlKem768EncapsulationKey = ml_kem::kem::EncapsulationKey; +type MlKem768DecapsulationKey = ml_kem::kem::DecapsulationKey; +type MlKem768EncapsulationKey = ml_kem::kem::EncapsulationKey; const X_WING_LABEL: &[u8; 6] = br"\.//^\"; @@ -62,9 +65,20 @@ pub const CIPHERTEXT_SIZE: usize = 1120; pub const ENCAPSULATION_RANDOMNESS_SIZE: usize = 64; /// Serialized ciphertext. -pub type Ciphertext = Array; +pub type Ciphertext = kem::Ciphertext; /// Shared secret key. -pub type SharedSecret = Array; +pub type SharedKey = Array; + +/// X-Wing Key Encapsulation Method (X-Wing-KEM). +#[derive(Clone, Copy, Debug, Default, Eq, PartialEq, PartialOrd, Ord)] +pub struct XWingKem; + +impl Kem for XWingKem { + type DecapsulationKey = DecapsulationKey; + type EncapsulationKey = EncapsulationKey; + type CiphertextSize = U1120; + type SharedKeySize = U32; +} // The naming convention of variables matches the RFC. // ss -> Shared Secret @@ -96,7 +110,7 @@ impl EncapsulationKey { pub fn encapsulate_deterministic( &self, randomness: &ArrayN, - ) -> (Ciphertext, SharedSecret) { + ) -> (Ciphertext, SharedKey) { // Split randomness into two 32-byte arrays let (rand_m, rand_x) = randomness.split::(); @@ -116,8 +130,8 @@ impl EncapsulationKey { } } -impl Encapsulate for EncapsulationKey { - fn encapsulate_with_rng(&self, rng: &mut R) -> (Ciphertext, SharedSecret) +impl Encapsulate for EncapsulationKey { + fn encapsulate_with_rng(&self, rng: &mut R) -> (Ciphertext, SharedKey) where R: CryptoRng + ?Sized, { @@ -132,11 +146,6 @@ impl Encapsulate for EncapsulationKey { } } -impl KemParams for EncapsulationKey { - type CiphertextSize = U1120; - type SharedSecretSize = U32; -} - impl KeySizeUser for EncapsulationKey { type KeySize = U1216; } @@ -185,9 +194,15 @@ impl DecapsulationKey { } } -impl Decapsulate for DecapsulationKey { +impl AsRef for DecapsulationKey { + fn as_ref(&self) -> &EncapsulationKey { + &self.ek + } +} + +impl Decapsulate for DecapsulationKey { #[allow(clippy::similar_names)] // So we can use the names as in the RFC - fn decapsulate(&self, ct: &Ciphertext) -> SharedSecret { + fn decapsulate(&self, ct: &Ciphertext) -> SharedKey { let ct = CiphertextMessage::from(ct); let (sk_m, sk_x, _pk_m, pk_x) = expand_key(&self.sk); @@ -200,14 +215,6 @@ impl Decapsulate for DecapsulationKey { } } -impl Decapsulator for DecapsulationKey { - type Encapsulator = EncapsulationKey; - - fn encapsulator(&self) -> &EncapsulationKey { - &self.ek - } -} - impl Drop for DecapsulationKey { fn drop(&mut self) { #[cfg(feature = "zeroize")] @@ -259,7 +266,7 @@ fn expand_key( let mut expanded: Shake256Reader = shaker.finalize_xof(); let seed = read_from(&mut expanded).into(); - let (sk_m, pk_m) = MlKem768::from_seed(seed); + let (sk_m, pk_m) = MlKem768::from_seed(&seed); let sk_x = read_from(&mut expanded); let sk_x = StaticSecret::from(sk_x); @@ -315,30 +322,12 @@ impl From for Ciphertext { } } -/// Generate a X-Wing key pair using `OsRng`. -#[cfg(feature = "getrandom")] -#[must_use] -pub fn generate_key_pair() -> (DecapsulationKey, EncapsulationKey) { - let sk = DecapsulationKey::generate(); - let pk = sk.encapsulator().clone(); - (sk, pk) -} - -/// Generate a X-Wing key pair using the provided rng. -pub fn generate_key_pair_from_rng( - rng: &mut R, -) -> (DecapsulationKey, EncapsulationKey) { - let sk = DecapsulationKey::generate_from_rng(rng); - let pk = sk.encapsulator().clone(); - (sk, pk) -} - fn combiner( ss_m: &ArrayN, ss_x: &x25519_dalek::SharedSecret, ct_x: &PublicKey, pk_x: &PublicKey, -) -> SharedSecret { +) -> SharedKey { use sha3::Digest; let mut hasher = Sha3_256::new(); @@ -358,6 +347,7 @@ fn read_from(reader: &mut Shake256Reader) -> [u8; N] { #[cfg(test)] mod tests { + use crate::{Kem, XWingKem}; use core::convert::Infallible; use getrandom::SysRng; use ml_kem::array::Array; @@ -430,7 +420,7 @@ mod tests { fn run_test(test_vector: TestVector) { let mut seed = SeedRng::new(test_vector.seed); - let (sk, pk) = generate_key_pair_from_rng(&mut seed); + let (sk, pk) = XWingKem::generate_keypair_from_rng(&mut seed); assert_eq!(sk.as_bytes(), &test_vector.sk); assert_eq!(&*pk.to_bytes(), test_vector.pk.as_slice()); @@ -461,9 +451,9 @@ mod tests { } #[test] + #[cfg(feature = "getrandom")] fn key_serialize() { - let sk = DecapsulationKey::generate_from_rng(&mut UnwrapErr(SysRng)); - let pk = sk.encapsulator().clone(); + let (sk, pk) = XWingKem::generate_keypair(); let sk_bytes = sk.as_bytes(); let pk_bytes = pk.to_bytes();