Created: 2023-02-06 17:54
{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
module Main (main) where
import Data.Type.Equality
data UntypedExpr
= Int_ Int
| Bool_ Bool
| IsZero_ UntypedExpr
| Plus_ UntypedExpr UntypedExpr
| If_ UntypedExpr UntypedExpr UntypedExpr
deriving (Show)
data Expr :: * -> * where
Int :: Int -> Expr Int
Bool :: Bool -> Expr Bool
IsZero :: Expr Int -> Expr Bool
Plus :: Expr Int -> Expr Int -> Expr Int
If :: Expr Bool -> Expr a -> Expr a -> Expr a
deriving instance Show a => Show (Expr a)
data ExprType :: * -> * where
IntT :: ExprType Int
BoolT :: ExprType Bool
deriving instance Show a => Show (ExprType a)
data ExprAny :: * where
ExprAny :: ExprType a -> Expr a -> ExprAny
inferType :: UntypedExpr -> Maybe ExprAny
inferType (Int_ n) = Just $ ExprAny IntT (Int n)
inferType (Bool_ n) = Just $ ExprAny BoolT (Bool n)
inferType (IsZero_ (Int_ n)) = Just $ ExprAny BoolT (IsZero (Int n))
inferType (IsZero_ (Plus_ (Int_ n) (Int_ m))) =
Just $ ExprAny BoolT (IsZero (Plus (Int n) (Int m)))
inferType (Plus_ (Int_ n) (Int_ m)) =
Just $ ExprAny IntT (Plus (Int n) (Int m))
inferType (If_ (Bool_ n) a b) =
case (inferType a, inferType b) of
(Just (ExprAny typeA a_), Just (ExprAny typeB b_)) ->
case equalType typeA typeB of
Just Refl -> Just $ ExprAny typeA (If (Bool n) a_ b_)
Nothing -> Nothing
_ -> Nothing
inferType _ = Nothing
equalType :: ExprType a -> ExprType b -> Maybe (a :~: b)
equalType IntT IntT = Just Refl
equalType BoolT BoolT = Just Refl
equalType _ _ = Nothing
typeCheck :: ExprType a -> UntypedExpr -> Maybe (Expr a)
typeCheck r1 e = case inferType e of
Just (ExprAny r2 te) ->
case equalType r1 r2 of
Just Refl -> Just te
Nothing -> Nothing
Nothing -> Nothing
runTypeChecker :: Show a => ExprType a -> UntypedExpr -> IO ()
runTypeChecker expectedType expr = do
putStrLn $ "Type checking expression: " <> show expr
-- Get the type of the expression, if properly constructred
case inferType expr of
Just (ExprAny IntT _) -> putStrLn "Type: IntT"
Just (ExprAny BoolT _) -> putStrLn "Type: BoolT"
Nothing -> putStrLn "Expression is not valid"
-- Type check against an expected type
case typeCheck expectedType expr of
Nothing -> do
putStrLn $ "Doesn't match wanted type (" <> show expectedType <> ")"
putStrLn $ "Cannot generate a typed expression"
Just typedExpr -> do
putStrLn $ "Matches wanted type (" <> show expectedType <> ")"
putStrLn $ "Typed expression:"
print typedExpr
main :: IO ()
main = do
runTypeChecker IntT $ IsZero_ (Int_ 10)
putStrLn ""
runTypeChecker BoolT $ IsZero_ (Int_ 10)
putStrLn ""
runTypeChecker BoolT $ IsZero_ (Bool_ False)
Prints:
ghci> main
Type checking expression: IsZero_ (Int_ 10)
Type: BoolT
Doesn't match wanted type (IntT)
Type checking expression: IsZero_ (Int_ 10)
Type: BoolT
Matches wanted type (BoolT)
IsZero (Int 10)
Type checking expression: IsZero_ (Bool_ False)
Expression is not valid
Doesn't match wanted type (BoolT)