{-# LANGUAGE DeriveAnyClass #-}
{-# LANGUAGE DeriveGeneric #-}
{-# LANGUAGE DeriveTraversable #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE StandaloneDeriving #-}
{-# OPTIONS_GHC -Wno-orphans #-}

module TypeChecker.HindleyMilner
  ( Infer,
    TypeError (..),
    UType,
    Polytype,
    applyBindings,
    generalize,
    toPolytype,
    toUType,
    withBinding,
    fresh,
    Poly (..),
    UTerm (UTVar, UTUnit, UTBool, UTInt, UTFun),
    (=:=),
    lookup,
    TypeF (..),
    mkVarName,
  )
where

import Control.Monad.Except
import Control.Monad.Reader
import Control.Unification hiding (applyBindings, (=:=))
import qualified Control.Unification as U
import Control.Unification.IntVar
import Data.Foldable (fold)
import Data.Functor.Fixedpoint
import Data.Functor.Identity
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe
import Data.Set (Set, (\\))
import qualified Data.Set as S
import Data.Text (pack)
import GHC.Generics (Generic1)
import qualified Trees.Common as L -- Lang
import Prelude hiding (lookup)

-- * Type

type Type = Fix TypeF

data TypeF a
  = TVarF L.Identifier
  | TUnitF
  | TBoolF
  | TIntF
  | TFunF a a
  deriving (Int -> TypeF a -> ShowS
[TypeF a] -> ShowS
TypeF a -> String
(Int -> TypeF a -> ShowS)
-> (TypeF a -> String) -> ([TypeF a] -> ShowS) -> Show (TypeF a)
forall a. Show a => Int -> TypeF a -> ShowS
forall a. Show a => [TypeF a] -> ShowS
forall a. Show a => TypeF a -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall a. Show a => Int -> TypeF a -> ShowS
showsPrec :: Int -> TypeF a -> ShowS
$cshow :: forall a. Show a => TypeF a -> String
show :: TypeF a -> String
$cshowList :: forall a. Show a => [TypeF a] -> ShowS
showList :: [TypeF a] -> ShowS
Show, TypeF a -> TypeF a -> Bool
(TypeF a -> TypeF a -> Bool)
-> (TypeF a -> TypeF a -> Bool) -> Eq (TypeF a)
forall a. Eq a => TypeF a -> TypeF a -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall a. Eq a => TypeF a -> TypeF a -> Bool
== :: TypeF a -> TypeF a -> Bool
$c/= :: forall a. Eq a => TypeF a -> TypeF a -> Bool
/= :: TypeF a -> TypeF a -> Bool
Eq, (forall a b. (a -> b) -> TypeF a -> TypeF b)
-> (forall a b. a -> TypeF b -> TypeF a) -> Functor TypeF
forall a b. a -> TypeF b -> TypeF a
forall a b. (a -> b) -> TypeF a -> TypeF b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> TypeF a -> TypeF b
fmap :: forall a b. (a -> b) -> TypeF a -> TypeF b
$c<$ :: forall a b. a -> TypeF b -> TypeF a
<$ :: forall a b. a -> TypeF b -> TypeF a
Functor, (forall m. Monoid m => TypeF m -> m)
-> (forall m a. Monoid m => (a -> m) -> TypeF a -> m)
-> (forall m a. Monoid m => (a -> m) -> TypeF a -> m)
-> (forall a b. (a -> b -> b) -> b -> TypeF a -> b)
-> (forall a b. (a -> b -> b) -> b -> TypeF a -> b)
-> (forall b a. (b -> a -> b) -> b -> TypeF a -> b)
-> (forall b a. (b -> a -> b) -> b -> TypeF a -> b)
-> (forall a. (a -> a -> a) -> TypeF a -> a)
-> (forall a. (a -> a -> a) -> TypeF a -> a)
-> (forall a. TypeF a -> [a])
-> (forall a. TypeF a -> Bool)
-> (forall a. TypeF a -> Int)
-> (forall a. Eq a => a -> TypeF a -> Bool)
-> (forall a. Ord a => TypeF a -> a)
-> (forall a. Ord a => TypeF a -> a)
-> (forall a. Num a => TypeF a -> a)
-> (forall a. Num a => TypeF a -> a)
-> Foldable TypeF
forall a. Eq a => a -> TypeF a -> Bool
forall a. Num a => TypeF a -> a
forall a. Ord a => TypeF a -> a
forall m. Monoid m => TypeF m -> m
forall a. TypeF a -> Bool
forall a. TypeF a -> Int
forall a. TypeF a -> [a]
forall a. (a -> a -> a) -> TypeF a -> a
forall m a. Monoid m => (a -> m) -> TypeF a -> m
forall b a. (b -> a -> b) -> b -> TypeF a -> b
forall a b. (a -> b -> b) -> b -> TypeF a -> b
forall (t :: * -> *).
(forall m. Monoid m => t m -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall m a. Monoid m => (a -> m) -> t a -> m)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall a b. (a -> b -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall b a. (b -> a -> b) -> b -> t a -> b)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. (a -> a -> a) -> t a -> a)
-> (forall a. t a -> [a])
-> (forall a. t a -> Bool)
-> (forall a. t a -> Int)
-> (forall a. Eq a => a -> t a -> Bool)
-> (forall a. Ord a => t a -> a)
-> (forall a. Ord a => t a -> a)
-> (forall a. Num a => t a -> a)
-> (forall a. Num a => t a -> a)
-> Foldable t
$cfold :: forall m. Monoid m => TypeF m -> m
fold :: forall m. Monoid m => TypeF m -> m
$cfoldMap :: forall m a. Monoid m => (a -> m) -> TypeF a -> m
foldMap :: forall m a. Monoid m => (a -> m) -> TypeF a -> m
$cfoldMap' :: forall m a. Monoid m => (a -> m) -> TypeF a -> m
foldMap' :: forall m a. Monoid m => (a -> m) -> TypeF a -> m
$cfoldr :: forall a b. (a -> b -> b) -> b -> TypeF a -> b
foldr :: forall a b. (a -> b -> b) -> b -> TypeF a -> b
$cfoldr' :: forall a b. (a -> b -> b) -> b -> TypeF a -> b
foldr' :: forall a b. (a -> b -> b) -> b -> TypeF a -> b
$cfoldl :: forall b a. (b -> a -> b) -> b -> TypeF a -> b
foldl :: forall b a. (b -> a -> b) -> b -> TypeF a -> b
$cfoldl' :: forall b a. (b -> a -> b) -> b -> TypeF a -> b
foldl' :: forall b a. (b -> a -> b) -> b -> TypeF a -> b
$cfoldr1 :: forall a. (a -> a -> a) -> TypeF a -> a
foldr1 :: forall a. (a -> a -> a) -> TypeF a -> a
$cfoldl1 :: forall a. (a -> a -> a) -> TypeF a -> a
foldl1 :: forall a. (a -> a -> a) -> TypeF a -> a
$ctoList :: forall a. TypeF a -> [a]
toList :: forall a. TypeF a -> [a]
$cnull :: forall a. TypeF a -> Bool
null :: forall a. TypeF a -> Bool
$clength :: forall a. TypeF a -> Int
length :: forall a. TypeF a -> Int
$celem :: forall a. Eq a => a -> TypeF a -> Bool
elem :: forall a. Eq a => a -> TypeF a -> Bool
$cmaximum :: forall a. Ord a => TypeF a -> a
maximum :: forall a. Ord a => TypeF a -> a
$cminimum :: forall a. Ord a => TypeF a -> a
minimum :: forall a. Ord a => TypeF a -> a
$csum :: forall a. Num a => TypeF a -> a
sum :: forall a. Num a => TypeF a -> a
$cproduct :: forall a. Num a => TypeF a -> a
product :: forall a. Num a => TypeF a -> a
Foldable, Functor TypeF
Foldable TypeF
Functor TypeF
-> Foldable TypeF
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> TypeF a -> f (TypeF b))
-> (forall (f :: * -> *) a.
    Applicative f =>
    TypeF (f a) -> f (TypeF a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> TypeF a -> m (TypeF b))
-> (forall (m :: * -> *) a. Monad m => TypeF (m a) -> m (TypeF a))
-> Traversable TypeF
forall (t :: * -> *).
Functor t
-> Foldable t
-> (forall (f :: * -> *) a b.
    Applicative f =>
    (a -> f b) -> t a -> f (t b))
-> (forall (f :: * -> *) a. Applicative f => t (f a) -> f (t a))
-> (forall (m :: * -> *) a b.
    Monad m =>
    (a -> m b) -> t a -> m (t b))
-> (forall (m :: * -> *) a. Monad m => t (m a) -> m (t a))
-> Traversable t
forall (m :: * -> *) a. Monad m => TypeF (m a) -> m (TypeF a)
forall (f :: * -> *) a. Applicative f => TypeF (f a) -> f (TypeF a)
forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> TypeF a -> m (TypeF b)
forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TypeF a -> f (TypeF b)
$ctraverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TypeF a -> f (TypeF b)
traverse :: forall (f :: * -> *) a b.
Applicative f =>
(a -> f b) -> TypeF a -> f (TypeF b)
$csequenceA :: forall (f :: * -> *) a. Applicative f => TypeF (f a) -> f (TypeF a)
sequenceA :: forall (f :: * -> *) a. Applicative f => TypeF (f a) -> f (TypeF a)
$cmapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> TypeF a -> m (TypeF b)
mapM :: forall (m :: * -> *) a b.
Monad m =>
(a -> m b) -> TypeF a -> m (TypeF b)
$csequence :: forall (m :: * -> *) a. Monad m => TypeF (m a) -> m (TypeF a)
sequence :: forall (m :: * -> *) a. Monad m => TypeF (m a) -> m (TypeF a)
Traversable, (forall a. TypeF a -> Rep1 TypeF a)
-> (forall a. Rep1 TypeF a -> TypeF a) -> Generic1 TypeF
forall a. Rep1 TypeF a -> TypeF a
forall a. TypeF a -> Rep1 TypeF a
forall k (f :: k -> *).
(forall (a :: k). f a -> Rep1 f a)
-> (forall (a :: k). Rep1 f a -> f a) -> Generic1 f
$cfrom1 :: forall a. TypeF a -> Rep1 TypeF a
from1 :: forall a. TypeF a -> Rep1 TypeF a
$cto1 :: forall a. Rep1 TypeF a -> TypeF a
to1 :: forall a. Rep1 TypeF a -> TypeF a
Generic1, Traversable TypeF
Traversable TypeF
-> (forall a.
    TypeF a -> TypeF a -> Maybe (TypeF (Either a (a, a))))
-> Unifiable TypeF
forall a. TypeF a -> TypeF a -> Maybe (TypeF (Either a (a, a)))
forall (t :: * -> *).
Traversable t
-> (forall a. t a -> t a -> Maybe (t (Either a (a, a))))
-> Unifiable t
$czipMatch :: forall a. TypeF a -> TypeF a -> Maybe (TypeF (Either a (a, a)))
zipMatch :: forall a. TypeF a -> TypeF a -> Maybe (TypeF (Either a (a, a)))
Unifiable)

-- * UType

type UType = UTerm TypeF IntVar

pattern UTVar :: L.Identifier -> UType
pattern $mUTVar :: forall {r}. UType -> (Identifier -> r) -> ((# #) -> r) -> r
$bUTVar :: Identifier -> UType
UTVar var = UTerm (TVarF var)

pattern UTUnit :: UType
pattern $mUTUnit :: forall {r}. UType -> ((# #) -> r) -> ((# #) -> r) -> r
$bUTUnit :: UType
UTUnit = UTerm TUnitF

pattern UTBool :: UType
pattern $mUTBool :: forall {r}. UType -> ((# #) -> r) -> ((# #) -> r) -> r
$bUTBool :: UType
UTBool = UTerm TBoolF

pattern UTInt :: UType
pattern $mUTInt :: forall {r}. UType -> ((# #) -> r) -> ((# #) -> r) -> r
$bUTInt :: UType
UTInt = UTerm TIntF

pattern UTFun :: UType -> UType -> UType
pattern $mUTFun :: forall {r}. UType -> (UType -> UType -> r) -> ((# #) -> r) -> r
$bUTFun :: UType -> UType -> UType
UTFun funT argT = UTerm (TFunF funT argT)

-- * Polytype

data Poly t = Forall [L.Identifier] t
  deriving (Poly t -> Poly t -> Bool
(Poly t -> Poly t -> Bool)
-> (Poly t -> Poly t -> Bool) -> Eq (Poly t)
forall t. Eq t => Poly t -> Poly t -> Bool
forall a. (a -> a -> Bool) -> (a -> a -> Bool) -> Eq a
$c== :: forall t. Eq t => Poly t -> Poly t -> Bool
== :: Poly t -> Poly t -> Bool
$c/= :: forall t. Eq t => Poly t -> Poly t -> Bool
/= :: Poly t -> Poly t -> Bool
Eq, Int -> Poly t -> ShowS
[Poly t] -> ShowS
Poly t -> String
(Int -> Poly t -> ShowS)
-> (Poly t -> String) -> ([Poly t] -> ShowS) -> Show (Poly t)
forall t. Show t => Int -> Poly t -> ShowS
forall t. Show t => [Poly t] -> ShowS
forall t. Show t => Poly t -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall t. Show t => Int -> Poly t -> ShowS
showsPrec :: Int -> Poly t -> ShowS
$cshow :: forall t. Show t => Poly t -> String
show :: Poly t -> String
$cshowList :: forall t. Show t => [Poly t] -> ShowS
showList :: [Poly t] -> ShowS
Show, (forall a b. (a -> b) -> Poly a -> Poly b)
-> (forall a b. a -> Poly b -> Poly a) -> Functor Poly
forall a b. a -> Poly b -> Poly a
forall a b. (a -> b) -> Poly a -> Poly b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> Poly a -> Poly b
fmap :: forall a b. (a -> b) -> Poly a -> Poly b
$c<$ :: forall a b. a -> Poly b -> Poly a
<$ :: forall a b. a -> Poly b -> Poly a
Functor)

type Polytype = Poly Type

type UPolytype = Poly UType

-- * Converters

toUType :: L.Type -> UType
toUType :: Type -> UType
toUType = \case
  Type
L.TUnit -> UType
UTUnit
  Type
L.TBool -> UType
UTBool
  Type
L.TInt -> UType
UTInt
  L.TFun Type
funT Type
argT -> UType -> UType -> UType
UTFun (Type -> UType
toUType Type
funT) (Type -> UType
toUType Type
argT)

toPolytype :: UPolytype -> Polytype
toPolytype :: UPolytype -> Polytype
toPolytype = (UType -> Type) -> UPolytype -> Polytype
forall a b. (a -> b) -> Poly a -> Poly b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (Maybe Type -> Type
forall a. HasCallStack => Maybe a -> a
fromJust (Maybe Type -> Type) -> (UType -> Maybe Type) -> UType -> Type
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UType -> Maybe Type
forall (t :: * -> *) v. Traversable t => UTerm t v -> Maybe (Fix t)
freeze)

-- * Infer

type Infer = ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity))

type Ctx = Map L.Identifier UPolytype

lookup :: L.Identifier -> Infer UType
lookup :: Identifier -> Infer UType
lookup Identifier
var = do
  Maybe UPolytype
varUPT <- (Ctx -> Maybe UPolytype)
-> ReaderT
     Ctx
     (ExceptT TypeError (IntBindingT TypeF Identity))
     (Maybe UPolytype)
forall r (m :: * -> *) a. MonadReader r m => (r -> a) -> m a
asks ((Ctx -> Maybe UPolytype)
 -> ReaderT
      Ctx
      (ExceptT TypeError (IntBindingT TypeF Identity))
      (Maybe UPolytype))
-> (Ctx -> Maybe UPolytype)
-> ReaderT
     Ctx
     (ExceptT TypeError (IntBindingT TypeF Identity))
     (Maybe UPolytype)
forall a b. (a -> b) -> a -> b
$ Identifier -> Ctx -> Maybe UPolytype
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup Identifier
var
  Infer UType
-> (UPolytype -> Infer UType) -> Maybe UPolytype -> Infer UType
forall b a. b -> (a -> b) -> Maybe a -> b
maybe (TypeError -> Infer UType
forall a.
TypeError
-> ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) a
forall e (m :: * -> *) a. MonadError e m => e -> m a
throwError (TypeError -> Infer UType) -> TypeError -> Infer UType
forall a b. (a -> b) -> a -> b
$ Identifier -> TypeError
UnboundVar Identifier
var) UPolytype -> Infer UType
instantiate Maybe UPolytype
varUPT
  where
    instantiate :: UPolytype -> Infer UType
    instantiate :: UPolytype -> Infer UType
instantiate (Forall [Identifier]
xs UType
uty) = do
      [UType]
xs' <- (Identifier -> Infer UType)
-> [Identifier]
-> ReaderT
     Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) [UType]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM (Infer UType -> Identifier -> Infer UType
forall a b. a -> b -> a
const Infer UType
fresh) [Identifier]
xs
      UType -> Infer UType
forall a.
a -> ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) a
forall (m :: * -> *) a. Monad m => a -> m a
return (UType -> Infer UType) -> UType -> Infer UType
forall a b. (a -> b) -> a -> b
$ Map (Either Identifier IntVar) UType -> UType -> UType
substU ([(Either Identifier IntVar, UType)]
-> Map (Either Identifier IntVar) UType
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([Either Identifier IntVar]
-> [UType] -> [(Either Identifier IntVar, UType)]
forall a b. [a] -> [b] -> [(a, b)]
zip ((Identifier -> Either Identifier IntVar)
-> [Identifier] -> [Either Identifier IntVar]
forall a b. (a -> b) -> [a] -> [b]
map Identifier -> Either Identifier IntVar
forall a b. a -> Either a b
Left [Identifier]
xs) [UType]
xs')) UType
uty

withBinding :: (MonadReader Ctx m) => L.Identifier -> UPolytype -> m a -> m a
withBinding :: forall (m :: * -> *) a.
MonadReader Ctx m =>
Identifier -> UPolytype -> m a -> m a
withBinding Identifier
x UPolytype
ty = (Ctx -> Ctx) -> m a -> m a
forall a. (Ctx -> Ctx) -> m a -> m a
forall r (m :: * -> *) a. MonadReader r m => (r -> r) -> m a -> m a
local (Identifier -> UPolytype -> Ctx -> Ctx
forall k a. Ord k => k -> a -> Map k a -> Map k a
M.insert Identifier
x UPolytype
ty)

ucata :: (Functor t) => (v -> a) -> (t a -> a) -> UTerm t v -> a
ucata :: forall (t :: * -> *) v a.
Functor t =>
(v -> a) -> (t a -> a) -> UTerm t v -> a
ucata v -> a
f t a -> a
_ (UVar v
v) = v -> a
f v
v
ucata v -> a
f t a -> a
g (UTerm t (UTerm t v)
t) = t a -> a
g ((UTerm t v -> a) -> t (UTerm t v) -> t a
forall a b. (a -> b) -> t a -> t b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ((v -> a) -> (t a -> a) -> UTerm t v -> a
forall (t :: * -> *) v a.
Functor t =>
(v -> a) -> (t a -> a) -> UTerm t v -> a
ucata v -> a
f t a -> a
g) t (UTerm t v)
t)

deriving instance Ord IntVar

-- * FreeVars

class FreeVars a where
  freeVars :: a -> Infer (Set (Either L.Identifier IntVar))

instance FreeVars UType where
  freeVars :: UType -> Infer (Set (Either Identifier IntVar))
freeVars UType
ut = do
    Set (Either Identifier IntVar)
fuvs <- ([IntVar] -> Set (Either Identifier IntVar))
-> ReaderT
     Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) [IntVar]
-> Infer (Set (Either Identifier IntVar))
forall a b.
(a -> b)
-> ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) a
-> ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap ([Either Identifier IntVar] -> Set (Either Identifier IntVar)
forall a. Ord a => [a] -> Set a
S.fromList ([Either Identifier IntVar] -> Set (Either Identifier IntVar))
-> ([IntVar] -> [Either Identifier IntVar])
-> [IntVar]
-> Set (Either Identifier IntVar)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (IntVar -> Either Identifier IntVar)
-> [IntVar] -> [Either Identifier IntVar]
forall a b. (a -> b) -> [a] -> [b]
map IntVar -> Either Identifier IntVar
forall a b. b -> Either a b
Right) (ReaderT
   Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) [IntVar]
 -> Infer (Set (Either Identifier IntVar)))
-> (IntBindingT TypeF Identity [IntVar]
    -> ReaderT
         Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) [IntVar])
-> IntBindingT TypeF Identity [IntVar]
-> Infer (Set (Either Identifier IntVar))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ExceptT TypeError (IntBindingT TypeF Identity) [IntVar]
-> ReaderT
     Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) [IntVar]
forall (m :: * -> *) a. Monad m => m a -> ReaderT Ctx m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ExceptT TypeError (IntBindingT TypeF Identity) [IntVar]
 -> ReaderT
      Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) [IntVar])
-> (IntBindingT TypeF Identity [IntVar]
    -> ExceptT TypeError (IntBindingT TypeF Identity) [IntVar])
-> IntBindingT TypeF Identity [IntVar]
-> ReaderT
     Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) [IntVar]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. IntBindingT TypeF Identity [IntVar]
-> ExceptT TypeError (IntBindingT TypeF Identity) [IntVar]
forall (m :: * -> *) a. Monad m => m a -> ExceptT TypeError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IntBindingT TypeF Identity [IntVar]
 -> Infer (Set (Either Identifier IntVar)))
-> IntBindingT TypeF Identity [IntVar]
-> Infer (Set (Either Identifier IntVar))
forall a b. (a -> b) -> a -> b
$ UType -> IntBindingT TypeF Identity [IntVar]
forall (t :: * -> *) v (m :: * -> *).
BindingMonad t v m =>
UTerm t v -> m [v]
getFreeVars UType
ut
    let ftvs :: Set (Either Identifier IntVar)
ftvs =
          (IntVar -> Set (Either Identifier IntVar))
-> (TypeF (Set (Either Identifier IntVar))
    -> Set (Either Identifier IntVar))
-> UType
-> Set (Either Identifier IntVar)
forall (t :: * -> *) v a.
Functor t =>
(v -> a) -> (t a -> a) -> UTerm t v -> a
ucata
            (Set (Either Identifier IntVar)
-> IntVar -> Set (Either Identifier IntVar)
forall a b. a -> b -> a
const Set (Either Identifier IntVar)
forall a. Set a
S.empty)
            (\case TVarF Identifier
x -> Either Identifier IntVar -> Set (Either Identifier IntVar)
forall a. a -> Set a
S.singleton (Identifier -> Either Identifier IntVar
forall a b. a -> Either a b
Left Identifier
x); TypeF (Set (Either Identifier IntVar))
f -> TypeF (Set (Either Identifier IntVar))
-> Set (Either Identifier IntVar)
forall m. Monoid m => TypeF m -> m
forall (t :: * -> *) m. (Foldable t, Monoid m) => t m -> m
fold TypeF (Set (Either Identifier IntVar))
f)
            UType
ut
    Set (Either Identifier IntVar)
-> Infer (Set (Either Identifier IntVar))
forall a.
a -> ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) a
forall (m :: * -> *) a. Monad m => a -> m a
return (Set (Either Identifier IntVar)
 -> Infer (Set (Either Identifier IntVar)))
