{-# LANGUAGE LambdaCase #-}

module Transformations.Ll.Ll (llAst) where

import Control.Monad.State (State, gets, modify, runState)
import qualified Data.List.NonEmpty as NE
import Data.Text (pack)
import MonadUtils
import qualified Transformations.Ll.Lfr as Lfr
import qualified Transformations.Simplifier.SimplifiedAst as Ast
import qualified Trees.Common as Common

-- * AST Lambda Lifter

-- | Convert AST to its lambda-free representation.
llAst :: Ast.Program -> Lfr.Program
llAst :: Program -> Program
llAst (Ast.Program [Declaration]
gDecls IdCnt
cnt) =
  let ([[GlobalDeclaration]]
gDecls', Env [FunDeclaration]
_ IdCnt
cnt') = State Env [[GlobalDeclaration]]
-> Env -> ([[GlobalDeclaration]], Env)
forall s a. State s a -> s -> (a, s)
runState ((Declaration -> StateT Env Identity [GlobalDeclaration])
-> [Declaration] -> State Env [[GlobalDeclaration]]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Declaration -> StateT Env Identity [GlobalDeclaration]
llGDecl [Declaration]
gDecls) ([FunDeclaration] -> IdCnt -> Env
Env [] IdCnt
cnt)
   in [GlobalDeclaration] -> IdCnt -> Program
Lfr.Program ([[GlobalDeclaration]] -> [GlobalDeclaration]
forall (t :: * -> *) a. Foldable t => t [a] -> [a]
concat [[GlobalDeclaration]]
gDecls') IdCnt
cnt'

-- * Internal

-- ** Lambda Lifter State

type LlState = State Env

data Env = Env
  { Env -> [FunDeclaration]
genFunDecls :: [FunDeclaration],
    Env -> IdCnt
idCnt :: Common.IdCnt
  }

data FunDeclaration = FunDecl Common.Identifier' [Common.Identifier'] Lfr.Expression

-- ** Lambda Lifters

llGDecl :: Ast.Declaration -> LlState [Lfr.GlobalDeclaration]
llGDecl :: Declaration -> StateT Env Identity [GlobalDeclaration]
llGDecl Declaration
decl = do
  GlobalDeclaration
decl' <- case Declaration
decl of
    Ast.DeclVar Identifier'
ident Expression
value ->
      (Expression -> GlobalDeclaration)
-> Expression -> StateT Env Identity GlobalDeclaration
forall a. (Expression -> a) -> Expression -> LlState a
ll1 (VarDeclaration -> GlobalDeclaration
Lfr.GlobVarDecl (VarDeclaration -> GlobalDeclaration)
-> (Expression -> VarDeclaration)
-> Expression
-> GlobalDeclaration
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Identifier' -> Expression -> VarDeclaration
Lfr.VarDecl Identifier'
ident) Expression
value
    Ast.DeclFun Identifier'
ident IsRec
_ (Ast.Fun NonEmpty Identifier'
params Expression
body) ->
      (Expression -> GlobalDeclaration)
-> Expression -> StateT Env Identity GlobalDeclaration
forall a. (Expression -> a) -> Expression -> LlState a
ll1 (Identifier' -> [Identifier'] -> Expression -> GlobalDeclaration
Lfr.GlobFunDecl Identifier'
ident (NonEmpty Identifier' -> [Identifier']
forall a. NonEmpty a -> [a]
NE.toList NonEmpty Identifier'
params)) Expression
body
  [GlobalDeclaration]
genDecls <- StateT Env Identity [GlobalDeclaration]
collectGenFunDecls
  [GlobalDeclaration] -> StateT Env Identity [GlobalDeclaration]
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return ([GlobalDeclaration] -> StateT Env Identity [GlobalDeclaration])
-> [GlobalDeclaration] -> StateT Env Identity [GlobalDeclaration]
forall a b. (a -> b) -> a -> b
$ [GlobalDeclaration] -> [GlobalDeclaration]
forall a. [a] -> [a]
reverse (GlobalDeclaration
decl' GlobalDeclaration -> [GlobalDeclaration] -> [GlobalDeclaration]
forall a. a -> [a] -> [a]
: [GlobalDeclaration]
genDecls)
  where
    collectGenFunDecls :: StateT Env Identity [GlobalDeclaration]
collectGenFunDecls = do
      [FunDeclaration]
genDecls <- (Env -> [FunDeclaration]) -> StateT Env Identity [FunDeclaration]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets Env -> [FunDeclaration]
genFunDecls
      (Env -> Env) -> StateT Env Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Env -> Env) -> StateT Env Identity ())
