-- |
-- Module      : Marking
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- Infrastructure that associates a security marking at type level to all
-- buffers created by the library.  This determines which buffers need the
-- scrubbed (Sec) or regular (Pub) variants.
--
{-# LANGUAGE CPP #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE PolyKinds #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeFamilyDependencies #-}
module Marking
    ( SecurityMarking(..), Classified(..), Leak(..), index
    , Marking.toNormalForm, unsafeCast
#ifdef ML_KEM_TESTING
    , Marking.toList
#endif
    ) where

import Control.DeepSeq (NFData(..))
import Control.Monad.ST

import Data.ByteArray (Bytes, ScrubbedBytes)
import qualified Data.ByteArray as B

import Data.Kind

import Foreign.Ptr (Ptr)

import Unsafe.Coerce

import Base
import Block (Block, MutableBlock, blockIndex)
import ScrubbedBlock (ScrubbedBlock)
import qualified Block
import qualified ByteArrayST as ST
import qualified ScrubbedBlock

data SecurityMarking = Sec | Pub  -- secret or public information

-- Transformation called only at expected location in the LWE problem, after
-- adding noise to secret information.
--
-- Block and ScrubbedBlock have the same representation, we can force coercion
-- from Sec to Pub even though the block will be actually scrubbed.  This is
-- simpler than copying to a real non-scrubbed block.
class Leak t where
    leak :: t Sec -> t Pub
    leak = t 'Sec -> t 'Pub
forall a b. a -> b
unsafeCoerce

class Classified (marking :: SecurityMarking) where
    type SecureBlock marking = (block :: Type -> Type) | block -> marking

    new :: (PrimType ty, PrimMonad prim) => proxy marking -> CountOf ty -> prim (MutableBlock ty (PrimState prim))
    thaw :: PrimMonad m => SecureBlock marking ty -> m (MutableBlock ty (PrimState m))
    unsafeFreeze :: PrimMonad prim => MutableBlock ty (PrimState prim) -> prim (SecureBlock marking ty)

#ifdef ML_KEM_TESTING
    eq :: (Eq ty, PrimType ty) => SecureBlock marking ty -> SecureBlock marking ty -> Bool
    showsPrec :: (PrimType ty, Show ty) => Int -> SecureBlock marking ty -> ShowS
    lengthBlock :: PrimType ty => SecureBlock marking ty -> CountOf ty
#endif

    type SecureBytes marking = bytes | bytes -> marking
    unsafeCreate :: Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes marking
    lengthBytes :: SecureBytes marking -> Int
    copyByteArrayToPtr :: SecureBytes marking -> Ptr a -> IO ()

instance Classified Pub where
    type SecureBlock Pub = Block

    new :: forall ty (prim :: * -> *) (proxy :: SecurityMarking -> *).
(PrimType ty, PrimMonad prim) =>
proxy 'Pub -> CountOf ty -> prim (MutableBlock ty (PrimState prim))
new proxy 'Pub
_ = CountOf ty -> prim (MutableBlock ty (PrimState prim))
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
Block.new
    thaw :: forall (m :: * -> *) ty.
PrimMonad m =>
SecureBlock 'Pub ty -> m (MutableBlock ty (PrimState m))
thaw = Block ty -> m (MutableBlock ty (PrimState m))
SecureBlock 'Pub ty -> m (MutableBlock ty (PrimState m))
forall (prim :: * -> *) ty.
PrimMonad prim =>
Block ty -> prim (MutableBlock ty (PrimState prim))
Block.thaw
    unsafeFreeze :: forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (SecureBlock 'Pub ty)
unsafeFreeze = MutableBlock ty (PrimState prim) -> prim (Block ty)
MutableBlock ty (PrimState prim) -> prim (SecureBlock 'Pub ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (Block ty)
Block.unsafeFreeze

#ifdef ML_KEM_TESTING
    eq = (==)
    showsPrec = Prelude.showsPrec
    lengthBlock = Block.length
#endif

    type SecureBytes Pub = Bytes
    unsafeCreate :: forall a. Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes 'Pub
unsafeCreate = Int -> (forall s. Ptr a -> ST s ()) -> Bytes
Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes 'Pub
forall ba p.
ByteArray ba =>
Int -> (forall s. Ptr p -> ST s ()) -> ba
ST.unsafeCreate
    {-# INLINE unsafeCreate #-}
    lengthBytes :: SecureBytes 'Pub -> Int
lengthBytes = Bytes -> Int
SecureBytes 'Pub -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length
    copyByteArrayToPtr :: forall a. SecureBytes 'Pub -> Ptr a -> IO ()
copyByteArrayToPtr = Bytes -> Ptr a -> IO ()
SecureBytes 'Pub -> Ptr a -> IO ()
forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO ()
forall p. Bytes -> Ptr p -> IO ()
B.copyByteArrayToPtr

instance Classified Sec where
    type SecureBlock Sec = ScrubbedBlock

    new :: forall ty (prim :: * -> *) (proxy :: SecurityMarking -> *).
(PrimType ty, PrimMonad prim) =>
proxy 'Sec -> CountOf ty -> prim (MutableBlock ty (PrimState prim))
new proxy 'Sec
_ = CountOf ty -> prim (MutableBlock ty (PrimState prim))
forall ty (prim :: * -> *).
(PrimType ty, PrimMonad prim) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
ScrubbedBlock.new
    thaw :: forall (m :: * -> *) ty.
PrimMonad m =>
SecureBlock 'Sec ty -> m (MutableBlock ty (PrimState m))
thaw = ScrubbedBlock ty -> m (MutableBlock ty (PrimState m))
SecureBlock 'Sec ty -> m (MutableBlock ty (PrimState m))
forall (m :: * -> *) ty.
PrimMonad m =>
ScrubbedBlock ty -> m (MutableBlock ty (PrimState m))
ScrubbedBlock.thaw
    unsafeFreeze :: forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (SecureBlock 'Sec ty)
unsafeFreeze = MutableBlock ty (PrimState prim) -> prim (ScrubbedBlock ty)
MutableBlock ty (PrimState prim) -> prim (SecureBlock 'Sec ty)
forall (prim :: * -> *) ty.
PrimMonad prim =>
MutableBlock ty (PrimState prim) -> prim (ScrubbedBlock ty)
ScrubbedBlock.unsafeFreeze

#ifdef ML_KEM_TESTING
    eq = (==)
    showsPrec = Prelude.showsPrec
    lengthBlock = ScrubbedBlock.length
#endif

    type SecureBytes Sec = ScrubbedBytes
    unsafeCreate :: forall a. Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes 'Sec
unsafeCreate = Int -> (forall s. Ptr a -> ST s ()) -> ScrubbedBytes
Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes 'Sec
forall ba p.
ByteArray ba =>
Int -> (forall s. Ptr p -> ST s ()) -> ba
ST.unsafeCreate
    {-# INLINE unsafeCreate #-}
    lengthBytes :: SecureBytes 'Sec -> Int
lengthBytes = ScrubbedBytes -> Int
SecureBytes 'Sec -> Int
forall ba. ByteArrayAccess ba => ba -> Int
B.length
    copyByteArrayToPtr :: forall a. SecureBytes 'Sec -> Ptr a -> IO ()
copyByteArrayToPtr = ScrubbedBytes -> Ptr a -> IO ()
SecureBytes 'Sec -> Ptr a -> IO ()
forall ba p. ByteArrayAccess ba => ba -> Ptr p -> IO ()
forall p. ScrubbedBytes -> Ptr p -> IO ()
B.copyByteArrayToPtr


-- for some functions we use the fact that Block and SecureBlock have the same
-- representation and implementation

unwrap :: SecureBlock marking a -> Block a
unwrap :: forall (marking :: SecurityMarking) a.
SecureBlock marking a -> Block a
unwrap = SecureBlock marking a -> Block a
forall a b. a -> b
unsafeCoerce

wrap :: Block b -> SecureBlock marking b
wrap :: forall b (marking :: SecurityMarking).
Block b -> SecureBlock marking b
wrap = Block b -> SecureBlock marking b
forall a b. a -> b
unsafeCoerce

index :: PrimType ty => SecureBlock marking ty -> Offset ty -> ty
index :: forall ty (marking :: SecurityMarking).
PrimType ty =>
SecureBlock marking ty -> Offset ty -> ty
index = Block ty -> Offset ty -> ty
forall ty. PrimType ty => Block ty -> Offset ty -> ty
blockIndex (Block ty -> Offset ty -> ty)
-> (SecureBlock marking ty -> Block ty)
-> SecureBlock marking ty
-> Offset ty
-> ty
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecureBlock marking ty -> Block ty
forall (marking :: SecurityMarking) a.
SecureBlock marking a -> Block a
unwrap

#ifdef ML_KEM_TESTING
toList :: PrimType ty => SecureBlock marking ty -> [ty]
toList = Block.toList . unwrap
#endif

toNormalForm :: SecureBlock marking ty -> ()
toNormalForm :: forall (marking :: SecurityMarking) ty.
SecureBlock marking ty -> ()
toNormalForm = Block ty -> ()
forall a. NFData a => a -> ()
rnf (Block ty -> ())
-> (SecureBlock marking ty -> Block ty)
-> SecureBlock marking ty
-> ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecureBlock marking ty -> Block ty
forall (marking :: SecurityMarking) a.
SecureBlock marking a -> Block a
unwrap

unsafeCast :: SecureBlock marking a -> SecureBlock marking b
unsafeCast :: forall (marking :: SecurityMarking) a b.
SecureBlock marking a -> SecureBlock marking b
unsafeCast = Block b -> SecureBlock marking b
forall b (marking :: SecurityMarking).
Block b -> SecureBlock marking b
wrap (Block b -> SecureBlock marking b)
-> (SecureBlock marking a -> Block b)
-> SecureBlock marking a
-> SecureBlock marking b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Block a -> Block b
forall a b. Block a -> Block b
Block.unsafeCast (Block a -> Block b)
-> (SecureBlock marking a -> Block a)
-> SecureBlock marking a
-> Block b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. SecureBlock marking a -> Block a
forall (marking :: SecurityMarking) a.
SecureBlock marking a -> Block a
unwrap