-> Set (Either Identifier IntVar)
-> Infer (Set (Either Identifier IntVar))
forall a b. (a -> b) -> a -> b
$ Set (Either Identifier IntVar)
fuvs Set (Either Identifier IntVar)
-> Set (Either Identifier IntVar) -> Set (Either Identifier IntVar)
forall a. Ord a => Set a -> Set a -> Set a
`S.union` Set (Either Identifier IntVar)
ftvs

instance FreeVars UPolytype where
  freeVars :: UPolytype -> Infer (Set (Either Identifier IntVar))
freeVars (Forall [Identifier]
xs UType
ut) = (Set (Either Identifier IntVar)
-> Set (Either Identifier IntVar) -> Set (Either Identifier IntVar)
forall a. Ord a => Set a -> Set a -> Set a
\\ [Either Identifier IntVar] -> Set (Either Identifier IntVar)
forall a. Ord a => [a] -> Set a
S.fromList ((Identifier -> Either Identifier IntVar)
-> [Identifier] -> [Either Identifier IntVar]
forall a b. (a -> b) -> [a] -> [b]
map Identifier -> Either Identifier IntVar
forall a b. a -> Either a b
Left [Identifier]
xs)) (Set (Either Identifier IntVar) -> Set (Either Identifier IntVar))
-> Infer (Set (Either Identifier IntVar))
-> Infer (Set (Either Identifier IntVar))
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> UType -> Infer (Set (Either Identifier IntVar))
forall a. FreeVars a => a -> Infer (Set (Either Identifier IntVar))
freeVars UType
ut

