module Network.TLS.MAC (
    macSSL,
    hmac,
    prf_MD5,
    prf_SHA1,
    prf_SHA256,
    prf_TLS,
    prf_MD5SHA1,
    PRF,
) where

import Data.ByteArray (ByteArray, ByteArrayAccess)
import qualified Data.ByteArray as BA

import Network.TLS.Crypto
import Network.TLS.Imports
import Network.TLS.Types

type HMAC = Secret -> ByteString -> Secret

macSSL :: Hash -> HMAC
macSSL :: Hash -> HMAC
macSSL Hash
alg Secret
secret ByteString
msg =
    Secret -> Secret
f (Secret -> Secret) -> Secret -> Secret
forall a b. (a -> b) -> a -> b
$
        [Secret] -> Secret
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat
            [ Secret
secret
            , Int -> Word8 -> Secret
forall ba. ByteArray ba => Int -> Word8 -> ba
BA.replicate Int
padLen Word8
0x5c
            , Secret -> Secret
f (Secret -> Secret) -> Secret -> Secret
forall a b. (a -> b) -> a -> b
$ [Secret] -> Secret
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat [Secret
secret, Int -> Word8 -> Secret
forall ba. ByteArray ba => Int -> Word8 -> ba
BA.replicate Int
padLen Word8
0x36, ByteString -> Secret
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
msg]
            ]
  where
    padLen :: Int
padLen = case Hash
alg of
        Hash
MD5 -> Int
48
        Hash
SHA1 -> Int
40
        Hash
_ -> [Char] -> Int
forall a. HasCallStack => [Char] -> a
error ([Char]
"internal error: macSSL called with " [Char] -> [Char] -> [Char]
forall a. [a] -> [a] -> [a]
++ Hash -> [Char]
forall a. Show a => a -> [Char]
show Hash
alg)
    f :: Secret -> Secret
f = Hash -> Secret -> Secret
forall ba. (ByteArray ba, ByteArrayAccess ba) => Hash -> ba -> ba
hash Hash
alg

hmac :: (ByteArray ba, ByteArrayAccess ba) => Hash -> ba -> ByteString -> ba
hmac :: forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> ba -> ByteString -> ba
hmac Hash
alg ba
secret ByteString
msg = ba -> ba
f (ba -> ba) -> ba -> ba
forall a b. (a -> b) -> a -> b
$ ba -> ba -> ba
forall bs. ByteArray bs => bs -> bs -> bs
BA.append ba
opad (ba -> ba
f (ba -> ba) -> ba -> ba
forall a b. (a -> b) -> a -> b
$ ba -> ba -> ba
forall bs. ByteArray bs => bs -> bs -> bs
BA.append ba
ipad (ba -> ba) -> ba -> ba
forall a b. (a -> b) -> a -> b
$ ByteString -> ba
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
msg)
  where
    opad :: ba