-> (Env -> Env) -> StateT Env Identity ()
forall a b. (a -> b) -> a -> b
$ \Env
env -> Env
env {genFunDecls :: [FunDeclaration]
genFunDecls = []}
      [GlobalDeclaration] -> StateT Env Identity [GlobalDeclaration]
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (FunDeclaration -> GlobalDeclaration
convertFunDecl (FunDeclaration -> GlobalDeclaration)
-> [FunDeclaration] -> [GlobalDeclaration]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [FunDeclaration]
genDecls)
    convertFunDecl :: FunDeclaration -> GlobalDeclaration
convertFunDecl (FunDecl Identifier'
i [Identifier']
ps Expression
b) = Identifier' -> [Identifier'] -> Expression -> GlobalDeclaration
Lfr.GlobFunDecl Identifier'
i [Identifier']
ps Expression
b

llExpr :: Ast.Expression -> LlState Lfr.Expression
llExpr :: Expression -> LlState Expression
llExpr = \case
  Ast.ExprId Identifier'
ident -> Expression -> LlState Expression
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expression -> LlState Expression)
-> Expression -> LlState Expression
forall a b. (a -> b) -> a -> b
$ Identifier' -> Expression
Lfr.ExprId Identifier'
ident
  Ast.ExprPrimVal PrimitiveValue
val -> Expression -> LlState Expression
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expression -> LlState Expression)
-> Expression -> LlState Expression
forall a b. (a -> b) -> a -> b
$ PrimitiveValue -> Expression
Lfr.ExprPrimVal PrimitiveValue
val
  Ast.ExprBinOp BinaryOperator
op Expression
lhs Expression
rhs -> (Expression -> Expression -> Expression)
-> Expression -> Expression -> LlState Expression
forall a.
(Expression -> Expression -> a)
-> Expression -> Expression -> LlState a
ll2 (BinaryOperator -> Expression -> Expression -> Expression
Lfr.ExprBinOp BinaryOperator
op) Expression
lhs Expression
rhs
  Ast.ExprUnOp UnaryOperator
op Expression
x -> (Expression -> Expression) -> Expression -> LlState Expression
forall a. (Expression -> a) -> Expression -> LlState a
ll1 (UnaryOperator -> Expression -> Expression
Lfr.ExprUnOp UnaryOperator
op) Expression
x
  Ast.ExprApp Expression
f Expression
arg -> (Expression -> Expression -> Expression)
-> Expression -> Expression -> LlState Expression
forall a.
(Expression -> Expression -> a)
-> Expression -> Expression -> LlState a
ll2 Expression -> Expression -> Expression
Lfr.ExprApp Expression
f Expression
arg
  Ast.ExprIte Expression
c Expression
t Expression
e -> (Expression -> Expression -> Expression -> Expression)
-> Expression -> Expression -> Expression -> LlState Expression
forall a.
(Expression -> Expression -> Expression -> a)
-> Expression -> Expression -> Expression -> LlState a
ll3 Expression -> Expression -> Expression -> Expression
Lfr.ExprIte Expression
c Expression
t Expression
e
  Ast.ExprLetIn Declaration
decl Expression
expr -> case Declaration
decl of
    Ast.DeclVar Identifier'
ident Expression
value -> do
      VarDeclaration
varDecl <- (Expression -> VarDeclaration)
-> Expression -> LlState VarDeclaration
forall a. (Expression -> a) -> Expression -> LlState a
ll1 (Identifier' -> Expression -> VarDeclaration
Lfr.VarDecl Identifier'
ident) Expression
value
      Expression
expr' <- Expression -> LlState Expression
llExpr Expression
expr
      Expression -> LlState Expression
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expression -> LlState Expression)
-> Expression -> LlState Expression
forall a b. (a -> b) -> a -> b
$ VarDeclaration -> Expression -> Expression
Lfr.ExprLetIn VarDeclaration
varDecl Expression
expr'
    Ast.DeclFun Identifier'
ident IsRec
_ Fun
fun -> Identifier' -> Fun -> StateT Env Identity ()
llFun Identifier'
ident Fun
fun StateT Env Identity () -> LlState Expression -> LlState Expression
forall a b.
StateT Env Identity a
-> StateT Env Identity b -> StateT Env Identity b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Expression -> LlState Expression
llExpr Expression
expr
  Ast.ExprFun Fun
fun -> do
    Identifier'
