Skip to content
Snippets Groups Projects
Select Git revision
  • fb1f90a722c4b88c4255e9ff933b96ebc661bb09
  • master default protected
2 results

check.rs

Blame
  • check.rs 13.76 KiB
    use std::io::Write;
    
    use env_logger::Builder;
    use log::{trace, LevelFilter};
    use std::convert::From;
    
    use crate::ast::*;
    use crate::env::{new_fn_env, new_type_env, Error, FnEnv, IdType, TypeEnv, VarEnv};
    
    // type check
    
    impl Op {
        fn get_type(&self) -> Type {
            use Op::*;
            match self {
                Mul | Div | Add | Sub => Type::I32,
                _ => Type::Bool, // if not I32
            }
        }
    }
    
    impl From<&Op> for Type {
        fn from(op: &Op) -> Self {
            op.get_type().clone()
        }
    }
    
    impl From<&Param> for Type {
        fn from(p: &Param) -> Self {
            p.ty.clone()
        }
    }
    
    impl From<&FnDecl> for Type {
        fn from(f: &FnDecl) -> Self {
            f.result.clone()
        }
    }
    
    fn expr_type(
        e: &Expr,
        fn_env: &FnEnv,
        type_env: &TypeEnv,
        var_env: &VarEnv,
    ) -> Result<Type, Error> {
        use Expr::*;
        trace!("expr_type {}", e);
        match e {
            Num(_) => Ok(Type::I32),
            Bool(_) => Ok(Type::Bool),
            // Infix(l, op, r) => {
            //     let lt = expr_type(l, fn_env, type_env, var_env)?;
            //     let rt = expr_type(r, fn_env, type_env, var_env)?;
            //     use Op::*;
            //     match op {
            //         // Arithmetic and Boolean
            //         Add | Mul | Div | Sub | And | Or => {
            //             let opt = From::from(op);
    
            //             // check if op and args are of same type
            //             if lt == rt && lt == opt {
            //                 Ok((false, opt))
            //             } else {
            //                 Err(format!("left {} op {} right {}", lt, opt, rt))
            //             }
            //         }
            //         // Equality
            //         Eq | Neq => {
            //             let opt = From::from(op);
    
            //             // check if op and args are of same type
            //             if lt == rt && lt == opt {
            //                 Ok((false, Type::Bool))
            //             } else {
            //                 Err(format!("left {} op {} right {}", lt, opt, rt))
            //             }
            //         }
            //         // Comparison
            //         Less | Greater | LessEq | GreaterEq => {
            //             if lt == Type::I32 && rt == Type::I32 {
            //                 Ok((false, Type::Bool))
            //             } else {
            //                 Err(format!("left {} right {}", lt, rt))
            //             }
            //         }
    
            //         _ => panic!("ICE on {}", e),
            //     }
            // }
            // Prefix(op, r) => {
            //     let rt = expr_type(r, fn_env, type_env, var_env)?;
            //     let opt = op.get_type();
    
            //     // check if both of same type
            //     if rt == opt {
            //         Ok((false, opt))
            //     } else {
            //         Err(format!("op {} rt {}", opt, rt))
            //     }
            // }
            // Call(s, args) => {
            //     trace!("call {} with {}", s, args);
    
            //     let argt: Vec<Type> = args
            //         .clone()
            //         .0
            //         .into_iter()
            //         .map(|e| expr_type(&*e, fn_env, type_env, var_env))
            //         .collect::<Result<_, _>>()?;
    
            //     trace!("arg types {:?}", argt);
            //     let f = match fn_env.get(s.as_str()) {
            //         Some(f) => f,
            //         None => Err(format!("{} not found", s))?,
            //     };
    
            //     let part: Vec<Type> = (f.params.0.clone())
            //         .into_iter()
            //         .map(|p| From::from(&p))
            //         .collect();
    
            //     trace!(
            //         "fn to call {} with params {} and types {:?}",
            //         f,
            //         f.params,
            //         part
            //     );
    
            //     if argt == part {
            //         Ok((false, f.result.clone()))
            //     } else {
            //         Err(format!(
            //             "arguments types {:?} does not match parameter types {:?}",
            //             argt, part
            //         ))
            //     }
            // }
            Id(id) => match var_env.get(id.to_string()) {
                Some(t) => Ok(t.clone()),
                None => Err(format!("variable not found {}", id)),
            },
    
            As(_e, _t) => unimplemented!("here we implement explicit type cast"),
    
            // Convert Expr::Ref to Type::Ref
            Ref(ref_e) => {
                let t = expr_type(ref_e, fn_env, type_env, var_env)?;
                trace!("ref_e {}, t {}", ref_e, t);
                let t = match t {
                    Type::Mut(t) => t.clone(),
                    t => Box::new(t),
                };
                Ok(Type::Ref(t))
            }
    
            // Convert Expr::Mut to Type::Mut
            RefMut(ref_mut_e) => {
                let t = expr_type(ref_mut_e, fn_env, type_env, var_env)?;
                match t {
                    Type::Mut(_) => Ok(Type::Ref(Box::new(t))),
                    _ => Err(format!("{} is not mutable in {}", t, ref_mut_e)),
                }
            }
    
            DeRef(deref_e) => {
                let t = expr_type(deref_e, fn_env, type_env, var_env)?;
                trace!("deref_t {}", &t);
                let t = strip_mut(t);
                trace!("strip deref_t {}", &t);
                match t {
                    Type::Ref(dt) => Ok(*dt),
                    _ => Err(format!("cannot deref {} of type {}", e, t)),
                }
            }
            _ => unimplemented!(),
        }
    }
    
    fn strip_mut(t: Type) -> Type {
        match t {
            Type::Mut(t) => *t,
            _ => t,
        }
    }
    
    pub fn check_stmts(
        stmts: &Stmts,
        fn_env: &FnEnv,
        type_env: &TypeEnv,
        var_env: &mut VarEnv,
    ) -> Result<Type, Error> {
        use Stmt::*;
    
        var_env.push_empty_scope();
    
        let t = stmts.stmts.iter().try_fold(Type::Unit, |_, s| {
            trace!("stmt: {}", s);
            match s {
                Stmt::Block(b) => check_stmts(b, fn_env, type_env, var_env),
    
                Expr(e) => expr_type(&*e, fn_env, type_env, var_env),
    
                Let(var_id, is_mut, ot, oe) => {
                    let t: Type = match (ot, oe) {
                        (Some(t), Some(e)) => {
                            let e_type = expr_type(&*e, fn_env, type_env, var_env)?;
                            match *t == e_type {
                                true => t.clone(),
                                false => {
                                    trace!("e {}", e);
                                    Err(format!("incompatible types, {} <> {}", t, e_type))?
                                }
                            }
                        }
                        (None, Some(e)) => expr_type(&*e, fn_env, type_env, var_env)?,
                        (Some(t), None) => t.clone(),
                        _ => Type::Unknown,
                    };
                    let t = match is_mut {
                        true => Type::Mut(Box::new(t)),
                        false => t,
                    };
                    var_env.new_id(var_id.clone(), t);
                    trace!("var_env {:?}", var_env);
                    Ok(Type::Unit)
                }
    
                Assign(lh, e) => {
                    let e_type = expr_type(&*e, fn_env, type_env, var_env)?;
                    trace!("e_type = {}", e_type);
                    trace!("lh = {:?}", lh);
    
                    let lh_type = expr_type(lh, fn_env, type_env, var_env)?;
                    trace!("lh_type {}", &lh_type);
    
                    if match &lh_type {
                        Type::Mut(t) => **t == e_type,
    
                        _ => Err(format!("assignment to immutable"))?,
                    } {
                        Ok(Type::Unit)
                    } else {
                        Err(format!("cannot assign {} = {}", &lh_type, e_type))
                    }
                }
    
                While(e, block) => match expr_type(&*e, fn_env, type_env, var_env) {
                    Ok(Type::Bool) => {
                        let _ = check_stmts(&block, fn_env, type_env, var_env)?;
                        Ok(Type::Unit) // a while statement is of unit type;
                    }
                    _ => Err("Condition not Boolean".to_string()),
                },
    
                If(e, then, o_else) => match expr_type(&*e, fn_env, type_env, var_env)? {
                    Type::Bool => {
                        // The condition is of Bool type
                        let then_type = check_stmts(&then, fn_env, type_env, var_env)?;
                        trace!("then type {}", then_type);
                        match o_else {
                            None => Ok(then_type), // type of the arm
                            Some(else_stmts) => {
                                let else_type = check_stmts(&else_stmts, fn_env, type_env, var_env)?;
                                trace!("else type {}", else_type);
    
                                match then_type == else_type {
                                    true => Ok(then_type), // same type of both arms
                                    false => {
                                        trace!("error-----");
                                        Err(format!(
                                            "'then' arm :{} does not match 'else' arm :{}",
                                            then_type, else_type
                                        ))
                                    }
                                }
                            }
                        }
                    }
                    _ => Err("Condition not Boolean".to_string()),
                },
            }
        })?;
    
        var_env.pop_scope();
        if !stmts.ret {
            Ok(t.clone())
        } else {
            Ok(Type::Unit)
        }
    }
    
    #[test]
    fn test_stmts() {
        use crate::grammar::*;
        use crate::*;
        // use Type::*;
    
        // setup environment
        let p = ProgramParser::new().parse(prog1!()).unwrap();
        let fn_env = new_fn_env(&p.fn_decls);
        let type_env = new_type_env(&p.type_decls);
        let mut var_env = VarEnv::new();
    
        trace!("{}", &p);
    
        let b = fn_env.get(&"b").unwrap();
    
        let body = &b.body;
        trace!("{}", &body);
    
        trace!("{:?}", check_stmts(body, &fn_env, &type_env, &mut var_env));
    }
    
    pub fn build_env(p: &Program) -> (FnEnv, TypeEnv) {
        let fn_env = new_fn_env(&p.fn_decls);
        let type_env = new_type_env(&p.type_decls);
    
        (fn_env, type_env)
    }
    
    pub fn dump_env(fn_env: &FnEnv, type_env: &TypeEnv) {
        trace!("fn_env {:?}", fn_env);
        trace!("type_env {:?}", type_env);
    }
    
    // check a whole Program
    pub fn check(p: &Program) -> Result<(), Error> {
        let _ = Builder::new()
            .filter(None, LevelFilter::Trace)
            .format(|buf, record| writeln!(buf, "{}", record.args()))
            .format_timestamp(None)
            .try_init();
    
        let (fn_env, type_env) = build_env(&p);
    
        trace!("Input program \n{}", &p);
    
        for fd in p.fn_decls.iter() {
            trace!("function id {}", fd.id);
            let mut var_env = VarEnv::new();
    
            // build a scope for the arguments
            let mut arg_ty = IdType::new();
            for Param { is_mut, id, ty } in fd.params.0.iter() {
                // TODO move is_mut to Type in parser
                arg_ty.insert(id.to_owned(), ty.clone());
            }
    
            var_env.push_param_scope(arg_ty);
    
            let stmt_type = check_stmts(&fd.body, &fn_env, &type_env, &mut var_env);
            trace!("function id {}: {:?}", fd.id, stmt_type);
    
            if stmt_type? != fd.result {
                Err(format!("return value does not match statements"))?;
            }
        }
        Ok(())
    }
    
    // unit test
    
    #[test]
    fn test_expr_type() {
        use crate::grammar::*;
        use crate::*;
        use Type::*;
    
        // setup environment
        let p = ProgramParser::new().parse(prog1!()).unwrap();
        let fn_env = new_fn_env(&p.fn_decls);
        let type_env = new_type_env(&p.type_decls);
        let mut var_env = VarEnv::new();
        var_env.push_empty_scope();
    
        // // some test variables in scope
        var_env.new_id("i".to_string(), I32); // let i : i32 ...
        var_env.new_id("j".to_string(), Unknown); // let i ...
    
        println!("p {}", p);
    
        // type of number
        assert_eq!(
            expr_type(&Expr::Num(1), &fn_env, &type_env, &var_env),
            Ok(I32)
        );
    
        // type of variables
    
        // not found
        assert!(expr_type(&Expr::Id("a".to_string()), &fn_env, &type_env, &var_env).is_err());
    
        // let i: i32 ...
        assert_eq!(
            expr_type(&Expr::Id("i".to_string()), &fn_env, &type_env, &var_env),
            Ok(I32)
        );
    
        // let j ... (has no type yet)
        assert!(expr_type(&Expr::Id("j".to_string()), &fn_env, &type_env, &var_env).is_err());
    
        // let n: A ...
        assert_eq!(
            expr_type(&Expr::Id("n".to_string()), &fn_env, &type_env, &var_env),
            Ok(Named("A".to_string()))
        );
    
        // type of arithmetic operation (for now just i32)
        assert_eq!(
            expr_type(
                &*ExprParser::new().parse("1 + 2 - 5").unwrap(),
                &fn_env,
                &type_env,
                &var_env
            ),
            Ok(I32)
        );
    
        // type of arithmetic unary operation (for now just i32)
        assert_eq!(
            expr_type(
                &*ExprParser::new().parse("- 5").unwrap(),
                &fn_env,
                &type_env,
                &var_env
            ),
            Ok(I32)
        );
    
        // call, with check, ok
        assert_eq!(
            expr_type(
                &*ExprParser::new().parse("b(1)").unwrap(),
                &fn_env,
                &type_env,
                &var_env
            ),
            Ok(I32)
        );
    
        // call, with check, ok (i: i32)
        assert_eq!(
            expr_type(
                &*ExprParser::new().parse("b(i)").unwrap(),
                &fn_env,
                &type_env,
                &var_env
            ),
            Ok(I32)
        );
    
        // call, with check, error wrong number args
        assert!(expr_type(
            &*ExprParser::new().parse("b(1, 2)").unwrap(),
            &fn_env,
            &type_env,
            &var_env
        )
        .is_err());
    
        // call, with check, error type of arg
        assert!(expr_type(
            &*ExprParser::new().parse("b(true)").unwrap(),
            &fn_env,
            &type_env,
            &var_env
        )
        .is_err());
    
        // call, with check, ok (i: i32)
        assert_eq!(
            expr_type(
                &*ExprParser::new().parse("c(n)").unwrap(),
                &fn_env,
                &type_env,
                &var_env
            ),
            Ok(Unit)
        );
    
        // TODO, ref/ref mut/deref
    }