instance FreeVars Ctx where
  freeVars :: Ctx -> Infer (Set (Either Identifier IntVar))
freeVars = ([Set (Either Identifier IntVar)]
 -> Set (Either Identifier IntVar))
-> ReaderT
     Ctx
     (ExceptT TypeError (IntBindingT TypeF Identity))
     [Set (Either Identifier IntVar)]
-> Infer (Set (Either Identifier IntVar))
forall a b.
(a -> b)
-> ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) a
-> ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap [Set (Either Identifier IntVar)] -> Set (Either Identifier IntVar)
forall (f :: * -> *) a. (Foldable f, Ord a) => f (Set a) -> Set a
S.unions (ReaderT
   Ctx
   (ExceptT TypeError (IntBindingT TypeF Identity))
   [Set (Either Identifier IntVar)]
 -> Infer (Set (Either Identifier IntVar)))
-> (Ctx
    -> ReaderT
         Ctx
         (ExceptT TypeError (IntBindingT TypeF Identity))
         [Set (Either Identifier IntVar)])
-> Ctx
-> Infer (Set (Either Identifier IntVar))
forall b c a. (b -> c) -> (a -> b) -> a -> c
. (UPolytype -> Infer (Set (Either Identifier IntVar)))
-> [UPolytype]
-> ReaderT
     Ctx
     (ExceptT TypeError (IntBindingT TypeF Identity))
     [Set (Either Identifier IntVar)]
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> [a] -> m [b]
mapM UPolytype -> Infer (Set (Either Identifier IntVar))
forall a. FreeVars a => a -> Infer (Set (Either Identifier IntVar))
freeVars ([UPolytype]
 -> ReaderT
      Ctx
      (ExceptT TypeError (IntBindingT TypeF Identity))
      [Set (Either Identifier IntVar)])
