-- |
-- Module      : Machine
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- Architecture-dependent utilities, to read/write unaligned machine words
-- in little-endian order
--
{-# LANGUAGE CPP #-}
module Machine
    ( WordM, WordLE, assertMultM, fromLE, toLE, wordBits, wordBytes
    ) where

#include "MachDeps.h"

-- Taken from `bytestring`, a list of architectures known to accept
-- unaligned loads and stores
#if defined(i386_HOST_ARCH) || defined(x86_64_HOST_ARCH)          \
    || ((defined(arm_HOST_ARCH) || defined(aarch64_HOST_ARCH))    \
         && defined(__ARM_FEATURE_UNALIGNED))                     \
    || defined(powerpc_HOST_ARCH) || defined(powerpc64_HOST_ARCH) \
    || defined(powerpc64le_HOST_ARCH)
#define MLKEM_ALLOW_UNALIGNED_OP 1

-- Little-endian conversion in `memory` / `ram` is avoided at compile
-- time only for AMD/Intel, here we will short circuit on ARM too
#if (defined(arm_HOST_ARCH) || defined(aarch64_HOST_ARCH)) \
    && !defined(WORDS_BIGENDIAN)
#define MLKEM_FORCE_LITTLE_ENDIAN_ARCH 1
#endif

#endif

import Control.Exception (assert)

#ifdef MLKEM_ALLOW_UNALIGNED_OP
import qualified Data.Memory.Endian as B
#endif

import Data.Bits
import Data.Word

#ifdef MLKEM_ALLOW_UNALIGNED_OP

-- our preferred word size
#if WORD_SIZE_IN_BITS == 64
type WordM = Word64
#else
type WordM = Word32
#endif

type WordLE = B.LE WordM

fromLE :: WordLE -> WordM
#ifdef MLKEM_FORCE_LITTLE_ENDIAN_ARCH
fromLE = B.unLE  -- unwrap constructor with no byte swapping
#else
fromLE :: WordLE -> WordM
fromLE = WordLE -> WordM
forall a. ByteSwap a => LE a -> a
B.fromLE  -- byte swap if necessary
#endif

toLE :: WordM -> WordLE
#ifdef MLKEM_FORCE_LITTLE_ENDIAN_ARCH
toLE = B.LE  -- wrap constructor with no byte swapping
#else
toLE :: WordM -> WordLE
toLE = WordM -> WordLE
forall a. ByteSwap a => a -> LE a
B.toLE  -- byte swap if necessary
#endif

#else

-- unaligned memory access is not allowed so we fallback to one byte at a time
-- and endianness does not matter

type WordM = Word8
type WordLE = WordM

fromLE :: WordLE -> WordM
fromLE = id

toLE :: WordM -> WordLE
toLE = id

#endif

wordBits :: Int
wordBits :: Int
wordBits = WordM -> Int
forall b. FiniteBits b => b -> Int
finiteBitSize (WordM
0 :: WordM)

wordBytes :: Int
wordBytes :: Int
wordBytes = Int -> Int -> Int
forall a. Integral a => a -> a -> a
div Int
wordBits Int
8

assertMultM :: Int -> a -> a
assertMultM :: forall a. Int -> a -> a
assertMultM Int
n = Bool -> a -> a
forall a. (?callStack::CallStack) => Bool -> a -> a
assert (Int
n Int -> Int -> Int
forall a. Bits a => a -> a -> a
.&. Int
mask Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
0)
  where mask :: Int
mask = Int
wordBytes Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1