From 7e92be3d91c81d374a9af826e8ce37137be233d0 Mon Sep 17 00:00:00 2001
From: Per Lindgren <per.lindgren@ltu.se>
Date: Thu, 17 Sep 2020 11:16:01 +0200
Subject: [PATCH] check binaray and unary ops

---
 examples/deref_assign2.rs |   2 +-
 src/ast.rs                |   2 +-
 src/check.rs              | 199 +++++++++++++++++++++-----------------
 src/env.rs                |   1 -
 src/grammar.lalrpop       |   2 +-
 tests/test_check.rs       |   5 +
 6 files changed, 118 insertions(+), 93 deletions(-)

diff --git a/examples/deref_assign2.rs b/examples/deref_assign2.rs
index 53a1872..968723a 100644
--- a/examples/deref_assign2.rs
+++ b/examples/deref_assign2.rs
@@ -1,6 +1,6 @@
 fn main() {
     let mut a = 7;
-    let mut b = &a;
+    let mut b = &mut a;
     let c = &mut b;
     *(*c) = 9;
     let d = a;
diff --git a/src/ast.rs b/src/ast.rs
index c54d740..3afa200 100644
--- a/src/ast.rs
+++ b/src/ast.rs
@@ -300,7 +300,7 @@ impl Display for Type {
             I32 => write!(fmt, "i32"),
             Ref(t) => write!(fmt, "&{}", t),
             Mut(t) => write!(fmt, "mut {}", t),
-            _ => unimplemented!(),
+            Unknown => write!(fmt, "Unknown"),
         }
     }
 }
diff --git a/src/check.rs b/src/check.rs
index 0fe868d..8ebb842 100644
--- a/src/check.rs
+++ b/src/check.rs
@@ -48,93 +48,105 @@ fn expr_type(
     match e {
         Num(_) => Ok(Type::I32),
         Bool(_) => Ok(Type::Bool),
-        // Infix(l, op, r) => {
-        //     let lt = expr_type(l, fn_env, type_env, var_env)?;
-        //     let rt = expr_type(r, fn_env, type_env, var_env)?;
-        //     use Op::*;
-        //     match op {
-        //         // Arithmetic and Boolean
-        //         Add | Mul | Div | Sub | And | Or => {
-        //             let opt = From::from(op);
-
-        //             // check if op and args are of same type
-        //             if lt == rt && lt == opt {
-        //                 Ok((false, opt))
-        //             } else {
-        //                 Err(format!("left {} op {} right {}", lt, opt, rt))
-        //             }
-        //         }
-        //         // Equality
-        //         Eq | Neq => {
-        //             let opt = From::from(op);
-
-        //             // check if op and args are of same type
-        //             if lt == rt && lt == opt {
-        //                 Ok((false, Type::Bool))
-        //             } else {
-        //                 Err(format!("left {} op {} right {}", lt, opt, rt))
-        //             }
-        //         }
-        //         // Comparison
-        //         Less | Greater | LessEq | GreaterEq => {
-        //             if lt == Type::I32 && rt == Type::I32 {
-        //                 Ok((false, Type::Bool))
-        //             } else {
-        //                 Err(format!("left {} right {}", lt, rt))
-        //             }
-        //         }
-
-        //         _ => panic!("ICE on {}", e),
-        //     }
-        // }
-        // Prefix(op, r) => {
-        //     let rt = expr_type(r, fn_env, type_env, var_env)?;
-        //     let opt = op.get_type();
-
-        //     // check if both of same type
-        //     if rt == opt {
-        //         Ok((false, opt))
-        //     } else {
-        //         Err(format!("op {} rt {}", opt, rt))
-        //     }
-        // }
-        // Call(s, args) => {
-        //     trace!("call {} with {}", s, args);
-
-        //     let argt: Vec<Type> = args
-        //         .clone()
-        //         .0
-        //         .into_iter()
-        //         .map(|e| expr_type(&*e, fn_env, type_env, var_env))
-        //         .collect::<Result<_, _>>()?;
-
-        //     trace!("arg types {:?}", argt);
-        //     let f = match fn_env.get(s.as_str()) {
-        //         Some(f) => f,
-        //         None => Err(format!("{} not found", s))?,
-        //     };
-
-        //     let part: Vec<Type> = (f.params.0.clone())
-        //         .into_iter()
-        //         .map(|p| From::from(&p))
-        //         .collect();
-
-        //     trace!(
-        //         "fn to call {} with params {} and types {:?}",
-        //         f,
-        //         f.params,
-        //         part
-        //     );
-
-        //     if argt == part {
-        //         Ok((false, f.result.clone()))
-        //     } else {
-        //         Err(format!(
-        //             "arguments types {:?} does not match parameter types {:?}",
-        //             argt, part
-        //         ))
-        //     }
-        // }
+        Infix(l, op, r) => {
+            let lt = expr_type(l, fn_env, type_env, var_env)?;
+            let rt = expr_type(r, fn_env, type_env, var_env)?;
+
+            let lt = strip_mut(lt);
+            let rt = strip_mut(rt);
+
+            use Op::*;
+            match op {
+                // Arithmetic and Boolean
+                Add | Mul | Div | Sub | And | Or => {
+                    // check if op and args are of i32
+                    if lt == Type::I32 && rt == Type::I32 {
+                        Ok(Type::I32)
+                    } else {
+                        Err(format!(
+                            "Num operation requires i32, left {}  right {}",
+                            lt, rt
+                        ))
+                    }
+                }
+                // Equality
+                Eq | Neq => {
+                    // check if args are of same type
+                    if lt == rt {
+                        Ok(Type::Bool)
+                    } else {
+                        Err(format!(
+                            "Comparison requires operands of same type, left {}, right {}",
+                            lt, rt
+                        ))
+                    }
+                }
+                // Comparison
+                Less | Greater | LessEq | GreaterEq => {
+                    if lt == Type::I32 && rt == Type::I32 {
+                        // check if args are of i32
+                        Ok(Type::Bool)
+                    } else {
+                        Err(format!(
+                            "Comparison requires operands of same type, left {}, right {}",
+                            lt, rt
+                        ))
+                    }
+                }
+                _ => panic!("ICE on {}", e),
+            }
+        }
+
+        Prefix(op, r) => {
+            let rt = expr_type(r, fn_env, type_env, var_env)?;
+            let opt = op.get_type();
+
+            // check if both of same type
+            if rt == opt {
+                Ok(opt)
+            } else {
+                Err(format!("op {} rt {}", opt, rt))
+            }
+        }
+
+        Call(s, args) => {
+            trace!("call {} with {}", s, args);
+
+            let argt: Vec<Type> = args
+                .clone()
+                .0
+                .into_iter()
+                .map(|e| expr_type(&*e, fn_env, type_env, var_env))
+                .collect::<Result<_, _>>()?;
+
+            trace!("arg types {:?}", argt);
+            let f = match fn_env.get(s.as_str()) {
+                Some(f) => f,
+                None => Err(format!("{} not found", s))?,
+            };
+
+            let part: Vec<Type> = (f.params.0.clone())
+                .into_iter()
+                .map(|p| From::from(&p))
+                .collect();
+
+            trace!(
+                "fn to call {} with params {} and types {:?}",
+                f,
+                f.params,
+                part
+            );
+
+            if argt == part {
+                Ok(f.result.clone())
+            } else {
+                Err(format!(
+                    "arguments types {:?} does not match parameter types {:?}",
+                    argt, part
+                ))
+            }
+        }
+
         Id(id) => match var_env.get(id.to_string()) {
             Some(t) => Ok(t.clone()),
             None => Err(format!("variable not found {}", id)),
@@ -172,7 +184,6 @@ fn expr_type(
                 _ => Err(format!("cannot deref {} of type {}", e, t)),
             }
         }
-        _ => unimplemented!(),
     }
 }
 