-> (Ctx -> [UPolytype])
-> Ctx
-> ReaderT
     Ctx
     (ExceptT TypeError (IntBindingT TypeF Identity))
     [Set (Either Identifier IntVar)]
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Ctx -> [UPolytype]
forall k a. Map k a -> [a]
M.elems

fresh :: Infer UType
fresh :: Infer UType
fresh = IntVar -> UType
forall (t :: * -> *) v. v -> UTerm t v
UVar (IntVar -> UType)
-> ReaderT
     Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) IntVar
-> Infer UType
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ExceptT TypeError (IntBindingT TypeF Identity) IntVar
-> ReaderT
     Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) IntVar
forall (m :: * -> *) a. Monad m => m a -> ReaderT Ctx m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (IntBindingT TypeF Identity IntVar
-> ExceptT TypeError (IntBindingT TypeF Identity) IntVar
forall (m :: * -> *) a. Monad m => m a -> ExceptT TypeError m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift IntBindingT TypeF Identity IntVar
forall (t :: * -> *) v (m :: * -> *). BindingMonad t v m => m v
freeVar)

-- * Errors

data TypeError where
  Unreachable :: TypeError
  UnboundVar :: L.Identifier -> TypeError
  Infinite :: IntVar -> UType -> TypeError
  ImpossibleBinOpApplication :: UType -> UType -> TypeError
  ImpossibleUnOpApplication :: UType -> TypeError
  Mismatch :: TypeF UType -> TypeF UType -> TypeError
  deriving (Int -> TypeError -> ShowS
[TypeError] -> ShowS
TypeError -> String
(Int -> TypeError -> ShowS)
-> (TypeError -> String)
-> ([TypeError] -> ShowS)
-> Show TypeError
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: Int -> TypeError -> ShowS
showsPrec :: Int -> TypeError -> ShowS
$cshow :: TypeError -> String
show :: TypeError -> String
$cshowList :: [TypeError] -> ShowS
showList :: [TypeError] -> ShowS
Show)

