-- |
-- Module      : Math
-- License     : BSD-3-Clause
-- Copyright   : (c) 2025 Olivier Chéron
--
-- Type classes that define additive and multiplicative operations, as well as
-- a multiply-then-add operation that will often optimize chaining.
--
-- The module also defines a non-homogenous multiplication that combines
-- typically a public operand (left) with a secret operand (right), producing a
-- secret output.
--
{-# LANGUAGE CPP #-}
{-# LANGUAGE MultiParamTypeClasses #-}
module Math
    ( Add(..), Mul(..), MulAdd(..), BiMul(..), BiMulAdd(..)
    ) where

#if !(MIN_VERSION_base(4,20,0))
import Data.List (foldl')
#endif

infixl 7 .*
infixr 7 ..*
infixl 6 .+, .-

class Add a where
    zero :: a
    (.+) :: a -> a -> a
    (.-) :: a -> a -> a
    neg :: a -> a

class Add a => Mul a where
    one :: a
    (.*) :: a -> a -> a

class Mul a => MulAdd a where
    -- invariant: mulAdd a b c == a .* b .+ c
    mulAdd :: a -> a -> a -> a

class Add a => BiMul b a where
    (..*) :: b -> a -> a

class BiMul b a => BiMulAdd b a where
    {-# MINIMAL biMulAdd | biMulFold #-}

    -- invariant: biMulAdd a b c == a ..* b .+ c
    biMulAdd :: b -> a -> a -> a
    biMulAdd b
b a
a a
x = a -> [(b, a)] -> a
forall b a (t :: * -> *).
(BiMulAdd b a, Foldable t) =>
a -> t (b, a) -> a
forall (t :: * -> *). Foldable t => a -> t (b, a) -> a
biMulFold a
x [(b
b, a
a)]
    {-# INLINE biMulAdd #-}

    -- repeated biMulAdd
    biMulFold :: Foldable t => a -> t (b, a) -> a
    biMulFold = (a -> (b, a) -> a) -> a -> t (b, a) -> a
forall b a. (b -> a -> b) -> b -> t a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' ((a -> (b, a) -> a) -> a -> t (b, a) -> a)
-> (a -> (b, a) -> a) -> a -> t (b, a) -> a
forall a b. (a -> b) -> a -> b
$ \a
c (b
b, a
a) -> b -> a -> a -> a
forall b a. BiMulAdd b a => b -> a -> a -> a
biMulAdd b
b a
a a
c
    {-# INLINE biMulFold #-}