{-# LANGUAGE LambdaCase #-}

module Transformations.Relabeler.Relabeler (relabelAst) where

import Control.Monad (replicateM_)
import Control.Monad.State (State, gets, modify, runState)
import qualified Data.List.NonEmpty as NE
import Data.Maybe (fromMaybe)
import MonadUtils
import qualified Transformations.Simplifier.SimplifiedAst as Ast
import qualified Trees.Common as Ast

-- * AST Relabeler

-- | Relabel identifiers in the AST so that each declaration creates an identifier with a unique name.
--
-- It helps to avoid naming errors in the future.
relabelAst :: Ast.Program -> Ast.Program
relabelAst :: Program -> Program
relabelAst (Ast.Program [Declaration]
decls IdCnt
cnt) =
  let ([Declaration]
decls', Env [IdMapping]
_ IdCnt
cnt') = State Env [Declaration] -> Env -> ([Declaration], Env)
forall s a. State s a -> s -> (a, s)
runState ((Declaration -> StateT Env Identity Declaration)
-> [Declaration] -> State Env [Declaration]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Declaration -> StateT Env Identity Declaration
relabelDecl [Declaration]
decls) ([IdMapping] -> IdCnt -> Env
Env [] IdCnt
cnt)
   in [Declaration] -> IdCnt -> Program
Ast.Program [Declaration]
decls' IdCnt
cnt'

-- * Internal

-- ** Relabeler State

type RelabelerState = State Env

data Env = Env
  { Env -> [IdMapping]
idMappings :: [IdMapping],
    Env -> IdCnt
idCnt :: Ast.IdCnt
  }

type IdMapping = (Ast.Identifier', Ast.Identifier')

-- ** Relabelers

relabelDecl :: Ast.Declaration -> RelabelerState Ast.Declaration
relabelDecl :: Declaration -> StateT Env Identity Declaration
relabelDecl = \case
  Ast.DeclVar Identifier'
ident Expression
value -> do
    Expression
value' <- Expression -> RelabelerState Expression
relabelExpr Expression
value
    Identifier'
ident' <- Identifier' -> RelabelerState Identifier'
pushAndMapId Identifier'
ident
    Declaration -> StateT Env Identity Declaration
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Declaration -> StateT Env Identity Declaration)
-> Declaration -> StateT Env Identity Declaration
forall a b. (a -> b) -> a -> b
$ Identifier' -> Expression -> Declaration
Ast.DeclVar Identifier'
ident' Expression
value'
  Ast.DeclFun Identifier'
ident IsRec
isRec Fun
fun -> do
    (Identifier'
ident', Fun
fun') <-
      if IsRec
isRec
        then (,) (Identifier' -> Fun -> (Identifier', Fun))
-> RelabelerState Identifier'
-> StateT Env Identity (Fun -> (Identifier', Fun))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Identifier' -> RelabelerState Identifier'
pushAndMapId Identifier'
ident StateT Env Identity (Fun -> (Identifier', Fun))
-> StateT Env Identity Fun
-> StateT Env Identity (Identifier', Fun)
forall a b.
StateT Env Identity (a -> b)
-> StateT Env Identity a -> StateT Env Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Fun -> StateT Env Identity Fun
relabelFun Fun
fun
        else (Identifier' -> Fun -> (Identifier', Fun))
-> Fun -> Identifier' -> (Identifier', Fun)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (,) (Fun -> Identifier' -> (Identifier', Fun))
-> StateT Env Identity Fun
-> StateT Env Identity (Identifier' -> (Identifier', Fun))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fun -> StateT Env Identity Fun
relabelFun Fun
fun StateT Env Identity (Identifier' -> (Identifier', Fun))
-> RelabelerState Identifier'
-> StateT Env Identity (Identifier', Fun)
forall a b.
StateT Env Identity (a -> b)
-> StateT Env Identity a -> StateT Env Identity b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> Identifier' -> RelabelerState Identifier'
pushAndMapId Identifier'
ident
    Declaration -> StateT Env Identity Declaration
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Declaration -> StateT Env Identity Declaration)
-> Declaration -> StateT Env Identity Declaration
forall a b. (a -> b) -> a -> b
$ Identifier' -> IsRec -> Fun -> Declaration
Ast.DeclFun Identifier'
ident' IsRec
isRec Fun
fun'

relabelExpr :: Ast.Expression -> RelabelerState Ast.Expression
relabelExpr :: Expression -> RelabelerState Expression
relabelExpr = \case
  Ast.ExprId Identifier'
ident -> Identifier' -> Expression
Ast.ExprId (Identifier' -> Expression)
-> RelabelerState Identifier' -> RelabelerState Expression
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Identifier' -> RelabelerState Identifier'
mapId Identifier'
ident
  Ast.ExprPrimVal PrimitiveValue
val -> Expression -> RelabelerState Expression
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expression -> RelabelerState Expression)
-> Expression -> RelabelerState Expression
forall a b. (a -> b) -> a -> b
$ PrimitiveValue -> Expression
Ast.ExprPrimVal PrimitiveValue
val
  Ast.ExprBinOp BinaryOperator
op Expression
lhs Expression
rhs -> (Expression -> Expression -> Expression)
-> Expression -> Expression -> RelabelerState Expression
forall a.
(Expression -> Expression -> a)
-> Expression -> Expression -> RelabelerState a
relabel2 (BinaryOperator -> Expression -> Expression -> Expression
Ast.ExprBinOp BinaryOperator
op) Expression
lhs Expression
rhs
  Ast.ExprUnOp UnaryOperator
op Expression
x -> (Expression -> Expression)
-> Expression -> RelabelerState Expression
forall a. (Expression -> a) -> Expression -> RelabelerState a
relabel1 (UnaryOperator -> Expression -> Expression
Ast.ExprUnOp UnaryOperator
op) Expression
x
  Ast.ExprApp Expression
f Expression
args -> (Expression -> Expression -> Expression)
-> Expression -> Expression -> RelabelerState Expression
forall a.
(Expression -> Expression -> a)
-> Expression -> Expression -> RelabelerState a
relabel2 Expression -> Expression -> Expression
Ast.ExprApp Expression
f Expression
args
  Ast.ExprIte Expression
c Expression
t Expression
e -> (Expression -> Expression -> Expression -> Expression)
-> Expression
-> Expression
-> Expression
-> RelabelerState Expression
forall a.
(Expression -> Expression -> Expression -> a)
-> Expression -> Expression -> Expression -> RelabelerState a
relabel3 Expression -> Expression -> Expression -> Expression
Ast.ExprIte Expression
c Expression
t Expression
e
  Ast.ExprLetIn Declaration
decl Expression
expr -> do
    Declaration
decl' <- Declaration -> StateT Env Identity Declaration
relabelDecl Declaration
decl
    Expression
expr' <- Expression -> RelabelerState Expression
relabelExpr Expression
expr
    RelabelerState ()
popId
    Expression -> RelabelerState Expression
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expression -> RelabelerState Expression)
-> Expression -> RelabelerState Expression
forall a b. (a -> b) -> a -> b
$ Declaration -> Expression -> Expression
Ast.ExprLetIn Declaration
decl' Expression
expr'
  Ast.ExprFun Fun