instance Fallible TypeF IntVar TypeError where
  occursFailure :: IntVar -> UType -> TypeError
occursFailure = IntVar -> UType -> TypeError
Infinite
  mismatchFailure :: TypeF UType -> TypeF UType -> TypeError
mismatchFailure = TypeF UType -> TypeF UType -> TypeError
Mismatch

(=:=) :: UType -> UType -> Infer UType
UType
s =:= :: UType -> UType -> Infer UType
=:= UType
t = ExceptT TypeError (IntBindingT TypeF Identity) UType -> Infer UType
forall (m :: * -> *) a. Monad m => m a -> ReaderT Ctx m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ExceptT TypeError (IntBindingT TypeF Identity) UType
 -> Infer UType)
-> ExceptT TypeError (IntBindingT TypeF Identity) UType
-> Infer UType
forall a b. (a -> b) -> a -> b
$ UType
s UType
-> UType -> ExceptT TypeError (IntBindingT TypeF Identity) UType
forall (t :: * -> *) v (m :: * -> *) e (em :: (* -> *) -> * -> *).
(BindingMonad t v m, Fallible t v e, MonadTrans em, Functor (em m),
 MonadError e (em m)) =>
UTerm t v -> UTerm t v -> em m (UTerm t v)
U.=:= UType
t

