-- |
-- Module      : Crypto.PubKey.ML_KEM
-- License     : BSD-3-Clause
-- Maintainer  : Olivier Chéron <olivier.cheron@gmail.com>
-- Stability   : provisional
-- Portability : unknown
--
-- Module-Lattice-based Key-Encapsulation Mechanism (ML-KEM), defined
-- in <https://csrc.nist.gov/pubs/fips/203/final FIPS 203>.
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE TypeFamilies #-}
module Crypto.PubKey.ML_KEM
    ( EncapsulationKey, DecapsulationKey, Ciphertext, SharedSecret
    -- * Operations
    , generate, generateOpen, generateWith, encapsulate, encapsulateWith
    , decapsulate
    -- * Parameter sets
    , ParamSet, ML_KEM_512, ML_KEM_768, ML_KEM_1024
    -- * Conversions and checks
    , Decode(..), Encode(..)
    , toPublic, checkKeyPair
    ) where

import Crypto.Random

import Data.ByteArray (ByteArray, ByteArrayAccess, ScrubbedBytes)
import qualified Data.ByteArray as B

import Internal

-- | ML-KEM-512 (security category 1)
data ML_KEM_512  = ML_KEM_512  deriving Int -> ML_KEM_512 -> ShowS
[ML_KEM_512] -> ShowS
ML_KEM_512 -> String
(Int -> ML_KEM_512 -> ShowS)
-> (ML_KEM_512 -> String)
-> ([ML_KEM_512] -> ShowS)
-> Show ML_KEM_512
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ML_KEM_512 -> ShowS
showsPrec :: Int -> ML_KEM_512 -> ShowS
$cshow :: ML_KEM_512 -> String
show :: ML_KEM_512 -> String
$cshowList :: [ML_KEM_512] -> ShowS
showList :: [ML_KEM_512] -> ShowS
Show
-- | ML-KEM-768 (security category 3)
data ML_KEM_768  = ML_KEM_768  deriving Int -> ML_KEM_768 -> ShowS
[ML_KEM_768] -> ShowS
ML_KEM_768 -> String
(Int -> ML_KEM_768 -> ShowS)
-> (ML_KEM_768 -> String)
-> ([ML_KEM_768] -> ShowS)
-> Show ML_KEM_768
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ML_KEM_768 -> ShowS
showsPrec :: Int -> ML_KEM_768 -> ShowS
$cshow :: ML_KEM_768 -> String
show :: ML_KEM_768 -> String
$cshowList :: [ML_KEM_768] -> ShowS
showList :: [ML_KEM_768] -> ShowS
Show
-- | ML-KEM-1024 (security category 5)
data ML_KEM_1024 = ML_KEM_1024 deriving Int -> ML_KEM_1024 -> ShowS
[ML_KEM_1024] -> ShowS
ML_KEM_1024 -> String
(Int -> ML_KEM_1024 -> ShowS)
-> (ML_KEM_1024 -> String)
-> ([ML_KEM_1024] -> ShowS)
-> Show ML_KEM_1024
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> ML_KEM_1024 -> ShowS
showsPrec :: Int -> ML_KEM_1024 -> ShowS
$cshow :: ML_KEM_1024 -> String
show :: ML_KEM_1024 -> String
$cshowList :: [ML_KEM_1024] -> ShowS
showList :: [ML_KEM_1024] -> ShowS
Show

instance ParamSet ML_KEM_512 where
    type K ML_KEM_512 = 2
    getParams :: forall (proxy :: * -> *). proxy ML_KEM_512 -> Params (K ML_KEM_512)
getParams proxy ML_KEM_512
_ = Word -> Word -> Int -> Int -> Params 2
forall (k :: Nat). Word -> Word -> Int -> Int -> Params k
Params Word
3 Word
2 Int
10 Int
4
instance ParamSet ML_KEM_768 where
    type K ML_KEM_768 = 3
    getParams :: forall (proxy :: * -> *). proxy ML_KEM_768 -> Params (K ML_KEM_768)
