-- |
-- Module      : BlockN
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- A secure block with length at type level
--
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE ScopedTypeVariables #-}
module BlockN
    ( BlockN, MutableBlockN, index, iterModify, mapEqPrimSize
    , BlockN.read, runNew, runThaw, runFold, BlockN.unsafeCast
    , BlockN.write, BlockN.zipWith
#ifdef ML_KEM_TESTING
    , create, BlockN.fromList, BlockN.replicate, BlockN.toList
#endif
    ) where

import Control.DeepSeq (NFData(..))
import Control.Monad.ST

import Data.Proxy

import Base
import Block (MutableBlock, blockRead, blockWrite, unsafeCastMut)
import Equality
import Fusion
import Marking (Classified, SecurityMarking)
import SecureBlock (SecureBlock)
import qualified SecureBlock
import Math

newtype BlockN marking (n :: Nat) a = BlockN { forall (marking :: SecurityMarking) (n :: Nat) a.
BlockN marking n a -> SecureBlock marking a
unBlockN :: SecureBlock marking a }

#ifdef ML_KEM_TESTING
instance (Classified marking, Eq a, PrimType a) => Eq (BlockN marking n a) where
    BlockN a == BlockN b = SecureBlock.eq a b

instance (Classified marking, PrimType a, Show a) => Show (BlockN marking n a) where
    showsPrec d = SecureBlock.showsPrec d . unBlockN
#endif

instance NFData (BlockN marking n a) where
    rnf :: BlockN marking n a -> ()
rnf = SecureBlock marking a -> ()
forall (marking :: SecurityMarking) ty.
SecureBlock marking ty -> ()
SecureBlock.toNormalForm (SecureBlock marking a -> ())
-> (BlockN marking n a -> SecureBlock marking a)
-> BlockN marking n a
-> ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockN marking n a -> SecureBlock marking a
forall (marking :: SecurityMarking) (n :: Nat) a.
BlockN marking n a -> SecureBlock marking a
unBlockN

instance (Classified marking, KnownNat n, PrimType a, Add a) => Add (BlockN marking n a) where
    zero :: BlockN marking n a