fun -> Fun -> Expression
Ast.ExprFun (Fun -> Expression)
-> StateT Env Identity Fun -> RelabelerState Expression
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> Fun -> StateT Env Identity Fun
relabelFun Fun
fun

relabelFun :: Ast.Fun -> RelabelerState Ast.Fun
relabelFun :: Fun -> StateT Env Identity Fun
relabelFun (Ast.Fun NonEmpty Identifier'
params Expression
body) = do
  NonEmpty Identifier'
params' <- (Identifier' -> RelabelerState Identifier')
-> NonEmpty Identifier'
-> StateT Env Identity (NonEmpty Identifier')
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> NonEmpty a -> m (NonEmpty b)
mapM Identifier' -> RelabelerState Identifier'
pushAndMapId NonEmpty Identifier'
params
  Expression
body' <- Expression -> RelabelerState Expression
relabelExpr Expression
body
  IdCnt -> RelabelerState () -> RelabelerState ()
forall (m :: * -> *) a. Applicative m => IdCnt -> m a -> m ()
replicateM_ (NonEmpty Identifier' -> IdCnt
forall a. NonEmpty a -> IdCnt
NE.length NonEmpty Identifier'
params) RelabelerState ()
popId
  Fun -> StateT Env Identity Fun
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Fun -> StateT Env Identity Fun) -> Fun -> StateT Env Identity Fun
forall a b. (a -> b) -> a -> b
$ NonEmpty Identifier' -> Expression -> Fun
Ast.Fun NonEmpty Identifier'
params' Expression
body'

-- ** Identifier Mappings

mapId :: Ast.Identifier' -> RelabelerState Ast.Identifier'
mapId :: Identifier' -> RelabelerState Identifier'
mapId Identifier'
ident = do
  [IdMapping]
ms <- (Env -> [IdMapping]) -> StateT Env Identity [IdMapping]
forall s (m :: * -> *) a. MonadState s m => (s -> a) -> m a
gets Env -> [IdMapping]
idMappings
  Identifier' -> RelabelerState Identifier'