applyBindings :: UType -> Infer UType
applyBindings :: UType -> Infer UType
applyBindings = ExceptT TypeError (IntBindingT TypeF Identity) UType -> Infer UType
forall (m :: * -> *) a. Monad m => m a -> ReaderT Ctx m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (ExceptT TypeError (IntBindingT TypeF Identity) UType
 -> Infer UType)
-> (UType -> ExceptT TypeError (IntBindingT TypeF Identity) UType)
-> UType
-> Infer UType
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UType -> ExceptT TypeError (IntBindingT TypeF Identity) UType
forall (t :: * -> *) v (m :: * -> *) e (em :: (* -> *) -> * -> *).
(BindingMonad t v m, Fallible t v e, MonadTrans em, Functor (em m),
 MonadError e (em m)) =>
UTerm t v -> em m (UTerm t v)
U.applyBindings

substU :: Map (Either L.Identifier IntVar) UType -> UType -> UType
substU :: Map (Either Identifier IntVar) UType -> UType -> UType
substU Map (Either Identifier IntVar) UType
m =
  (IntVar -> UType) -> (TypeF UType -> UType) -> UType -> UType
forall (t :: * -> *) v a.
Functor t =>
(v -> a) -> (t a -> a) -> UTerm t v -> a
ucata
    (\IntVar
v -> UType -> Maybe UType -> UType
forall a. a -> Maybe a -> a
fromMaybe (IntVar -> UType
forall (t :: * -> *) v. v -> UTerm t v
UVar IntVar
v) (Either Identifier IntVar
-> Map (Either Identifier IntVar) UType -> Maybe UType
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (IntVar -> Either Identifier IntVar
forall a b. b -> Either a b
Right IntVar
v) Map (Either Identifier IntVar) UType
m))
    ( \case
        TVarF Identifier
v -> UType -> Maybe UType -> UType
forall a. a -> Maybe a -> a
fromMaybe (Identifier -> UType
UTVar Identifier
v) (Either Identifier IntVar
-> Map (Either Identifier IntVar) UType -> Maybe UType
forall k a. Ord k => k -> Map k a -> Maybe a
M.lookup (Identifier -> Either Identifier IntVar
forall a b. a -> Either a b
Left Identifier
v) Map (Either Identifier IntVar) UType
m)
        TypeF UType
f -> TypeF UType -> UType
forall (t :: * -> *) v. t (UTerm t v) -> UTerm t v
UTerm TypeF UType
f
    )

