-- |
-- Module      : Crypto
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- Crypto-related utilities like the ML-KEM hash and PRF functions, or more
-- general concerns like constant-time equality and selection.
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Crypto
    ( ConstEqW(..), BoolW, andW, toBool, constSelectBytes, snoc, append, eq
    , prf, h, j, g, BlockDigest, unBlockDigest, hashToBlock
    ) where

import Crypto.Hash (Context)
import Crypto.Hash.Algorithms
import Crypto.Hash.IO

import Control.Exception (assert)
import Control.Monad
import Control.Monad.ST

import Data.ByteArray (ByteArray, ByteArrayAccess, Bytes, ScrubbedBytes)
import qualified Data.ByteArray as B

import Data.Bits
import Data.Word

import GHC.TypeNats

import Foreign.Marshal.Utils (fillBytes)
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import Foreign.Storable (pokeByteOff)

import Block (Block)
import Builder (Builder)
import Machine
import ScrubbedBlock (ScrubbedBlock)
import Vector (Vector)
import qualified Block
import qualified Builder
import qualified ByteArrayST as ST
import qualified ScrubbedBlock
import qualified Vector

newtype BoolW = BoolW Word

#ifdef ML_KEM_TESTING
instance Show BoolW where
    showsPrec d = showsPrec d . toBool
#endif

toBool :: BoolW -> Bool
toBool :: BoolW -> Bool
toBool (BoolW Word
mask) = Word
mask Word -> Word -> Bool
forall a. Eq a => a -> a -> Bool
/= Word
0

falseW, trueW :: BoolW
falseW :: BoolW
falseW = Word -> BoolW
BoolW Word
0
trueW :: BoolW
trueW = Word -> BoolW
BoolW Word
forall a. Bounded a => a
maxBound

andW :: BoolW -> BoolW -> BoolW
andW :: BoolW -> BoolW -> BoolW
andW (BoolW Word
a) (BoolW Word
b) = Word -> BoolW
BoolW (Word
a Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
b)

bitsW :: Int
bitsW :: Int
bitsW = let BoolW Word
x = BoolW
falseW in Word -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize Word
x

bytesW :: Int
bytesW :: Int
bytesW = Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
bitsW Int
8

eqW :: Word -> Word -> BoolW
eqW :: Word -> Word -> BoolW
eqW Word
a Word
b = Word -> BoolW
isZeroW (Word
a Word -> Word -> Word
forall a. Bits a => a -> a -> a
`xor` Word
b)
  where
    isZeroW :: Word -> BoolW
isZeroW Word
x = Word -> BoolW
BoolW (Word -> BoolW) -> Word -> BoolW
forall a b. (a -> b) -> a -> b
$ Word -> Word
forall {a}. (Num a, Bits a) => a -> a
msbW (Word -> Word
forall a. Bits a => a -> a
complement Word
x Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. (Word
x Word -> Word -> Word
forall a. Num a => a -> a -> a
- Word
1))
    msbW :: a -> a
msbW a
x = a -> a
forall a. Num a => a -> a
negate (a
x a -> Int -> a
forall a. Bits a => a -> Int -> a
`unsafeShiftR` (Int
bitsW Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))

assertMultW :: Int -> a -> a
assertMultW :: forall a. Int -> a -> a
assertMultW Int
n = Bool -> a -> a
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
n Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
mask Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0)
  where mask :: Int
mask = Int
bytesW Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1

class ConstEqW a where
    constEqW :: a -> a -> BoolW

instance ConstEqW a => ConstEqW (Vector n a) where
    constEqW :: Vector n a -> Vector n a -> BoolW
constEqW =
        (BoolW -> a -> a -> BoolW)
-> (a -> a -> BoolW) -> Vector n a -> Vector n a -> BoolW
forall c a b (n :: Nat).
(c -> a -> b -> c)
-> (a -> b -> c) -> Vector n a -> Vector n b -> c
Vector.fold1ZipWith (\BoolW
mask a
x a
y -> BoolW
mask BoolW -> BoolW -> BoolW
`andW` a -> a -> BoolW
forall a. ConstEqW a => a -> a -> BoolW
constEqW a
x a
y) a -> a -> BoolW
forall a. ConstEqW a => a -> a -> BoolW
constEqW