opad = (Word8 -> Word8) -> ba -> ba
forall ba.
(ByteArrayAccess ba, ByteArray ba) =>
(Word8 -> Word8) -> ba -> ba
BA.map (Word8
0x5c Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor`) ba
k'
    ipad :: ba
ipad = (Word8 -> Word8) -> ba -> ba
forall ba.
(ByteArrayAccess ba, ByteArray ba) =>
(Word8 -> Word8) -> ba -> ba
BA.map (Word8
0x36 Word8 -> Word8 -> Word8
forall a. Bits a => a -> a -> a
`xor`) ba
k'

    f :: ba -> ba
f = Hash -> ba -> ba
forall ba. (ByteArray ba, ByteArrayAccess ba) => Hash -> ba -> ba
hash Hash
alg
    bl :: Int
bl = Hash -> Int
hashBlockSize Hash
alg

    k' :: ba
k' = ba -> ba -> ba
forall bs. ByteArray bs => bs -> bs -> bs
BA.append ba
kt ba
pad
      where
        kt :: ba
kt = if ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ba
secret Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
> Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bl then ba -> ba
f ba
secret else ba
secret
        pad :: ba
pad = Int -> Word8 -> ba
forall ba. ByteArray ba => Int -> Word8 -> ba
BA.replicate (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
bl Int -> Int -> Int
forall a. Num a => a -> a -> a
- ba -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length ba
kt) Word8
0

hmacIter
    :: HMAC -> Secret -> ByteString -> ByteString -> Int -> [Secret]
hmacIter :: HMAC -> Secret -> ByteString -> ByteString -> Int -> [Secret]
hmacIter HMAC
f Secret
secret ByteString
seed ByteString
aprev Int
len =
    let an :: Secret
an = HMAC
f Secret
secret ByteString
aprev
     in let out :: Secret
out = HMAC
f Secret
secret ([Secret] -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat [Secret
an, ByteString -> Secret
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert ByteString
seed])
         in let digestsize :: Int
digestsize = Secret -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length Secret
out
             in if Int
digestsize Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
len
                    then [Int -> Secret -> Secret
forall bs. ByteArray bs => Int -> bs -> bs
BA.take (Int -> Int
forall a b. (Integral a, Num b) => a -> b
fromIntegral Int
len) Secret
out]
                    else Secret
out Secret -> [Secret] -> [Secret]
forall a. a -> [a] -> [a]
: HMAC -> Secret -> ByteString -> ByteString -> Int -> [Secret]
hmacIter HMAC
f Secret
secret ByteString
seed (Secret -> ByteString
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
bin -> bout
BA.convert Secret
an) (Int
len Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
digestsize)

type PRF = Secret -> ByteString -> Int -> Secret

prf_SHA1 :: PRF
prf_SHA1 :: PRF
prf_SHA1 Secret
secret ByteString
seed Int
len = [Secret] -> Secret
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat ([Secret] -> Secret) -> [Secret] -> Secret
forall a b. (a -> b) -> a -> b
$ HMAC -> Secret -> ByteString -> ByteString -> Int -> [Secret]
hmacIter (Hash -> HMAC
forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> ba -> ByteString -> ba
hmac Hash
SHA1) Secret
secret ByteString
seed ByteString
seed Int
len

prf_MD5 :: PRF
prf_MD5 :: PRF
prf_MD5 Secret
secret ByteString
seed Int
len = [Secret] -> Secret
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat ([Secret] -> Secret) -> [Secret] -> Secret
forall a b. (a -> b) -> a -> b
$ HMAC -> Secret -> ByteString -> ByteString -> Int -> [Secret]
hmacIter (Hash -> HMAC
forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> ba -> ByteString -> ba
hmac Hash
MD5) Secret
secret ByteString
seed ByteString
seed Int
len

prf_MD5SHA1 :: PRF
prf_MD5SHA1 :: PRF
prf_MD5SHA1 Secret
secret ByteString
seed Int
len =
    Secret -> Secret -> Secret
forall a b c.
(ByteArrayAccess a, ByteArrayAccess b, ByteArray c) =>
a -> b -> c
BA.xor (PRF
prf_MD5 Secret
s1 ByteString
seed Int
len) (PRF
prf_SHA1 Secret
s2 ByteString
seed Int
len)
  where
    slen :: Int
slen = Secret -> Int
forall ba. ByteArrayAccess ba => ba -> Int
BA.length Secret
secret
    s1 :: Secret
s1 = Int -> Secret -> Secret
forall bs. ByteArray bs => Int -> bs -> bs
BA.take (Int
slen Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
slen Int -> Int -> Int
forall a. Integral a => a -> a -> a
`mod` Int
2) Secret
secret
    s2 :: Secret
s2 = Int -> Secret -> Secret
forall bs. ByteArray bs => Int -> bs -> bs
BA.drop (Int
slen Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
2) Secret
secret

prf_SHA256 :: PRF
prf_SHA256 :: PRF
prf_SHA256 Secret
secret ByteString
seed Int
len = [Secret] -> Secret
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat ([Secret] -> Secret) -> [Secret] -> Secret
forall a b. (a -> b) -> a -> b
$ HMAC -> Secret -> ByteString -> ByteString -> Int -> [Secret]
hmacIter (Hash -> HMAC
forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> ba -> ByteString -> ba
hmac Hash
SHA256) Secret
secret ByteString
seed ByteString
seed Int
len

-- | For now we ignore the version, but perhaps some day the PRF will depend
-- not only on the cipher PRF algorithm, but also on the protocol version.
prf_TLS :: Version -> Hash -> PRF
prf_TLS :: Version -> Hash -> PRF
prf_TLS Version
_ Hash
halg Secret
secret ByteString
seed Int
len =
    [Secret] -> Secret
forall bin bout.
(ByteArrayAccess bin, ByteArray bout) =>
[bin] -> bout
BA.concat ([Secret] -> Secret) -> [Secret] -> Secret
forall a b. (a -> b) -> a -> b
$ HMAC -> Secret -> ByteString -> ByteString -> Int -> [Secret]
hmacIter (Hash -> HMAC
forall ba.
(ByteArray ba, ByteArrayAccess ba) =>
Hash -> ba -> ByteString -> ba
hmac Hash
halg) Secret
secret ByteString
seed ByteString
seed Int
len