mkVarName :: String -> IntVar -> L.Identifier
mkVarName :: String -> IntVar -> Identifier
mkVarName String
nm (IntVar Int
v) = String -> Identifier
pack (String
nm String -> ShowS
forall a. Semigroup a => a -> a -> a
<> Int -> String
forall a. Show a => a -> String
show (Int
v Int -> Int -> Int
forall a. Num a => a -> a -> a
+ (Int
forall a. Bounded a => a
maxBound :: Int) Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1))

generalize :: UType -> Infer UPolytype
generalize :: UType -> Infer UPolytype
generalize UType
uty = do
  UType
uty' <- UType -> Infer UType
applyBindings UType
uty
  Ctx
ctx <- ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) Ctx
forall r (m :: * -> *). MonadReader r m => m r
ask
  Set (Either Identifier IntVar)
tmFreeVars <- UType -> Infer (Set (Either Identifier IntVar))
forall a. FreeVars a => a -> Infer (Set (Either Identifier IntVar))
freeVars UType
uty'
  Set (Either Identifier IntVar)
ctxFreeVars <- Ctx -> Infer (Set (Either Identifier IntVar))
forall a. FreeVars a => a -> Infer (Set (Either Identifier IntVar))
freeVars Ctx
ctx
  let fvs :: [Either Identifier IntVar]