instance ConstEqW (Block Word) where
    constEqW :: Block Word -> Block Word -> BoolW
constEqW Block Word
a Block Word
b
        | Block Word -> CountOf Word
forall ty. PrimType ty => Block ty -> CountOf ty
Block.length Block Word
a CountOf Word -> CountOf Word -> Bool
forall a. Eq a => a -> a -> Bool
/= Block Word -> CountOf Word
forall ty. PrimType ty => Block ty -> CountOf ty
Block.length Block Word
b = BoolW
falseW
        | Bool
otherwise = (BoolW -> Word -> Word -> BoolW)
-> BoolW -> Block Word -> Block Word -> BoolW
forall a b c.
(PrimType a, PrimType b) =>
(c -> a -> b -> c) -> c -> Block a -> Block b -> c
Block.foldZipWith (\BoolW
mask Word
x Word
y -> BoolW
mask BoolW -> BoolW -> BoolW
`andW` Word -> Word -> BoolW
eqW Word
x Word
y) BoolW
trueW Block Word
a Block Word
b

instance ConstEqW (ScrubbedBlock Word) where
    constEqW :: ScrubbedBlock Word -> ScrubbedBlock Word -> BoolW
constEqW ScrubbedBlock Word
a ScrubbedBlock Word
b
        | ScrubbedBlock Word -> CountOf Word
forall ty. PrimType ty => ScrubbedBlock ty -> CountOf ty
ScrubbedBlock.length ScrubbedBlock Word
a CountOf Word -> CountOf Word -> Bool
forall a. Eq a => a -> a -> Bool
/= ScrubbedBlock Word -> CountOf Word
forall ty. PrimType ty => ScrubbedBlock ty -> CountOf ty
ScrubbedBlock.length ScrubbedBlock Word
b = BoolW
falseW
        | Bool
otherwise = (BoolW -> Word -> Word -> BoolW)
-> BoolW -> ScrubbedBlock Word -> ScrubbedBlock Word -> BoolW
forall a b c.
(PrimType a, PrimType b) =>
(c -> a -> b -> c) -> c -> ScrubbedBlock a -> ScrubbedBlock b -> c
ScrubbedBlock.foldZipWith (\BoolW
mask Word
x Word
y -> BoolW
mask BoolW -> BoolW -> BoolW
`andW` Word -> Word -> BoolW
eqW Word
x Word
y) BoolW
trueW ScrubbedBlock Word
a ScrubbedBlock Word
b

instance ConstEqW Bytes where
    constEqW :: Bytes -> Bytes -> BoolW
constEqW = Bytes -> Bytes -> BoolW
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> BoolW
bytesConstEqW

instance ConstEqW ScrubbedBytes where
    constEqW :: ScrubbedBytes -> ScrubbedBytes -> BoolW
constEqW = ScrubbedBytes -> ScrubbedBytes -> BoolW
forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> BoolW
bytesConstEqW

bytesConstEqW :: (ByteArrayAccess bs1, ByteArrayAccess bs2) => bs1 -> bs2 -> BoolW
bytesConstEqW :: forall bs1 bs2.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
bs1 -> bs2 -> BoolW
bytesConstEqW bs1
a bs2
b
    | bs1 -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length bs1
a Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
/= bs2 -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length bs2
b = BoolW
falseW
    | Bool
otherwise = (BoolW -> Word -> Word -> BoolW) -> BoolW -> bs1 -> bs2 -> BoolW
forall bs1 bs2 c.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
(c -> Word -> Word -> c) -> c -> bs1 -> bs2 -> c
foldZipWith (\BoolW
mask Word
x Word
y -> BoolW
mask BoolW -> BoolW -> BoolW
`andW` Word -> Word -> BoolW
eqW Word
x Word
y) BoolW
trueW bs1
a bs2
b

