{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE CPP #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE TypeOperators #-}
{-# OPTIONS_GHC -fno-warn-orphans #-}

-- | Defines the core functionality of this package. This package is
-- distinguished from Yesod.Persist in that the latter additionally exports the
-- persistent modules themselves.
module Yesod.Persist.Core
    ( YesodPersist (..)
    , defaultRunDB
    , YesodPersistRunner (..)
    , defaultGetDBRunner
    , DBRunner (..)
    , runDBSource
    , respondSourceDB
    , YesodDB
    , get404
    , getBy404
    , insert400
    , insert400_
    ) where

import Database.Persist
import Control.Monad.Trans.Reader (ReaderT, runReaderT)

import Yesod.Core
import Data.Conduit
import Blaze.ByteString.Builder (Builder)
import Data.Pool
import Control.Monad.Trans.Resource
import Control.Exception (throwIO)
import Yesod.Core.Types (HandlerContents (HCError))
import qualified Database.Persist.Sql as SQL
#if MIN_VERSION_persistent(2,13,0)
import Data.List.NonEmpty (toList)
import qualified Database.Persist.SqlBackend.Internal as SQL
#endif

unSqlPersistT :: a -> a
unSqlPersistT :: forall a. a -> a
unSqlPersistT = a -> a
forall a. a -> a
id

type YesodDB site = ReaderT (YesodPersistBackend site) (HandlerFor site)

class Monad (YesodDB site) => YesodPersist site where
    type YesodPersistBackend site
    -- | Allows you to execute database actions within Yesod Handlers. For databases that support it, code inside the action will run as an atomic transaction.
    --
    --
    -- ==== __Example Usage__
    --
    -- > userId <- runDB $ do
    -- >   userId <- insert $ User "username" "email@example.com"
    -- >   insert_ $ UserPreferences userId True
    -- >   pure userId
    runDB :: YesodDB site a -> HandlerFor site a

-- | Helper for creating 'runDB'.
--
-- Since 1.2.0
defaultRunDB :: PersistConfig c
             => (site -> c)
             -> (site -> PersistConfigPool c)
             -> PersistConfigBackend c (HandlerFor site) a
             -> HandlerFor site a
defaultRunDB :: forall c site a.
PersistConfig c =>
(site -> c)
-> (site -> PersistConfigPool c)
-> PersistConfigBackend c (HandlerFor site) a
-> HandlerFor site a
defaultRunDB site -> c
getConfig site -> PersistConfigPool c
getPool PersistConfigBackend c (HandlerFor site) a
f = do
    master <- HandlerFor site site
HandlerFor site (HandlerSite (HandlerFor site))
forall (m :: * -> *). MonadHandler m => m (HandlerSite m)
getYesod
    Database.Persist.runPool
        (getConfig master)
        f
        (getPool master)

-- |
--
-- Since 1.2.0
class YesodPersist site => YesodPersistRunner site where
    -- | This function differs from 'runDB' in that it returns a database
    -- runner function, as opposed to simply running a single action. This will
    -- usually mean that a connection is taken from a pool and then reused for
    -- each invocation. This can be useful for creating streaming responses;
    -- see 'runDBSource'.
    --
    -- It additionally returns a cleanup function to free the connection.  If
    -- your code finishes successfully, you /must/ call this cleanup to
    -- indicate changes should be committed. Otherwise, for SQL backends at
    -- least, a rollback will be used instead.
    --
    -- Since 1.2.0
    getDBRunner :: HandlerFor site (DBRunner site, HandlerFor site ())

newtype DBRunner site = DBRunner
    { forall site.
DBRunner site -> forall a. YesodDB site a -> HandlerFor site a
runDBRunner :: forall a. YesodDB site a -> HandlerFor site a
    }

-- | Helper for implementing 'getDBRunner'.
--
-- Since 1.2.0
defaultGetDBRunner :: (SQL.IsSqlBackend backend, YesodPersistBackend site ~ backend)
                   => (site -> Pool backend)
                   -> HandlerFor site (DBRunner site, HandlerFor site ())
defaultGetDBRunner :: forall backend site.
(IsSqlBackend backend, YesodPersistBackend site ~ backend) =>
(site -> Pool backend)
-> HandlerFor site (DBRunner site, HandlerFor site ())
defaultGetDBRunner site -> Pool backend
getPool = do
    pool <- (site -> Pool backend)
-> HandlerFor site site -> HandlerFor site (Pool backend)
forall a b. (a -> b) -> HandlerFor site a -> HandlerFor site b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap site -> Pool backend
getPool HandlerFor site site
HandlerFor site (HandlerSite (HandlerFor site))
forall (m :: * -> *). MonadHandler m => m (HandlerSite m)
getYesod
    let withPrep backend
conn BaseBackend backend -> (Text -> IO Statement) -> t
f = BaseBackend backend -> (Text -> IO Statement) -> t
f (backend -> BaseBackend backend
forall backend.
HasPersistBackend backend =>
backend -> BaseBackend backend
persistBackend backend
conn) (SqlBackend -> Text -> IO Statement
SQL.getStmtConn (SqlBackend -> Text -> IO Statement)
-> SqlBackend -> Text -> IO Statement
forall a b. (a -> b) -> a -> b
$ backend -> BaseBackend backend
forall backend.
HasPersistBackend backend =>
backend -> BaseBackend backend
persistBackend backend
conn)
    (relKey, (conn, local)) <- allocate
        (do
            (conn, local) <- takeResource pool
#if MIN_VERSION_persistent(2,9,0)
            withPrep conn (\BaseBackend backend
c Text -> IO Statement
f -> SqlBackend
-> (Text -> IO Statement) -> Maybe IsolationLevel -> IO ()
SQL.connBegin BaseBackend backend
SqlBackend
c Text -> IO Statement
f Maybe IsolationLevel
forall a. Maybe a
Nothing)
#else
            withPrep conn SQL.connBegin
#endif
            return (conn, local)
            )
        (\(backend
conn, LocalPool backend
local) -> do
            backend
-> (BaseBackend backend -> (Text -> IO Statement) -> IO ())
-> IO ()
forall {backend} {t}.
(BaseBackend backend ~ SqlBackend, HasPersistBackend backend) =>
backend
-> (BaseBackend backend -> (Text -> IO Statement) -> t) -> t
withPrep backend
conn BaseBackend backend -> (Text -> IO Statement) -> IO ()
SqlBackend -> (Text -> IO Statement) -> IO ()
SQL.connRollback
            Pool backend -> LocalPool backend -> backend -> IO ()
forall a. Pool a -> LocalPool a -> a -> IO ()
destroyResource Pool backend
pool LocalPool backend
local backend
conn)

    let cleanup = IO () -> HandlerFor site ()
forall a. IO a -> HandlerFor site a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO () -> HandlerFor site ()) -> IO () -> HandlerFor site ()
forall a b. (a -> b) -> a -> b
$ do
            backend
-> (BaseBackend backend -> (Text -> IO Statement) -> IO ())
-> IO ()
forall {backend} {t}.
(BaseBackend backend ~ SqlBackend, HasPersistBackend backend) =>
backend
-> (BaseBackend backend -> (Text -> IO Statement) -> t) -> t
withPrep backend
conn BaseBackend backend -> (Text -> IO Statement) -> IO ()
SqlBackend -> (Text -> IO Statement) -> IO ()
SQL.connCommit
            LocalPool backend -> backend -> IO ()
forall a. LocalPool a -> a -> IO ()
putResource LocalPool backend
local backend
conn
            _ <- ReleaseKey -> IO (Maybe (IO ()))
forall (m :: * -> *). MonadIO m => ReleaseKey -> m (Maybe (IO ()))
unprotect ReleaseKey
relKey
            return ()

    return (DBRunner $ \YesodDB site a
x -> ReaderT backend (HandlerFor site) a -> backend -> HandlerFor site a
forall r (m :: * -> *) a. ReaderT r m a -> r -> m a
runReaderT (ReaderT backend (HandlerFor site) a
-> ReaderT backend (HandlerFor site) a
forall a. a -> a
unSqlPersistT ReaderT backend (HandlerFor site) a
YesodDB site a
x) backend
conn, cleanup)

-- | Like 'runDB', but transforms a @Source@. See 'respondSourceDB' for an
-- example, practical use case.
--
-- Since 1.2.0
runDBSource :: YesodPersistRunner site
            => ConduitT () a (YesodDB site) ()
            -> ConduitT () a (HandlerFor site) ()
runDBSource :: forall site a.
YesodPersistRunner site =>
ConduitT () a (YesodDB site) ()
-> ConduitT () a (HandlerFor site) ()
runDBSource ConduitT
  () a (ReaderT (YesodPersistBackend site) (HandlerFor site)) ()
src = do
    (dbrunner, cleanup) <- HandlerFor site (DBRunner site, HandlerFor site ())
-> ConduitT
     () a (HandlerFor site) (DBRunner site, HandlerFor site ())
forall (m :: * -> *) a. Monad m => m a -> ConduitT () a m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift HandlerFor site (DBRunner site, HandlerFor site ())
forall site.
YesodPersistRunner site =>
HandlerFor site (DBRunner site, HandlerFor site ())
getDBRunner
    transPipe (runDBRunner dbrunner) src
    lift cleanup

-- | Extends 'respondSource' to create a streaming database response body.
respondSourceDB :: YesodPersistRunner site
                => ContentType
                -> ConduitT () (Flush Builder) (YesodDB site) ()
                -> HandlerFor site TypedContent
respondSourceDB :: forall site.
YesodPersistRunner site =>
ContentType
-> ConduitT () (Flush Builder) (YesodDB site) ()
-> HandlerFor site TypedContent
respondSourceDB ContentType
ctype = ContentType
-> ConduitT () (Flush Builder) (HandlerFor site) ()
-> HandlerFor site TypedContent
forall site.
ContentType
-> ConduitT () (Flush Builder) (HandlerFor site) ()
-> HandlerFor site TypedContent
respondSource ContentType
ctype (ConduitT () (Flush Builder) (HandlerFor site) ()
 -> HandlerFor site TypedContent)
-> (ConduitT
      ()
      (Flush Builder)
      (ReaderT (YesodPersistBackend site) (HandlerFor site))
      ()
    -> ConduitT () (Flush Builder) (HandlerFor site) ())
-> ConduitT
     ()
     (Flush Builder)
     (ReaderT (YesodPersistBackend site) (HandlerFor site))
     ()
-> HandlerFor site TypedContent
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ConduitT
  ()
  (Flush Builder)
  (ReaderT (YesodPersistBackend site) (HandlerFor site))
  ()
-> ConduitT () (Flush Builder) (HandlerFor site) ()
forall site a.
YesodPersistRunner site =>
ConduitT () a (YesodDB site) ()
-> ConduitT () a (HandlerFor site) ()
runDBSource

-- | Get the given entity by ID, or return a 404 not found if it doesn't exist.
get404 :: (MonadIO m, PersistStoreRead backend, PersistRecordBackend val backend)
       => Key val
       -> ReaderT backend m val
get404 :: forall (m :: * -> *) backend val.
(MonadIO m, PersistStoreRead backend,
 PersistRecordBackend val backend) =>
Key val -> ReaderT backend m val
get404 Key val
key = do
    mres <- Key val -> ReaderT backend m (Maybe val)
forall backend record (m :: * -> *).
(PersistStoreRead backend, MonadIO m,
 PersistRecordBackend record backend) =>
Key record -> ReaderT backend m (Maybe record)
forall record (m :: * -> *).
(MonadIO m, PersistRecordBackend record backend) =>
Key record -> ReaderT backend m (Maybe record)
get Key val
key
    case mres of
        Maybe val
Nothing -> ReaderT backend m val
forall (m :: * -> *) a. MonadIO m => m a
notFound'
        Just val
res -> val -> ReaderT backend m val
forall a. a -> ReaderT backend m a
forall (m :: * -> *) a. Monad m => a -> m a
return val
res

-- | Get the given entity by unique key, or return a 404 not found if it doesn't
--   exist.
getBy404 :: (PersistUniqueRead backend, PersistRecordBackend val backend, MonadIO m)
         => Unique val
         -> ReaderT backend m (Entity val)
getBy404 :: forall backend val (m :: * -> *).
(PersistUniqueRead backend, PersistRecordBackend val backend,
 MonadIO m) =>
Unique val -> ReaderT backend m (Entity val)
getBy404 Unique val
key = do
    mres <- Unique val -> ReaderT backend m (Maybe (Entity val))
forall backend record (m :: * -> *).
(PersistUniqueRead backend, MonadIO m,
 PersistRecordBackend record backend) =>
Unique record -> ReaderT backend m (Maybe (Entity record))
forall record (m :: * -> *).
(MonadIO m, PersistRecordBackend record backend) =>
Unique record -> ReaderT backend m (Maybe (Entity record))
getBy Unique val
key
    case mres of
        Maybe (Entity val)
Nothing -> ReaderT backend m (Entity val)
forall (m :: * -> *) a. MonadIO m => m a
notFound'
        Just Entity val
res -> Entity val -> ReaderT backend m (Entity val)
forall a. a -> ReaderT backend m a
forall (m :: * -> *) a. Monad m => a -> m a
return Entity val
res

-- | Create a new record in the database, returning an automatically
-- created key, or raise a 400 bad request if a uniqueness constraint
-- is violated.
--
-- @since 1.4.1
insert400
    :: ( MonadIO m
       , PersistUniqueWrite backend
       , PersistRecordBackend val backend
#if MIN_VERSION_persistent(2,14,0)
       , SafeToInsert val
#endif
       )
    => val
    -> ReaderT backend m (Key val)
insert400 :: forall (m :: * -> *) backend val.
(MonadIO m, PersistUniqueWrite backend,
 PersistRecordBackend val backend, SafeToInsert val) =>
val -> ReaderT backend m (Key val)
insert400 val
datum = do
    conflict <- val -> ReaderT backend m (Maybe (Unique val))
forall record backend (m :: * -> *).
(MonadIO m, PersistRecordBackend record backend,
 PersistUniqueRead backend) =>
record -> ReaderT backend m (Maybe (Unique record))
checkUnique val
datum
    case conflict of
        Just Unique val
unique ->
            Texts -> ReaderT backend m (Key val)
forall (m :: * -> *) a. MonadIO m => Texts -> m a
badRequest' (Texts -> ReaderT backend m (Key val))
-> Texts -> ReaderT backend m (Key val)
forall a b. (a -> b) -> a -> b
$ ((FieldNameHS, FieldNameDB) -> Text)
-> [(FieldNameHS, FieldNameDB)] -> Texts
forall a b. (a -> b) -> [a] -> [b]
map (FieldNameHS -> Text
getName (FieldNameHS -> Text)
-> ((FieldNameHS, FieldNameDB) -> FieldNameHS)
-> (FieldNameHS, FieldNameDB)
-> Text
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (FieldNameHS, FieldNameDB) -> FieldNameHS
forall a b. (a, b) -> a
fst) ([(FieldNameHS, FieldNameDB)] -> Texts)
-> [(FieldNameHS, FieldNameDB)] -> Texts
forall a b. (a -> b) -> a -> b
$ NonEmpty (FieldNameHS, FieldNameDB) -> [(FieldNameHS, FieldNameDB)]
forall {a}. NonEmpty a -> [a]
mkList (NonEmpty (FieldNameHS, FieldNameDB)
 -> [(FieldNameHS, FieldNameDB)])
-> NonEmpty (FieldNameHS, FieldNameDB)
-> [(FieldNameHS, FieldNameDB)]
forall a b. (a -> b) -> a -> b
$ Unique val -> NonEmpty (FieldNameHS, FieldNameDB)
forall record.
PersistEntity record =>
Unique record -> NonEmpty (FieldNameHS, FieldNameDB)
persistUniqueToFieldNames Unique val
unique
        Maybe (Unique val)
Nothing -> val -> ReaderT backend m (Key val)
forall backend record (m :: * -> *).
(PersistStoreWrite backend, MonadIO m,
 PersistRecordBackend record backend, SafeToInsert record) =>
record -> ReaderT backend m (Key record)
forall record (m :: * -> *).
(MonadIO m, PersistRecordBackend record backend,
 SafeToInsert record) =>
record -> ReaderT backend m (Key record)
insert val
datum
  where
#if MIN_VERSION_persistent(2,12,0)
    getName :: FieldNameHS -> Text
getName = FieldNameHS -> Text
unFieldNameHS
#else
    getName = unHaskellName
#endif
#if MIN_VERSION_persistent(2,13,0)
    mkList :: NonEmpty a -> [a]
mkList = NonEmpty a -> [a]
forall {a}. NonEmpty a -> [a]
toList
#else
    mkList = id
#endif

-- | Same as 'insert400', but doesn’t return a key.
--
-- @since 1.4.1
insert400_ :: ( MonadIO m
              , PersistUniqueWrite backend
              , PersistRecordBackend val backend
#if MIN_VERSION_persistent(2,14,0)
              , SafeToInsert val
#endif
              )
           => val
           -> ReaderT backend m ()
insert400_ :: forall (m :: * -> *) backend val.
(MonadIO m, PersistUniqueWrite backend,
 PersistRecordBackend val backend, SafeToInsert val) =>
val -> ReaderT backend m ()
insert400_ val
datum = val -> ReaderT backend m (Key val)
forall (m :: * -> *) backend val.
(MonadIO m, PersistUniqueWrite backend,
 PersistRecordBackend val backend, SafeToInsert val) =>
val -> ReaderT backend m (Key val)
insert400 val
datum ReaderT backend m (Key val)
-> ReaderT backend m () -> ReaderT backend m ()
forall a b.
ReaderT backend m a -> ReaderT backend m b -> ReaderT backend m b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> () -> ReaderT backend m ()
forall a. a -> ReaderT backend m a
forall (m :: * -> *) a. Monad m => a -> m a
return ()

-- | Should be equivalent to @lift . notFound@, but there's an apparent bug in
-- GHC 7.4.2 that leads to segfaults. This is a workaround.
notFound' :: MonadIO m => m a
notFound' :: forall (m :: * -> *) a. MonadIO m => m a
notFound' = IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> IO a -> m a
forall a b. (a -> b) -> a -> b
$ HandlerContents -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (HandlerContents -> IO a) -> HandlerContents -> IO a
forall a b. (a -> b) -> a -> b
$ ErrorResponse -> HandlerContents
HCError ErrorResponse
NotFound

-- | Constructed like 'notFound'', and for the same reasons.
badRequest' :: MonadIO m => Texts -> m a
badRequest' :: forall (m :: * -> *) a. MonadIO m => Texts -> m a
badRequest' = IO a -> m a
forall a. IO a -> m a
forall (m :: * -> *) a. MonadIO m => IO a -> m a
liftIO (IO a -> m a) -> (Texts -> IO a) -> Texts -> m a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. HandlerContents -> IO a
forall e a. (HasCallStack, Exception e) => e -> IO a
throwIO (HandlerContents -> IO a)
-> (Texts -> HandlerContents) -> Texts -> IO a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ErrorResponse -> HandlerContents
HCError (ErrorResponse -> HandlerContents)
-> (Texts -> ErrorResponse) -> Texts -> HandlerContents
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Texts -> ErrorResponse
InvalidArgs