Skip to content
Snippets Groups Projects
Select Git revision
  • ab54bad0e5751c0b314bc7e8a9681faf151dadcc
  • master default
2 results

main.rs

Blame
  • check.rs 18.33 KiB
    use std::io::Write;
    
    use env_logger::Builder;
    use log::{trace, LevelFilter};
    use std::convert::From;
    
    use crate::ast::*;
    use crate::env::{new_fn_env, new_type_env, Error, FnEnv, IdType, TypeEnv, VarEnv};
    
    // type check
    
    impl Op {
        fn get_type(&self) -> Type {
            use Op::*;
            match self {
                Mul | Div | Add | Sub => Type::I32,
                _ => Type::Bool, // if not I32
            }
        }
    }
    
    impl From<&Op> for Type {
        fn from(op: &Op) -> Self {
            op.get_type().clone()
        }
    }
    
    impl From<&Param> for Type {
        fn from(p: &Param) -> Self {
            p.ty.clone()
        }
    }
    
    impl From<&FnDecl> for Type {
        fn from(f: &FnDecl) -> Self {
            f.result.clone()
        }
    }
    
    fn expr_type(
        e: &Expr,
        fn_env: &FnEnv,
        type_env: &TypeEnv,
        var_env: &VarEnv,
    ) -> Result<Type, Error> {
        use Expr::*;
        trace!("expr_type {}", e);
        match e {
            Num(_) => Ok(Type::I32),
            Bool(_) => Ok(Type::Bool),
            // Infix(l, op, r) => {
            //     let lt = expr_type(l, fn_env, type_env, var_env)?;
            //     let rt = expr_type(r, fn_env, type_env, var_env)?;
            //     use Op::*;
            //     match op {
            //         // Arithmetic and Boolean
            //         Add | Mul | Div | Sub | And | Or => {
            //             let opt = From::from(op);
    
            //             // check if op and args are of same type
            //             if lt == rt && lt == opt {
            //                 Ok((false, opt))
            //             } else {
            //                 Err(format!("left {} op {} right {}", lt, opt, rt))
            //             }
            //         }
            //         // Equality
            //         Eq | Neq => {
            //             let opt = From::from(op);
    
            //             // check if op and args are of same type
            //             if lt == rt && lt == opt {
            //                 Ok((false, Type::Bool))
            //             } else {
            //                 Err(format!("left {} op {} right {}", lt, opt, rt))
            //             }
            //         }
            //         // Comparison
            //         Less | Greater | LessEq | GreaterEq => {
            //             if lt == Type::I32 && rt == Type::I32 {
            //                 Ok((false, Type::Bool))
            //             } else {
            //                 Err(format!("left {} right {}", lt, rt))
            //             }
            //         }
    
            //         _ => panic!("ICE on {}", e),
            //     }
            // }
            // Prefix(op, r) => {
            //     let rt = expr_type(r, fn_env, type_env, var_env)?;
            //     let opt = op.get_type();
    
            //     // check if both of same type
            //     if rt == opt {
            //         Ok((false, opt))
            //     } else {
            //         Err(format!("op {} rt {}", opt, rt))
            //     }
            // }
            // Call(s, args) => {
            //     trace!("call {} with {}", s, args);
    
            //     let argt: Vec<Type> = args
            //         .clone()
            //         .0
            //         .into_iter()
            //         .map(|e| expr_type(&*e, fn_env, type_env, var_env))
            //         .collect::<Result<_, _>>()?;
    
            //     trace!("arg types {:?}", argt);
            //     let f = match fn_env.get(s.as_str()) {
            //         Some(f) => f,
            //         None => Err(format!("{} not found", s))?,
            //     };
    
            //     let part: Vec<Type> = (f.params.0.clone())
            //         .into_iter()
            //         .map(|p| From::from(&p))
            //         .collect();
    
            //     trace!(
            //         "fn to call {} with params {} and types {:?}",
            //         f,
            //         f.params,
            //         part
            //     );
    
            //     if argt == part {
            //         Ok((false, f.result.clone()))
            //     } else {
            //         Err(format!(
            //             "arguments types {:?} does not match parameter types {:?}",
            //             argt, part
            //         ))
            //     }
            // }
            Id(id) => match var_env.get(id.to_string()) {
                Some(t) => match t {
                    // variable bound to type
                    Some(t) => Ok(t.clone()),
                    // not yet bound
                    None => Err(format!("variable {:?} has no type yet", id)),
                },
                None => Err(format!("variable not found {}", id)),
            },
    
            As(_e, _t) => unimplemented!("here we implement explicit type cast"),
    
            // Convert Expr::Ref to Type::Ref
            Ref(ref_e) => {
                let t = expr_type(ref_e, fn_env, type_env, var_env)?;
                Ok(Type::Ref(Box::new(t)))
            }
    
            // Convert Expr::Mut to Type::Mut
            RefMut(ref_mut_e) => {
                let t = expr_type(ref_mut_e, fn_env, type_env, var_env)?;
                Ok(Type::Mut(Box::new(t)))
            }
    
            DeRef(deref_e) => {
                let t = expr_type(deref_e, fn_env, type_env, var_env)?;
                trace!("deref_t {}", &t);
                match t {
                    Type::Ref(dt) => Ok(*dt),
                    _ => Err(format!("cannot deref {} of type {}", e, t)),
                }
            }
            _ => unimplemented!(),
        }
    }
    
    // // walk the left hand expression in an assignment.
    // // Ok(Type), the type of the left hand
    // // Err(), illegal left hand
    // fn left_expr_type(e: &Expr, to_type: Type, var_env: &mut VarEnv) -> Result<(), Error> {
    //     match e {
    //         Expr::Id(id) => {
    //             match var_env.get(id.clone()) {
    //                 Some((true, Some(t))) => {
    //                     match *t == to_type {
    //                         true => Ok(()), // mutable, defined of same type
    //                         false => Err("types differ".to_string()),
    //                     }
    //                 }
    //                 Some((true, None)) => unimplemented!(),
    //                 Some((false, _)) => Err("variable declared immutable".to_string()),
    //                 None => Err("variable not in scope".to_string()),
    //             }
    //         }
    //         Expr::DeRef(ref_mut_e) => {
    //             let e: &Expr = &*ref_mut_e;
    //             trace!("DeRef({})", e);
    //             match e {
    //                 Expr::RefMut(e) => left_expr_type(&*e, to_type, var_env),
    //                 _ => Err("cannot deref left hand expression".to_string()),
    //             }
    //             // }
    //             // let ref_mut_id = left_expr_type(ref_mut_e, var_env)?;
    //             // trace!("left: RefMut({})", ref_mut_id);
    
    //             // match var_env.get(ref_mut_id) {
    //             //     Some(Val::Ref(id)) => {
    //             //         println!("Ref id {}", id);
    //             //         id.clone()
    //             //     }
    //             //     Some(Val::RefMut(id)) => {
    //             //         println!("RefMut id {}", id);
    //             //         id.clone()
    //             //     }
    //             //     _ => panic!("deref failed"),
    //             // }
    
    //             // Ok(ref_mut_id)
    //         }
    
    //         // let ev = eval_left_expr(e, m, fn_env);
    //         // println!("eval_left {:?}", ev);
    
    //         // match m.get(ev) {
    //         //     Some(Val::Ref(id)) => {
    //         //         println!("Ref id {}", id);
    //         //         id.clone()
    //         //     }
    //         //     Some(Val::RefMut(id)) => {
    //         //         println!("RefMut id {}", id);
    //         //         id.clone()
    //         //     }
    //         //     _ => panic!("deref failed"),
    //         // }
    //         _ => Err(format!("illegal left hand expression {}", e)),
    //     }
    // }
    
    // walk the left hand expression in an assignment.
    // Ok(Type), the type of the left hand
    // Err(), illegal left hand
    fn left_expr_type(e: &Expr, var_env: &mut VarEnv) -> Result<Id, Error> {
        match e {
            Expr::Id(id) => Ok(id.to_owned()),
            //     match var_env.get(id.clone()) {
            //         Some((true, Some(t))) => {
            //             match *t == to_type {
            //                 true => Ok(()), // mutable, defined of same type
            //                 false => Err("types differ".to_string()),
            //             }
            //         }
            //         Some((true, None)) => unimplemented!(),
            //         Some((false, _)) => Err("variable declared immutable".to_string()),
            //         None => Err("variable not in scope".to_string()),
            //     }
            // }
            Expr::DeRef(ref_mut_e) => {
                trace!("DeRef({})", ref_mut_e);
                let lh = left_expr_type(ref_mut_e, var_env)?;
                trace!("DeRef({})", lh);
                match var_env.get(lh) {
                    Some(t) => {
                        trace!("t {:?}", t);
                        panic!();
                    }
                    //     match *t == to_type {
                    //         true => Ok(()), // mutable, defined of same type
                    //         false => Err("types differ".to_string()),
                    //     }
                    // }
                    // Some((true, None)) => unimplemented!(),
                    // Some((false, _)) => Err("variable declared immutable".to_string()),
                    None => Err("variable not in scope".to_string()),
                }
            }
    
            //     match e {
            //         Expr::RefMut(e) => left_expr_type(&*e, to_type, var_env),
            //         _ => Err("cannot deref left hand expression".to_string()),
            //     }
            // }
            //  println!("eval_left deref e {:?}", e);
    
            //     let ev = eval_left_expr(e, m, fn_env);
            //     println!("eval_left {:?}", ev);
    
            //     match m.get(ev) {
            //         Some(Val::Ref(id)) => {
            //             println!("Ref id {}", id);
            //             id.clone()
            //         }
            //         Some(Val::RefMut(id)) => {
            //             println!("RefMut id {}", id);
            //             id.clone()
            //         }
            //         _ => panic!("deref failed"),
            //     }
    
            // Ok(ref_mut_id)
            // }
    
            // let ev = eval_left_expr(e, m, fn_env);
            // println!("eval_left {:?}", ev);
    
            // match m.get(ev) {
            //     Some(Val::Ref(id)) => {
            //         println!("Ref id {}", id);
            //         id.clone()
            //     }
            //     Some(Val::RefMut(id)) => {
            //         println!("RefMut id {}", id);
            //         id.clone()
            //     }
            //     _ => panic!("deref failed"),
            // }
            _ => Err(format!("illegal left hand expression {}", e)),
        }
    }
    
    pub fn check_stmts(
        stmts: &Stmts,
        fn_env: &FnEnv,
        type_env: &TypeEnv,
        var_env: &mut VarEnv,
    ) -> Result<Type, Error> {
        use Stmt::*;
    
        var_env.push_empty_scope();
    
        let t = stmts.stmts.iter().try_fold(Type::Unit, |_, s| {
            trace!("stmt: {}", s);
            match s {
                Stmt::Block(b) => check_stmts(b, fn_env, type_env, var_env),
    
                Expr(e) => expr_type(&*e, fn_env, type_env, var_env),
    
                Let(var_id, is_mut, ot, oe) => {
                    let ot: Option<Type> = match (ot, oe) {
                        (Some(t), Some(e)) => {
                            let e_type = expr_type(&*e, fn_env, type_env, var_env)?;
                            match *t == e_type {
                                true => Some(t.clone()),
                                false => Err("incompatible types")?,
                            }
                        }
                        (None, Some(e)) => Some(expr_type(&*e, fn_env, type_env, var_env)?),
                        (ot, None) => ot.clone(),
                    };
                    let ot = match is_mut {
                        true => Type::Mut(Box::new(ot)),
                        false => ot,
                    };
                    var_env.new_id(var_id.clone(), ot);
                    trace!("var_env {:?}", var_env);
                    Ok(Type::Unit)
                }
    
                Assign(lh, e) => {
                    let expr_type = expr_type(&*e, fn_env, type_env, var_env)?;
                    trace!("expr_type = {}", expr_type);
                    trace!("v = {:?}", lh);
    
                    let id = left_expr_type(&*lh, var_env)?;
                    //  var_env.update(id, expr_type)?;
                    Ok(Type::Unit)
                }
    
                While(e, block) => match expr_type(&*e, fn_env, type_env, var_env) {
                    Ok(Type::Bool) => {
                        let _ = check_stmts(&block, fn_env, type_env, var_env)?;
                        Ok(Type::Unit) // a while statement is of unit type;
                    }
                    _ => Err("Condition not Boolean".to_string()),
                },
    
                If(e, then, o_else) => match expr_type(&*e, fn_env, type_env, var_env)? {
                    Type::Bool => {
                        // The condition is of Bool type
                        let then_type = check_stmts(&then, fn_env, type_env, var_env)?;
                        trace!("then type {}", then_type);
                        match o_else {
                            None => Ok(then_type), // type of the arm
                            Some(else_stmts) => {
                                let else_type = check_stmts(&else_stmts, fn_env, type_env, var_env)?;
                                trace!("else type {}", else_type);
    
                                match then_type == else_type {
                                    true => Ok(then_type), // same type of both arms
                                    false => {
                                        trace!("error-----");
                                        Err(format!(
                                            "'then' arm :{} does not match 'else' arm :{}",
                                            then_type, else_type
                                        ))
                                    }
                                }
                            }
                        }
                    }
                    _ => Err("Condition not Boolean".to_string()),
                },
            }
        })?;
    
        var_env.pop_scope();
        if !stmts.ret {
            Ok(t.clone())
        } else {
            Ok(Type::Unit)
        }
    }
    
    #[test]
    fn test_stmts() {
        use crate::grammar::*;
        use crate::*;
        // use Type::*;
    
        // setup environment
        let p = ProgramParser::new().parse(prog1!()).unwrap();
        let fn_env = new_fn_env(&p.fn_decls);
        let type_env = new_type_env(&p.type_decls);
        let mut var_env = VarEnv::new();
    
        trace!("{}", &p);
    
        let b = fn_env.get(&"b").unwrap();
    
        let body = &b.body;
        trace!("{}", &body);
    
        trace!("{:?}", check_stmts(body, &fn_env, &type_env, &mut var_env));
    }
    
    pub fn build_env(p: &Program) -> (FnEnv, TypeEnv) {
        let fn_env = new_fn_env(&p.fn_decls);
        let type_env = new_type_env(&p.type_decls);
    
        (fn_env, type_env)
    }
    
    pub fn dump_env(fn_env: &FnEnv, type_env: &TypeEnv) {
        trace!("fn_env {:?}", fn_env);
        trace!("type_env {:?}", type_env);
    }
    
    // check a whole Program
    pub fn check(p: &Program) -> Result<(), Error> {
        let _ = Builder::new()
            .filter(None, LevelFilter::Trace)
            .format(|buf, record| writeln!(buf, "{}", record.args()))
            .format_timestamp(None)
            .try_init();
    
        let (fn_env, type_env) = build_env(&p);
    
        trace!("Input program \n{}", &p);
    
        for fd in p.fn_decls.iter() {
            trace!("function id {}", fd.id);
            let mut var_env = VarEnv::new();
    
            // build a scope for the arguments
            let mut arg_ty = IdType::new();
            for Param { is_mut, id, ty } in fd.params.0.iter() {
                // TODO move is_mut to Type
                arg_ty.insert(id.to_owned(), Some(ty.clone()));
            }
    
            var_env.push_param_scope(arg_ty);
    
            let stmt_type = check_stmts(&fd.body, &fn_env, &type_env, &mut var_env);
            trace!("function id {}: {:?}", fd.id, stmt_type);
    
            if stmt_type? != fd.result {
                Err(format!("return value does not match statements"))?;
            }
        }
        Ok(())
    }
    
    // unit test
    
    #[test]
    fn test_expr_type() {
        use crate::grammar::*;
        use crate::*;
        use Type::*;
    
        // setup environment
        let p = ProgramParser::new().parse(prog1!()).unwrap();
        let fn_env = new_fn_env(&p.fn_decls);
        let type_env = new_type_env(&p.type_decls);
        let mut var_env = VarEnv::new();
        var_env.push_empty_scope();
    
        // // some test variables in scope
        var_env.new_id("i".to_string(), Some(I32)); // let i : i32 ...
        var_env.new_id("j".to_string(), None); // let i ...
        var_env.new_id("n".to_string(), Some(Type::Named("A".to_string()))); // let n : A ...
    
        println!("p {}", p);
    
        // type of number
        assert_eq!(
            expr_type(&Expr::Num(1), &fn_env, &type_env, &var_env),
            Ok(I32)
        );
    
        // type of variables
    
        // not found
        assert!(expr_type(&Expr::Id("a".to_string()), &fn_env, &type_env, &var_env).is_err());
    
        // let i: i32 ...
        assert_eq!(
            expr_type(&Expr::Id("i".to_string()), &fn_env, &type_env, &var_env),
            Ok(I32)
        );
    
        // let j ... (has no type yet)
        assert!(expr_type(&Expr::Id("j".to_string()), &fn_env, &type_env, &var_env).is_err());
    
        // let n: A ...
        assert_eq!(
            expr_type(&Expr::Id("n".to_string()), &fn_env, &type_env, &var_env),
            Ok(Named("A".to_string()))
        );
    
        // type of arithmetic operation (for now just i32)
        assert_eq!(
            expr_type(
                &*ExprParser::new().parse("1 + 2 - 5").unwrap(),
                &fn_env,
                &type_env,
                &var_env
            ),
            Ok(I32)
        );
    
        // type of arithmetic unary operation (for now just i32)
        assert_eq!(
            expr_type(
                &*ExprParser::new().parse("- 5").unwrap(),
                &fn_env,
                &type_env,
                &var_env
            ),
            Ok(I32)
        );
    
        // call, with check, ok
        assert_eq!(
            expr_type(
                &*ExprParser::new().parse("b(1)").unwrap(),
                &fn_env,
                &type_env,
                &var_env
            ),
            Ok(I32)
        );
    
        // call, with check, ok (i: i32)
        assert_eq!(
            expr_type(
                &*ExprParser::new().parse("b(i)").unwrap(),
                &fn_env,
                &type_env,
                &var_env
            ),
            Ok(I32)
        );
    
        // call, with check, error wrong number args
        assert!(expr_type(
            &*ExprParser::new().parse("b(1, 2)").unwrap(),
            &fn_env,
            &type_env,
            &var_env
        )
        .is_err());
    
        // call, with check, error type of arg
        assert!(expr_type(
            &*ExprParser::new().parse("b(true)").unwrap(),
            &fn_env,
            &type_env,
            &var_env
        )
        .is_err());
    
        // call, with check, ok (i: i32)
        assert_eq!(
            expr_type(
                &*ExprParser::new().parse("c(n)").unwrap(),
                &fn_env,
                &type_env,
                &var_env
            ),
            Ok(Unit)
        );
    
        // TODO, ref/ref mut/deref
    }