foldZipWith :: (ByteArrayAccess bs1, ByteArrayAccess bs2)
            => (c -> Word -> Word -> c) -> c -> bs1 -> bs2 -> c
foldZipWith :: forall bs1 bs2 c.
(ByteArrayAccess bs1, ByteArrayAccess bs2) =>
(c -> Word -> Word -> c) -> c -> bs1 -> bs2 -> c
foldZipWith c -> Word -> Word -> c
f c
c bs1
a bs2
b = Bool -> c -> c
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
sa Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
sb) (c -> c) -> c -> c
forall a b. (a -> b) -> a -> b
$ Int -> c -> c
forall a. Int -> a -> a
assertMultW Int
sa (c -> c) -> c -> c
forall a b. (a -> b) -> a -> b
$ Int -> c -> c
forall a. Int -> a -> a
assertMultW Int
sb (c -> c) -> c -> c
forall a b. (a -> b) -> a -> b
$
    (forall s. ST s c) -> c
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s c) -> c) -> (forall s. ST s c) -> c
forall a b. (a -> b) -> a -> b
$ bs1 -> (Ptr Word -> ST s c) -> ST s c
forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
ST.withByteArray bs1
a ((Ptr Word -> ST s c) -> ST s c) -> (Ptr Word -> ST s c) -> ST s c
forall a b. (a -> b) -> a -> b
$ \Ptr Word
pa -> bs2 -> (Ptr Word -> ST s c) -> ST s c
forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
ST.withByteArray bs2
b ((Ptr Word -> ST s c) -> ST s c) -> (Ptr Word -> ST s c) -> ST s c
forall a b. (a -> b) -> a -> b
$ \Ptr Word
pb ->
        Ptr Word -> Ptr Word -> c -> Int -> ST s c
forall {s}. Ptr Word -> Ptr Word -> c -> Int -> ST s c
loop (Ptr Word
pa :: Ptr Word) (Ptr Word
pb :: Ptr Word) c
c Int
0
  where
    !sa :: Int
sa = bs1 -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length bs1
a
    !sb :: Int
sb = bs2 -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length bs2
b

    loop :: Ptr Word -> Ptr Word -> c -> Int -> ST s c
loop !Ptr Word
pa !Ptr Word
pb !c
acc Int
i
        | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
sa = c -> ST s c
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return c
acc
        | Bool
otherwise = do
            va <- Ptr Word -> ST s Word
forall a s. Storable a => Ptr a -> ST s a
ST.peek Ptr Word
pa
            vb <- ST.peek pb
            loop (pa `plusPtr` bytesW) (pb `plusPtr` bytesW) (f acc va vb) (i + bytesW)
{-# INLINE foldZipWith #-}

zipWith :: (Word -> Word -> Word) -> ScrubbedBytes -> ScrubbedBytes -> ScrubbedBytes
zipWith :: (Word -> Word -> Word)
-> ScrubbedBytes -> ScrubbedBytes -> ScrubbedBytes
zipWith Word -> Word -> Word
f ScrubbedBytes
a ScrubbedBytes
b = Bool -> ScrubbedBytes -> ScrubbedBytes
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
sa Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
sb) (ScrubbedBytes -> ScrubbedBytes) -> ScrubbedBytes -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ Int -> ScrubbedBytes -> ScrubbedBytes
forall a. Int -> a -> a
assertMultW Int
sa (ScrubbedBytes -> ScrubbedBytes) -> ScrubbedBytes -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ Int -> ScrubbedBytes -> ScrubbedBytes
forall a. Int -> a -> a
assertMultW Int
sb (ScrubbedBytes -> ScrubbedBytes) -> ScrubbedBytes -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$
    Int -> (forall s. Ptr Word -> ST s ()) -> ScrubbedBytes
forall ba p.
ByteArray ba =>
Int -> (forall s. Ptr p -> ST s ()) -> ba
ST.unsafeCreate Int
sa ((forall s. Ptr Word -> ST s ()) -> ScrubbedBytes)
-> (forall s. Ptr Word -> ST s ()) -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr Word
out ->
        ScrubbedBytes -> (Ptr Word -> ST s ()) -> ST s ()
forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
ST.withByteArray ScrubbedBytes
a ((Ptr Word -> ST s ()) -> ST s ())
-> (Ptr Word -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word
pa -> ScrubbedBytes -> (Ptr Word -> ST s ()) -> ST s ()
forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
ST.withByteArray ScrubbedBytes
b ((Ptr Word -> ST s ()) -> ST s ())
-> (Ptr Word -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Ptr Word
pb ->
            Ptr Word -> Ptr Word -> Ptr Word -> Int -> ST s ()
forall s. Ptr Word -> Ptr Word -> Ptr Word -> Int -> ST s ()
loop Ptr Word
out Ptr Word
pa Ptr Word
pb Int
0
  where
    !sa :: Int
sa = ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ScrubbedBytes
a
    !sb :: Int
sb = ScrubbedBytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length ScrubbedBytes
b

    loop :: Ptr Word -> Ptr Word -> Ptr Word -> Int -> ST s ()
    loop :: forall s. Ptr Word -> Ptr Word -> Ptr Word -> Int -> ST s ()
loop !Ptr Word
out !Ptr Word
pa !Ptr Word
pb Int
i = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
sa) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        va <- Ptr Word -> ST s Word
forall a s. Storable a => Ptr a -> ST s a
ST.peek Ptr Word
pa
        vb <- ST.peek pb
        ST.pokeByteOff out i $ f va vb
        loop out (pa `plusPtr` bytesW) (pb `plusPtr` bytesW) (i + bytesW)
{-# INLINE zipWith #-}

constSelectBytes :: BoolW -> ScrubbedBytes -> ScrubbedBytes -> ScrubbedBytes
constSelectBytes :: BoolW -> ScrubbedBytes -> ScrubbedBytes -> ScrubbedBytes
constSelectBytes (BoolW !Word
mask) = (Word -> Word -> Word)
-> ScrubbedBytes -> ScrubbedBytes -> ScrubbedBytes
Crypto.zipWith Word -> Word -> Word
f
  where f :: Word -> Word -> Word
f Word
yes Word
no = (Word
mask Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
yes) Word -> Word -> Word
forall a. Bits a => a -> a -> a
.|. (Word -> Word
forall a. Bits a => a -> a
complement Word
mask Word -> Word -> Word
forall a. Bits a => a -> a -> a
.&. Word
no)

-- This version of snoc accepts a more general input and uses internally a call
-- to copyByteArrayToPtr, so it does not need a trampoline when the input is
-- backed by Block Word8
snoc :: ByteArrayAccess a => a -> Word8 -> ScrubbedBytes
snoc :: forall a. ByteArrayAccess a => a -> Word8 -> ScrubbedBytes
snoc a
a Word8
b =
    Int -> (Ptr (ZonkAny 3) -> IO ()) -> ScrubbedBytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze (Int
na Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) ((Ptr (ZonkAny 3) -> IO ()) -> ScrubbedBytes)
-> (Ptr (ZonkAny 3) -> IO ()) -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr (ZonkAny 3)
p -> do
        a -> Ptr (ZonkAny 3) -> IO ()
forall p. a -> Ptr p -> IO ()
forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO ()
B.copyByteArrayToPtr a
a Ptr (ZonkAny 3)
p
        Ptr (ZonkAny 3) -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr (ZonkAny 3)
p Int
na Word8
b
  where na :: Int
na = a -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length a
a
{-# INLINE snoc #-}

-- This version of append is more polymorphic and requires no trampoline when
-- fed with an input backed by Block Word8.
append :: (ByteArrayAccess a, ByteArrayAccess b) => a -> b -> ScrubbedBytes
append :: forall a b.
(ByteArrayAccess a, ByteArrayAccess b) =>
a -> b -> ScrubbedBytes
append a
a b
b =
    Int -> (Ptr (ZonkAny 1) -> IO ()) -> ScrubbedBytes
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.allocAndFreeze (Int
na Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
nb) ((Ptr (ZonkAny 1) -> IO ()) -> ScrubbedBytes)
-> (Ptr (ZonkAny 1) -> IO ()) -> ScrubbedBytes
forall a b. (a -> b) -> a -> b
$ \Ptr (ZonkAny 1)
p -> do
        a -> Ptr (ZonkAny 1) -> IO ()
forall p. a -> Ptr p -> IO ()
forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO ()
B.copyByteArrayToPtr a
a Ptr (ZonkAny 1)
p
        b -> Ptr (ZonkAny 2) -> IO ()
forall p. b -> Ptr p -> IO ()
forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO ()
B.copyByteArrayToPtr b
b (Ptr (ZonkAny 1)
p Ptr (ZonkAny 1) -> Int -> Ptr (ZonkAny 2)
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
na)
  where
    na :: Int
na = a -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length a
a
    nb :: Int
nb = b -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length b
b
{-# INLINE append #-}

eq :: (ByteArrayAccess a, ByteArrayAccess b) => a -> b -> Bool
eq :: forall a b.
(ByteArrayAccess a, ByteArrayAccess b) =>
a -> b -> Bool
eq a
a b
b = Bool -> Bool -> Bool
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
sa Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
sb) (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> Bool -> Bool
forall a. Int -> a -> a
assertMultM Int
sa (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$ Int -> Bool -> Bool
forall a. Int -> a -> a
assertMultM Int
sb (Bool -> Bool) -> Bool -> Bool
forall a b. (a -> b) -> a -> b
$
    (forall s. ST s Bool) -> Bool
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s Bool) -> Bool) -> (forall s. ST s Bool) -> Bool
forall a b. (a -> b) -> a -> b
$ a -> (Ptr WordM -> ST s Bool) -> ST s Bool
forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
ST.withByteArray a
a ((Ptr WordM -> ST s Bool) -> ST s Bool)
-> (Ptr WordM -> ST s Bool) -> ST s Bool
forall a b. (a -> b) -> a -> b
$ \Ptr WordM
pa -> b -> (Ptr WordM -> ST s Bool) -> ST s Bool
forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
ST.withByteArray b
b ((Ptr WordM -> ST s Bool) -> ST s Bool)
-> (Ptr WordM -> ST s Bool) -> ST s Bool
forall a b. (a -> b) -> a -> b
$ \Ptr WordM
pb ->
        Ptr WordM -> Ptr WordM -> Int -> ST s Bool
forall {b} {s}.
(Storable b, Eq b) =>
Ptr b -> Ptr b -> Int -> ST s Bool
loop (Ptr WordM
pa :: Ptr WordM) (Ptr WordM
pb :: Ptr WordM) Int
0
  where
    !sa :: Int
sa = a -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length a
a
    !sb :: Int
sb = b -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length b
b

    loop :: Ptr b -> Ptr b -> Int -> ST s Bool
loop !Ptr b
pa !Ptr b
pb Int
i
        | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
sa = Bool -> ST s Bool
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return Bool
True
        | Bool
otherwise = do
            va <- Ptr b -> ST s b
forall a s. Storable a => Ptr a -> ST s a
ST.peek Ptr b
pa
            vb <- ST.peek pb
            if va == vb
                then loop (pa `plusPtr` wordBytes) (pb `plusPtr` wordBytes) (i + wordBytes)
                else return False

prf :: ByteArrayAccess s => Word -> s -> Word8 -> ScrubbedBytes
prf :: forall s. ByteArrayAccess s => Word -> s -> Word8 -> ScrubbedBytes
prf !Word
eta s
s !Word8
b = case Nat -> SomeNat
someNatVal (Word -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word
8 Word -> Word -> Word
forall a. Num a => a -> a -> a
* Word
64 Word -> Word -> Word
forall a. Num a => a -> a -> a
* Word
eta)) of
    SomeNat Proxy n
proxy -> Digest (SHAKE256 n) -> ScrubbedBytes
forall a. Digest a -> ScrubbedBytes
unDigest (Proxy n -> Digest (SHAKE256 n)
forall (bitlen :: Nat) (proxy :: Nat -> *).
KnownNat bitlen =>
proxy bitlen -> Digest (SHAKE256 bitlen)
doHash Proxy n
proxy)
  where
    doHash :: KnownNat bitlen => proxy bitlen -> Digest (SHAKE256 bitlen)
    doHash :: forall (bitlen :: Nat) (proxy :: Nat -> *).
KnownNat bitlen =>
proxy bitlen -> Digest (SHAKE256 bitlen)
doHash proxy bitlen
_ = ScrubbedBytes -> Digest (SHAKE256 bitlen)
forall a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
ba -> Digest a
hash (s -> Word8 -> ScrubbedBytes
forall a. ByteArrayAccess a => a -> Word8 -> ScrubbedBytes
snoc s
s Word8
b)

h :: ByteArrayAccess s => s -> Bytes
h :: forall s. ByteArrayAccess s => s -> Bytes
h = Builder 'Pub -> Bytes
Builder 'Pub -> SecureBytes 'Pub
forall (marking :: SecurityMarking).
Classified marking =>
Builder marking -> SecureBytes marking
Builder.run (Builder 'Pub -> Bytes) -> (s -> Builder 'Pub) -> s -> Bytes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SHA3_256 -> s -> Builder 'Pub
forall (marking :: SecurityMarking) a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
a -> ba -> Builder marking
hashWith SHA3_256
SHA3_256

j :: ScrubbedBytes -> ScrubbedBytes
j :: ScrubbedBytes -> ScrubbedBytes
j = Builder 'Sec -> ScrubbedBytes
Builder 'Sec -> SecureBytes 'Sec
forall (marking :: SecurityMarking).
Classified marking =>
Builder marking -> SecureBytes marking
Builder.run (Builder 'Sec -> ScrubbedBytes)
-> (ScrubbedBytes -> Builder 'Sec)
-> ScrubbedBytes
-> ScrubbedBytes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SHAKE256 256 -> ScrubbedBytes -> Builder 'Sec
forall (marking :: SecurityMarking) a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
a -> ba -> Builder marking
hashWith (SHAKE256 256
forall (bitlen :: Nat). SHAKE256 bitlen
SHAKE256 :: SHAKE256 256)

g  :: ByteArray ba => ScrubbedBytes -> (ba, B.View ScrubbedBytes)
g :: forall ba.
ByteArray ba =>
ScrubbedBytes -> (ba, View ScrubbedBytes)
g ScrubbedBytes
c = (View (SecureBytes 'Sec) -> ba
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
B.convert (View (SecureBytes 'Sec) -> ba) -> View (SecureBytes 'Sec) -> ba
forall a b. (a -> b) -> a -> b
$ SecureBytes 'Sec -> Int -> View (SecureBytes 'Sec)
forall bytes. ByteArrayAccess bytes => bytes -> Int -> View bytes
B.takeView SecureBytes 'Sec
ab Int
32, ScrubbedBytes -> Int -> View ScrubbedBytes
forall bytes. ByteArrayAccess bytes => bytes -> Int -> View bytes
B.dropView ScrubbedBytes
SecureBytes 'Sec
ab Int
32)
  where ab :: SecureBytes 'Sec
ab = Builder 'Sec -> SecureBytes 'Sec
forall (marking :: SecurityMarking).
Classified marking =>
Builder marking -> SecureBytes marking
Builder.run (Builder 'Sec -> SecureBytes 'Sec)
-> Builder 'Sec -> SecureBytes 'Sec
forall a b. (a -> b) -> a -> b
$ SHA3_512 -> ScrubbedBytes -> Builder 'Sec
forall (marking :: SecurityMarking) a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
a -> ba -> Builder marking
hashWith SHA3_512
SHA3_512 ScrubbedBytes
c

-- Override cryptonite/crypton types and hashing functions.
--
-- Standard type Digest is a newtype over an unpinned Block Word8, which
-- requires a trampoline to implement most Ptr access to the underlying byte
-- array.  Instead we re-implement here the Digest type over ScrubbedBytes as
-- well as pinned Block backends, to avoid trampoline costs.  Additionnally
-- we use the mutable API to avoid copying the hashing Context in between
-- steps init/update/finalize and then clear the content.

newtype Digest a = Digest { forall a. Digest a -> ScrubbedBytes
unDigest :: ScrubbedBytes }
newtype BlockDigest a = BlockDigest { forall a. BlockDigest a -> Block Word8
unBlockDigest :: Block Word8 }

hash :: forall a ba. (HashAlgorithm a, ByteArrayAccess ba) => ba -> Digest a
hash :: forall a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
ba -> Digest a
hash = ScrubbedBytes -> Digest a
forall a. ScrubbedBytes -> Digest a
Digest (ScrubbedBytes -> Digest a)
-> (ba -> ScrubbedBytes) -> ba -> Digest a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder 'Sec -> ScrubbedBytes
Builder 'Sec -> SecureBytes 'Sec
forall (marking :: SecurityMarking).
Classified marking =>
Builder marking -> SecureBytes marking
Builder.run (Builder 'Sec -> ScrubbedBytes)
-> (ba -> Builder 'Sec) -> ba -> ScrubbedBytes
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> ba -> Builder 'Sec
forall (marking :: SecurityMarking) a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
a -> ba -> Builder marking
hashWith (a
forall a. (?callStack::CallStack) => a
undefined :: a)

hashToBlock :: forall a. HashAlgorithm a => Bytes -> BlockDigest a
hashToBlock :: forall a. HashAlgorithm a => Bytes -> BlockDigest a
hashToBlock = Block Word8 -> BlockDigest a
forall a. Block Word8 -> BlockDigest a
BlockDigest (Block Word8 -> BlockDigest a)
-> (Bytes -> Block Word8) -> Bytes -> BlockDigest a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder 'Pub -> Block Word8
Builder.runToBlock (Builder 'Pub -> Block Word8)
-> (Bytes -> Builder 'Pub) -> Bytes -> Block Word8
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Bytes -> Builder 'Pub
forall (marking :: SecurityMarking) a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
a -> ba -> Builder marking
hashWith (a
forall a. (?callStack::CallStack) => a
undefined :: a)

hashWith :: forall marking a ba. (HashAlgorithm a, ByteArrayAccess ba) => a -> ba -> Builder marking
hashWith :: forall (marking :: SecurityMarking) a ba.
(HashAlgorithm a, ByteArrayAccess ba) =>
a -> ba -> Builder marking
hashWith a
a ba
ba = Int -> (Ptr (Digest a) -> IO ()) -> Builder marking
forall a (marking :: SecurityMarking).
Int -> (Ptr a -> IO ()) -> Builder marking
Builder.unsafeCreate (a -> Int
forall a. HashAlgorithm a => a -> Int
hashDigestSize a
a) ((Ptr (Digest a) -> IO ()) -> Builder marking)
-> (Ptr (Digest a) -> IO ()) -> Builder marking
forall a b. (a -> b) -> a -> b
$ \Ptr (Digest a)
dig -> do
    ctx <- IO (MutableContext a)
forall alg. HashAlgorithm alg => IO (MutableContext alg)
hashMutableInit
    hashMutableUpdate (ctx :: MutableContext a) ba
    B.withByteArray ctx $ \Ptr (ZonkAny 0)
pctx -> do
        Ptr (Context a) -> Ptr (Digest a) -> IO ()
forall a.
HashAlgorithm a =>
Ptr (Context a) -> Ptr (Digest a) -> IO ()
hashInternalFinalize (Ptr (ZonkAny 0) -> Ptr (Context a)
forall a b. Ptr a -> Ptr b
castPtr Ptr (ZonkAny 0)
pctx :: Ptr (Context a)) Ptr (Digest a)
dig
        Ptr (ZonkAny 0) -> Word8 -> Int -> IO ()
forall a. Ptr a -> Word8 -> Int -> IO ()
fillBytes Ptr (ZonkAny 0)
pctx Word8
0 (MutableContext a -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length MutableContext a
ctx)