forall a. a -> StateT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Identifier' -> RelabelerState Identifier')
-> Identifier' -> RelabelerState Identifier'
forall a b. (a -> b) -> a -> b
$ Identifier' -> Maybe Identifier' -> Identifier'
forall a. a -> Maybe a -> a
fromMaybe Identifier'
ident (Identifier' -> [IdMapping] -> Maybe Identifier'
forall a b. Eq a => a -> [(a, b)] -> Maybe b
lookup Identifier'
ident [IdMapping]
ms)

pushAndMapId :: Ast.Identifier' -> RelabelerState Ast.Identifier'
pushAndMapId :: Identifier' -> RelabelerState Identifier'
pushAndMapId Identifier'
ident = Identifier' -> RelabelerState ()
pushId Identifier'
ident RelabelerState ()
-> RelabelerState Identifier' -> RelabelerState Identifier'
forall a b.
StateT Env Identity a
-> StateT Env Identity b -> StateT Env Identity b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Identifier' -> RelabelerState Identifier'
mapId Identifier'
ident

pushId :: Ast.Identifier' -> RelabelerState ()
pushId :: Identifier' -> RelabelerState ()
pushId Identifier'
ident = (Env -> Env) -> RelabelerState ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Env -> Env) -> RelabelerState ())
-> (Env -> Env) -> RelabelerState ()
forall a b. (a -> b) -> a -> b
$ \(Env [IdMapping]
ms IdCnt
cnt) ->
  Env
    { idMappings :: [IdMapping]
idMappings = (Identifier'
ident, IdCnt -> Identifier -> Identifier'
Ast.Gen IdCnt
cnt (Identifier' -> Identifier
getName Identifier'
ident)) IdMapping -> [IdMapping] -> [IdMapping]
forall a. a -> [a] -> [a]
: [IdMapping]
ms,
      idCnt :: IdCnt
idCnt = IdCnt
cnt IdCnt -> IdCnt -> IdCnt
forall a. Num a => a -> a -> a
+ IdCnt
1
    }
  where
    getName :: Identifier' -> Identifier
getName (Ast.Gen IdCnt
_ Identifier
name) = Identifier
name
    getName (Ast.Txt Identifier
name) = Identifier
name

popId :: RelabelerState ()
popId :: RelabelerState ()
popId = (Env -> Env) -> RelabelerState ()
forall s (m :: * -> *). MonadState s m => (s -> s) -> m ()
modify ((Env -> Env) -> RelabelerState ())
-> (Env -> Env) -> RelabelerState ()
forall a b. (a -> b) -> a -> b
$ \env :: Env
env@(Env [IdMapping]
ms IdCnt
_) -> Env
env {idMappings :: [IdMapping]
idMappings = [IdMapping] -> [IdMapping]
forall a. HasCallStack => [a] -> [a]
tail [IdMapping]
ms}

-- ** Utils

relabel1 ::
  (Ast.Expression -> a) ->
  (Ast.Expression -> RelabelerState a)
relabel1 :: forall a. (Expression -> a) -> Expression -> RelabelerState a
relabel1 = (Expression -> RelabelerState Expression)
-> (Expression -> a) -> Expression -> StateT Env Identity a
forall (m :: * -> *) a' a b.
Monad m =>
(a' -> m a) -> (a -> b) -> a' -> m b
liftM1' Expression -> RelabelerState Expression
relabelExpr

relabel2 ::
  (Ast.Expression -> Ast.Expression -> a) ->
  (Ast.Expression -> Ast.Expression -> RelabelerState a)
relabel2 :: forall a.
(Expression -> Expression -> a)
-> Expression -> Expression -> RelabelerState a
relabel2 = (Expression -> RelabelerState Expression)
-> (Expression -> Expression -> a)
-> Expression
-> Expression
-> StateT Env Identity a
forall (m :: * -> *) a' a b.
Monad m =>
(a' -> m a) -> (a -> a -> b) -> a' -> a' -> m b
liftM2' Expression -> RelabelerState Expression
relabelExpr

relabel3 ::
  (Ast.Expression -> Ast.Expression -> Ast.Expression -> a) ->
  (Ast.Expression -> Ast.Expression -> Ast.Expression -> RelabelerState a)
relabel3 :: forall a.
(Expression -> Expression -> Expression -> a)
-> Expression -> Expression -> Expression -> RelabelerState a
relabel3 = (Expression -> RelabelerState Expression)
-> (Expression -> Expression -> Expression -> a)
-> Expression
-> Expression
-> Expression
-> StateT Env Identity a
forall (m :: * -> *) a' a b.
Monad m =>
(a' -> m a) -> (a -> a -> a -> b) -> a' -> a' -> a' -> m b
liftM3' Expression -> RelabelerState Expression
relabelExpr