getParams proxy ML_KEM_768
_ = Word -> Word -> Int -> Int -> Params 3
forall (k :: Nat). Word -> Word -> Int -> Int -> Params k
Params Word
2 Word
2 Int
10 Int
4
instance ParamSet ML_KEM_1024 where
    type K ML_KEM_1024 = 4
    getParams :: forall (proxy :: * -> *).
proxy ML_KEM_1024 -> Params (K ML_KEM_1024)
getParams proxy ML_KEM_1024
_ = Word -> Word -> Int -> Int -> Params 4
forall (k :: Nat). Word -> Word -> Int -> Int -> Params k
Params Word
2 Word
2 Int
11 Int
5

-- | Generate an ML-KEM key pair from a random seed.
generate :: (ParamSet a, MonadRandom m)
         => proxy a -> m (EncapsulationKey a, DecapsulationKey a)
generate :: forall a (m :: * -> *) (proxy :: * -> *).
(ParamSet a, MonadRandom m) =>
proxy a -> m (EncapsulationKey a, DecapsulationKey a)
generate proxy a
p = do
    seed <- Int -> m ScrubbedBytes
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
64
    let d = ScrubbedBytes -> Int -> View ScrubbedBytes
forall bytes. ByteArrayAccess bytes => bytes -> Int -> View bytes
B.takeView ScrubbedBytes
seed Int
32
        z = Int -> ScrubbedBytes -> ScrubbedBytes
forall bs. ByteArray bs => Int -> bs -> bs
B.drop Int
32 ScrubbedBytes
seed
    return (Internal.keyGen p d z)

-- | Generate a random seed (d, z) and the expanded key pair, returning
-- everything.  This is Algorithm 19b introduced in Section 7 of
-- <https://www.rfc-editor.org/rfc/rfc9935 RFC 9935>.
--
-- Later use 'generateWith' to re-expand a seed value (d, z) that has been
-- recovered from storage.
generateOpen :: (ParamSet a, ByteArray d, ByteArray z, MonadRandom m)
             => proxy a -> m (EncapsulationKey a, DecapsulationKey a, d, z)
generateOpen :: forall a d z (m :: * -> *) (proxy :: * -> *).
(ParamSet a, ByteArray d, ByteArray z, MonadRandom m) =>
proxy a -> m (EncapsulationKey a, DecapsulationKey a, d, z)
generateOpen proxy a
p = do
    d <- Int -> m d
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
32
    z <- getRandomBytes 32
    let (ek, dk) = Internal.keyGen p d (B.convert z)
    return (ek, dk, d, z)

-- | Generate an ML-KEM key pair from the specified seed (d, z).  Length of
-- inputs must be 32 bytes.
generateWith :: (ParamSet a, ByteArrayAccess d, ByteArrayAccess z)
             => proxy a -> d -> z -> Maybe (EncapsulationKey a, DecapsulationKey a)
generateWith :: forall a d z (proxy :: * -> *).
(ParamSet a, ByteArrayAccess d, ByteArrayAccess z) =>
proxy a -> d -> z -> Maybe (EncapsulationKey a, DecapsulationKey a)
generateWith proxy a
p d
d z
z
    | d -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length d
d Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = Maybe (EncapsulationKey a, DecapsulationKey a)
forall a. Maybe a
Nothing
    | z -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length z
z Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = Maybe (EncapsulationKey a, DecapsulationKey a)
forall a. Maybe a
Nothing
    | Bool
otherwise = (EncapsulationKey a, DecapsulationKey a)
-> Maybe (EncapsulationKey a, DecapsulationKey a)
forall a. a -> Maybe a
Just ((EncapsulationKey a, DecapsulationKey a)
 -> Maybe (EncapsulationKey a, DecapsulationKey a))
-> (EncapsulationKey a, DecapsulationKey a)
-> Maybe (EncapsulationKey a, DecapsulationKey a)
forall a b. (a -> b) -> a -> b
$ proxy a
-> d -> ScrubbedBytes -> (EncapsulationKey a, DecapsulationKey a)
forall a d (proxy :: * -> *).
(ParamSet a, ByteArrayAccess d) =>
proxy a
-> d -> ScrubbedBytes -> (EncapsulationKey a, DecapsulationKey a)
Internal.keyGen proxy a
p d
d (z -> ScrubbedBytes
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert z
z)