ident <- LlState Identifier'
genId
    Identifier' -> Fun -> StateT Env Identity ()
llFun Identifier'
ident Fun
fun
    Expression -> LlState Expression
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expression -> LlState Expression)
-> Expression -> LlState Expression
forall a b. (a -> b) -> a -> b
$ Identifier' -> Expression
Lfr.ExprId Identifier'
ident

llFun :: Common.Identifier' -> Ast.Fun -> LlState ()
llFun :: Identifier' -> Fun -> StateT Env Identity ()
llFun Identifier'
ident (Ast.Fun NonEmpty Identifier'
params Expression
body) = do
  FunDeclaration
fun <- (Expression -> FunDeclaration)
-> Expression -> LlState FunDeclaration
forall a. (Expression -> a) -> Expression -> LlState a
ll1 (Identifier' -> [Identifier'] -> Expression -> FunDeclaration
FunDecl Identifier'
ident (NonEmpty Identifier' -> [Identifier']
forall a. NonEmpty a -> [a]
NE.toList NonEmpty Identifier'
params)) Expression
body
  (Env -> Env) -> StateT Env Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Env -> Env) -> StateT Env Identity ())
-> (Env -> Env) -> StateT Env Identity ()
forall a b. (a -> b) -> a -> b
$ \Env
env -> Env
env {genFunDecls :: [FunDeclaration]
genFunDecls = FunDeclaration
fun FunDeclaration -> [FunDeclaration] -> [FunDeclaration]
forall a. a -> [a] -> [a]
: Env -> [FunDeclaration]
genFunDecls Env
env}

-- ** Identifier Generation

genId :: LlState Common.Identifier'
genId :: LlState Identifier'
genId = do
  IdCnt
cnt <- (Env -> IdCnt) -> StateT Env Identity IdCnt
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets Env -> IdCnt
idCnt
  (Env -> Env) -> StateT Env Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Env -> Env) -> StateT Env Identity ())
-> (Env -> Env) -> StateT Env Identity ()
forall a b. (a -> b) -> a -> b
$ \Env
env -> Env
env {idCnt :: IdCnt
idCnt = IdCnt
cnt IdCnt -> IdCnt -> IdCnt
forall a. Num a => a -> a -> a
+ IdCnt
1}
  Identifier' -> LlState Identifier'
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Identifier' -> LlState Identifier')
-> Identifier' -> LlState Identifier'
forall a b. (a -> b) -> a -> b
$ IdCnt -> Identifier -> Identifier'
Common.Gen IdCnt
cnt (Identifier -> Identifier') -> Identifier -> Identifier'
forall a b. (a -> b) -> a -> b
$ String -> Identifier
pack String
"ll"

-- ** Utils

ll1 ::
  (Lfr.Expression -> a) ->
  (Ast.Expression -> LlState a)
ll1 :: forall a. (Expression -> a) -> Expression -> LlState a
ll1 = (Expression -> LlState Expression)
-> (Expression -> a) -> Expression -> StateT Env Identity a
forall (m :: * -> *) a' a b.
Monad m =>
(a' -> m a) -> (a -> b) -> a' -> m b
liftM1' Expression -> LlState Expression
llExpr

ll2 ::
  (Lfr.Expression -> Lfr.Expression -> a) ->
  (Ast.Expression -> Ast.Expression -> LlState a)
ll2 :: forall a.
(Expression -> Expression -> a)
-> Expression -> Expression -> LlState a
ll2 = (Expression -> LlState Expression)
-> (Expression -> Expression -> a)
-> Expression
-> Expression
-> StateT Env Identity a
forall (m :: * -> *) a' a b.
Monad m =>
(a' -> m a) -> (a -> a -> b) -> a' -> a' -> m b
liftM2' Expression -> LlState Expression
llExpr

ll3 ::
  (Lfr.Expression -> Lfr.Expression -> Lfr.Expression -> a) ->
  (Ast.Expression -> Ast.Expression -> Ast.Expression -> LlState a)
ll3 :: forall a.
(Expression -> Expression -> Expression -> a)
-> Expression -> Expression -> Expression -> LlState a
ll3 = (Expression -> LlState Expression)
-> (Expression -> Expression -> Expression -> a)
-> Expression
-> Expression
-> Expression
-> StateT Env Identity a
forall (m :: * -> *) a' a b.
Monad m =>
(a' -> m a) -> (a -> a -> a -> b) -> a' -> a' -> a' -> m b
liftM3' Expression -> LlState Expression
llExpr