@@ -234,7 +245,17 @@ pub fn check_stmts(
                 trace!("lh_type {}", &lh_type);
 
                 if match &lh_type {
-                    Type::Mut(t) => **t == e_type,
+                    Type::Unknown => {
+                        trace!("assign to unknown");
+                        true
+                    }
+                    Type::Mut(t) => match **t {
+                        Type::Unknown => {
+                            trace!("assign to `mut Unknown`");
+                            true
+                        }
+                        _ => **t == e_type,
+                    },
 
                     _ => Err(format!("assignment to immutable"))?,
                 } {
diff --git a/src/env.rs b/src/env.rs
index c2f4aba..86fdfbe 100644
--- a/src/env.rs
+++ b/src/env.rs
@@ -53,7 +53,6 @@ impl VarEnv {
     }
 
     pub fn update(&mut self, id: String, ty: Type) -> Result<(), Error> {
-        use Type::*;
         match self.get_mut(id.clone()) {
             Some(ot) => match ot {
                 Type::Unknown => {
diff --git a/src/grammar.lalrpop b/src/grammar.lalrpop
index adee04c..134ef77 100644
--- a/src/grammar.lalrpop
+++ b/src/grammar.lalrpop
@@ -218,5 +218,5 @@ Num: i32 = {
 };
 
 Id: String = {
-    r"([a-z]|[A-Z])([a-z]|[A-Z]|[0-9]|_)*" => String::from_str(<>).unwrap(),
+    r"([a-z]|[A-Z]|_)([a-z]|[A-Z]|[0-9]|_)*" => String::from_str(<>).unwrap(),
 };
diff --git a/tests/test_check.rs b/tests/test_check.rs
index 844b0b1..d120948 100644
--- a/tests/test_check.rs
+++ b/tests/test_check.rs
@@ -1,6 +1,11 @@
 use erode::check::check;
 use erode::read_file;
 
+#[test]
+fn check_w1_2() {
+    assert!(check(&read_file::parse("examples/w1_2.rs")).is_ok());
+}
+
 #[test]
 fn check_prog() {
     check(&read_file::parse("examples/minimal.rs")).unwrap();
-- 
GitLab