-- | Generate a shared secret key and an associated ciphertext using randomness.
encapsulate :: (ParamSet a, MonadRandom m)
            => EncapsulationKey a -> m (SharedSecret a, Ciphertext a)
encapsulate :: forall a (m :: * -> *).
(ParamSet a, MonadRandom m) =>
EncapsulationKey a -> m (SharedSecret a, Ciphertext a)
encapsulate EncapsulationKey a
ek = do
    m <- Int -> m ScrubbedBytes
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
32
    return (Internal.encaps ek (m :: ScrubbedBytes))

-- | Generate a shared secret key and an associated ciphertext using a
-- specified random input.  This byte array must be 32 bytes and not repeated
-- with other encapsulations.  For testing purposes.
encapsulateWith :: (ParamSet a, ByteArrayAccess m)
                => EncapsulationKey a -> m -> Maybe (SharedSecret a, Ciphertext a)
encapsulateWith :: forall a m.
(ParamSet a, ByteArrayAccess m) =>
EncapsulationKey a -> m -> Maybe (SharedSecret a, Ciphertext a)
encapsulateWith EncapsulationKey a
ek m
m
    | m -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length m
m Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= Int
32 = Maybe (SharedSecret a, Ciphertext a)
forall a. Maybe a
Nothing
    | Bool
otherwise = (SharedSecret a, Ciphertext a)
-> Maybe (SharedSecret a, Ciphertext a)
forall a. a -> Maybe a
Just ((SharedSecret a, Ciphertext a)
 -> Maybe (SharedSecret a, Ciphertext a))
-> (SharedSecret a, Ciphertext a)
-> Maybe (SharedSecret a, Ciphertext a)
forall a b. (a -> b) -> a -> b
$ EncapsulationKey a -> m -> (SharedSecret a, Ciphertext a)
forall a m.
(ParamSet a, ByteArrayAccess m) =>
EncapsulationKey a -> m -> (SharedSecret a, Ciphertext a)
Internal.encaps EncapsulationKey a
ek m
m

-- | Return the shared secret for a given ciphertext.  Does implicit rejection
-- in the event the ciphertext or encapsulation key have been tampered with.
decapsulate :: ParamSet a => DecapsulationKey a -> Ciphertext a -> SharedSecret a
decapsulate :: forall a.
ParamSet a =>
DecapsulationKey a -> Ciphertext a -> SharedSecret a
decapsulate = DecapsulationKey a -> Ciphertext a -> SharedSecret a
forall a.
ParamSet a =>
DecapsulationKey a -> Ciphertext a -> SharedSecret a
Internal.decaps

-- | Try to detect corruptions in a pair of keys.  Note that this does not
-- fully guarantee that the key pair was properly generated.  Returns @True@
-- when the key pair is found valid.
checkKeyPair :: (ParamSet a, MonadRandom m)
             => (EncapsulationKey a, DecapsulationKey a) -> m Bool
checkKeyPair :: forall a (m :: * -> *).
(ParamSet a, MonadRandom m) =>
(EncapsulationKey a, DecapsulationKey a) -> m Bool
checkKeyPair (EncapsulationKey a
ek, DecapsulationKey a
dk) = do
    m <- Int -> m ScrubbedBytes
forall byteArray. ByteArray byteArray => Int -> m byteArray
forall (m :: * -> *) byteArray.
(MonadRandom m, ByteArray byteArray) =>
Int -> m byteArray
getRandomBytes Int
32
    let (kk, ct) = Internal.encaps ek (m :: ScrubbedBytes)
        kk' = DecapsulationKey a -> Ciphertext a -> SharedSecret a
forall a.
ParamSet a =>
DecapsulationKey a -> Ciphertext a -> SharedSecret a
Internal.decaps DecapsulationKey a
dk Ciphertext a
ct
    return (kk' == kk)