{-# LANGUAGE DataKinds #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE RankNTypes #-}
module Builder
( Builder, builderLength, bytes, copyBuilderToPtr, create, run, runRelaxed
, runToBlock, unsafeCreate
) where
import Data.ByteArray (ByteArray)
import Control.Monad.ST
import Control.Monad.ST.Unsafe
import Data.Semigroup
import Data.Word
import Foreign.Ptr (Ptr, castPtr, plusPtr)
import qualified GHC.Exts as Exts
import Base
import Block (Block)
import Marking (Classified, Leak(..), SecurityMarking(..))
import SecureBytes (SecureBytes)
import qualified Block
import qualified ByteArrayST as ST
import qualified SecureBytes
data Builder (marking :: SecurityMarking) = Builder
{ forall (marking :: SecurityMarking). Builder marking -> Int
builderLength :: {-# UNPACK #-} !Int
, forall (marking :: SecurityMarking).
Builder marking -> forall s. Ptr Word8 -> ST s ()
copyBuilderToPtr :: forall s. Ptr Word8 -> ST s ()
}
instance Semigroup (Builder marking) where
Builder marking
b1 <> :: Builder marking -> Builder marking -> Builder marking
<> Builder marking
b2 = Int -> (forall s. Ptr Word8 -> ST s ()) -> Builder marking
forall a (marking :: SecurityMarking).
Int -> (forall s. Ptr a -> ST s ()) -> Builder marking
create (Int
n1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n2) ((forall s. Ptr Word8 -> ST s ()) -> Builder marking)
-> (forall s. Ptr Word8 -> ST s ()) -> Builder marking
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p ->
Builder marking -> forall s. Ptr Word8 -> ST s ()
forall (marking :: SecurityMarking).
Builder marking -> forall s. Ptr Word8 -> ST s ()
copyBuilderToPtr Builder marking
b1 Ptr Word8
p ST s () -> ST s () -> ST s ()
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Builder marking -> forall s. Ptr Word8 -> ST s ()
forall (marking :: SecurityMarking).
Builder marking -> forall s. Ptr Word8 -> ST s ()
copyBuilderToPtr Builder marking
b2 (Ptr Word8
p Ptr Word8 -> Int -> Ptr Word8
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Int
n1)
where
n1 :: Int
n1 = Builder marking -> Int
forall (marking :: SecurityMarking). Builder marking -> Int
builderLength Builder marking
b1
n2 :: Int
n2 = Builder marking -> Int
forall (marking :: SecurityMarking). Builder marking -> Int
builderLength Builder marking
b2
instance Monoid (Builder marking) where
mempty :: Builder marking
mempty = Builder marking
forall (marking :: SecurityMarking). Builder marking
empty
mconcat :: [Builder marking] -> Builder marking
mconcat [Builder marking]
builders = Int -> (forall s. Ptr Word8 -> ST s ()) -> Builder marking
forall a (marking :: SecurityMarking).
Int -> (forall s. Ptr a -> ST s ()) -> Builder marking
create ([Builder marking] -> Int
forall {marking :: SecurityMarking}. [Builder marking] -> Int
getSize ([Builder marking] -> [Builder marking]
forall a. a -> a
Exts.inline [Builder marking]
builders)) ((forall s. Ptr Word8 -> ST s ()) -> Builder marking)
-> (forall s. Ptr Word8 -> ST s ()) -> Builder marking
forall a b. (a -> b) -> a -> b
$
[Builder marking] -> Ptr Word8 -> ST s ()
forall {marking :: SecurityMarking} {s}.
[Builder marking] -> Ptr Word8 -> ST s ()
go ([Builder marking] -> [Builder marking]
forall a. a -> a
Exts.inline [Builder marking]
builders)
where
getSize :: [Builder marking] -> Int
getSize = Sum Int -> Int
forall a. Sum a -> a
getSum (Sum Int -> Int)
-> ([Builder marking] -> Sum Int) -> [Builder marking] -> Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. [Sum Int] -> Sum Int
forall a. Monoid a => [a] -> a
Prelude.mconcat ([Sum Int] -> Sum Int)
-> ([Builder marking] -> [Sum Int]) -> [Builder marking] -> Sum Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Builder marking -> Sum Int) -> [Builder marking] -> [Sum Int]
forall a b. (a -> b) -> [a] -> [b]
map (Int -> Sum Int
forall a. a -> Sum a
Sum (Int -> Sum Int)
-> (Builder marking -> Int) -> Builder marking -> Sum Int
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder marking -> Int
forall (marking :: SecurityMarking). Builder marking -> Int
builderLength)
go :: [Builder marking] -> Ptr Word8 -> ST s ()
go = (Builder marking -> (Ptr Word8 -> ST s ()) -> Ptr Word8 -> ST s ())
-> (Ptr Word8 -> ST s ())
-> [Builder marking]
-> Ptr Word8
-> ST s ()
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr Builder marking -> (Ptr Word8 -> ST s ()) -> Ptr Word8 -> ST s ()
forall {marking :: SecurityMarking} {b} {s} {b}.
Builder marking -> (Ptr b -> ST s b) -> Ptr Word8 -> ST s b
c (ST s () -> Ptr Word8 -> ST s ()
forall a b. a -> b -> a
const (ST s () -> Ptr Word8 -> ST s ())
-> ST s () -> Ptr Word8 -> ST s ()
forall a b. (a -> b) -> a -> b
$ () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ())
c :: Builder marking -> (Ptr b -> ST s b) -> Ptr Word8 -> ST s b
c Builder marking
b Ptr b -> ST s b
k = (Ptr Word8 -> ST s b) -> Ptr Word8 -> ST s b
forall a b. (a -> b) -> a -> b
Exts.oneShot ((Ptr Word8 -> ST s b) -> Ptr Word8 -> ST s b)
-> (Ptr Word8 -> ST s b) -> Ptr Word8 -> ST s b
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
p -> Builder marking -> forall s. Ptr Word8 -> ST s ()
forall (marking :: SecurityMarking).
Builder marking -> forall s. Ptr Word8 -> ST s ()
copyBuilderToPtr Builder marking
b Ptr Word8
p ST s () -> ST s b -> ST s b
forall a b. ST s a -> ST s b -> ST s b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Ptr b -> ST s b
k (Ptr Word8
p Ptr Word8 -> Int -> Ptr b
forall a b. Ptr a -> Int -> Ptr b
`plusPtr` Builder marking -> Int
forall (marking :: SecurityMarking). Builder marking -> Int
builderLength Builder marking
b)
{-# INLINE mconcat #-}
instance Leak Builder
bytes :: Classified marking => SecureBytes marking -> Builder marking
bytes :: forall (marking :: SecurityMarking).
Classified marking =>
SecureBytes marking -> Builder marking
bytes SecureBytes marking
b = Int -> (Ptr (ZonkAny 0) -> IO ()) -> Builder marking
forall a (marking :: SecurityMarking).
Int -> (Ptr a -> IO ()) -> Builder marking
unsafeCreate (SecureBytes marking -> Int
forall (marking :: SecurityMarking).
Classified marking =>
SecureBytes marking -> Int
SecureBytes.length SecureBytes marking
b) (SecureBytes marking -> Ptr (ZonkAny 0) -> IO ()
forall a. SecureBytes marking -> Ptr a -> IO ()
forall (marking :: SecurityMarking) a.
Classified marking =>
SecureBytes marking -> Ptr a -> IO ()
SecureBytes.copyByteArrayToPtr SecureBytes marking
b)
create :: Int -> (forall s. Ptr a -> ST s ()) -> Builder marking
create :: forall a (marking :: SecurityMarking).
Int -> (forall s. Ptr a -> ST s ()) -> Builder marking
create Int
n forall s. Ptr a -> ST s ()
f = Int -> (forall s. Ptr Word8 -> ST s ()) -> Builder marking
forall (marking :: SecurityMarking).
Int -> (forall s. Ptr Word8 -> ST s ()) -> Builder marking
Builder Int
n (Ptr a -> ST s ()
forall s. Ptr a -> ST s ()
f (Ptr a -> ST s ()) -> (Ptr Word8 -> Ptr a) -> Ptr Word8 -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr Word8 -> Ptr a
forall a b. Ptr a -> Ptr b
castPtr)
{-# INLINE create #-}
empty :: Builder marking
empty :: forall (marking :: SecurityMarking). Builder marking
empty = Int -> (forall s. Ptr Word8 -> ST s ()) -> Builder marking
forall (marking :: SecurityMarking).
Int -> (forall s. Ptr Word8 -> ST s ()) -> Builder marking
Builder Int
0 ((forall s. Ptr Word8 -> ST s ()) -> Builder marking)
-> (forall s. Ptr Word8 -> ST s ()) -> Builder marking
forall a b. (a -> b) -> a -> b
$ \Ptr Word8
_ -> () -> ST s ()
forall a. a -> ST s a
forall (m :: * -> *) a. Monad m => a -> m a
return ()
run :: Classified marking => Builder marking -> SecureBytes marking
run :: forall (marking :: SecurityMarking).
Classified marking =>
Builder marking -> SecureBytes marking
run Builder marking
b = Int -> (forall s. Ptr Word8 -> ST s ()) -> SecureBytes marking
forall a.
Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes marking
forall (marking :: SecurityMarking) a.
Classified marking =>
Int -> (forall s. Ptr a -> ST s ()) -> SecureBytes marking
SecureBytes.unsafeCreate (Builder marking -> Int
forall (marking :: SecurityMarking). Builder marking -> Int
builderLength Builder marking
b) (Builder marking -> forall s. Ptr Word8 -> ST s ()
forall (marking :: SecurityMarking).
Builder marking -> forall s. Ptr Word8 -> ST s ()
copyBuilderToPtr Builder marking
b)
runRelaxed :: ByteArray ba => Builder Pub -> ba
runRelaxed :: forall ba. ByteArray ba => Builder 'Pub -> ba
runRelaxed Builder 'Pub
b = Int -> (forall s. Ptr Word8 -> ST s ()) -> ba
forall ba p.
ByteArray ba =>
Int -> (forall s. Ptr p -> ST s ()) -> ba
ST.unsafeCreate (Builder 'Pub -> Int
forall (marking :: SecurityMarking). Builder marking -> Int
builderLength Builder 'Pub
b) (Builder 'Pub -> forall s. Ptr Word8 -> ST s ()
forall (marking :: SecurityMarking).
Builder marking -> forall s. Ptr Word8 -> ST s ()
copyBuilderToPtr Builder 'Pub
b)
runToBlock :: Builder Pub -> Block Word8
runToBlock :: Builder 'Pub -> Block Word8
runToBlock Builder 'Pub
b = (forall s. ST s (Block Word8)) -> Block Word8
forall a. (forall s. ST s a) -> a
runST ((forall s. ST s (Block Word8)) -> Block Word8)
-> (forall s. ST s (Block Word8)) -> Block Word8
forall a b. (a -> b) -> a -> b
$ do
mb <- CountOf Word8 -> ST s (MutableBlock Word8 (PrimState (ST s)))
forall (prim :: * -> *) ty.
(PrimMonad prim, PrimType ty) =>
CountOf ty -> prim (MutableBlock ty (PrimState prim))
Block.newPinned (Int -> CountOf Word8
forall ty. Int -> CountOf ty
CountOf (Int -> CountOf Word8) -> Int -> CountOf Word8
forall a b. (a -> b) -> a -> b
$ Builder 'Pub -> Int
forall (marking :: SecurityMarking). Builder marking -> Int
builderLength Builder 'Pub
b)
copyBuilderToPtr b (Block.mutableContents mb)
Block.unsafeFreeze mb
unsafeCreate :: Int -> (Ptr a -> IO ()) -> Builder marking
unsafeCreate :: forall a (marking :: SecurityMarking).
Int -> (Ptr a -> IO ()) -> Builder marking
unsafeCreate Int
n Ptr a -> IO ()
f = Int -> (forall s. Ptr a -> ST s ()) -> Builder marking
forall a (marking :: SecurityMarking).
Int -> (forall s. Ptr a -> ST s ()) -> Builder marking
create Int
n (IO () -> ST s ()
forall a s. IO a -> ST s a
unsafeIOToST (IO () -> ST s ()) -> (Ptr a -> IO ()) -> Ptr a -> ST s ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ptr a -> IO ()
f)
{-# INLINE unsafeCreate #-}