{-# LANGUAGE LambdaCase #-}

module Transformations.Cc.Cc (ccAst) where

import Control.Monad.Reader (MonadReader (local), Reader, asks, runReader)
import Data.Foldable (Foldable (foldl'))
import Data.Foldable.Extra (notNull)
import Data.Functor.Foldable (ListF (..), Recursive (cata))
import Data.List.NonEmpty (NonEmpty, (<|))
import qualified Data.List.NonEmpty as NE
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromMaybe)
import Data.Set (Set, (\\))
import qualified Data.Set as Set
import MonadUtils
import qualified StdLib
import qualified Transformations.Simplifier.SimplifiedAst as Ast
import qualified Trees.Common as Ast

-- * AST Closure Converter

-- | Convert AST to its closure-free representation.
ccAst :: Ast.Program -> Ast.Program
ccAst :: Program -> Program
ccAst (Ast.Program [Declaration]
decls IdCnt
cnt) =
  let decls' :: [Declaration]
decls' = Reader Env [Declaration] -> Env -> [Declaration]
forall r a. Reader r a -> r -> a
runReader ((Declaration -> ReaderT Env Identity Declaration)
-> [Declaration] -> Reader Env [Declaration]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM Declaration -> ReaderT Env Identity Declaration
ccGDecl [Declaration]
decls) (Set Identifier' -> Map Identifier' Expression -> Env
Env Set Identifier'
gIds Map Identifier' Expression
forall k a. Map k a
Map.empty)
      gIds :: Set Identifier'
gIds = [Identifier'] -> Set Identifier'
forall a. Ord a => [a] -> Set a
Set.fromList ([Identifier'] -> Set Identifier')
-> [Identifier'] -> Set Identifier'
forall a b. (a -> b) -> a -> b
$ (Identifier -> Identifier'
Ast.Txt (Identifier -> Identifier') -> [Identifier] -> [Identifier']
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Identifier]
StdLib.decls) [Identifier'] -> [Identifier'] -> [Identifier']
forall a. Semigroup a => a -> a -> a
<> (Declaration -> Identifier'
Ast.declId (Declaration -> Identifier') -> [Declaration] -> [Identifier']
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Declaration]
decls)
   in [Declaration] -> IdCnt -> Program
Ast.Program [Declaration]
decls' IdCnt
cnt

-- * Internal

-- ** Closure Converter Info

type CcInfo = Reader Env

data Env = Env
  { Env -> Set Identifier'
globalIds :: Set Ast.Identifier',
    Env -> Map Identifier' Expression
declIdMappings :: Map Ast.Identifier' Ast.Expression
  }

-- ** Closure Converters

ccGDecl :: Ast.Declaration -> CcInfo Ast.Declaration
ccGDecl :: Declaration -> ReaderT Env Identity Declaration
ccGDecl Declaration
gDecl = case Declaration
gDecl of
  Ast.DeclVar Identifier'
ident Expression
val -> (Expression -> Declaration)
-> Expression -> ReaderT Env Identity Declaration
forall a. (Expression -> a) -> Expression -> CcInfo a
cc1 (Identifier' -> Expression -> Declaration
Ast.DeclVar Identifier'
ident) Expression
val
  Ast.DeclFun Identifier'
ident IsRec
isRec (Ast.Fun NonEmpty Identifier'
params Expression
body) -> do
    Fun
fun <- (Expression -> Fun) -> Expression -> CcInfo Fun
forall a. (Expression -> a) -> Expression -> CcInfo a
cc1 (NonEmpty Identifier' -> Expression -> Fun
Ast.Fun NonEmpty Identifier'
params) Expression
body
    Declaration -> ReaderT Env Identity Declaration
forall a. a -> ReaderT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Declaration -> ReaderT Env Identity Declaration)
-> Declaration -> ReaderT Env Identity Declaration
forall a b. (a -> b) -> a -> b
$ Identifier' -> IsRec -> Fun -> Declaration
Ast.DeclFun Identifier'
ident IsRec
isRec Fun
fun

ccExpr :: Ast.Expression -> CcInfo Ast.Expression
ccExpr :: Expression -> CcInfo Expression
ccExpr = \case
  Ast.ExprId Identifier'
ident -> do
    Map Identifier' Expression
ms <- (Env -> Map Identifier' Expression)
-> ReaderT Env Identity (Map Identifier' Expression)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> Map Identifier' Expression
declIdMappings
    Expression -> CcInfo Expression
forall a. a -> ReaderT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expression -> CcInfo Expression)
-> Expression -> CcInfo Expression
forall a b. (a -> b) -> a -> b
$ Expression -> Maybe Expression -> Expression
forall a. a -> Maybe a -> a
fromMaybe (Identifier' -> Expression
Ast.ExprId Identifier'
ident) (Map Identifier' Expression
ms Map Identifier' Expression -> Identifier' -> Maybe Expression
forall k a. Ord k => Map k a -> k -> Maybe a
Map.!? Identifier'
ident)
  Ast.ExprPrimVal PrimitiveValue
val -> Expression -> CcInfo Expression
forall a. a -> ReaderT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expression -> CcInfo Expression)
-> Expression -> CcInfo Expression
forall a b. (a -> b) -> a -> b
$ PrimitiveValue -> Expression
Ast.ExprPrimVal PrimitiveValue
val
  Ast.ExprBinOp BinaryOperator
op Expression
lhs Expression
rhs -> (Expression -> Expression -> Expression)
-> Expression -> Expression -> CcInfo Expression
forall a.
(Expression -> Expression -> a)
-> Expression -> Expression -> CcInfo a
cc2 (BinaryOperator -> Expression -> Expression -> Expression
Ast.ExprBinOp BinaryOperator
op) Expression
lhs Expression
rhs
  Ast.ExprUnOp UnaryOperator
op Expression
x -> (Expression -> Expression) -> Expression -> CcInfo Expression
forall a. (Expression -> a) -> Expression -> CcInfo a
cc1 (UnaryOperator -> Expression -> Expression
Ast.ExprUnOp UnaryOperator
op) Expression
x
  Ast.ExprApp Expression
f Expression
arg -> (Expression -> Expression -> Expression)
-> Expression -> Expression -> CcInfo Expression
forall a.
(Expression -> Expression -> a)
-> Expression -> Expression -> CcInfo a
cc2 Expression -> Expression -> Expression
Ast.ExprApp Expression
f Expression
arg
  Ast.ExprIte Expression
c Expression
t Expression
e -> (Expression -> Expression -> Expression -> Expression)
-> Expression -> Expression -> Expression -> CcInfo Expression
forall a.
(Expression -> Expression -> Expression -> a)
-> Expression -> Expression -> Expression -> CcInfo a
cc3 Expression -> Expression -> Expression -> Expression
Ast.ExprIte Expression
c Expression
t Expression
e
  Ast.ExprLetIn Declaration
decl Expression
expr -> case Declaration
decl of
    Ast.DeclVar Identifier'
ident Expression
value -> do
      Declaration
decl' <- (Expression -> Declaration)
-> Expression -> ReaderT Env Identity Declaration
forall a. (Expression -> a) -> Expression -> CcInfo a
cc1 (Identifier' -> Expression -> Declaration
Ast.DeclVar Identifier'
ident) Expression
value
      (Expression -> Expression) -> Expression -> CcInfo Expression
forall a. (Expression -> a) -> Expression -> CcInfo a
cc1 (Declaration -> Expression -> Expression
Ast.ExprLetIn Declaration
decl') Expression
expr
    Ast.DeclFun Identifier'
ident IsRec
isRec (Ast.Fun NonEmpty Identifier'
params Expression
body) -> do
      Set Identifier'
gs <- (Env -> Set Identifier') -> ReaderT Env Identity (Set Identifier')
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> Set Identifier'
globalIds
      let fv :: [Identifier']
fv = Set Identifier' -> [Identifier']
forall a. Set a -> [a]
Set.toList (Set Identifier' -> [Identifier'])
-> Set Identifier' -> [Identifier']
forall a b. (a -> b) -> a -> b
$ Expression -> Set Identifier'
findFv Expression
body Set Identifier' -> Set Identifier' -> Set Identifier'
forall a. Ord a => Set a -> Set a -> Set a
\\ (NonEmpty Identifier' -> Set Identifier'
forall a. Ord a => NonEmpty a -> Set a
toSet (Identifier'
ident Identifier' -> NonEmpty Identifier' -> NonEmpty Identifier'
forall a. a -> NonEmpty a -> NonEmpty a
<| NonEmpty Identifier'
params) Set Identifier' -> Set Identifier' -> Set Identifier'
forall a. Semigroup a => a -> a -> a
<> Set Identifier'
gs)

      let withNewMappings :: ReaderT Env Identity a -> ReaderT Env Identity a
withNewMappings =
            if [Identifier'] -> IsRec
forall (f :: * -> *) a. Foldable f => f a -> IsRec
notNull [Identifier']
fv
              then (Env -> Env) -> ReaderT Env Identity a -> ReaderT Env Identity a
forall a.
(Env -> Env) -> ReaderT Env Identity a -> ReaderT Env Identity a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local ((Env -> Env) -> ReaderT Env Identity a -> ReaderT Env Identity a)
-> (Env -> Env) -> ReaderT Env Identity a -> ReaderT Env Identity a
forall a b. (a -> b) -> a -> b
$ \Env
env ->
                let app :: Expression
app = Expression -> [Expression] -> Expression
forall (t :: * -> *).
Foldable t =>
Expression -> t Expression -> Expression
apply (Identifier' -> Expression
Ast.ExprId Identifier'
ident) (Identifier' -> Expression
Ast.ExprId (Identifier' -> Expression) -> [Identifier'] -> [Expression]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Identifier']
fv)
                    ms :: Map Identifier' Expression
ms = Identifier'
-> Expression
-> Map Identifier' Expression
-> Map Identifier' Expression
forall k a. Ord k => k -> a -> Map k a -> Map k a
Map.insert Identifier'
ident Expression
app (Env -> Map Identifier' Expression
declIdMappings Env
env)
                 in Env
env {declIdMappings :: Map Identifier' Expression
declIdMappings = Map Identifier' Expression
ms}
              else ReaderT Env Identity a -> ReaderT Env Identity a
forall a. a -> a
id

      CcInfo Expression -> CcInfo Expression
forall {a}. ReaderT Env Identity a -> ReaderT Env Identity a
withNewMappings (CcInfo Expression -> CcInfo Expression)
-> CcInfo Expression -> CcInfo Expression
forall a b. (a -> b) -> a -> b
$ do
        Fun
closedFun <- (Expression -> Fun) -> Expression -> CcInfo Fun
forall a. (Expression -> a) -> Expression -> CcInfo a
cc1 (NonEmpty Identifier' -> Expression -> Fun
Ast.Fun ([Identifier'] -> NonEmpty Identifier' -> NonEmpty Identifier'
forall a. [a] -> NonEmpty a -> NonEmpty a
prependList [Identifier']
fv NonEmpty Identifier'
params)) Expression
body
        let decl' :: Declaration
decl' = Identifier' -> IsRec -> Fun -> Declaration
Ast.DeclFun Identifier'
ident IsRec
isRec Fun
closedFun
        (Expression -> Expression) -> Expression -> CcInfo Expression
forall a. (Expression -> a) -> Expression -> CcInfo a
cc1 (Declaration -> Expression -> Expression
Ast.ExprLetIn Declaration
decl') Expression
expr
  Ast.ExprFun (Ast.Fun NonEmpty Identifier'
params Expression
body) -> do
    Set Identifier'
gs <- (Env -> Set Identifier') -> ReaderT Env Identity (Set Identifier')
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks Env -> Set Identifier'
globalIds
    let fv :: [Identifier']
fv = Set Identifier' -> [Identifier']
forall a. Set a -> [a]
Set.toList (Set Identifier' -> [Identifier'])
-> Set Identifier' -> [Identifier']
forall a b. (a -> b) -> a -> b
$ Expression -> Set Identifier'
findFv Expression
body Set Identifier' -> Set Identifier' -> Set Identifier'
forall a. Ord a => Set a -> Set a -> Set a
\\ (NonEmpty Identifier' -> Set Identifier'
forall a. Ord a => NonEmpty a -> Set a
toSet NonEmpty Identifier'
params Set Identifier' -> Set Identifier' -> Set Identifier'
forall a. Semigroup a => a -> a -> a
<> Set Identifier'
gs)

    Fun
closedFun <- (Expression -> Fun) -> Expression -> CcInfo Fun
forall a. (Expression -> a) -> Expression -> CcInfo a
cc1 (NonEmpty Identifier' -> Expression -> Fun
Ast.Fun ([Identifier'] -> NonEmpty Identifier' -> NonEmpty Identifier'
forall a. [a] -> NonEmpty a -> NonEmpty a
prependList [Identifier']
fv NonEmpty Identifier'
params)) Expression
body
    Expression -> CcInfo Expression
forall a. a -> ReaderT Env Identity a
forall (m :: * -> *) a. Monad m => a -> m a
return (Expression -> CcInfo Expression)
-> Expression -> CcInfo Expression
forall a b. (a -> b) -> a -> b
$ Expression -> [Expression] -> Expression
forall (t :: * -> *).
Foldable t =>
Expression -> t Expression -> Expression
apply (Fun -> Expression
Ast.ExprFun Fun
closedFun) (Identifier' -> Expression
Ast.ExprId (Identifier' -> Expression) -> [Identifier'] -> [Expression]
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> [Identifier']
fv)
  where
    apply :: (Foldable t) => Ast.Expression -> t Ast.Expression -> Ast.Expression
    apply :: forall (t :: * -> *).
Foldable t =>
Expression -> t Expression -> Expression
apply = (Expression -> Expression -> Expression)
-> Expression -> t Expression -> Expression
forall b a. (b -> a -> b) -> b -> t a -> b
forall (t :: * -> *) b a.
Foldable t =>
(b -> a -> b) -> b -> t a -> b
foldl' Expression -> Expression -> Expression
Ast.ExprApp

findFv :: Ast.Expression -> Set Ast.Identifier'
findFv :: Expression -> Set Identifier'
findFv = \case
  Ast.ExprId Identifier'
ident -> Identifier' -> Set Identifier'
forall a. a -> Set a
Set.singleton Identifier'
ident
  Ast.ExprPrimVal PrimitiveValue
_ -> Set Identifier'
forall a. Set a
Set.empty
  Ast.ExprBinOp BinaryOperator
_ Expression
lhs Expression
rhs -> [Expression] -> Set Identifier'
findFv' [Expression
lhs, Expression
rhs]
  Ast.ExprUnOp UnaryOperator
_ Expression
x -> Expression -> Set Identifier'
findFv Expression
x
  Ast.ExprApp Expression
f Expression
arg -> [Expression] -> Set Identifier'
findFv' [Expression
f, Expression
arg]
  Ast.ExprIte Expression
c Expression
t Expression
e -> [Expression] -> Set Identifier'
findFv' [Expression
c, Expression
t, Expression
e]
  Ast.ExprLetIn Declaration
decl Expression
expr -> case Declaration
decl of
    Ast.DeclVar Identifier'
ident Expression
value ->
      let fv :: Set Identifier'
fv = [Expression] -> Set Identifier'
findFv' [Expression
value, Expression
expr]
       in Identifier' -> Set Identifier' -> Set Identifier'
forall a. Ord a => a -> Set a -> Set a
Set.delete Identifier'
ident Set Identifier'
fv
    Ast.DeclFun Identifier'
ident IsRec
_ (Ast.Fun NonEmpty Identifier'
params Expression
body) ->
      let fv :: Set Identifier'
fv = [Expression] -> Set Identifier'
findFv' [Expression
body, Expression
expr]
       in Set Identifier'
fv Set Identifier' -> Set Identifier' -> Set Identifier'
forall a. Ord a => Set a -> Set a -> Set a
\\ NonEmpty Identifier' -> Set Identifier'
forall a. Ord a => NonEmpty a -> Set a
toSet (Identifier'
ident Identifier' -> NonEmpty Identifier' -> NonEmpty Identifier'
forall a. a -> NonEmpty a -> NonEmpty a
<| NonEmpty Identifier'
params)
  Ast.ExprFun (Ast.Fun NonEmpty Identifier'
params Expression
body) -> Expression -> Set Identifier'
findFv Expression
body Set Identifier' -> Set Identifier' -> Set Identifier'
forall a. Ord a => Set a -> Set a -> Set a
\\ NonEmpty Identifier' -> Set Identifier'
forall a. Ord a => NonEmpty a -> Set a
toSet NonEmpty Identifier'
params
  where
    findFv' :: [Ast.Expression] -> Set Ast.Identifier'
    findFv' :: [Expression] -> Set Identifier'
findFv' = (Base [Expression] (Set Identifier') -> Set Identifier')
-> [Expression] -> Set Identifier'
forall t a. Recursive t => (Base t a -> a) -> t -> a
forall a. (Base [Expression] a -> a) -> [Expression] -> a
cata ListF Expression (Set Identifier') -> Set Identifier'
Base [Expression] (Set Identifier') -> Set Identifier'
merger
      where
        merger :: ListF Expression (Set Identifier') -> Set Identifier'
merger (Cons Expression
x Set Identifier'
accum) = Set Identifier'
accum Set Identifier' -> Set Identifier' -> Set Identifier'
forall a. Semigroup a => a -> a -> a
<> Expression -> Set Identifier'
findFv Expression
x
        merger ListF Expression (Set Identifier')
Nil = Set Identifier'
forall a. Set a
Set.empty

-- ** Collection Utils

toSet :: (Ord a) => NonEmpty a -> Set a
toSet :: forall a. Ord a => NonEmpty a -> Set a
toSet = [a] -> Set a
forall a. Ord a => [a] -> Set a
Set.fromList ([a] -> Set a) -> (NonEmpty a -> [a]) -> NonEmpty a -> Set a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. NonEmpty a -> [a]
forall a. NonEmpty a -> [a]
NE.toList

prependList :: [a] -> NonEmpty a -> NonEmpty a
prependList :: forall a. [a] -> NonEmpty a -> NonEmpty a
prependList [] NonEmpty a
ne = NonEmpty a
ne
prependList [a]
l NonEmpty a
ne = [a] -> NonEmpty a
forall a. HasCallStack => [a] -> NonEmpty a
NE.fromList [a]
l NonEmpty a -> NonEmpty a -> NonEmpty a
forall a. Semigroup a => a -> a -> a
<> NonEmpty a
ne

-- ** Utils

cc1 ::
  (Ast.Expression -> a) ->
  (Ast.Expression -> CcInfo a)
cc1 :: forall a. (Expression -> a) -> Expression -> CcInfo a
cc1 = (Expression -> CcInfo Expression)
-> (Expression -> a) -> Expression -> ReaderT Env Identity a
forall (m :: * -> *) a' a b.
Monad m =>
(a' -> m a) -> (a -> b) -> a' -> m b
liftM1' Expression -> CcInfo Expression
ccExpr

cc2 ::
  (Ast.Expression -> Ast.Expression -> a) ->
  (Ast.Expression -> Ast.Expression -> CcInfo a)
cc2 :: forall a.
(Expression -> Expression -> a)
-> Expression -> Expression -> CcInfo a
cc2 = (Expression -> CcInfo Expression)
-> (Expression -> Expression -> a)
-> Expression
-> Expression
-> ReaderT Env Identity a
forall (m :: * -> *) a' a b.
Monad m =>
(a' -> m a) -> (a -> a -> b) -> a' -> a' -> m b
liftM2' Expression -> CcInfo Expression
ccExpr

cc3 ::
  (Ast.Expression -> Ast.Expression -> Ast.Expression -> a) ->
  (Ast.Expression -> Ast.Expression -> Ast.Expression -> CcInfo a)
cc3 :: forall a.
(Expression -> Expression -> Expression -> a)
-> Expression -> Expression -> Expression -> CcInfo a
cc3 = (Expression -> CcInfo Expression)
-> (Expression -> Expression -> Expression -> a)
-> Expression
-> Expression
-> Expression
-> ReaderT Env Identity a
forall (m :: * -> *) a' a b.
Monad m =>
(a' -> m a) -> (a -> a -> a -> b) -> a' -> a' -> a' -> m b
liftM3' Expression -> CcInfo Expression
ccExpr