-- |
-- Module      : Auxiliary
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- ML-KEM auxiliary functions
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GeneralizedNewtypeDeriving #-}
{-# LANGUAGE MagicHash #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE UnboxedTuples #-}
module Auxiliary
    ( Zq, Rq, Tq, (..+), (..-)
    , ntt, nttInv, rcompress, rdecompress
    , byteEncode, byteDecode, byteEncode12, byteDecode12
    , byteEncode1, byteDecode1, sampleNTT, samplePolyCBD
#ifdef ML_KEM_TESTING
    , compress, decompress
    , bitRev7, fromZq, toZq, fromCoeffs, toCoeffs
#endif
    ) where

import Crypto.Hash.Algorithms

import Data.ByteArray (ByteArrayAccess, Bytes, View)
import qualified Data.ByteArray as B

import Data.Primitive.Types (Prim(..))

import Control.DeepSeq (NFData(..))
import Control.Monad
import Control.Monad.ST

import Data.Bits
import Data.Proxy
import Data.Word

import GHC.TypeNats

import Foreign.Ptr (Ptr, plusPtr)
import Foreign.Storable (pokeByteOff)

import Unsafe.Coerce

import Base
import Block (blockIndex)
import BlockN (BlockN, MutableBlockN)
import Builder (Builder)
import Crypto (BlockDigest)
import Machine
import Marking (Classified, SecurityMarking(..), Leak(..))
import SecureBlock (SecureBlock)
import SecureBytes (SecureBytes)
import qualified BlockN
import qualified Builder
import qualified ByteArrayST as ST
import qualified Crypto
import Math

type N = 256

n :: Int
n :: Int
n = Int
256

q :: Integer
q :: Integer
q = Integer
3329

q16 :: Word16
q16 :: Word16
q16 = Integer -> Word16
forall a. Num a => Integer -> a
fromInteger Integer
q

q32 :: Word32
q32 :: Word32
q32 = Integer -> Word32
forall a. Num a => Integer -> a
fromInteger Integer
q

q64 :: Word64
q64 :: WordM
q64 = Integer -> WordM
forall a. Num a => Integer -> a
fromInteger Integer
q

bitRev7 :: Word8 -> Word8
bitRev7 :: Word8 -> Word8
bitRev7 Word8
b =
    (Word8
b Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
6 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
1) Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|.
    (Word8
b Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
5 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
1) Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
1 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|.
    (Word8
b Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
4 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
1) Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
2 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|.
    (Word8
b Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
3 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
1) Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
3 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|.
    (Word8
b Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
2 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
1) Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
4 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|.
    (Word8
b Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
1) Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
5 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.|.
    (Word8
b Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
.&. Word8
1) Word8 -> Int -> Word8
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
6

-- Reduction 𝑥 mod 𝑞 for 0 ≤ 𝑥 < 2𝑞
reduceSimple :: Word16 -> Word16
reduceSimple :: Word16 -> Word16
reduceSimple Word16
x = (Word16
mask Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
x) Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. (Word16 -> Word16
forall a. Bits a => a -> a
complement Word16
mask Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
subtracted)
  where
    subtracted :: Word16
subtracted = Word16
x Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
- Word16
q16
    mask :: Word16
mask = Word16 -> Word16
forall a. Num a => a -> a
negate (Word16
subtracted Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
15)
{-# INLINE reduceSimple #-}

-- Reduction 𝑥 mod 𝑞 for 0 ≤ 𝑥 < 2𝑞² + 𝑞
reduce :: Word32 -> Word16
reduce :: Word32 -> Word16
reduce Word32
x = Word16 -> Word16
reduceSimple (Word32 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
remainder)
  where
    p :: WordM
p = Word32 -> WordM
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word32
x WordM -> WordM -> WordM
forall a. Num a => a -> a -> a
* ((WordM
1 WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
24) WordM -> WordM -> WordM
forall a. Integral a => a -> a -> a
`div` WordM
q64)
    quotient :: Word32
quotient = WordM -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral (WordM
p WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
24)
    remainder :: Word32
remainder = Word32
x Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
- Word32
quotient Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
q32
{-# INLINE reduce #-}

newtype Zq = Zq Word16
#ifdef ML_KEM_TESTING
    deriving (Eq, Show)
#else
    deriving Zq -> Zq -> Bool
(Zq -> Zq -> Bool) -> (Zq -> Zq -> Bool) -> Eq Zq
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: Zq -> Zq -> Bool
== :: Zq -> Zq -> Bool
$c/= :: Zq -> Zq -> Bool
/= :: Zq -> Zq -> Bool
Eq
#endif

instance Prim Zq where
    sizeOf# :: Zq -> Int#
sizeOf# (Zq Word16
a) = Word16 -> Int#
forall a. Prim a => a -> Int#
sizeOf# Word16
a
    {-# INLINE sizeOf# #-}
    alignment# :: Zq -> Int#
alignment# (Zq Word16
a) = Word16 -> Int#
forall a. Prim a => a -> Int#
alignment# Word16
a
    {-# INLINE alignment# #-}
#if MIN_VERSION_primitive(0,9,0)
    sizeOfType# :: Proxy Zq -> Int#
sizeOfType# Proxy Zq
_ = Proxy Word16 -> Int#
forall a. Prim a => Proxy a -> Int#
sizeOfType# (Proxy Word16
forall {k} (t :: k). Proxy t
Proxy :: Proxy Word16)
    {-# INLINE sizeOfType# #-}
    alignmentOfType# :: Proxy Zq -> Int#
alignmentOfType# Proxy Zq
_ = Proxy Word16 -> Int#
forall a. Prim a => Proxy a -> Int#
alignmentOfType# (Proxy Word16
forall {k} (t :: k). Proxy t
Proxy :: Proxy Word16)
    {-# INLINE alignmentOfType# #-}
#endif
    indexByteArray# :: ByteArray# -> Int# -> Zq
indexByteArray# ByteArray#
ba Int#
i = Word16 -> Zq
Zq (ByteArray# -> Int# -> Word16
forall a. Prim a => ByteArray# -> Int# -> a
indexByteArray# ByteArray#
ba Int#
i)
    {-# INLINE indexByteArray# #-}
    readByteArray# :: forall s.
MutableByteArray# s -> Int# -> State# s -> (# State# s, Zq #)
readByteArray# MutableByteArray# s
mba Int#
i State# s
s =
        case MutableByteArray# s -> Int# -> State# s -> (# State# s, Word16 #)
forall s.
MutableByteArray# s -> Int# -> State# s -> (# State# s, Word16 #)
forall a s.
Prim a =>
MutableByteArray# s -> Int# -> State# s -> (# State# s, a #)
readByteArray# MutableByteArray# s
mba Int#
i State# s
s of
            (# State# s
s', Word16
a #) -> (# State# s
s', Word16 -> Zq
Zq Word16
a #)
    {-# INLINE readByteArray# #-}
    writeByteArray# :: forall s. MutableByteArray# s -> Int# -> Zq -> State# s -> State# s
writeByteArray# MutableByteArray# s
mba Int#
i (Zq Word16
a) = MutableByteArray# s -> Int# -> Word16 -> State# s -> State# s
forall s.
MutableByteArray# s -> Int# -> Word16 -> State# s -> State# s
forall a s.
Prim a =>
MutableByteArray# s -> Int# -> a -> State# s -> State# s
writeByteArray# MutableByteArray# s
mba Int#
i Word16
a
    {-# INLINE writeByteArray# #-}
    setByteArray# :: forall s.
MutableByteArray# s -> Int# -> Int# -> Zq -> State# s -> State# s
setByteArray# MutableByteArray# s
mba Int#
i Int#
len (Zq Word16
a) = MutableByteArray# s
-> Int# -> Int# -> Word16 -> State# s -> State# s
forall s.
MutableByteArray# s
-> Int# -> Int# -> Word16 -> State# s -> State# s
forall a s.
Prim a =>
MutableByteArray# s -> Int# -> Int# -> a -> State# s -> State# s
setByteArray# MutableByteArray# s
mba Int#
i Int#
len Word16
a
    {-# INLINE setByteArray# #-}
    indexOffAddr# :: Addr# -> Int# -> Zq
indexOffAddr# Addr#
addr Int#
i = Word16 -> Zq
Zq (Addr# -> Int# -> Word16
forall a. Prim a => Addr# -> Int# -> a
indexOffAddr# Addr#
addr Int#
i)
    {-# INLINE indexOffAddr# #-}
    readOffAddr# :: forall s. Addr# -> Int# -> State# s -> (# State# s, Zq #)
readOffAddr# Addr#
addr Int#
i State# s
s =
        case Addr# -> Int# -> State# s -> (# State# s, Word16 #)
forall s. Addr# -> Int# -> State# s -> (# State# s, Word16 #)
forall a s.
Prim a =>
Addr# -> Int# -> State# s -> (# State# s, a #)
readOffAddr# Addr#
addr Int#
i State# s
s of
            (# State# s
s', Word16
a #) -> (# State# s
s', Word16 -> Zq
Zq Word16
a #)
    {-# INLINE readOffAddr# #-}
    writeOffAddr# :: forall s. Addr# -> Int# -> Zq -> State# s -> State# s
writeOffAddr# Addr#
addr Int#
i (Zq Word16
a) = Addr# -> Int# -> Word16 -> State# s -> State# s
forall s. Addr# -> Int# -> Word16 -> State# s -> State# s
forall a s. Prim a => Addr# -> Int# -> a -> State# s -> State# s
writeOffAddr# Addr#
addr Int#
i Word16
a
    {-# INLINE writeOffAddr# #-}
    setOffAddr# :: forall s. Addr# -> Int# -> Int# -> Zq -> State# s -> State# s
setOffAddr# Addr#
addr Int#
i Int#
len (Zq Word16
a) = Addr# -> Int# -> Int# -> Word16 -> State# s -> State# s
forall s. Addr# -> Int# -> Int# -> Word16 -> State# s -> State# s
forall a s.
Prim a =>
Addr# -> Int# -> Int# -> a -> State# s -> State# s
setOffAddr# Addr#
addr Int#
i Int#
len Word16
a
    {-# INLINE setOffAddr# #-}

instance PrimSized Zq where
    type PrimSize Zq = 2

instance Add Zq where
    zero :: Zq
zero = Word16 -> Zq
Zq Word16
0
    Zq Word16
a .+ :: Zq -> Zq -> Zq
.+ Zq Word16
b = Word16 -> Zq
Zq (Word16 -> Zq) -> Word16 -> Zq
forall a b. (a -> b) -> a -> b
$ Word16 -> Word16
reduceSimple (Word16
a Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word16
b)
    Zq Word16
a .- :: Zq -> Zq -> Zq
.- Zq Word16
b = Word16 -> Zq
Zq (Word16 -> Zq) -> Word16 -> Zq
forall a b. (a -> b) -> a -> b
$ Word16 -> Word16
reduceSimple (Word16
a Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ Word16
q16 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
- Word16
b)
    neg :: Zq -> Zq
neg (Zq Word16
a) = Word16 -> Zq
Zq (Word16 -> Zq) -> Word16 -> Zq
forall a b. (a -> b) -> a -> b
$ Word16 -> Word16
reduceSimple (Word16
q16 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
- Word16
a)

instance Mul Zq where
    one :: Zq
one = Word16 -> Zq
Zq Word16
1
    Zq Word16
a .* :: Zq -> Zq -> Zq
.* Zq Word16
b = Word16 -> Zq
Zq (Word16 -> Zq) -> Word16 -> Zq
forall a b. (a -> b) -> a -> b
$ Word32 -> Word16
reduce (Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
a Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
b)

#ifdef ML_KEM_TESTING
instance MulAdd Zq where
    mulAdd (Zq a) (Zq b) (Zq c) = Zq $ reduce $
        fromIntegral a * fromIntegral b + fromIntegral c

instance BiMul Zq Zq where
    (..*) = (.*)

instance BiMulAdd Zq Zq where
    biMulAdd = mulAdd

fromZq :: Zq -> Word16
fromZq (Zq a) = a
#endif

toZq :: Word16 -> Zq
toZq :: Word16 -> Zq
toZq = Word16 -> Zq
Zq (Word16 -> Zq) -> (Word16 -> Word16) -> Word16 -> Zq
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word32 -> Word16
reduce (Word32 -> Word16) -> (Word16 -> Word32) -> Word16 -> Word16
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral

newtype Rq marking = Rq (BlockN marking N Zq)
#ifdef ML_KEM_TESTING
    deriving (Eq, Show)
#endif

instance Classified marking => Add (Rq marking) where
    zero :: Rq marking
zero = BlockN marking N Zq -> Rq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq BlockN marking N Zq
forall a. Add a => a
zero
    Rq BlockN marking N Zq
a .+ :: Rq marking -> Rq marking -> Rq marking
.+ Rq BlockN marking N Zq
b = BlockN marking N Zq -> Rq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN marking N Zq
a BlockN marking N Zq -> BlockN marking N Zq -> BlockN marking N Zq
forall a. Add a => a -> a -> a
.+ BlockN marking N Zq
b)
    {-# INLINE (.+) #-}
    Rq BlockN marking N Zq
a .- :: Rq marking -> Rq marking -> Rq marking
.- Rq BlockN marking N Zq
b = BlockN marking N Zq -> Rq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN marking N Zq
a BlockN marking N Zq -> BlockN marking N Zq -> BlockN marking N Zq
forall a. Add a => a -> a -> a
.- BlockN marking N Zq
b)
    {-# INLINE (.-) #-}
    neg :: Rq marking -> Rq marking
neg (Rq BlockN marking N Zq
a) = BlockN marking N Zq -> Rq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN marking N Zq -> BlockN marking N Zq
forall a. Add a => a -> a
neg BlockN marking N Zq
a)
    {-# INLINE neg #-}

infixl 6 ..+, ..-

-- Transformation called only at expected location in the LWE problem, after
-- adding noise to secret information.
(..+) :: Rq Sec -> Rq Sec -> Rq Pub
Rq 'Sec
a ..+ :: Rq 'Sec -> Rq 'Sec -> Rq 'Pub
..+ Rq 'Sec
b = Rq 'Sec -> Rq 'Pub
forall (t :: SecurityMarking -> *). Leak t => t 'Sec -> t 'Pub
leak (Rq 'Sec
a Rq 'Sec -> Rq 'Sec -> Rq 'Sec
forall a. Add a => a -> a -> a
.+ Rq 'Sec
b)
{-# INLINE (..+) #-}

(..-) :: Rq Pub -> Rq Sec -> Rq Sec
Rq BlockN 'Pub N Zq
a ..- :: Rq 'Pub -> Rq 'Sec -> Rq 'Sec
..- Rq BlockN 'Sec N Zq
b = BlockN 'Sec N Zq -> Rq 'Sec
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN 'Sec N Zq -> Rq 'Sec) -> BlockN 'Sec N Zq -> Rq 'Sec
forall a b. (a -> b) -> a -> b
$ (Zq -> Zq -> Zq)
-> BlockN 'Sec N Zq -> BlockN 'Pub N Zq -> BlockN 'Sec N Zq
forall (mc :: SecurityMarking) (n :: Nat) a b c
       (ma :: SecurityMarking) (mb :: SecurityMarking).
(Classified mc, KnownNat n, PrimType a, PrimType b, PrimType c) =>
(a -> b -> c) -> BlockN ma n a -> BlockN mb n b -> BlockN mc n c
BlockN.zipWith ((Zq -> Zq -> Zq) -> Zq -> Zq -> Zq
forall a b c. (a -> b -> c) -> b -> a -> c
flip Zq -> Zq -> Zq
forall a. Add a => a -> a -> a
(.-)) BlockN 'Sec N Zq
b BlockN 'Pub N Zq
a
{-# INLINE (..-) #-}

instance Leak Rq

#ifdef ML_KEM_TESTING
fromCoeffs :: [Zq] -> Maybe (Rq Sec)
fromCoeffs = fmap Rq . BlockN.fromList

toCoeffs :: Rq Sec -> [Zq]
toCoeffs (Rq a) = BlockN.toList a
#endif

newtype Tq marking = Tq (BlockN marking N Zq)
#ifdef ML_KEM_TESTING
    deriving (Eq, Show, NFData)
#else
    deriving Tq marking -> ()
(Tq marking -> ()) -> NFData (Tq marking)
forall a. (a -> ()) -> NFData a
forall (marking :: SecurityMarking). Tq marking -> ()
$crnf :: forall (marking :: SecurityMarking). Tq marking -> ()
rnf :: Tq marking -> ()
NFData
#endif

instance Classified marking => Add (Tq marking) where
    zero :: Tq marking
zero = BlockN marking N Zq -> Tq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq BlockN marking N Zq
forall a. Add a => a
zero
    Tq BlockN marking N Zq
a .+ :: Tq marking -> Tq marking -> Tq marking
.+ Tq BlockN marking N Zq
b = BlockN marking N Zq -> Tq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN marking N Zq
a BlockN marking N Zq -> BlockN marking N Zq -> BlockN marking N Zq
forall a. Add a => a -> a -> a
.+ BlockN marking N Zq
b)
    {-# INLINE (.+) #-}
    Tq BlockN marking N Zq
a .- :: Tq marking -> Tq marking -> Tq marking
.- Tq BlockN marking N Zq
b = BlockN marking N Zq -> Tq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN marking N Zq
a BlockN marking N Zq -> BlockN marking N Zq -> BlockN marking N Zq
forall a. Add a => a -> a -> a
.- BlockN marking N Zq
b)
    {-# INLINE (.-) #-}
    neg :: Tq marking -> Tq marking
neg (Tq BlockN marking N Zq
a) = BlockN marking N Zq -> Tq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN marking N Zq -> BlockN marking N Zq
forall a. Add a => a -> a
neg BlockN marking N Zq
a)
    {-# INLINE neg #-}

instance Leak Tq

instance BiMul (Tq Pub) (Tq Sec) where
    ..* :: Tq 'Pub -> Tq 'Sec -> Tq 'Sec
(..*) = Tq 'Pub -> Tq 'Sec -> Tq 'Sec
multiplyNTTs
    {-# INLINE (..*) #-}

instance BiMulAdd (Tq Pub) (Tq Sec) where
    biMulFold :: forall (t :: * -> *).
Foldable t =>
Tq 'Sec -> t (Tq 'Pub, Tq 'Sec) -> Tq 'Sec
biMulFold = Tq 'Sec -> t (Tq 'Pub, Tq 'Sec) -> Tq 'Sec
forall (t :: * -> *).
Foldable t =>
Tq 'Sec -> t (Tq 'Pub, Tq 'Sec) -> Tq 'Sec
multiplyNTTsFold
    {-# INLINE biMulFold #-}

#ifdef ML_KEM_TESTING
instance Mul (Tq Sec) where
    one = Tq $ BlockN.create $ \(Offset i) -> if even i then one else zero
    (.*) = (..*) . leak

instance MulAdd (Tq Sec) where
    mulAdd = biMulAdd . leak
#endif

instance Crypto.ConstEqW (Tq Sec) where
    constEqW :: Tq 'Sec -> Tq 'Sec -> BoolW
constEqW (Tq BlockN 'Sec N Zq
a) (Tq BlockN 'Sec N Zq
b) = ScrubbedBlock Word -> ScrubbedBlock Word -> BoolW
forall a. ConstEqW a => a -> a -> BoolW
Crypto.constEqW
        (BlockN 'Sec N Zq -> SecureBlock 'Sec Word
forall (marking :: SecurityMarking) (n :: Nat) a b.
BlockN marking n a -> SecureBlock marking b
BlockN.unsafeCast BlockN 'Sec N Zq
a :: SecureBlock Sec Word)
        (BlockN 'Sec N Zq -> SecureBlock 'Sec Word
forall (marking :: SecurityMarking) (n :: Nat) a b.
BlockN marking n a -> SecureBlock marking b
BlockN.unsafeCast BlockN 'Sec N Zq
b :: SecureBlock Sec Word)

instance Crypto.ConstEqW (Tq Pub) where
    constEqW :: Tq 'Pub -> Tq 'Pub -> BoolW
constEqW (Tq BlockN 'Pub N Zq
a) (Tq BlockN 'Pub N Zq
b) = PrimArray Word -> PrimArray Word -> BoolW
forall a. ConstEqW a => a -> a -> BoolW
Crypto.constEqW
        (BlockN 'Pub N Zq -> SecureBlock 'Pub Word
forall (marking :: SecurityMarking) (n :: Nat) a b.
BlockN marking n a -> SecureBlock marking b
BlockN.unsafeCast BlockN 'Pub N Zq
a :: SecureBlock Pub Word)
        (BlockN 'Pub N Zq -> SecureBlock 'Pub Word
forall (marking :: SecurityMarking) (n :: Nat) a b.
BlockN marking n a -> SecureBlock marking b
BlockN.unsafeCast BlockN 'Pub N Zq
b :: SecureBlock Pub Word)

-- Computes the NTT representation of the given polynomial
ntt :: Classified marking => Rq marking -> Tq marking
ntt :: forall (marking :: SecurityMarking).
Classified marking =>
Rq marking -> Tq marking
ntt (Rq BlockN marking N Zq
a) = BlockN marking N Zq -> Tq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN marking N Zq -> Tq marking)
-> BlockN marking N Zq -> Tq marking
forall a b. (a -> b) -> a -> b
$ BlockN marking N Zq
-> (forall s. MutableBlockN marking N Zq s -> ST s ())
-> BlockN marking N Zq
forall (marking :: SecurityMarking) (n :: Nat) a.
(Classified marking, KnownNat n, PrimType a) =>
BlockN marking n a
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runThaw BlockN marking N Zq
a MutableBlockN marking N Zq s -> ST s ()
forall s. MutableBlockN marking N Zq s -> ST s ()
forall (marking :: SecurityMarking) s.
MutableBlockN marking N Zq s -> ST s ()
mutNtt
{-# INLINE ntt #-}

mutNtt :: MutableBlockN marking N Zq s -> ST s ()
mutNtt :: forall (marking :: SecurityMarking) s.
MutableBlockN marking N Zq s -> ST s ()
mutNtt !MutableBlockN marking N Zq s
b = Offset Zq -> Offset Zq -> ST s ()
outer Offset Zq
1 Offset Zq
128
  where
    outer :: Offset Zq -> Offset Zq -> ST s ()
outer !Offset Zq
i Offset Zq
len = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
len Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
>= Offset Zq
2) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
inner Offset Zq
i Offset Zq
len Offset Zq
0

    inner :: Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
inner !Offset Zq
i !Offset Zq
len Offset Zq
start
        | Offset Zq
start Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Offset Zq
256 = do
            let zeta :: Zq
zeta = BlockN 'Pub 128 Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub 128 Zq
zetaPowBitRev Offset Zq
i -- 17 ^ bitRev7 i
            Zq -> Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
loop Zq
zeta (Offset Zq
start Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
len) Offset Zq
len Offset Zq
start
            Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
inner (Offset Zq
i Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1) Offset Zq
len (Offset Zq
start Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Int -> Offset Zq -> Offset Zq
forall ty ty2. Int -> Offset ty -> Offset ty2
offsetShiftL Int
1 Offset Zq
len)
        | Bool
otherwise = Offset Zq -> Offset Zq -> ST s ()
outer Offset Zq
i (Int -> Offset Zq -> Offset Zq
forall ty ty2. Int -> Offset ty -> Offset ty2
offsetShiftR Int
1 Offset Zq
len)

    loop :: Zq -> Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
loop !Zq
zeta Offset Zq
end Offset Zq
len Offset Zq
j =
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
j Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Offset Zq
end) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
            t <- (Zq
zeta Zq -> Zq -> Zq
forall a. Mul a => a -> a -> a
.*) (Zq -> Zq) -> ST s Zq -> ST s Zq
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> MutableBlockN marking N Zq (PrimState (ST s))
-> Offset Zq -> ST s Zq
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim) -> Offset a -> prim a
BlockN.read MutableBlockN marking N Zq s
MutableBlockN marking N Zq (PrimState (ST s))
b (Offset Zq
j Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
len)
            x <- BlockN.read b j
            BlockN.write b (j + len) (x .- t)
            BlockN.write b j (x .+ t)
            loop zeta end len (j + 1)
{-# NOINLINE mutNtt #-}

-- Computes the polynomial that corresponds to the given NTT representation
nttInv :: Tq Sec -> Rq Sec
nttInv :: Tq 'Sec -> Rq 'Sec
nttInv (Tq BlockN 'Sec N Zq
a) = BlockN 'Sec N Zq -> Rq 'Sec
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN 'Sec N Zq -> Rq 'Sec) -> BlockN 'Sec N Zq -> Rq 'Sec
forall a b. (a -> b) -> a -> b
$ BlockN 'Sec N Zq
-> (forall s. MutableBlockN 'Sec N Zq s -> ST s ())
-> BlockN 'Sec N Zq
forall (marking :: SecurityMarking) (n :: Nat) a.
(Classified marking, KnownNat n, PrimType a) =>
BlockN marking n a
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runThaw BlockN 'Sec N Zq
a MutableBlockN 'Sec N Zq s -> ST s ()
forall s. MutableBlockN 'Sec N Zq s -> ST s ()
mutNttInv
{-# INLINE nttInv #-}

mutNttInv :: MutableBlockN Sec N Zq s -> ST s ()
mutNttInv :: forall s. MutableBlockN 'Sec N Zq s -> ST s ()
mutNttInv !MutableBlockN 'Sec N Zq s
b = do
    Offset Zq -> Offset Zq -> ST s ()
outer Offset Zq
127 Offset Zq
2
    (Zq -> Zq) -> MutableBlockN 'Sec N Zq (PrimState (ST s)) -> ST s ()
forall (marking :: SecurityMarking) (n :: Nat) ty (prim :: * -> *).
(PrimType ty, KnownNat n, PrimMonad prim) =>
(ty -> ty)
-> MutableBlockN marking n ty (PrimState prim) -> prim ()
BlockN.iterModify (\Zq
x -> Zq
x Zq -> Zq -> Zq
forall a. Mul a => a -> a -> a
.* Word16 -> Zq
Zq Word16
3303) MutableBlockN 'Sec N Zq s
MutableBlockN 'Sec N Zq (PrimState (ST s))
b
  where
    outer :: Offset Zq -> Offset Zq -> ST s ()
outer !Offset Zq
i Offset Zq
len = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
len Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
<= Offset Zq
128) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
inner Offset Zq
i Offset Zq
len Offset Zq
0

    inner :: Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
inner !Offset Zq
i !Offset Zq
len Offset Zq
start
        | Offset Zq
start Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Offset Zq
256 = do
            let zeta :: Zq
zeta = BlockN 'Pub 128 Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub 128 Zq
zetaPowBitRev Offset Zq
i -- 17 ^ bitRev7 i
            Zq -> Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
loop Zq
zeta (Offset Zq
start Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
len) Offset Zq
len Offset Zq
start
            Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
inner (Offset Zq
i Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
- Offset Zq
1) Offset Zq
len (Offset Zq
start Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Int -> Offset Zq -> Offset Zq
forall ty ty2. Int -> Offset ty -> Offset ty2
offsetShiftL Int
1 Offset Zq
len)
        | Bool
otherwise = Offset Zq -> Offset Zq -> ST s ()
outer Offset Zq
i (Int -> Offset Zq -> Offset Zq
forall ty ty2. Int -> Offset ty -> Offset ty2
offsetShiftL Int
1 Offset Zq
len)

    loop :: Zq -> Offset Zq -> Offset Zq -> Offset Zq -> ST s ()
loop !Zq
zeta Offset Zq
end Offset Zq
len Offset Zq
j =
        Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
j Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Offset Zq
end) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
            t <- MutableBlockN 'Sec N Zq (PrimState (ST s)) -> Offset Zq -> ST s Zq
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim) -> Offset a -> prim a
BlockN.read MutableBlockN 'Sec N Zq s
MutableBlockN 'Sec N Zq (PrimState (ST s))
b Offset Zq
j
            x <- BlockN.read b (j + len)
            BlockN.write b j (t .+ x)
            BlockN.write b (j + len) (zeta .* (x .- t))
            loop zeta end len (j + 1)
{-# NOINLINE mutNttInv #-}

-- Computes the product of two NTT representations
multiplyNTTs :: Tq Pub -> Tq Sec -> Tq Sec
multiplyNTTs :: Tq 'Pub -> Tq 'Sec -> Tq 'Sec
multiplyNTTs Tq 'Pub
f Tq 'Sec
g = BlockN 'Sec N Zq -> Tq 'Sec
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN 'Sec N Zq -> Tq 'Sec) -> BlockN 'Sec N Zq -> Tq 'Sec
forall a b. (a -> b) -> a -> b
$
    Proxy 'Sec
-> (forall s. MutableBlockN 'Sec N Zq s -> ST s ())
-> BlockN 'Sec N Zq
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runNew (Proxy 'Sec
forall {k} (t :: k). Proxy t
Proxy :: Proxy Sec) ((forall s. MutableBlockN 'Sec N Zq s -> ST s ())
 -> BlockN 'Sec N Zq)
-> (forall s. MutableBlockN 'Sec N Zq s -> ST s ())
-> BlockN 'Sec N Zq
forall a b. (a -> b) -> a -> b
$ Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
forall s.
Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
mutMultiplyNTTs Tq 'Pub
f Tq 'Sec
g
{-# INLINE multiplyNTTs #-}

mutMultiplyNTTs :: Tq Pub -> Tq Sec -> MutableBlockN Sec N Zq s -> ST s ()
mutMultiplyNTTs :: forall s.
Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
mutMultiplyNTTs (Tq !BlockN 'Pub N Zq
f) (Tq !BlockN 'Sec N Zq
g) MutableBlockN 'Sec N Zq s
bb = MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
forall s. MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
loop MutableBlockN 'Sec N Zq s
bb Offset Zq
0
  where
    loop :: MutableBlockN Sec N Zq s -> Offset Zq -> ST s ()
    loop :: forall s. MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
loop !MutableBlockN 'Sec N Zq s
b Offset Zq
i = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
i Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Offset Zq
128) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        let ii :: Offset Zq
ii = Int -> Offset Zq -> Offset Zq
forall ty ty2. Int -> Offset ty -> Offset ty2
offsetShiftL Int
1 Offset Zq
i
            a0 :: Zq
a0 = BlockN 'Pub N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub N Zq
f Offset Zq
ii
            a1 :: Zq
a1 = BlockN 'Pub N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub N Zq
f (Offset Zq
ii Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1)
            b0 :: Zq
b0 = BlockN 'Sec N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Sec N Zq
g Offset Zq
ii
            b1 :: Zq
b1 = BlockN 'Sec N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Sec N Zq
g (Offset Zq
ii Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1)
            (Zq
c0, Zq
c1) = Zq -> Zq -> Zq -> Zq -> Zq -> (Zq, Zq)
baseCaseMultiply Zq
a0 Zq
a1 Zq
b0 Zq
b1 (BlockN 'Pub 128 Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub 128 Zq
gamma Offset Zq
i)
        MutableBlockN 'Sec N Zq (PrimState (ST s))
-> Offset Zq -> Zq -> ST s ()
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
BlockN.write MutableBlockN 'Sec N Zq s
MutableBlockN 'Sec N Zq (PrimState (ST s))
b Offset Zq
ii Zq
c0
        MutableBlockN 'Sec N Zq (PrimState (ST s))
-> Offset Zq -> Zq -> ST s ()
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
BlockN.write MutableBlockN 'Sec N Zq s
MutableBlockN 'Sec N Zq (PrimState (ST s))
b (Offset Zq
ii Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1) Zq
c1
        MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
forall s. MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
loop MutableBlockN 'Sec N Zq s
b (Offset Zq
i Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1)

-- Computes the product of two degree-one polynomials with respect to a quadratic modulus
baseCaseMultiply :: Zq -> Zq -> Zq -> Zq -> Zq -> (Zq, Zq)
baseCaseMultiply :: Zq -> Zq -> Zq -> Zq -> Zq -> (Zq, Zq)
baseCaseMultiply (Zq Word16
a0) (Zq Word16
a1) (Zq Word16
b0) (Zq Word16
b1) (Zq Word16
g) = (Word16 -> Zq
Zq Word16
c0, Word16 -> Zq
Zq Word16
c1)
  where
    a
x mul :: a -> a -> a
`mul` a
y = a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
x a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
y
    b1g :: Word16
b1g = Word32 -> Word16
reduce (Word16
b1 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
g)
    !c0 :: Word16
c0 = Word32 -> Word16
reduce (Word16
a0 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b0 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word16
a1 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b1g)
    !c1 :: Word16
c1 = Word32 -> Word16
reduce (Word16
a0 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b1 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word16
a1 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b0)

multiplyNTTsFold :: Foldable t => Tq Sec -> t (Tq Pub, Tq Sec) -> Tq Sec
multiplyNTTsFold :: forall (t :: * -> *).
Foldable t =>
Tq 'Sec -> t (Tq 'Pub, Tq 'Sec) -> Tq 'Sec
multiplyNTTsFold (Tq BlockN 'Sec N Zq
c) =
    BlockN 'Sec N Zq -> Tq 'Sec
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN 'Sec N Zq -> Tq 'Sec)
-> (t (Tq 'Pub, Tq 'Sec) -> BlockN 'Sec N Zq)
-> t (Tq 'Pub, Tq 'Sec)
-> Tq 'Sec
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockN 'Sec N Zq
-> (forall s.
    (Tq 'Pub, Tq 'Sec) -> MutableBlockN 'Sec N Zq s -> ST s ())
-> t (Tq 'Pub, Tq 'Sec)
-> BlockN 'Sec N Zq
forall (marking :: SecurityMarking) (n :: Nat) a (t :: * -> *) b.
(Classified marking, KnownNat n, PrimType a, Foldable t) =>
BlockN marking n a
-> (forall s. b -> MutableBlockN marking n a s -> ST s ())
-> t b
-> BlockN marking n a
BlockN.runFold BlockN 'Sec N Zq
c ((Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ())
-> (Tq 'Pub, Tq 'Sec) -> MutableBlockN 'Sec N Zq s -> ST s ()
forall a b c. (a -> b -> c) -> (a, b) -> c
uncurry Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
forall s.
Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
multiplyNTTsAdd)
{-# INLINE multiplyNTTsFold #-}

-- Multiply then add a third term
multiplyNTTsAdd :: Tq Pub -> Tq Sec -> MutableBlockN Sec N Zq s -> ST s ()
multiplyNTTsAdd :: forall s.
Tq 'Pub -> Tq 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
multiplyNTTsAdd (Tq !BlockN 'Pub N Zq
f) (Tq !BlockN 'Sec N Zq
g) MutableBlockN 'Sec N Zq s
bb = MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
forall s. MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
loop MutableBlockN 'Sec N Zq s
bb Offset Zq
0
  where
    loop :: MutableBlockN Sec N Zq s -> Offset Zq -> ST s ()
    loop :: forall s. MutableBlockN 'Sec N Zq s -> Offset Zq -> ST s ()
loop !MutableBlockN 'Sec N Zq s
b Offset Zq
i = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
i Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Offset Zq
128) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        let ii :: Offset Zq
ii = Int -> Offset Zq -> Offset Zq
forall ty ty2. Int -> Offset ty -> Offset ty2
offsetShiftL Int
1 Offset Zq
i
        c0 <- MutableBlockN 'Sec N Zq (PrimState (ST s)) -> Offset Zq -> ST s Zq
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim) -> Offset a -> prim a
BlockN.read MutableBlockN 'Sec N Zq s
MutableBlockN 'Sec N Zq (PrimState (ST s))
b Offset Zq
ii
        c1 <- BlockN.read b (ii + 1)
        let a0 = BlockN 'Pub N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub N Zq
f Offset Zq
ii
            a1 = BlockN 'Pub N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Pub N Zq
f (Offset Zq
ii Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1)
            b0 = BlockN 'Sec N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Sec N Zq
g Offset Zq
ii
            b1 = BlockN 'Sec N Zq -> Offset Zq -> Zq
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN 'Sec N Zq
g (Offset Zq
ii Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1)
            (d0, d1) = baseCaseMultiplyAdd a0 a1 b0 b1 c0 c1 (BlockN.index gamma i)
        BlockN.write b ii d0
        BlockN.write b (ii + 1) d1
        loop b (i + 1)

-- baseCaseMultiply then add a third term
baseCaseMultiplyAdd :: Zq -> Zq -> Zq -> Zq -> Zq -> Zq -> Zq -> (Zq, Zq)
baseCaseMultiplyAdd :: Zq -> Zq -> Zq -> Zq -> Zq -> Zq -> Zq -> (Zq, Zq)
baseCaseMultiplyAdd (Zq Word16
a0) (Zq Word16
a1) (Zq Word16
b0) (Zq Word16
b1) (Zq Word16
c0) (Zq Word16
c1) (Zq Word16
g) = (Word16 -> Zq
Zq Word16
d0, Word16 -> Zq
Zq Word16
d1)
  where
    a
x mul :: a -> a -> a
`mul` a
y = a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
x a -> a -> a
forall a. Num a => a -> a -> a
* a -> a
forall a b. (Integral a, Num b) => a -> b
fromIntegral a
y
    b1g :: Word16
b1g = Word32 -> Word16
reduce (Word16
b1 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
g)
    !d0 :: Word16
d0 = Word32 -> Word16
reduce (Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
c0 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word16
a0 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b0 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word16
a1 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b1g)
    !d1 :: Word16
d1 = Word32 -> Word16
reduce (Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
c1 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word16
a0 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b1 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ Word16
a1 Word16 -> Word16 -> Word32
forall {a} {a} {a}. (Integral a, Integral a, Num a) => a -> a -> a
`mul` Word16
b0)

-- Values of 17 ^ BitRev7(𝑖) mod 𝑞 for 𝑖 ∈ {0, … , 127}
zetaPowBitRev :: BlockN Pub 128 Zq
zetaPowBitRev :: BlockN 'Pub 128 Zq
zetaPowBitRev = Proxy 'Pub
-> (forall s. MutableBlockN 'Pub 128 Zq s -> ST s ())
-> BlockN 'Pub 128 Zq
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runNew (Proxy 'Pub
forall {k} (t :: k). Proxy t
Proxy :: Proxy Pub) ((forall s. MutableBlockN 'Pub 128 Zq s -> ST s ())
 -> BlockN 'Pub 128 Zq)
-> (forall s. MutableBlockN 'Pub 128 Zq s -> ST s ())
-> BlockN 'Pub 128 Zq
forall a b. (a -> b) -> a -> b
$ \MutableBlockN 'Pub 128 Zq s
out ->
    (Zq -> Offset Zq -> ST s Zq) -> Zq -> [Offset Zq] -> ST s ()
forall (t :: * -> *) (m :: * -> *) b a.
(Foldable t, Monad m) =>
(b -> a -> m b) -> b -> t a -> m ()
foldM_ (MutableBlockN 'Pub 128 Zq (PrimState (ST s))
-> Zq -> Offset Zq -> ST s Zq
forall {m :: * -> *} {marking :: SecurityMarking} {n :: Nat}.
PrimMonad m =>
MutableBlockN marking n Zq (PrimState m) -> Zq -> Offset Zq -> m Zq
loop MutableBlockN 'Pub 128 Zq s
MutableBlockN 'Pub 128 Zq (PrimState (ST s))
out) Zq
forall a. Mul a => a
one [Offset Zq]
offsets
  where
    offsets :: [Offset Zq]
offsets = (Word8 -> Offset Zq) -> [Word8] -> [Offset Zq]
forall a b. (a -> b) -> [a] -> [b]
Prelude.map (Word8 -> Offset Zq
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Offset Zq) -> (Word8 -> Word8) -> Word8 -> Offset Zq
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Word8 -> Word8
bitRev7) [Word8
0 .. Word8
127]
    loop :: MutableBlockN marking n Zq (PrimState m) -> Zq -> Offset Zq -> m Zq
loop MutableBlockN marking n Zq (PrimState m)
b Zq
acc Offset Zq
i = MutableBlockN marking n Zq (PrimState m) -> Offset Zq -> Zq -> m ()
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
BlockN.write MutableBlockN marking n Zq (PrimState m)
b Offset Zq
i Zq
acc m () -> m Zq -> m Zq
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Zq -> m Zq
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Word16 -> Zq
Zq Word16
17 Zq -> Zq -> Zq
forall a. Mul a => a -> a -> a
.* Zq
acc)

-- Values of 17 ^ 2.BitRev7(𝑖)+1 mod 𝑞 for 𝑖 ∈ {0, … , 127}
gamma :: BlockN Pub 128 Zq
gamma :: BlockN 'Pub 128 Zq
gamma = (Zq -> Zq) -> BlockN 'Pub 128 Zq -> BlockN 'Pub 128 Zq
forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, EqPrimSize a b) =>
(a -> b) -> BlockN marking n a -> BlockN marking n b
BlockN.mapEqPrimSize (\Zq
z -> Zq
z Zq -> Zq -> Zq
forall a. Mul a => a -> a -> a
.* Zq
z Zq -> Zq -> Zq
forall a. Mul a => a -> a -> a
.* Word16 -> Zq
Zq Word16
17) BlockN 'Pub 128 Zq
zetaPowBitRev

-- Compress a field element with 𝑑 < 12
compress :: Int -> Zq -> Word16
compress :: Int -> Zq -> Word16
compress Int
d (Zq Word16
x) = WordM -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (WordM -> Word16) -> WordM -> Word16
forall a b. (a -> b) -> a -> b
$
    ((Word16 -> WordM
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
x WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
d WordM -> WordM -> WordM
forall a. Num a => a -> a -> a
+ WordM
qHalf) WordM -> WordM -> WordM
forall a. Num a => a -> a -> a
* WordM
factor) WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
34
  where
    qHalf :: WordM
qHalf = (WordM
q64 WordM -> WordM -> WordM
forall a. Num a => a -> a -> a
+ WordM
1) WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1
    factor :: WordM
factor = (WordM
1 WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
34) WordM -> WordM -> WordM
forall a. Integral a => a -> a -> a
`div` WordM
q64
{-# INLINE compress #-}

-- Decompress a field element with 𝑑 < 12
decompress :: Int -> Word16 -> Zq
decompress :: Int -> Word16 -> Zq
decompress Int
d Word16
y = Word16 -> Zq
Zq (Word16 -> Zq) -> Word16 -> Zq
forall a b. (a -> b) -> a -> b
$ Word32 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word32
x2d Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
d)
  where x2d :: Word32
x2d = Word16 -> Word32
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
y Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
* Word32
q32 Word32 -> Word32 -> Word32
forall a. Num a => a -> a -> a
+ (Word32
1 Word32 -> Int -> Word32
forall a. Bits a => a -> Int -> a
`unsafeShiftL` (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
{-# INLINE decompress #-}

-- Compress a polynomial with 𝑑 < 12
rcompress :: Classified marking => Int -> Rq marking -> BlockN marking N Word16
rcompress :: forall (marking :: SecurityMarking).
Classified marking =>
Int -> Rq marking -> BlockN marking N Word16
rcompress !Int
d (Rq BlockN marking N Zq
a) = (Zq -> Word16) -> BlockN marking N Zq -> BlockN marking N Word16
forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, EqPrimSize a b) =>
(a -> b) -> BlockN marking n a -> BlockN marking n b
BlockN.mapEqPrimSize (Int -> Zq -> Word16
compress Int
d) BlockN marking N Zq
a
{-# INLINE rcompress #-}

-- Decompress a polynomial with 𝑑 < 12
rdecompress :: Classified marking => Int -> BlockN marking N Word16 -> Rq marking
rdecompress :: forall (marking :: SecurityMarking).
Classified marking =>
Int -> BlockN marking N Word16 -> Rq marking
rdecompress !Int
d = BlockN marking N Zq -> Rq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN marking N Zq -> Rq marking)
-> (BlockN marking N Word16 -> BlockN marking N Zq)
-> BlockN marking N Word16
-> Rq marking
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word16 -> Zq) -> BlockN marking N Word16 -> BlockN marking N Zq
forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, EqPrimSize a b) =>
(a -> b) -> BlockN marking n a -> BlockN marking n b
BlockN.mapEqPrimSize (Int -> Word16 -> Zq
decompress Int
d)
{-# INLINE rdecompress #-}

-- Generates a pseudorandom element of T𝑞 from a seed and two indices
sampleNTT :: SecureBytes Pub -> Word8 -> Word8 -> Tq Pub
sampleNTT :: SecureBytes 'Pub -> Word8 -> Word8 -> Tq 'Pub
sampleNTT SecureBytes 'Pub
seed !Word8
x !Word8
y = BlockN 'Pub N Zq -> Tq 'Pub
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN 'Pub N Zq -> Tq 'Pub) -> BlockN 'Pub N Zq -> Tq 'Pub
forall a b. (a -> b) -> a -> b
$
    Proxy 'Pub
-> (forall s. MutableBlockN 'Pub N Zq s -> ST s ())
-> BlockN 'Pub N Zq
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runNew (Proxy 'Pub
forall {k} (t :: k). Proxy t
Proxy :: Proxy Pub) ((forall s. MutableBlockN 'Pub N Zq s -> ST s ())
 -> BlockN 'Pub N Zq)
-> (forall s. MutableBlockN 'Pub N Zq s -> ST s ())
-> BlockN 'Pub N Zq
forall a b. (a -> b) -> a -> b
$ \MutableBlockN 'Pub N Zq s
b -> MutableBlockN 'Pub N Zq (PrimState (ST s))
-> Int -> Offset Word8 -> Offset Zq -> ST s ()
forall {m :: * -> *} {marking :: SecurityMarking} {n :: Nat}.
PrimMonad m =>
MutableBlockN marking n Zq (PrimState m)
-> Int -> Offset Word8 -> Offset Zq -> m ()
runXof MutableBlockN 'Pub N Zq s
MutableBlockN 'Pub N Zq (PrimState (ST s))
b (Int
280 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
3) Offset Word8
0 Offset Zq
0
  where
    runXof :: MutableBlockN marking n Zq (PrimState m)
-> Int -> Offset Word8 -> Offset Zq -> m ()
runXof !MutableBlockN marking n Zq (PrimState m)
b !Int
xofLen !Offset Word8
pos !Offset Zq
j = case Nat -> SomeNat
someNatVal (Int -> Nat
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Int
8 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
xofLen)) of
        SomeNat Proxy n
proxy -> do
            let bytes :: Block Word8
bytes = BlockDigest (SHAKE128 n) -> Block Word8
forall a. BlockDigest a -> Block Word8
Crypto.unBlockDigest (Proxy n -> BlockDigest (SHAKE128 n)
forall (bitlen :: Nat) (proxy :: Nat -> *).
KnownNat bitlen =>
proxy bitlen -> BlockDigest (SHAKE128 bitlen)
doHash Proxy n
proxy)
            MutableBlockN marking n Zq (PrimState m)
-> Int -> Block Word8 -> Offset Word8 -> Offset Zq -> m ()
loop MutableBlockN marking n Zq (PrimState m)
b Int
xofLen Block Word8
bytes Offset Word8
pos Offset Zq
j

    loop :: MutableBlockN marking n Zq (PrimState m)
-> Int -> Block Word8 -> Offset Word8 -> Offset Zq -> m ()
loop !MutableBlockN marking n Zq (PrimState m)
b !Int
xofLen !Block Word8
bytes !Offset Word8
pos Offset Zq
j
        | Offset Zq
j Offset Zq -> Offset Zq -> Bool
forall a. Eq a => a -> a -> Bool
== Offset Zq
256 = () -> m ()
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Offset Word8
pos Offset Word8 -> Offset Word8 -> Bool
forall a. Ord a => a -> a -> Bool
>= Int -> Offset Word8
forall ty. Int -> Offset ty
Offset Int
xofLen = MutableBlockN marking n Zq (PrimState m)
-> Int -> Offset Word8 -> Offset Zq -> m ()
runXof MutableBlockN marking n Zq (PrimState m)
b (Int
xofLen Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
56 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
3) Offset Word8
pos Offset Zq
j
        | Bool
otherwise = do
            let c0 :: Word16
c0 = Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Word16) -> Word8 -> Word16
forall a b. (a -> b) -> a -> b
$ Block Word8 -> Offset Word8 -> Word8
forall ty. PrimType ty => Block ty -> Offset ty -> ty
blockIndex Block Word8
bytes Offset Word8
pos
                c1 :: Word16
c1 = Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Word16) -> Word8 -> Word16
forall a b. (a -> b) -> a -> b
$ Block Word8 -> Offset Word8 -> Word8
forall ty. PrimType ty => Block ty -> Offset ty -> ty
blockIndex Block Word8
bytes (Offset Word8
pos Offset Word8 -> Offset Word8 -> Offset Word8
forall a. Num a => a -> a -> a
+ Offset Word8
1)
                c2 :: Word16
c2 = Word8 -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word8 -> Word16) -> Word8 -> Word16
forall a b. (a -> b) -> a -> b
$ Block Word8 -> Offset Word8 -> Word8
forall ty. PrimType ty => Block ty -> Offset ty -> ty
blockIndex Block Word8
bytes (Offset Word8
pos Offset Word8 -> Offset Word8 -> Offset Word8
forall a. Num a => a -> a -> a
+ Offset Word8
2)
                d1 :: Word16
d1 = Word16
c0 Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ (Word16
c1 Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
0xF) Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
8
                d2 :: Word16
d2 = (Word16
c1 Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
4) Word16 -> Word16 -> Word16
forall a. Num a => a -> a -> a
+ (Word16
c2 Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
4)
            j2 <- MutableBlockN marking n Zq (PrimState m)
-> Offset Zq -> Word16 -> m (Offset Zq)
forall {m :: * -> *} {marking :: SecurityMarking} {n :: Nat}.
PrimMonad m =>
MutableBlockN marking n Zq (PrimState m)
-> Offset Zq -> Word16 -> m (Offset Zq)
poke MutableBlockN marking n Zq (PrimState m)
b Offset Zq
j Word16
d1
            when (j2 < 256) $ poke b j2 d2 >>= loop b xofLen bytes (pos + 3)

    poke :: MutableBlockN marking n Zq (PrimState m)
-> Offset Zq -> Word16 -> m (Offset Zq)
poke MutableBlockN marking n Zq (PrimState m)
b Offset Zq
j Word16
d
        | Word16
d Word16 -> Word16 -> Bool
forall a. Ord a => a -> a -> Bool
< Word16
q16 = MutableBlockN marking n Zq (PrimState m) -> Offset Zq -> Zq -> m ()
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
BlockN.write MutableBlockN marking n Zq (PrimState m)
b Offset Zq
j (Word16 -> Zq
Zq Word16
d) m () -> m (Offset Zq) -> m (Offset Zq)
forall a b. m a -> m b -> m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Offset Zq -> m (Offset Zq)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return (Offset Zq
j Offset Zq -> Offset Zq -> Offset Zq
forall a. Num a => a -> a -> a
+ Offset Zq
1)
        | Bool
otherwise = Offset Zq -> m (Offset Zq)
forall a. a -> m a
forall (m :: * -> *) a. Monad m => a -> m a
return Offset Zq
j

    doHash :: KnownNat bitlen => proxy bitlen -> BlockDigest (SHAKE128 bitlen)
    doHash :: forall (bitlen :: Nat) (proxy :: Nat -> *).
KnownNat bitlen =>
proxy bitlen -> BlockDigest (SHAKE128 bitlen)
doHash proxy bitlen
_ = Bytes -> BlockDigest (SHAKE128 bitlen)
forall a. HashAlgorithm a => Bytes -> BlockDigest a
Crypto.hashToBlock Bytes
SecureBytes 'Pub
input

    input :: SecureBytes Pub
    !input :: SecureBytes 'Pub
input = Int -> (Ptr (ZonkAny 0) -> IO ()) -> SecureBytes 'Pub
forall a p. ByteArray a => Int -> (Ptr p -> IO ()) -> a
B.unsafeCreate (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
2) ((Ptr (ZonkAny 0) -> IO ()) -> SecureBytes 'Pub)
-> (Ptr (ZonkAny 0) -> IO ()) -> SecureBytes 'Pub
forall a b. (a -> b) -> a -> b
$ \Ptr (ZonkAny 0)
d -> do
        Bytes -> Ptr (ZonkAny 0) -> IO ()
forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO ()
forall p. Bytes -> Ptr p -> IO ()
B.copyByteArrayToPtr Bytes
SecureBytes 'Pub
seed Ptr (ZonkAny 0)
d
        Ptr (ZonkAny 0) -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr (ZonkAny 0)
d Int
len Word8
x
        Ptr (ZonkAny 0) -> Int -> Word8 -> IO ()
forall b. Ptr b -> Int -> Word8 -> IO ()
forall a b. Storable a => Ptr b -> Int -> a -> IO ()
pokeByteOff Ptr (ZonkAny 0)
d (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) Word8
y
    len :: Int
len = Bytes -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length Bytes
SecureBytes 'Pub
seed

peekWord :: Ptr WordLE -> ST s WordM
peekWord :: forall s. Ptr WordLE -> ST s WordM
peekWord Ptr WordLE
p = WordLE -> WordM
fromLE (WordLE -> WordM) -> ST s WordLE -> ST s WordM
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr WordLE -> ST s WordLE
forall a s. Storable a => Ptr a -> ST s a
ST.peek Ptr WordLE
p

peekWordPos :: Ptr WordLE -> BitPos -> ST s WordM
peekWordPos :: forall s. Ptr WordLE -> BitPos -> ST s WordM
peekWordPos Ptr WordLE
a BitPos
bp = WordLE -> WordM
fromLE (WordLE -> WordM) -> ST s WordLE -> ST s WordM
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr WordLE -> Int -> ST s WordLE
forall a s. Storable a => Ptr a -> Int -> ST s a
ST.peekElemOff Ptr WordLE
a (BitPos -> Int
wordOff BitPos
bp)

pokeWordPos :: Ptr WordLE -> BitPos -> WordM -> ST s ()
pokeWordPos :: forall s. Ptr WordLE -> BitPos -> WordM -> ST s ()
pokeWordPos Ptr WordLE
a BitPos
bp = Ptr WordLE -> Int -> WordLE -> ST s ()
forall a s. Storable a => Ptr a -> Int -> a -> ST s ()
ST.pokeElemOff Ptr WordLE
a (BitPos -> Int
wordOff BitPos
bp) (WordLE -> ST s ()) -> (WordM -> WordLE) -> WordM -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. WordM -> WordLE
toLE

newtype BitPos = BitPos Int

zeroPos :: BitPos
zeroPos :: BitPos
zeroPos = Int -> BitPos
BitPos Int
0

wordOff :: BitPos -> Int
wordOff :: BitPos -> Int
wordOff (BitPos Int
p) = Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
p Int
wordBits

bitPos :: BitPos -> Int
bitPos :: BitPos -> Int
bitPos (BitPos Int
p) = Int
p Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (Int
wordBits Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1)

availPos :: Int -> BitPos -> Int
availPos :: Int -> BitPos -> Int
availPos Int
requested (BitPos Int
p) = Int -> Int -> Int
forall a. Ord a => a -> a -> a
min Int
available Int
requested
  where available :: Int
available = Int
wordBits Int -> Int -> Int
forall a. Num a => a -> a -> a
- (Int
p Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. (Int
wordBits Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))

nextPos :: Int -> BitPos -> (Int, BitPos)
nextPos :: Int -> BitPos -> (Int, BitPos)
nextPos Int
requested (BitPos Int
p) = (Int
howMany, Int -> BitPos
BitPos (Int -> BitPos) -> Int -> BitPos
forall a b. (a -> b) -> a -> b
$ Int
p Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
howMany)
  where howMany :: Int
howMany = Int -> BitPos -> Int
availPos Int
requested (Int -> BitPos
BitPos Int
p)

getMask :: Int -> WordM
getMask :: Int -> WordM
getMask Int
howMany
    | Int
howMany Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
wordBits = WordM
forall a. Bounded a => a
maxBound
    | Bool
otherwise = (WordM
1 WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
howMany) WordM -> WordM -> WordM
forall a. Num a => a -> a -> a
- WordM
1
    -- branch useful only when processing one byte at a time due to
    -- architecture not supporting unaligned memory access

-- Takes a seed as input and outputs a pseudorandom sample from the
-- distribution D_eta
samplePolyCBD :: Word -> SecureBytes Sec -> Rq Sec
samplePolyCBD :: Word -> SecureBytes 'Sec -> Rq 'Sec
samplePolyCBD Word
eta SecureBytes 'Sec
input = BlockN 'Sec N Zq -> Rq 'Sec
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Rq marking
Rq (BlockN 'Sec N Zq -> Rq 'Sec) -> BlockN 'Sec N Zq -> Rq 'Sec
forall a b. (a -> b) -> a -> b
$
    Proxy 'Sec
-> (forall s. MutableBlockN 'Sec N Zq s -> ST s ())
-> BlockN 'Sec N Zq
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runNew (Proxy 'Sec
forall {k} (t :: k). Proxy t
Proxy :: Proxy Sec) ((forall s. MutableBlockN 'Sec N Zq s -> ST s ())
 -> BlockN 'Sec N Zq)
-> (forall s. MutableBlockN 'Sec N Zq s -> ST s ())
-> BlockN 'Sec N Zq
forall a b. (a -> b) -> a -> b
$ Word -> SecureBytes 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
forall s.
Word -> SecureBytes 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
mutSamplePolyCBD Word
eta SecureBytes 'Sec
input
{-# INLINE samplePolyCBD #-}

mutSamplePolyCBD :: Word -> SecureBytes Sec -> MutableBlockN Sec N Zq s -> ST s ()
mutSamplePolyCBD :: forall s.
Word -> SecureBytes 'Sec -> MutableBlockN 'Sec N Zq s -> ST s ()
mutSamplePolyCBD !Word
eta !SecureBytes 'Sec
input MutableBlockN 'Sec N Zq s
ff =
    ScrubbedBytes -> (Ptr WordLE -> ST s ()) -> ST s ()
forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
ST.withByteArray ScrubbedBytes
SecureBytes 'Sec
input ((Ptr WordLE -> ST s ()) -> ST s ())
-> (Ptr WordLE -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Ptr WordLE
p -> Ptr WordLE
-> MutableBlockN 'Sec N Zq s -> Offset Zq -> BitPos -> ST s ()
forall s.
Ptr WordLE
-> MutableBlockN 'Sec N Zq s -> Offset Zq -> BitPos -> ST s ()
loop Ptr WordLE
p MutableBlockN 'Sec N Zq s
ff Offset Zq
0 BitPos
zeroPos
  where
    loop :: Ptr WordLE -> MutableBlockN Sec N Zq s -> Offset Zq -> BitPos -> ST s ()
    loop :: forall s.
Ptr WordLE
-> MutableBlockN 'Sec N Zq s -> Offset Zq -> BitPos -> ST s ()
loop !Ptr WordLE
p !MutableBlockN 'Sec N Zq s
f !Offset Zq
i !BitPos
bp = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Zq
i Offset Zq -> Offset Zq -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Offset Zq
forall ty. Int -> Offset ty
Offset Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        (xs, bp') <- Ptr WordLE -> BitPos -> Word16 -> Int -> ST s (Word16, BitPos)
forall s.
Ptr WordLE -> BitPos -> Word16 -> Int -> ST s (Word16, BitPos)
getBits Ptr WordLE
p BitPos
bp Word16
0 (Word -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word
eta)
        (ys, bp'') <- getBits p bp' 0 (fromIntegral eta)
        BlockN.write f i (Zq xs .- Zq ys)
        loop p f (i + 1) bp''

    getBits :: Ptr WordLE -> BitPos -> Word16 -> Int -> ST s (Word16, BitPos)
    getBits :: forall s.
Ptr WordLE -> BitPos -> Word16 -> Int -> ST s (Word16, BitPos)
getBits !Ptr WordLE
p !BitPos
bp !Word16
acc !Int
j
        | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0    = (Word16, BitPos) -> ST s (Word16, BitPos)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return (Word16
acc, BitPos
bp)
        | Bool
otherwise = do
            x <- (WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` BitPos -> Int
bitPos BitPos
bp) (WordM -> WordM) -> ST s WordM -> ST s WordM
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr WordLE -> BitPos -> ST s WordM
forall s. Ptr WordLE -> BitPos -> ST s WordM
peekWordPos Ptr WordLE
p BitPos
bp
            let (howMany, bp') = nextPos j bp
                bits = WordM
x WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. Int -> WordM
getMask Int
howMany
            getBits p bp' (acc + fromIntegral (popCount bits)) (j - howMany)
{-# NOINLINE mutSamplePolyCBD #-}

-- Encodes an array of 𝑑-bit integers into a byte array for 1 ≤ 𝑑 ≤ 12
byteEncode :: Int -> BlockN marking N Word16 -> Builder marking
byteEncode :: forall (marking :: SecurityMarking).
Int -> BlockN marking N Word16 -> Builder marking
byteEncode Int
d BlockN marking N Word16
f = Int -> (forall s. Ptr WordLE -> ST s ()) -> Builder marking
forall a (marking :: SecurityMarking).
Int -> (forall s. Ptr a -> ST s ()) -> Builder marking
Builder.create (Int
32 Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
d) (Int -> BlockN marking N Word16 -> Ptr WordLE -> ST s ()
forall (marking :: SecurityMarking) s.
Int -> BlockN marking N Word16 -> Ptr WordLE -> ST s ()
runByteEncode Int
d BlockN marking N Word16
f)
{-# INLINE byteEncode #-}

runByteEncode :: Int -> BlockN marking N Word16 -> Ptr WordLE -> ST s ()
runByteEncode :: forall (marking :: SecurityMarking) s.
Int -> BlockN marking N Word16 -> Ptr WordLE -> ST s ()
runByteEncode !Int
d !BlockN marking N Word16
f Ptr WordLE
dst = Ptr WordLE -> Int -> BitPos -> WordM -> Word16 -> Int -> ST s ()
loop Ptr WordLE
dst Int
0 BitPos
zeroPos WordM
0 (Int -> Word16
get Int
0) Int
d
  where
    get :: Int -> Word16
get = BlockN marking N Word16 -> Offset Word16 -> Word16
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN marking N Word16
f (Offset Word16 -> Word16)
-> (Int -> Offset Word16) -> Int -> Word16
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> Offset Word16
forall ty. Int -> Offset ty
Offset
    {-# INLINE get #-}

    loop :: Ptr WordLE -> Int -> BitPos -> WordM -> Word16 -> Int -> ST s ()
loop !Ptr WordLE
b !Int
pos !BitPos
bp !WordM
o !Word16
a Int
j
        | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0, Int
pos' Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0 = Ptr WordLE -> Int -> BitPos -> WordM -> Word16 -> Int -> ST s ()
loop Ptr WordLE
b Int
pos' BitPos
bp WordM
o (Int -> Word16
get Int
pos') Int
d
        | BitPos -> Int
bitPos BitPos
bp Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
howMany Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
wordBits = Ptr WordLE -> Int -> BitPos -> WordM -> Word16 -> Int -> ST s ()
loop Ptr WordLE
b Int
pos BitPos
bp' WordM
o' Word16
a' Int
j'
        | Bool
otherwise = Ptr WordLE -> BitPos -> WordM -> ST s ()
forall s. Ptr WordLE -> BitPos -> WordM -> ST s ()
pokeWordPos Ptr WordLE
b BitPos
bp WordM
o' ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr WordLE -> Int -> BitPos -> WordM -> Word16 -> Int -> ST s ()
loop Ptr WordLE
b Int
pos BitPos
bp' WordM
0 Word16
a' Int
j'
      where
        pos' :: Int
pos' = Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
        (Int
howMany, BitPos
bp') = Int -> BitPos -> (Int, BitPos)
nextPos Int
j BitPos
bp
        x :: WordM
x = Word16 -> WordM
forall a b. (Integral a, Num b) => a -> b
fromIntegral Word16
a WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. Int -> WordM
getMask Int
howMany
        o' :: WordM
o' = WordM
o WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.|. (WordM
x WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` BitPos -> Int
bitPos BitPos
bp)
        a' :: Word16
a' = Word16
a Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
howMany
        j' :: Int
j' = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
howMany

-- Optimization of byteEncode when 𝑑=1
byteEncode1 :: BlockN Sec N Word16 -> Builder Sec
byteEncode1 :: BlockN 'Sec N Word16 -> Builder 'Sec
byteEncode1 !BlockN 'Sec N Word16
f = Int -> (forall s. Ptr WordLE -> ST s ()) -> Builder 'Sec
forall a (marking :: SecurityMarking).
Int -> (forall s. Ptr a -> ST s ()) -> Builder marking
Builder.create Int
32 (BlockN 'Sec N Word16 -> Ptr WordLE -> ST s ()
forall (marking :: SecurityMarking) s.
BlockN marking N Word16 -> Ptr WordLE -> ST s ()
runByteEncode1 BlockN 'Sec N Word16
f)
{-# INLINE byteEncode1 #-}

runByteEncode1 :: BlockN marking N Word16 -> Ptr WordLE -> ST s ()
runByteEncode1 :: forall (marking :: SecurityMarking) s.
BlockN marking N Word16 -> Ptr WordLE -> ST s ()
runByteEncode1 !BlockN marking N Word16
f Ptr WordLE
dst = Ptr WordLE -> WordM -> Int -> ST s ()
forall s. Ptr WordLE -> WordM -> Int -> ST s ()
loop Ptr WordLE
dst WordM
0 Int
0
  where
    loop :: Ptr WordLE -> WordM -> Int -> ST s ()
    loop :: forall s. Ptr WordLE -> WordM -> Int -> ST s ()
loop !Ptr WordLE
b !WordM
o Int
pos
        | Int
pos Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
n = () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
        | BitPos -> Int
bitPos BitPos
bp Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1 Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
wordBits = Ptr WordLE -> WordM -> Int -> ST s ()
forall s. Ptr WordLE -> WordM -> Int -> ST s ()
loop Ptr WordLE
b WordM
o' (Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
        | Bool
otherwise = Ptr WordLE -> BitPos -> WordM -> ST s ()
forall s. Ptr WordLE -> BitPos -> WordM -> ST s ()
pokeWordPos Ptr WordLE
b BitPos
bp WordM
o' ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr WordLE -> WordM -> Int -> ST s ()
forall s. Ptr WordLE -> WordM -> Int -> ST s ()
loop Ptr WordLE
b WordM
0 (Int
pos Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
      where
        bp :: BitPos
bp = Int -> BitPos
BitPos Int
pos
        x :: WordM
x = Word16 -> WordM
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Word16
a Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.&. Word16
1)
        o' :: WordM
o' = WordM
o WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.|. (WordM
x WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftL` BitPos -> Int
bitPos BitPos
bp)
        a :: Word16
a = BlockN marking N Word16 -> Offset Word16 -> Word16
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
BlockN.index BlockN marking N Word16
f (Int -> Offset Word16
forall ty. Int -> Offset ty
Offset Int
pos)

-- byteEncode with 𝑑=12 after conversion from the field
byteEncode12 :: Tq marking -> Builder marking
byteEncode12 :: forall (marking :: SecurityMarking). Tq marking -> Builder marking
byteEncode12 = Int -> BlockN marking N Word16 -> Builder marking
forall (marking :: SecurityMarking).
Int -> BlockN marking N Word16 -> Builder marking
byteEncode Int
12 (BlockN marking N Word16 -> Builder marking)
-> (Tq marking -> BlockN marking N Word16)
-> Tq marking
-> Builder marking
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Tq marking -> BlockN marking N Word16
forall (marking :: SecurityMarking).
Tq marking -> BlockN marking N Word16
fromField
  where
    fromField :: Tq marking -> BlockN marking N Word16
    fromField :: forall (marking :: SecurityMarking).
Tq marking -> BlockN marking N Word16
fromField (Tq BlockN marking N Zq
f) = BlockN marking N Zq -> BlockN marking N Word16
forall a b. a -> b
unsafeCoerce BlockN marking N Zq
f
{-# INLINE byteEncode12 #-}

-- Decodes a byte array into an array of 𝑑-bit integers for 1 ≤ 𝑑 ≤ 12
byteDecode :: forall marking ba. (Classified marking, ByteArrayAccess ba) => Int -> ba -> BlockN marking N Word16
byteDecode :: forall (marking :: SecurityMarking) ba.
(Classified marking, ByteArrayAccess ba) =>
Int -> ba -> BlockN marking N Word16
byteDecode Int
d ba
b = Proxy marking
-> (forall s. MutableBlockN marking N Word16 s -> ST s ())
-> BlockN marking N Word16
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runNew (Proxy marking
forall {k} (t :: k). Proxy t
Proxy :: Proxy marking) ((forall s. MutableBlockN marking N Word16 s -> ST s ())
 -> BlockN marking N Word16)
-> (forall s. MutableBlockN marking N Word16 s -> ST s ())
-> BlockN marking N Word16
forall a b. (a -> b) -> a -> b
$ Int -> ba -> MutableBlockN marking N Word16 s -> ST s ()
forall ba (marking :: SecurityMarking) s.
ByteArrayAccess ba =>
Int -> ba -> MutableBlockN marking N Word16 s -> ST s ()
mutByteDecode Int
d ba
b
{-# INLINE byteDecode #-}

mutByteDecode :: ByteArrayAccess ba => Int -> ba -> MutableBlockN marking N Word16 s -> ST s ()
mutByteDecode :: forall ba (marking :: SecurityMarking) s.
ByteArrayAccess ba =>
Int -> ba -> MutableBlockN marking N Word16 s -> ST s ()
mutByteDecode !Int
d !ba
b !MutableBlockN marking N Word16 s
f = ba -> (Ptr WordLE -> ST s ()) -> ST s ()
forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
ST.withByteArray ba
b ((Ptr WordLE -> ST s ()) -> ST s ())
-> (Ptr WordLE -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Ptr WordLE
p -> Ptr WordLE -> BitPos -> Offset Word16 -> ST s ()
outer Ptr WordLE
p BitPos
zeroPos Offset Word16
0
  where
    outer :: Ptr WordLE -> BitPos -> Offset Word16 -> ST s ()
outer !Ptr WordLE
p !BitPos
bp Offset Word16
i = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Offset Word16
i Offset Word16 -> Offset Word16 -> Bool
forall a. Ord a => a -> a -> Bool
< Int -> Offset Word16
forall ty. Int -> Offset ty
Offset Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ Ptr WordLE -> Offset Word16 -> BitPos -> Word16 -> Int -> ST s ()
inner Ptr WordLE
p Offset Word16
i BitPos
bp Word16
0 Int
0

    inner :: Ptr WordLE -> Offset Word16 -> BitPos -> Word16 -> Int -> ST s ()
inner !Ptr WordLE
p !Offset Word16
i !BitPos
bp !Word16
v Int
j
        | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
d = MutableBlockN marking N Word16 (PrimState (ST s))
-> Offset Word16 -> Word16 -> ST s ()
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
BlockN.write MutableBlockN marking N Word16 s
MutableBlockN marking N Word16 (PrimState (ST s))
f Offset Word16
i Word16
v ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr WordLE -> BitPos -> Offset Word16 -> ST s ()
outer Ptr WordLE
p BitPos
bp (Offset Word16
i Offset Word16 -> Offset Word16 -> Offset Word16
forall a. Num a => a -> a -> a
+ Offset Word16
1)
        | Bool
otherwise = do
            let (Int
howMany, BitPos
bp') = Int -> BitPos -> (Int, BitPos)
nextPos (Int
d Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
j) BitPos
bp
            y <- Ptr WordLE -> BitPos -> Int -> ST s WordM
forall s. Ptr WordLE -> BitPos -> Int -> ST s WordM
get Ptr WordLE
p BitPos
bp Int
howMany
            let v' = Word16
v Word16 -> Word16 -> Word16
forall a. Bits a => a -> a -> a
.|. (WordM -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral WordM
y Word16 -> Int -> Word16
forall a. Bits a => a -> Int -> a
`unsafeShiftL` Int
j)
                j' = Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
howMany
            inner p i bp' v' j'

    get :: Ptr WordLE -> BitPos -> Int -> ST s WordM
    get :: forall s. Ptr WordLE -> BitPos -> Int -> ST s WordM
get Ptr WordLE
p BitPos
bp Int
howMany = do
        x <- (WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` BitPos -> Int
bitPos BitPos
bp) (WordM -> WordM) -> ST s WordM -> ST s WordM
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Ptr WordLE -> BitPos -> ST s WordM
forall s. Ptr WordLE -> BitPos -> ST s WordM
peekWordPos Ptr WordLE
p BitPos
bp
        return (x .&. getMask howMany)
{-# SPECIALIZE mutByteDecode :: forall marking s. Int -> View Bytes -> MutableBlockN marking N Word16 s -> ST s () #-}

-- Optimization of byteDecode when 𝑑=1
byteDecode1 :: ByteArrayAccess ba => ba -> BlockN Sec N Word16
byteDecode1 :: forall ba. ByteArrayAccess ba => ba -> BlockN 'Sec N Word16
byteDecode1 ba
b = Proxy 'Sec
-> (forall s. MutableBlockN 'Sec N Word16 s -> ST s ())
-> BlockN 'Sec N Word16
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
BlockN.runNew (Proxy 'Sec
forall {k} (t :: k). Proxy t
Proxy :: Proxy Sec) ((forall s. MutableBlockN 'Sec N Word16 s -> ST s ())
 -> BlockN 'Sec N Word16)
-> (forall s. MutableBlockN 'Sec N Word16 s -> ST s ())
-> BlockN 'Sec N Word16
forall a b. (a -> b) -> a -> b
$ ba -> MutableBlockN 'Sec N Word16 s -> ST s ()
forall ba s.
ByteArrayAccess ba =>
ba -> MutableBlockN 'Sec N Word16 s -> ST s ()
mutByteDecode1 ba
b
{-# INLINE byteDecode1 #-}

mutByteDecode1 :: ByteArrayAccess ba => ba -> MutableBlockN Sec N Word16 s -> ST s ()
mutByteDecode1 :: forall ba s.
ByteArrayAccess ba =>
ba -> MutableBlockN 'Sec N Word16 s -> ST s ()
mutByteDecode1 !ba
b !MutableBlockN 'Sec N Word16 s
f = ba -> (Ptr WordLE -> ST s ()) -> ST s ()
forall ba p s a.
ByteArrayAccess ba =>
ba -> (Ptr p -> ST s a) -> ST s a
ST.withByteArray ba
b ((Ptr WordLE -> ST s ()) -> ST s ())
-> (Ptr WordLE -> ST s ()) -> ST s ()
forall a b. (a -> b) -> a -> b
$ \Ptr WordLE
p -> Ptr WordLE -> Int -> ST s ()
outer Ptr WordLE
p Int
0
  where
    outer :: Ptr WordLE -> Int -> ST s ()
outer !Ptr WordLE
p Int
i = Bool -> ST s () -> ST s ()
forall (f :: * -> *). Applicative f => Bool -> f () -> f ()
when (Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
< Int
n) (ST s () -> ST s ()) -> ST s () -> ST s ()
forall a b. (a -> b) -> a -> b
$ do
        x <- Ptr WordLE -> ST s WordM
forall s. Ptr WordLE -> ST s WordM
peekWord Ptr WordLE
p
        inner (p `plusPtr` wordBytes) x i 0

    inner :: Ptr WordLE -> WordM -> Int -> Int -> ST s ()
inner !Ptr WordLE
p !WordM
acc !Int
i Int
j
        | Int
j Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
wordBits = Ptr WordLE -> Int -> ST s ()
outer Ptr WordLE
p Int
i
        | Bool
otherwise = do
            let v :: Word16
v = WordM -> Word16
forall a b. (Integral a, Num b) => a -> b
fromIntegral (WordM
acc WordM -> WordM -> WordM
forall a. Bits a => a -> a -> a
.&. WordM
1)
            MutableBlockN 'Sec N Word16 (PrimState (ST s))
-> Offset Word16 -> Word16 -> ST s ()
forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
BlockN.write MutableBlockN 'Sec N Word16 s
MutableBlockN 'Sec N Word16 (PrimState (ST s))
f (Int -> Offset Word16
forall ty. Int -> Offset ty
Offset Int
i) Word16
v
            Ptr WordLE -> WordM -> Int -> Int -> ST s ()
inner Ptr WordLE
p (WordM
acc WordM -> Int -> WordM
forall a. Bits a => a -> Int -> a
`unsafeShiftR` Int
1) (Int
i Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1) (Int
j Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- byteDecode with 𝑑=12 and conversion to the field
byteDecode12 :: (Classified marking, ByteArrayAccess ba) => ba -> Tq marking
byteDecode12 :: forall (marking :: SecurityMarking) ba.
(Classified marking, ByteArrayAccess ba) =>
ba -> Tq marking
byteDecode12 = BlockN marking N Zq -> Tq marking
forall (marking :: SecurityMarking).
BlockN marking N Zq -> Tq marking
Tq (BlockN marking N Zq -> Tq marking)
-> (ba -> BlockN marking N Zq) -> ba -> Tq marking
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Word16 -> Zq) -> BlockN marking N Word16 -> BlockN marking N Zq
forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, EqPrimSize a b) =>
(a -> b) -> BlockN marking n a -> BlockN marking n b
BlockN.mapEqPrimSize Word16 -> Zq
toZq (BlockN marking N Word16 -> BlockN marking N Zq)
-> (ba -> BlockN marking N Word16) -> ba -> BlockN marking N Zq
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Int -> ba -> BlockN marking N Word16
forall (marking :: SecurityMarking) ba.
(Classified marking, ByteArrayAccess ba) =>
Int -> ba -> BlockN marking N Word16
byteDecode Int
12
{-# INLINE byteDecode12 #-}