{-# LANGUAGE LambdaCase #-}

module Transformations.Simplifier.Simplifier (simplifyAst) where

import Control.Monad.State (State, get, modify, runState)
import Data.Text (pack)
import MonadUtils
import qualified Parser.Ast as Ast
import qualified Transformations.Simplifier.SimplifiedAst as SAst
import qualified Trees.Common as Common

-- * AST Simplifier

simplifyAst :: Ast.Program -> SAst.Program
simplifyAst :: Program -> Program
simplifyAst (Ast.Program [Statement]
stmts) =
  let ([Declaration]
decls, IdCnt
cnt) = State IdCnt [Declaration] -> IdCnt -> ([Declaration], IdCnt)
forall s a. State s a -> s -> (a, s)
runState ((Statement -> StateT IdCnt Identity Declaration)
-> [Statement] -> State IdCnt [Declaration]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Statement -> StateT IdCnt Identity Declaration
simplifyStmt [Statement]
stmts) IdCnt
0
   in [Declaration] -> IdCnt -> Program
SAst.Program [Declaration]
decls IdCnt
cnt

-- * Internal

-- ** Simplifier State

type SimplifierState = State Common.IdCnt

-- ** Simplifiers

simplifyStmt :: Ast.Statement -> SimplifierState SAst.Declaration
simplifyStmt :: Statement -> StateT IdCnt Identity Declaration
simplifyStmt = \case
  Ast.StmtDecl Declaration
decl -> Declaration -> StateT IdCnt Identity Declaration
simplifyDecl Declaration
decl
  Ast.StmtExpr Expression
expr -> Identifier' -> Expression -> Declaration
SAst.DeclVar (Identifier' -> Expression -> Declaration)
-> StateT IdCnt Identity Identifier'
-> StateT IdCnt Identity (Expression -> Declaration)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> StateT IdCnt Identity Identifier'
genId StateT IdCnt Identity (Expression -> Declaration)
-> StateT IdCnt Identity Expression
-> StateT IdCnt Identity Declaration
forall a b.
StateT IdCnt Identity (a -> b)
-> StateT IdCnt Identity a -> StateT IdCnt Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expression -> StateT IdCnt Identity Expression
simplifyExpr Expression
expr

simplifyDecl :: Ast.Declaration -> SimplifierState SAst.Declaration
simplifyDecl :: Declaration -> StateT IdCnt Identity Declaration
simplifyDecl = \case
  Ast.DeclVar (Identifier, Maybe Type)
ident Expression
value -> (Expression -> Declaration)
-> Expression -> StateT IdCnt Identity Declaration
forall a. (Expression -> a) -> Expression -> SimplifierState a
simplify1 (Identifier' -> Expression -> Declaration
SAst.DeclVar ((Identifier, Maybe Type) -> Identifier'
convertTypedId (Identifier, Maybe Type)
ident)) Expression
value
  Ast.DeclFun Identifier
ident IsRec
isRec Fun
fun -> Identifier' -> IsRec -> Fun -> Declaration
SAst.DeclFun (Identifier -> Identifier'
convertId Identifier
ident) IsRec
isRec (Fun -> Declaration)
-> StateT IdCnt Identity Fun -> StateT IdCnt Identity Declaration
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fun -> StateT IdCnt Identity Fun
simplifyFun Fun
fun

simplifyExpr :: Ast.Expression -> SimplifierState SAst.Expression
simplifyExpr :: Expression -> StateT IdCnt Identity Expression
simplifyExpr = \case
  Ast.ExprId Identifier
ident -> Expression -> StateT IdCnt Identity Expression
forall a. a -> StateT IdCnt Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expression -> StateT IdCnt Identity Expression)
-> Expression -> StateT IdCnt Identity Expression
forall a b. (a -> b) -> a -> b
$ Identifier' -> Expression
SAst.ExprId (Identifier -> Identifier'
convertId Identifier
ident)
  Ast.ExprPrimVal PrimitiveValue
val -> Expression -> StateT IdCnt Identity Expression
forall a. a -> StateT IdCnt Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expression -> StateT IdCnt Identity Expression)
-> Expression -> StateT IdCnt Identity Expression
forall a b. (a -> b) -> a -> b
$ PrimitiveValue -> Expression
SAst.ExprPrimVal PrimitiveValue
val
  Ast.ExprBinOp BinaryOperator
op Expression
lhs Expression
rhs -> (Expression -> Expression -> Expression)
-> Expression -> Expression -> StateT IdCnt Identity Expression
forall a.
(Expression -> Expression -> a)
-> Expression -> Expression -> SimplifierState a
simplify2 (BinaryOperator -> Expression -> Expression -> Expression
SAst.ExprBinOp BinaryOperator
op) Expression
lhs Expression
rhs
  Ast.ExprUnOp UnaryOperator
op Expression
x -> (Expression -> Expression)
-> Expression -> StateT IdCnt Identity Expression
forall a. (Expression -> a) -> Expression -> SimplifierState a
simplify1 (UnaryOperator -> Expression -> Expression
SAst.ExprUnOp UnaryOperator
op) Expression
x
  Ast.ExprApp Expression
f Expression
arg -> (Expression -> Expression -> Expression)
-> Expression -> Expression -> StateT IdCnt Identity Expression
forall a.
(Expression -> Expression -> a)
-> Expression -> Expression -> SimplifierState a
simplify2 Expression -> Expression -> Expression
SAst.ExprApp Expression
f Expression
arg
  Ast.ExprIte Expression