fvs = Set (Either Identifier IntVar) -> [Either Identifier IntVar]
forall a. Set a -> [a]
S.toList (Set (Either Identifier IntVar) -> [Either Identifier IntVar])
-> Set (Either Identifier IntVar) -> [Either Identifier IntVar]
forall a b. (a -> b) -> a -> b
$ Set (Either Identifier IntVar)
tmFreeVars Set (Either Identifier IntVar)
-> Set (Either Identifier IntVar) -> Set (Either Identifier IntVar)
forall a. Ord a => Set a -> Set a -> Set a
\\ Set (Either Identifier IntVar)
ctxFreeVars
      xs :: [Identifier]
xs = (Either Identifier IntVar -> Identifier)
-> [Either Identifier IntVar] -> [Identifier]
forall a b. (a -> b) -> [a] -> [b]
map ((Identifier -> Identifier)
-> (IntVar -> Identifier) -> Either Identifier IntVar -> Identifier
forall a c b. (a -> c) -> (b -> c) -> Either a b -> c
either Identifier -> Identifier
forall a. a -> a
id (String -> IntVar -> Identifier
mkVarName String
"a")) [Either Identifier IntVar]
fvs
  UPolytype -> Infer UPolytype
forall a.
a -> ReaderT Ctx (ExceptT TypeError (IntBindingT TypeF Identity)) a
forall (m :: * -> *) a. Monad m => a -> m a
return (UPolytype -> Infer UPolytype) -> UPolytype -> Infer UPolytype
forall a b. (a -> b) -> a -> b
$ [Identifier] -> UType -> UPolytype
forall t. [Identifier] -> t -> Poly t
Forall [Identifier]
xs (Map (Either Identifier IntVar) UType -> UType -> UType
substU ([(Either Identifier IntVar, UType)]
-> Map (Either Identifier IntVar) UType
forall k a. Ord k => [(k, a)] -> Map k a
M.fromList ([Either Identifier IntVar]
-> [UType] -> [(Either Identifier IntVar, UType)]
forall a b. [a] -> [b] -> [(a, b)]
zip [Either Identifier IntVar]
fvs ((Identifier -> UType) -> [Identifier] -> [UType]
forall a b. (a -> b) -> [a] -> [b]
map Identifier -> UType
UTVar [Identifier]
xs))) UType
uty')