zero = a -> BlockN marking n a
forall (marking :: SecurityMarking) (n :: Nat) a.
(Classified marking, KnownNat n, PrimType a) =>
a -> BlockN marking n a
BlockN.replicate a
forall a. Add a => a
zero
    {-# INLINE zero #-}
    .+ :: BlockN marking n a -> BlockN marking n a -> BlockN marking n a
(.+) = (a -> a -> a)
-> BlockN marking n a -> BlockN marking n a -> BlockN marking n a
forall (mc :: SecurityMarking) (n :: Nat) a b c
       (ma :: SecurityMarking) (mb :: SecurityMarking).
(Classified mc, KnownNat n, PrimType a, PrimType b, PrimType c) =>
(a -> b -> c) -> BlockN ma n a -> BlockN mb n b -> BlockN mc n c
BlockN.zipWith a -> a -> a
forall a. Add a => a -> a -> a
(.+)
    {-# INLINE (.+) #-}
    .- :: BlockN marking n a -> BlockN marking n a -> BlockN marking n a
(.-) = (a -> a -> a)
-> BlockN marking n a -> BlockN marking n a -> BlockN marking n a
forall (mc :: SecurityMarking) (n :: Nat) a b c
       (ma :: SecurityMarking) (mb :: SecurityMarking).
(Classified mc, KnownNat n, PrimType a, PrimType b, PrimType c) =>
(a -> b -> c) -> BlockN ma n a -> BlockN mb n b -> BlockN mc n c
BlockN.zipWith a -> a -> a
forall a. Add a => a -> a -> a
(.-)
    {-# INLINE (.-) #-}
    neg :: BlockN marking n a -> BlockN marking n a
neg = (a -> a) -> BlockN marking n a -> BlockN marking n a
forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, PrimType a, PrimType b) =>
(a -> b) -> BlockN marking n a -> BlockN marking n b
BlockN.map a -> a
forall a. Add a => a -> a
neg
    {-# INLINE neg #-}

newtype MutableBlockN (marking :: SecurityMarking) (n :: Nat) a m = MutableBlockN { forall (marking :: SecurityMarking) (n :: Nat) a m.
MutableBlockN marking n a m -> MutableBlock a m
unMutableBlockN :: MutableBlock a m }

instance (Classified marking, KnownNat n, PrimType a) => Fusion (BlockN marking n a) where
    type Mut (BlockN marking n a) s = MutableBlockN marking n a s
    newF :: forall s. ST s (Mut (BlockN marking n a) s)
newF = Proxy marking
-> ST s (MutableBlockN marking n a (PrimState (ST s)))
forall (proxy :: SecurityMarking -> *) (marking :: SecurityMarking)
       (n :: Nat) a (prim :: * -> *).
(Classified marking, KnownNat n, PrimMonad prim, PrimType a) =>
proxy marking -> prim (MutableBlockN marking n a (PrimState prim))
new Proxy marking
forall {k} (t :: k). Proxy t
Proxy
    thawF :: forall s. BlockN marking n a -> ST s (Mut (BlockN marking n a) s)
thawF = BlockN marking n a -> ST s (Mut (BlockN marking n a) s)
BlockN marking n a
-> ST s (MutableBlockN marking n a (PrimState (ST s)))
forall (marking :: SecurityMarking) (prim :: * -> *) (n :: Nat) a.
(Classified marking, PrimMonad prim) =>
BlockN marking n a
-> prim (MutableBlockN marking n a (PrimState prim))
thaw
    unsafeFreezeF :: forall s. Mut (BlockN marking n a) s -> ST s (BlockN marking n a)
unsafeFreezeF = Mut (BlockN marking n a) s -> ST s (BlockN marking n a)
MutableBlockN marking n a (PrimState (ST s))
-> ST s (BlockN marking n a)
forall (marking :: SecurityMarking) (prim :: * -> *) (n :: Nat) a.
(Classified marking, PrimMonad prim) =>
MutableBlockN marking n a (PrimState prim)
-> prim (BlockN marking n a)
unsafeFreeze

-- Endomorphism specialization: a different implementation is substituted
-- wherever possible with rewrite rules.  Identical input and output types give
-- a chance for a transformation to be fused on an existing mutable block.

{-# RULES
"mapEndo" [~2] forall f. BlockN.map f = mapEndo f
"zipWithEndoL" [~2] forall f. BlockN.zipWith f = zipWithEndoL f
  #-}

index :: PrimType a => BlockN marking n a -> Offset a -> a
index :: forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
index = SecureBlock marking a -> Offset a -> a
forall ty (marking :: SecurityMarking).
PrimType ty =>
SecureBlock marking ty -> Offset ty -> ty
SecureBlock.index (SecureBlock marking a -> Offset a -> a)
-> (BlockN marking n a -> SecureBlock marking a)
-> BlockN marking n a
-> Offset a
-> a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockN marking n a -> SecureBlock marking a
forall (marking :: SecurityMarking) (n :: Nat) a.
BlockN marking n a -> SecureBlock marking a
unBlockN

replicate :: forall marking n a. (Classified marking, KnownNat n, PrimType a) => a -> BlockN marking n a
replicate :: forall (marking :: SecurityMarking) (n :: Nat) a.
(Classified marking, KnownNat n, PrimType a) =>
a -> BlockN marking n a
replicate = (Offset a -> a) -> BlockN marking n a
forall (marking :: SecurityMarking) (n :: Nat) ty.
(Classified marking, KnownNat n, PrimType ty) =>
(Offset ty -> ty) -> BlockN marking n ty
create ((Offset a -> a) -> BlockN marking n a)
-> (a -> Offset a -> a) -> a -> BlockN marking n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> Offset a -> a
forall a b. a -> b -> a
const

#ifdef ML_KEM_TESTING
fromList :: forall marking n a. (Classified marking, KnownNat n, PrimType a) => [a] -> Maybe (BlockN marking n a)
fromList elems
    | Prelude.length elems /= sz = Nothing
    | otherwise = Just $ runNew (Proxy :: Proxy marking) $ \mb -> go mb 0 elems
  where
    !sz = fromIntegral $ natVal (Proxy :: Proxy n)

    go !mb !i list = case list of
        []     -> return ()
        (x:xs) -> write mb i x >> go mb (i + 1) xs

toList :: PrimType a => BlockN marking n a -> [a]
toList = SecureBlock.toList . unBlockN
#endif

create :: forall marking n ty. (Classified marking, KnownNat n, PrimType ty)
       => (Offset ty -> ty)
       -> BlockN marking n ty
create :: forall (marking :: SecurityMarking) (n :: Nat) ty.
(Classified marking, KnownNat n, PrimType ty) =>
(Offset ty -> ty) -> BlockN marking n ty
create Offset ty -> ty
initializer = Proxy marking
-> (forall s. MutableBlockN marking n ty s -> ST s ())
-> BlockN marking n ty
forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
runNew (Proxy marking
forall {k} (t :: k). Proxy t
Proxy :: Proxy marking) ((forall s. MutableBlockN marking n ty s -> ST s ())
 -> BlockN marking n ty)
-> (forall s. MutableBlockN marking n ty s -> ST s ())
-> BlockN marking n ty
forall a b. (a -> b) -> a -> b
$ (Offset ty -> ty)
-> MutableBlockN marking n ty (PrimState (ST s)) -> ST s ()
forall (marking :: SecurityMarking) (n :: Nat) ty (prim :: * -> *).
(PrimType ty, KnownNat n, PrimMonad prim) =>
(Offset ty -> ty)
-> MutableBlockN marking n ty (PrimState prim) -> prim ()
iterSet Offset ty -> ty
initializer
{-# INLINE create #-}

map :: (Classified marking, KnownNat n, PrimType a, PrimType b)
    => (a -> b) -> BlockN marking n a -> BlockN marking n b
map :: forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, PrimType a, PrimType b) =>
(a -> b) -> BlockN marking n a -> BlockN marking n b
map a -> b
f (BlockN !SecureBlock marking a
a) = (Offset b -> b) -> BlockN marking n b
forall (marking :: SecurityMarking) (n :: Nat) ty.
(Classified marking, KnownNat n, PrimType ty) =>
(Offset ty -> ty) -> BlockN marking n ty
create ((Offset b -> b) -> BlockN marking n b)
-> (Offset b -> b) -> BlockN marking n b
forall a b. (a -> b) -> a -> b
$ \(Offset Int
i) -> a -> b
f (SecureBlock marking a -> Offset a -> a
forall ty (marking :: SecurityMarking).
PrimType ty =>
SecureBlock marking ty -> Offset ty -> ty
SecureBlock.index SecureBlock marking a
a (Int -> Offset a
forall ty. Int -> Offset ty
Offset Int
i))
{-# INLINE [2] map #-}

mapEndo :: (Classified marking, KnownNat n, PrimType a)
        => (a -> a) -> BlockN marking n a -> BlockN marking n a
mapEndo :: forall (marking :: SecurityMarking) (n :: Nat) a.
(Classified marking, KnownNat n, PrimType a) =>
(a -> a) -> BlockN marking n a -> BlockN marking n a
mapEndo = (a -> a) -> BlockN marking n a -> BlockN marking n a
forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, EqPrimSize a b) =>
(a -> b) -> BlockN marking n a -> BlockN marking n b
mapEqPrimSize
{-# INLINE mapEndo #-}

mapEqPrimSize :: (Classified marking, KnownNat n, EqPrimSize a b) => (a -> b) -> BlockN marking n a -> BlockN marking n b
mapEqPrimSize :: forall (marking :: SecurityMarking) (n :: Nat) a b.
(Classified marking, KnownNat n, EqPrimSize a b) =>
(a -> b) -> BlockN marking n a -> BlockN marking n b
mapEqPrimSize a -> b
f = Context (BlockN marking n b) -> BlockN marking n b
forall a. Fusion a => Context a -> a
runContext (Context (BlockN marking n b) -> BlockN marking n b)
-> (BlockN marking n a -> Context (BlockN marking n b))
-> BlockN marking n a
-> BlockN marking n b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> b)
-> Context (BlockN marking n a) -> Context (BlockN marking n b)
forall a b (marking :: SecurityMarking) (n :: Nat).
(EqPrimSize a b, Classified marking, KnownNat n) =>
(a -> b)
-> Context (BlockN marking n a) -> Context (BlockN marking n b)
iterMapContext a -> b
f (Context (BlockN marking n a) -> Context (BlockN marking n b))
-> (BlockN marking n a -> Context (BlockN marking n a))
-> BlockN marking n a
-> Context (BlockN marking n b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockN marking n a -> Context (BlockN marking n a)
forall a. Fusion a => a -> Context a
thawContext
{-# INLINE mapEqPrimSize #-}

iterModify :: forall marking n ty prim. (PrimType ty, KnownNat n, PrimMonad prim)
           => (ty -> ty)
           -> MutableBlockN marking n ty (PrimState prim)
           -> prim ()
iterModify :: forall (marking :: SecurityMarking) (n :: Nat) ty (prim :: * -> *).
(PrimType ty, KnownNat n, PrimMonad prim) =>
(ty -> ty)
-> MutableBlockN marking n ty (PrimState prim) -> prim ()
iterModify ty -> ty
f = (Offset ty -> ty -> ty)
-> MutableBlockN marking n ty (PrimState prim) -> prim ()
forall (marking :: SecurityMarking) (n :: Nat) ty (prim :: * -> *).
(PrimType ty, KnownNat n, PrimMonad prim) =>
(Offset ty -> ty -> ty)
-> MutableBlockN marking n ty (PrimState prim) -> prim ()
iterModifyIx (\Offset ty
_ ty
x -> ty -> ty
f ty
x)
{-# INLINE iterModify #-}

iterModifyIx :: forall marking n ty prim. (PrimType ty, KnownNat n, PrimMonad prim)
             => (Offset ty -> ty -> ty)
             -> MutableBlockN marking n ty (PrimState prim)
             -> prim ()
iterModifyIx :: forall (marking :: SecurityMarking) (n :: Nat) ty (prim :: * -> *).
(PrimType ty, KnownNat n, PrimMonad prim) =>
(Offset ty -> ty -> ty)
-> MutableBlockN marking n ty (PrimState prim) -> prim ()
iterModifyIx Offset ty -> ty -> ty
f (MutableBlockN !MutableBlock ty (PrimState prim)
ma) = Offset ty -> prim ()
loop Offset ty
0
  where
    !sz :: CountOf ty
sz = Nat -> CountOf ty
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> CountOf ty) -> Nat -> CountOf ty
forall a b. (a -> b) -> a -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)

    loop :: Offset ty -> prim ()
loop Offset ty
i
        | Offset ty
i Offset ty -> CountOf ty -> Bool
forall ty. Offset ty -> CountOf ty -> Bool
.==# CountOf ty
sz = () -> prim ()
forall a. a -> prim a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        | Bool
otherwise = MutableBlock ty (PrimState prim) -> Offset ty -> prim ty
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
MutableBlock ty (PrimState prim) -> Offset ty -> prim ty
blockRead MutableBlock ty (PrimState prim)
ma Offset ty
i prim ty -> (ty -> prim ()) -> prim ()
forall a b. prim a -> (a -> prim b) -> prim b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \ty
x -> MutableBlock ty (PrimState prim) -> Offset ty -> ty -> prim ()
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
MutableBlock ty (PrimState prim) -> Offset ty -> ty -> prim ()
blockWrite MutableBlock ty (PrimState prim)
ma Offset ty
i (Offset ty -> ty -> ty
f Offset ty
i ty
x) prim () -> prim () -> prim ()
forall a b. prim a -> prim b -> prim b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Offset ty -> prim ()
loop (Offset ty
i Offset ty -> Offset ty -> Offset ty
forall a. Num a => a -> a -> a
+ Offset ty
1)
{-# INLINE iterModifyIx #-}

iterSet :: forall marking n ty prim. (PrimType ty, KnownNat n, PrimMonad prim)
        => (Offset ty -> ty)
        -> MutableBlockN marking n ty (PrimState prim)
        -> prim ()
iterSet :: forall (marking :: SecurityMarking) (n :: Nat) ty (prim :: * -> *).
(PrimType ty, KnownNat n, PrimMonad prim) =>
(Offset ty -> ty)
-> MutableBlockN marking n ty (PrimState prim) -> prim ()
iterSet Offset ty -> ty
f (MutableBlockN !MutableBlock ty (PrimState prim)
ma) = Offset ty -> prim ()
loop Offset ty
0
  where
    !sz :: CountOf ty
sz = Nat -> CountOf ty
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> CountOf ty) -> Nat -> CountOf ty
forall a b. (a -> b) -> a -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)

    loop :: Offset ty -> prim ()
loop Offset ty
i
        | Offset ty
i Offset ty -> CountOf ty -> Bool
forall ty. Offset ty -> CountOf ty -> Bool
.==# CountOf ty
sz = () -> prim ()
forall a. a -> prim a
forall (f :: * -> *) a. Applicative f => a -> f a
pure ()
        | Bool
otherwise = MutableBlock ty (PrimState prim) -> Offset ty -> ty -> prim ()
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
MutableBlock ty (PrimState prim) -> Offset ty -> ty -> prim ()
blockWrite MutableBlock ty (PrimState prim)
ma Offset ty
i (Offset ty -> ty
f Offset ty
i) prim () -> prim () -> prim ()
forall a b. prim a -> prim b -> prim b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Offset ty -> prim ()
loop (Offset ty
i Offset ty -> Offset ty -> Offset ty
forall a. Num a => a -> a -> a
+ Offset ty
1)
{-# INLINE iterSet #-}

zipWith :: (Classified mc, KnownNat n, PrimType a, PrimType b, PrimType c)
        => (a -> b -> c) -> BlockN ma n a -> BlockN mb n b -> BlockN mc n c
zipWith :: forall (mc :: SecurityMarking) (n :: Nat) a b c
       (ma :: SecurityMarking) (mb :: SecurityMarking).
(Classified mc, KnownNat n, PrimType a, PrimType b, PrimType c) =>
(a -> b -> c) -> BlockN ma n a -> BlockN mb n b -> BlockN mc n c
zipWith a -> b -> c
f (BlockN !SecureBlock ma a
a) (BlockN !SecureBlock mb b
b) =
    (Offset c -> c) -> BlockN mc n c
forall (marking :: SecurityMarking) (n :: Nat) ty.
(Classified marking, KnownNat n, PrimType ty) =>
(Offset ty -> ty) -> BlockN marking n ty
create ((Offset c -> c) -> BlockN mc n c)
-> (Offset c -> c) -> BlockN mc n c
forall a b. (a -> b) -> a -> b
$ \(Offset Int
i) ->
        a -> b -> c
f (SecureBlock ma a -> Offset a -> a
forall ty (marking :: SecurityMarking).
PrimType ty =>
SecureBlock marking ty -> Offset ty -> ty
SecureBlock.index SecureBlock ma a
a (Int -> Offset a
forall ty. Int -> Offset ty
Offset Int
i)) (SecureBlock mb b -> Offset b -> b
forall ty (marking :: SecurityMarking).
PrimType ty =>
SecureBlock marking ty -> Offset ty -> ty
SecureBlock.index SecureBlock mb b
b (Int -> Offset b
forall ty. Int -> Offset ty
Offset Int
i))
{-# INLINE [2] zipWith #-}

zipWithEndoL :: (Classified ma, KnownNat n, PrimType a, PrimType b)
             => (a -> b -> a) -> BlockN ma n a -> BlockN mb n b -> BlockN ma n a
zipWithEndoL :: forall (ma :: SecurityMarking) (n :: Nat) a b
       (mb :: SecurityMarking).
(Classified ma, KnownNat n, PrimType a, PrimType b) =>
(a -> b -> a) -> BlockN ma n a -> BlockN mb n b -> BlockN ma n a
zipWithEndoL a -> b -> a
f BlockN ma n a
a BlockN mb n b
b = Context (BlockN ma n a) -> BlockN ma n a
forall a. Fusion a => Context a -> a
runContext (BlockN mb n b -> Context (BlockN ma n a) -> Context (BlockN ma n a)
forall a b. a -> Context b -> Context b
seqContext BlockN mb n b
b ((Int -> a -> a)
-> Context (BlockN ma n a) -> Context (BlockN ma n a)
forall a b (marking :: SecurityMarking) (n :: Nat).
(EqPrimSize a b, Classified marking, KnownNat n) =>
(Int -> a -> b)
-> Context (BlockN marking n a) -> Context (BlockN marking n b)
iterMapIxContext Int -> a -> a
g (BlockN ma n a -> Context (BlockN ma n a)
forall a. Fusion a => a -> Context a
thawContext BlockN ma n a
a)))
  where g :: Int -> a -> a
g Int
i a
x = a -> b -> a
f a
x (BlockN mb n b -> Offset b -> b
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
index BlockN mb n b
b (Offset b -> b) -> Offset b -> b
forall a b. (a -> b) -> a -> b
$ Int -> Offset b
forall ty. Int -> Offset ty
Offset Int
i)
{-# INLINE zipWithEndoL #-}

unsafeCast :: BlockN marking n a -> SecureBlock marking b
unsafeCast :: forall (marking :: SecurityMarking) (n :: Nat) a b.
BlockN marking n a -> SecureBlock marking b
unsafeCast = SecureBlock marking a -> SecureBlock marking b
forall (marking :: SecurityMarking) a b.
SecureBlock marking a -> SecureBlock marking b
SecureBlock.unsafeCast (SecureBlock marking a -> SecureBlock marking b)
-> (BlockN marking n a -> SecureBlock marking a)
-> BlockN marking n a
-> SecureBlock marking b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockN marking n a -> SecureBlock marking a
forall (marking :: SecurityMarking) (n :: Nat) a.
BlockN marking n a -> SecureBlock marking a
unBlockN

read :: (PrimMonad prim, PrimType a) => MutableBlockN marking n a (PrimState prim) -> Offset a -> prim a
read :: forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim) -> Offset a -> prim a
read = MutableBlock a (PrimState prim) -> Offset a -> prim a
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
MutableBlock ty (PrimState prim) -> Offset ty -> prim ty
blockRead (MutableBlock a (PrimState prim) -> Offset a -> prim a)
-> (MutableBlockN marking n a (PrimState prim)
    -> MutableBlock a (PrimState prim))
-> MutableBlockN marking n a (PrimState prim)
-> Offset a
-> prim a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MutableBlockN marking n a (PrimState prim)
-> MutableBlock a (PrimState prim)
forall (marking :: SecurityMarking) (n :: Nat) a m.
MutableBlockN marking n a m -> MutableBlock a m
unMutableBlockN

write :: (PrimMonad prim, PrimType a) => MutableBlockN marking n a (PrimState prim) -> Offset a -> a -> prim ()
write :: forall (prim :: * -> *) a (marking :: SecurityMarking) (n :: Nat).
(PrimMonad prim, PrimType a) =>
MutableBlockN marking n a (PrimState prim)
-> Offset a -> a -> prim ()
write = MutableBlock a (PrimState prim) -> Offset a -> a -> prim ()
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
MutableBlock ty (PrimState prim) -> Offset ty -> ty -> prim ()
blockWrite (MutableBlock a (PrimState prim) -> Offset a -> a -> prim ())
-> (MutableBlockN marking n a (PrimState prim)
    -> MutableBlock a (PrimState prim))
-> MutableBlockN marking n a (PrimState prim)
-> Offset a
-> a
-> prim ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MutableBlockN marking n a (PrimState prim)
-> MutableBlock a (PrimState prim)
forall (marking :: SecurityMarking) (n :: Nat) a m.
MutableBlockN marking n a m -> MutableBlock a m
unMutableBlockN

new :: forall proxy marking n a prim. (Classified marking, KnownNat n, PrimMonad prim, PrimType a) => proxy marking -> prim (MutableBlockN marking n a (PrimState prim))
new :: forall (proxy :: SecurityMarking -> *) (marking :: SecurityMarking)
       (n :: Nat) a (prim :: * -> *).
(Classified marking, KnownNat n, PrimMonad prim, PrimType a) =>
proxy marking -> prim (MutableBlockN marking n a (PrimState prim))
new proxy marking
prx = MutableBlock a (PrimState prim)
-> MutableBlockN marking n a (PrimState prim)
forall (marking :: SecurityMarking) (n :: Nat) a m.
MutableBlock a m -> MutableBlockN marking n a m
MutableBlockN (MutableBlock a (PrimState prim)
 -> MutableBlockN marking n a (PrimState prim))
-> prim (MutableBlock a (PrimState prim))
-> prim (MutableBlockN marking n a (PrimState prim))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> proxy marking
-> CountOf a -> prim (MutableBlock a (PrimState prim))
forall ty (prim :: * -> *) (proxy :: SecurityMarking -> *).
(PrimType ty, PrimMonad prim) =>
proxy marking
-> CountOf ty -> prim (MutableBlock ty (PrimState prim))
forall (marking :: SecurityMarking) ty (prim :: * -> *)
       (proxy :: SecurityMarking -> *).
(Classified marking, PrimType ty, PrimMonad prim) =>
proxy marking
-> CountOf ty -> prim (MutableBlock ty (PrimState prim))
SecureBlock.new proxy marking
prx (Int -> CountOf a
forall ty. Int -> CountOf ty
CountOf Int
sz)
  where !sz :: Int
sz = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)
{-# INLINE new #-}

runThaw :: (Classified marking, KnownNat n, PrimType a) => BlockN marking n a -> (forall s. MutableBlockN marking n a s -> ST s ()) -> BlockN marking n a
runThaw :: forall (marking :: SecurityMarking) (n :: Nat) a.
(Classified marking, KnownNat n, PrimType a) =>
BlockN marking n a
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
runThaw BlockN marking n a
a forall s. MutableBlockN marking n a s -> ST s ()
f = Context (BlockN marking n a) -> BlockN marking n a
forall a. Fusion a => Context a -> a
runContext ((forall s. Mut (BlockN marking n a) s -> ST s ())
-> Context (BlockN marking n a) -> Context (BlockN marking n a)
forall a. (forall s. Mut a s -> ST s ()) -> Context a -> Context a
modifyContext Mut (BlockN marking n a) s -> ST s ()
MutableBlockN marking n a s -> ST s ()
forall s. Mut (BlockN marking n a) s -> ST s ()
forall s. MutableBlockN marking n a s -> ST s ()
f (BlockN marking n a -> Context (BlockN marking n a)
forall a. Fusion a => a -> Context a
thawContext BlockN marking n a
a))
{-# INLINE runThaw #-}

runNew :: (Classified marking, KnownNat n, PrimType a) => proxy marking -> (forall s. MutableBlockN marking n a s -> ST s ()) -> BlockN marking n a
runNew :: forall (marking :: SecurityMarking) (n :: Nat) a
       (proxy :: SecurityMarking -> *).
(Classified marking, KnownNat n, PrimType a) =>
proxy marking
-> (forall s. MutableBlockN marking n a s -> ST s ())
-> BlockN marking n a
runNew proxy marking
_ forall s. MutableBlockN marking n a s -> ST s ()
f = Context (BlockN marking n a) -> BlockN marking n a
forall a. Fusion a => Context a -> a
runContext ((forall s. Mut (BlockN marking n a) s -> ST s ())
-> Context (BlockN marking n a) -> Context (BlockN marking n a)
forall a. (forall s. Mut a s -> ST s ()) -> Context a -> Context a
modifyContext Mut (BlockN marking n a) s -> ST s ()
MutableBlockN marking n a s -> ST s ()
forall s. Mut (BlockN marking n a) s -> ST s ()
forall s. MutableBlockN marking n a s -> ST s ()
f Context (BlockN marking n a)
forall a. Fusion a => Context a
newContext)
{-# INLINE runNew #-}

runFold :: (Classified marking, KnownNat n, PrimType a, Foldable t) => BlockN marking n a -> (forall s. b -> MutableBlockN marking n a s -> ST s ()) -> t b -> BlockN marking n a
runFold :: forall (marking :: SecurityMarking) (n :: Nat) a (t :: * -> *) b.
(Classified marking, KnownNat n, PrimType a, Foldable t) =>
BlockN marking n a
-> (forall s. b -> MutableBlockN marking n a s -> ST s ())
-> t b
-> BlockN marking n a
runFold BlockN marking n a
a forall s. b -> MutableBlockN marking n a s -> ST s ()
f = Context (BlockN marking n a) -> BlockN marking n a
forall a. Fusion a => Context a -> a
runContext (Context (BlockN marking n a) -> BlockN marking n a)
-> (t b -> Context (BlockN marking n a))
-> t b
-> BlockN marking n a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (forall s. b -> Mut (BlockN marking n a) s -> ST s ())
-> Context (BlockN marking n a)
-> t b
-> Context (BlockN marking n a)
forall (t :: * -> *) b a.
Foldable t =>
(forall s. b -> Mut a s -> ST s ())
-> Context a -> t b -> Context a
foldContext b -> Mut (BlockN marking n a) s -> ST s ()
b -> MutableBlockN marking n a s -> ST s ()
forall s. b -> Mut (BlockN marking n a) s -> ST s ()
forall s. b -> MutableBlockN marking n a s -> ST s ()
f (BlockN marking n a -> Context (BlockN marking n a)
forall a. Fusion a => a -> Context a
thawContext BlockN marking n a
a)
{-# INLINE runFold #-}

thaw :: (Classified marking, PrimMonad prim) => BlockN marking n a -> prim (MutableBlockN marking n a (PrimState prim))
thaw :: forall (marking :: SecurityMarking) (prim :: * -> *) (n :: Nat) a.
(Classified marking, PrimMonad prim) =>
BlockN marking n a
-> prim (MutableBlockN marking n a (PrimState prim))
thaw = (MutableBlock a (PrimState prim)
 -> MutableBlockN marking n a (PrimState prim))
-> prim (MutableBlock a (PrimState prim))
-> prim (MutableBlockN marking n a (PrimState prim))
forall a b. (a -> b) -> prim a -> prim b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap MutableBlock a (PrimState prim)
-> MutableBlockN marking n a (PrimState prim)
forall (marking :: SecurityMarking) (n :: Nat) a m.
MutableBlock a m -> MutableBlockN marking n a m
MutableBlockN (prim (MutableBlock a (PrimState prim))
 -> prim (MutableBlockN marking n a (PrimState prim)))
-> (BlockN marking n a -> prim (MutableBlock a (PrimState prim)))
-> BlockN marking n a
-> prim (MutableBlockN marking n a (PrimState prim))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecureBlock marking a -> prim (MutableBlock a (PrimState prim))
forall (marking :: SecurityMarking) (m :: * -> *) ty.
(Classified marking, PrimMonad m) =>
SecureBlock marking ty -> m (MutableBlock ty (PrimState m))
forall (m :: * -> *) ty.
PrimMonad m =>
SecureBlock marking ty -> m (MutableBlock ty (PrimState m))
SecureBlock.thaw (SecureBlock marking a -> prim (MutableBlock a (PrimState prim)))
-> (BlockN marking n a -> SecureBlock marking a)
-> BlockN marking n a
-> prim (MutableBlock a (PrimState prim))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. BlockN marking n a -> SecureBlock marking a
forall (marking :: SecurityMarking) (n :: Nat) a.
BlockN marking n a -> SecureBlock marking a
unBlockN

unsafeFreeze :: (Classified marking, PrimMonad prim) => MutableBlockN marking n a (PrimState prim) -> prim (BlockN marking n a)
unsafeFreeze :: forall (marking :: SecurityMarking) (prim :: * -> *) (n :: Nat) a.
(Classified marking, PrimMonad prim) =>
MutableBlockN marking n a (PrimState prim)
-> prim (BlockN marking n a)
unsafeFreeze = (SecureBlock marking a -> BlockN marking n a)
-> prim (SecureBlock marking a) -> prim (BlockN marking n a)
forall a b. (a -> b) -> prim a -> prim b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap SecureBlock marking a -> BlockN marking n a
forall (marking :: SecurityMarking) (n :: Nat) a.
SecureBlock marking a -> BlockN marking n a
BlockN (prim (SecureBlock marking a) -> prim (BlockN marking n a))
-> (MutableBlockN marking n a (PrimState prim)
    -> prim (SecureBlock marking a))
-> MutableBlockN marking n a (PrimState prim)
-> prim (BlockN marking n a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MutableBlock a (PrimState prim) -> prim (SecureBlock marking a)
forall (marking :: SecurityMarking) (prim :: * -> *) ty.
(Classified marking, PrimMonad prim) =>
MutableBlock ty (PrimState prim) -> prim (SecureBlock marking ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (SecureBlock marking ty)
SecureBlock.unsafeFreeze (MutableBlock a (PrimState prim) -> prim (SecureBlock marking a))
-> (MutableBlockN marking n a (PrimState prim)
    -> MutableBlock a (PrimState prim))
-> MutableBlockN marking n a (PrimState prim)
-> prim (SecureBlock marking a)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. MutableBlockN marking n a (PrimState prim)
-> MutableBlock a (PrimState prim)
forall (marking :: SecurityMarking) (n :: Nat) a m.
MutableBlockN marking n a m -> MutableBlock a m
unMutableBlockN

unsafeMapIx :: forall marking n a b prim. (KnownNat n, EqPrimSize a b, PrimMonad prim) => (Int -> a -> b) -> MutableBlockN marking n a (PrimState prim) -> prim (MutableBlockN marking n b (PrimState prim))
unsafeMapIx :: forall (marking :: SecurityMarking) (n :: Nat) a b
       (prim :: * -> *).
(KnownNat n, EqPrimSize a b, PrimMonad prim) =>
(Int -> a -> b)
-> MutableBlockN marking n a (PrimState prim)
-> prim (MutableBlockN marking n b (PrimState prim))
unsafeMapIx Int -> a -> b
f (MutableBlockN !MutableBlock a (PrimState prim)
ma) = MutableBlock b (PrimState prim)
-> MutableBlockN marking n b (PrimState prim)
forall (marking :: SecurityMarking) (n :: Nat) a m.
MutableBlock a m -> MutableBlockN marking n a m
MutableBlockN (MutableBlock b (PrimState prim)
 -> MutableBlockN marking n b (PrimState prim))
-> (MutableBlock b (PrimState prim)
    -> MutableBlock b (PrimState prim))
-> MutableBlock b (PrimState prim)
-> MutableBlockN marking n b (PrimState prim)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (a -> b)
-> MutableBlock b (PrimState prim)
-> MutableBlock b (PrimState prim)
forall a b (k :: * -> * -> *) c. EqPrimSize a b => k a b -> c -> c
ensureEqPrimSize a -> b
witness (MutableBlock b (PrimState prim)
 -> MutableBlockN marking n b (PrimState prim))
-> prim (MutableBlock b (PrimState prim))
-> prim (MutableBlockN marking n b (PrimState prim))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Int -> prim (MutableBlock b (PrimState prim))
loop Int
0
  where
    witness :: a -> b
witness = a -> b
forall a. HasCallStack => a
undefined :: a -> b
    !sz :: Int
sz = Nat -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral (Nat -> Int) -> Nat -> Int
forall a b. (a -> b) -> a -> b
$ Proxy n -> Nat
forall (n :: Nat) (proxy :: Nat -> *). KnownNat n => proxy n -> Nat
natVal (Proxy n
forall {k} (t :: k). Proxy t
Proxy :: Proxy n)
    loop :: Int -> prim (MutableBlock b (PrimState prim))
loop Int
i
        | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
sz = MutableBlock b (PrimState prim)
-> prim (MutableBlock b (PrimState prim))
forall a. a -> prim a
forall (m :: * -> *) a. Monad m => a -> m a
return (MutableBlock a (PrimState prim) -> MutableBlock b (PrimState prim)
forall a m b. MutableBlock a m -> MutableBlock b m
unsafeCastMut MutableBlock a (PrimState prim)
ma)
        | Bool
otherwise = do
            a <- MutableBlock a (PrimState prim) -> Offset a -> prim a
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
MutableBlock ty (PrimState prim) -> Offset ty -> prim ty
blockRead MutableBlock a (PrimState prim)
ma (Int -> Offset a
forall ty. Int -> Offset ty
Offset Int
i)
            blockWrite (unsafeCastMut ma) (Offset i) (f i a)
            loop (i + 1)
{-# INLINE unsafeMapIx #-}

--

iterMapContext :: (EqPrimSize a b, Classified marking, KnownNat n) => (a -> b) -> Context (BlockN marking n a) -> Context (BlockN marking n b)
iterMapContext :: forall a b (marking :: SecurityMarking) (n :: Nat).
(EqPrimSize a b, Classified marking, KnownNat n) =>
(a -> b)
-> Context (BlockN marking n a) -> Context (BlockN marking n b)
iterMapContext a -> b
f = (Int -> a -> b)
-> Context (BlockN marking n a) -> Context (BlockN marking n b)
forall a b (marking :: SecurityMarking) (n :: Nat).
(EqPrimSize a b, Classified marking, KnownNat n) =>
(Int -> a -> b)
-> Context (BlockN marking n a) -> Context (BlockN marking n b)
iterMapIxContext (\Int
_ a
x -> a -> b
f a
x)
{-# INLINE iterMapContext #-}

iterMapIxContext :: (EqPrimSize a b, Classified marking, KnownNat n) => (Int -> a -> b) -> Context (BlockN marking n a) -> Context (BlockN marking n b)
iterMapIxContext :: forall a b (marking :: SecurityMarking) (n :: Nat).
(EqPrimSize a b, Classified marking, KnownNat n) =>
(Int -> a -> b)
-> Context (BlockN marking n a) -> Context (BlockN marking n b)
iterMapIxContext Int -> a -> b
f = MapF (BlockN marking n a) (BlockN marking n b)
-> Context (BlockN marking n a) -> Context (BlockN marking n b)
forall a b. MapF a b -> Context a -> Context b
mapContext MapF (BlockN marking n a) (BlockN marking n b)
m
  where m :: MapF (BlockN marking n a) (BlockN marking n b)
m = MapF { mapUpdate :: forall s.
Mut (BlockN marking n a) s -> ST s (Mut (BlockN marking n b) s)
mapUpdate = (Int -> a -> b)
-> MutableBlockN marking n a (PrimState (ST s))
-> ST s (MutableBlockN marking n b (PrimState (ST s)))
forall (marking :: SecurityMarking) (n :: Nat) a b
       (prim :: * -> *).
(KnownNat n, EqPrimSize a b, PrimMonad prim) =>
(Int -> a -> b)
-> MutableBlockN marking n a (PrimState prim)
-> prim (MutableBlockN marking n b (PrimState prim))
unsafeMapIx Int -> a -> b
f
                 , mapInit :: forall s. BlockN marking n a -> ST s (Mut (BlockN marking n b) s)
mapInit = \BlockN marking n a
x -> ST s (Mut (BlockN marking n b) s)
ST s (MutableBlockN marking n b s)
forall s. ST s (Mut (BlockN marking n b) s)
forall a s. Fusion a => ST s (Mut a s)
newF ST s (MutableBlockN marking n b s)
-> (MutableBlockN marking n b s
    -> ST s (MutableBlockN marking n b s))
-> ST s (MutableBlockN marking n b s)
forall a b. ST s a -> (a -> ST s b) -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> (a -> m b) -> m b
>>= \MutableBlockN marking n b s
mb -> (Offset b -> b)
-> MutableBlockN marking n b (PrimState (ST s)) -> ST s ()
forall (marking :: SecurityMarking) (n :: Nat) ty (prim :: * -> *).
(PrimType ty, KnownNat n, PrimMonad prim) =>
(Offset ty -> ty)
-> MutableBlockN marking n ty (PrimState prim) -> prim ()
iterSet (BlockN marking n a -> Offset b -> b
g BlockN marking n a
x) MutableBlockN marking n b s
MutableBlockN marking n b (PrimState (ST s))
mb ST s ()
-> ST s (MutableBlockN marking n b s)
-> ST s (MutableBlockN marking n b s)
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> MutableBlockN marking n b s -> ST s (MutableBlockN marking n b s)
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return MutableBlockN marking n b s
mb
                 }
        g :: BlockN marking n a -> Offset b -> b
g BlockN marking n a
x (Offset Int
i) = Int -> a -> b
f Int
i (BlockN marking n a -> Offset a -> a
forall a (marking :: SecurityMarking) (n :: Nat).
PrimType a =>
BlockN marking n a -> Offset a -> a
index BlockN marking n a
x (Int -> Offset a
forall ty. Int -> Offset ty
Offset Int
i))
{-# INLINE [1] iterMapIxContext #-}


-- Fusion rules
--
-- "iterMapIxContext/iterMapIxContext" merges element-wise transformations as
-- single operations.  For example @a .+ b .+ c@ becomes a single loop that
-- processes all input blocks in parallel and writes to the destination block.
--
-- "iterMapIxContext/seqContext" moves strictness annotations upstream so that
-- they do not prevent other rules from firing.

{-# RULES
"iterMapIxContext/seqContext" [~1] forall a f c. iterMapIxContext f (seqContext a c) = seqContext a (iterMapIxContext f c)
"iterMapIxContext/iterMapIxContext" [~1] forall f g c. iterMapIxContext f (iterMapIxContext g c) = iterMapIxContext (\i a -> f i (g i a)) c
  #-}