c Expression
t Expression
e -> (Expression -> Expression -> Expression -> Expression)
-> Expression
-> Expression
-> Expression
-> StateT IdCnt Identity Expression
forall a.
(Expression -> Expression -> Expression -> a)
-> Expression -> Expression -> Expression -> SimplifierState a
simplify3 Expression -> Expression -> Expression -> Expression
SAst.ExprIte Expression
c Expression
t Expression
e
  Ast.ExprLetIn Declaration
decl Expression
expr -> Declaration -> Expression -> Expression
SAst.ExprLetIn (Declaration -> Expression -> Expression)
-> StateT IdCnt Identity Declaration
-> StateT IdCnt Identity (Expression -> Expression)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Declaration -> StateT IdCnt Identity Declaration
simplifyDecl Declaration
decl StateT IdCnt Identity (Expression -> Expression)
-> StateT IdCnt Identity Expression
-> StateT IdCnt Identity Expression
forall a b.
StateT IdCnt Identity (a -> b)
-> StateT IdCnt Identity a -> StateT IdCnt Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Expression -> StateT IdCnt Identity Expression
simplifyExpr Expression
expr
  Ast.ExprFun Fun
fun -> Fun -> Expression
SAst.ExprFun (Fun -> Expression)
-> StateT IdCnt Identity Fun -> StateT IdCnt Identity Expression
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fun -> StateT IdCnt Identity Fun
simplifyFun Fun
fun

simplifyFun :: Ast.Fun -> SimplifierState SAst.Fun
simplifyFun :: Fun -> StateT IdCnt Identity Fun
simplifyFun (Ast.Fun NonEmpty (Identifier, Maybe Type)
params Maybe Type
_ Expression
body) = (Expression -> Fun) -> Expression -> StateT IdCnt Identity Fun
forall a. (Expression -> a) -> Expression -> SimplifierState a
simplify1 (NonEmpty Identifier' -> Expression -> Fun
SAst.Fun ((Identifier, Maybe Type) -> Identifier'
convertTypedId ((Identifier, Maybe Type) -> Identifier')
-> NonEmpty (Identifier, Maybe Type) -> NonEmpty Identifier'
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> NonEmpty (Identifier, Maybe Type)
params)) Expression
body

-- ** Identifier Conversion and Generation

convertId :: Common.Identifier -> Common.Identifier'
convertId :: Identifier -> Identifier'
convertId = Identifier -> Identifier'
Common.Txt

convertTypedId :: (Common.Identifier, Maybe Common.Type) -> Common.Identifier'
convertTypedId :: (Identifier, Maybe Type) -> Identifier'
convertTypedId = Identifier -> Identifier'
Common.Txt (Identifier -> Identifier')
-> ((Identifier, Maybe Type) -> Identifier)
-> (Identifier, Maybe Type)
-> Identifier'
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (Identifier, Maybe Type) -> Identifier
forall a b. (a, b) -> a
fst

genId :: SimplifierState Common.Identifier'
genId :: StateT IdCnt Identity Identifier'
genId = do
  IdCnt
cnt <- StateT IdCnt Identity IdCnt
forall s (m :: * -> *). MonadState s m => m s
get
  (IdCnt -> IdCnt) -> StateT IdCnt Identity ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify (IdCnt -> IdCnt -> IdCnt
forall a. Num a => a -> a -> a
+ IdCnt
1)
  Identifier' -> StateT IdCnt Identity Identifier'
forall a. a -> StateT IdCnt Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Identifier' -> StateT IdCnt Identity Identifier')
-> Identifier' -> StateT IdCnt Identity Identifier'
forall a b. (a -> b) -> a -> b
$ IdCnt -> Identifier -> Identifier'
Common.Gen IdCnt
cnt (Identifier -> Identifier') -> Identifier -> Identifier'
forall a b. (a -> b) -> a -> b
$ String -> Identifier
pack String
"simp"

-- ** Utils

simplify1 ::
  (SAst.Expression -> a) ->
  (Ast.Expression -> SimplifierState a)
simplify1 :: forall a. (Expression -> a) -> Expression -> SimplifierState a
simplify1 = (Expression -> StateT IdCnt Identity Expression)
-> (Expression -> a) -> Expression -> StateT IdCnt Identity a
forall (m :: * -> *) a' a b.
Monad m =>
(a' -> m a) -> (a -> b) -> a' -> m b
liftM1' Expression -> StateT IdCnt Identity Expression
simplifyExpr

simplify2 ::
  (SAst.Expression -> SAst.Expression -> a) ->
  (Ast.Expression -> Ast.Expression -> SimplifierState a)
simplify2 :: forall a.
(Expression -> Expression -> a)
-> Expression -> Expression -> SimplifierState a
simplify2 = (Expression -> StateT IdCnt Identity Expression)
-> (Expression -> Expression -> a)
-> Expression
-> Expression
-> StateT IdCnt Identity a
forall (m :: * -> *) a' a b.
Monad m =>
(a' -> m a) -> (a -> a -> b) -> a' -> a' -> m b
liftM2' Expression -> StateT IdCnt Identity Expression
simplifyExpr

simplify3 ::
  (SAst.Expression -> SAst.Expression -> SAst.Expression -> a) ->
  (Ast.Expression -> Ast.Expression -> Ast.Expression -> SimplifierState a)
simplify3 :: forall a.
(Expression -> Expression -> Expression -> a)
-> Expression -> Expression -> Expression -> SimplifierState a
simplify3 = (Expression -> StateT IdCnt Identity Expression)
-> (Expression -> Expression -> Expression -> a)
-> Expression
-> Expression
-> Expression
-> StateT IdCnt Identity a
forall (m :: * -> *) a' a b.
Monad m =>
(a' -> m a) -> (a -> a -> a -> b) -> a' -> a' -> a' -> m b
liftM3' Expression -> StateT IdCnt Identity Expression
simplifyExpr