Created: 2023-02-06 17:54

{-# LANGUAGE GADTs #-}
{-# LANGUAGE KindSignatures #-}
module Main (main) where
 
import Data.Type.Equality
 
data UntypedExpr
  = Int_ Int
  | Bool_ Bool
  | IsZero_ UntypedExpr
  | Plus_ UntypedExpr UntypedExpr
  | If_ UntypedExpr UntypedExpr UntypedExpr
  deriving (Show)
 
data Expr :: * -> * where
    Int :: Int -> Expr Int
    Bool :: Bool -> Expr Bool
    IsZero :: Expr Int -> Expr Bool
    Plus :: Expr Int -> Expr Int -> Expr Int
    If :: Expr Bool -> Expr a -> Expr a -> Expr a
 
deriving instance Show a => Show (Expr a)
 
data ExprType :: * -> * where
    IntT :: ExprType Int
    BoolT :: ExprType Bool
 
deriving instance Show a => Show (ExprType a)
 
data ExprAny :: * where
    ExprAny :: ExprType a -> Expr a -> ExprAny
 
inferType :: UntypedExpr -> Maybe ExprAny
inferType (Int_ n) = Just $ ExprAny IntT (Int n)
inferType (Bool_ n) = Just $ ExprAny BoolT (Bool n)
inferType (IsZero_ (Int_ n)) = Just $ ExprAny BoolT (IsZero (Int n))
inferType (IsZero_ (Plus_ (Int_ n) (Int_ m))) =
  Just $ ExprAny BoolT (IsZero (Plus (Int n) (Int m)))
inferType (Plus_ (Int_ n) (Int_ m)) =
  Just $ ExprAny IntT (Plus (Int n) (Int m))
inferType (If_ (Bool_ n) a b) =
  case (inferType a, inferType b) of
    (Just (ExprAny typeA a_), Just (ExprAny typeB b_)) ->
      case equalType typeA typeB of
        Just Refl -> Just $ ExprAny typeA (If (Bool n) a_ b_)
        Nothing -> Nothing
    _ -> Nothing
inferType _ = Nothing
 
equalType :: ExprType a -> ExprType b -> Maybe (a :~: b)
equalType IntT IntT = Just Refl
equalType BoolT BoolT = Just Refl
equalType _ _ = Nothing
 
typeCheck :: ExprType a -> UntypedExpr -> Maybe (Expr a)
typeCheck r1 e = case inferType e of
  Just (ExprAny r2 te) ->
    case equalType r1 r2 of
      Just Refl -> Just te
      Nothing -> Nothing
  Nothing -> Nothing
 
runTypeChecker :: Show a => ExprType a -> UntypedExpr -> IO ()
runTypeChecker expectedType expr = do
  putStrLn $ "Type checking expression: " <> show expr
 
  -- Get the type of the expression, if properly constructred
  case inferType expr of
    Just (ExprAny IntT _) -> putStrLn "Type: IntT"
    Just (ExprAny BoolT _) -> putStrLn "Type: BoolT"
    Nothing -> putStrLn "Expression is not valid"
 
  -- Type check against an expected type
  case typeCheck expectedType expr of
    Nothing -> do
      putStrLn $ "Doesn't match wanted type (" <> show expectedType <> ")"
      putStrLn $ "Cannot generate a typed expression"
    Just typedExpr -> do 
      putStrLn $ "Matches wanted type (" <> show expectedType <> ")"
      putStrLn $ "Typed expression:"
      print typedExpr
 
main :: IO ()
main = do
  runTypeChecker IntT  $ IsZero_ (Int_ 10)
  putStrLn ""
  runTypeChecker BoolT $ IsZero_ (Int_ 10)
  putStrLn ""
  runTypeChecker BoolT $ IsZero_ (Bool_ False)
 

Prints:

ghci> main
Type checking expression: IsZero_ (Int_ 10)
Type: BoolT
Doesn't match wanted type (IntT)

Type checking expression: IsZero_ (Int_ 10)
Type: BoolT
Matches wanted type (BoolT)
IsZero (Int 10)

Type checking expression: IsZero_ (Bool_ False)
Expression is not valid
Doesn't match wanted type (BoolT)