diff --git a/drivers/c.drv b/drivers/c.drv index 42599d0e7831bdd61e8811bd63a09a95a6d796d5..444b1b5f4545362ff0e09aa2a1edd24ea350aa55 100644 --- a/drivers/c.drv +++ b/drivers/c.drv @@ -27,17 +27,17 @@ module mach.int.Int32 syntax type int32 "int32_t" - syntax val (+) "(%1) + (%2)" - syntax val (-) "(%1) - (%2)" - syntax val (-_) "-(%1)" - syntax val ( * ) "(%1) * (%2)" - syntax val (/) "(%1) / (%2)" - syntax val (%) "(%1) % (%2)" - syntax val (=) "(%1) == (%2)" - syntax val (<=) "(%1) <= (%2)" - syntax val (<) "(%1) < (%2)" - syntax val (>=) "(%1) >= (%2)" - syntax val (>) "(%1) > (%2)" + syntax val (+) "%1 + %2" + syntax val (-) "%1 - %2" + syntax val (-_) "-%1" + syntax val ( * ) "%1 * %2" + syntax val (/) "%1 / %2" + syntax val (%) "%1 % %2" + syntax val (=) "%1 == %2" + syntax val (<=) "%1 <= %2" + syntax val (<) "%1 < %2" + syntax val (>=) "%1 >= %2" + syntax val (>) "%1 > %2" end module mach.int.UInt32Gen @@ -54,16 +54,16 @@ module mach.int.UInt32 syntax converter of_int "%1U" - syntax val (+) "(%1) + (%2)" - syntax val (-) "(%1) - (%2)" - syntax val ( * ) "(%1) * (%2)" - syntax val (/) "(%1) / (%2)" - syntax val (%) "(%1) % (%2)" - syntax val (=) "(%1) == (%2)" - syntax val (<=) "(%1) <= (%2)" - syntax val (<) "(%1) < (%2)" - syntax val (>=) "(%1) >= (%2)" - syntax val (>) "(%1) > (%2)" + syntax val (+) "%1 + %2" + syntax val (-) "%1 - %2" + syntax val ( * ) "%1 * %2" + syntax val (/) "%1 / %2" + syntax val (%) "%1 % %2" + syntax val (=) "%1 == %2" + syntax val (<=) "%1 <= %2" + syntax val (<) "%1 < %2" + syntax val (>=) "%1 >= %2" + syntax val (>) "%1 > %2" end @@ -185,16 +185,16 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt); syntax converter of_int "%1U" - syntax val (+) "(%1) + (%2)" - syntax val (-) "(%1) - (%2)" - syntax val ( * ) "(%1) * (%2)" - syntax val (/) "(%1) / (%2)" - syntax val (%) "(%1) % (%2)" - syntax val (=) "(%1) == (%2)" - syntax val (<=) "(%1) <= (%2)" - syntax val (<) "(%1) < (%2)" - syntax val (>=) "(%1) >= (%2)" - syntax val (>) "(%1) > (%2)" + syntax val (+) "%1 + %2" + syntax val (-) "%1 - %2" + syntax val ( * ) "%1 * %2" + syntax val (/) "%1 / %2" + syntax val (%) "%1 % %2" + syntax val (=) "%1 == %2" + syntax val (<=) "%1 <= %2" + syntax val (<) "%1 < %2" + syntax val (>=) "%1 >= %2" + syntax val (>) "%1 > %2" syntax val add_with_carry "add32_with_carry" syntax val sub_with_borrow "sub32_with_borrow" @@ -202,17 +202,17 @@ struct __lsld32_result lsld32(uint32_t x, uint32_t cnt); syntax val add3 "add32_3" syntax val lsld "lsld32" - syntax val add_mod "(%1) + (%2)" - syntax val sub_mod "(%1) - (%2)" - syntax val mul_mod "(%1) * (%2)" + syntax val add_mod "%1 + %2" + syntax val sub_mod "%1 - %2" + syntax val mul_mod "%1 * %2" syntax val div2by1 "(uint32_t)((((uint64_t)%1) | (((uint64_t)%2) << 32))/(uint64_t)(%3))" - syntax val lsl "(%1) << (%2)" - syntax val lsr "(%1) >> (%2)" + syntax val lsl "%1 << %2" + syntax val lsr "%1 >> %2" - syntax val is_msb_set "(%1) & 0x80000000U" + syntax val is_msb_set "%1 & 0x80000000U" syntax val count_leading_zeros "__builtin_clz(%1)" @@ -227,17 +227,17 @@ module mach.int.Int64 syntax type int64 "int64_t" - syntax val (+) "(%1) + (%2)" - syntax val (-) "(%1) - (%2)" - syntax val (-_) "-(%1)" - syntax val ( * ) "(%1) * (%2)" - syntax val (/) "(%1) / (%2)" - syntax val (%) "(%1) % (%2)" - syntax val (=) "(%1) == (%2)" - syntax val (<=) "(%1) <= (%2)" - syntax val (<) "(%1) < (%2)" - syntax val (>=) "(%1) >= (%2)" - syntax val (>) "(%1) > (%2)" + syntax val (+) "%1 + %2" + syntax val (-) "%1 - %2" + syntax val (-_) "-%1" + syntax val ( * ) "%1 * %2" + syntax val (/) "%1 / %2" + syntax val (%) "%1 % %2" + syntax val (=) "%1 == %2" + syntax val (<=) "%1 <= %2" + syntax val (<) "%1 < %2" + syntax val (>=) "%1 >= %2" + syntax val (>) "%1 > %2" end module mach.int.UInt64Gen @@ -253,17 +253,17 @@ module mach.int.UInt64 syntax converter of_int "%1ULL" - syntax val (+) "(%1) + (%2)" - syntax val (-) "(%1) - (%2)" - syntax val (-_) "-(%1)" - syntax val ( * ) "(%1) * (%2)" - syntax val (/) "(%1) / (%2)" - syntax val (%) "(%1) % (%2)" - syntax val (=) "(%1) == (%2)" - syntax val (<=) "(%1) <= (%2)" - syntax val (<) "(%1) < (%2)" - syntax val (>=) "(%1) >= (%2)" - syntax val (>) "(%1) > (%2)" + syntax val (+) "%1 + %2" + syntax val (-) "%1 - %2" + syntax val (-_) "-%1" + syntax val ( * ) "%1 * %2" + syntax val (/) "%1 / %2" + syntax val (%) "%1 % %2" + syntax val (=) "%1 == %2" + syntax val (<=) "%1 <= %2" + syntax val (<) "%1 < %2" + syntax val (>=) "%1 >= %2" + syntax val (>) "%1 > %2" end @@ -508,16 +508,16 @@ static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt) " syntax converter of_int "%1ULL" - syntax val (+) "(%1) + (%2)" - syntax val (-) "(%1) - (%2)" - syntax val ( * ) "(%1) * (%2)" - syntax val (/) "(%1) / (%2)" - syntax val (%) "(%1) % (%2)" - syntax val (=) "(%1) == (%2)" - syntax val (<=) "(%1) <= (%2)" - syntax val (<) "(%1) < (%2)" - syntax val (>=) "(%1) >= (%2)" - syntax val (>) "(%1) > (%2)" + syntax val (+) "%1 + %2" + syntax val (-) "%1 - %2" + syntax val ( * ) "%1 * %2" + syntax val (/) "%1 / %2" + syntax val (%) "%1 % %2" + syntax val (=) "%1 == %2" + syntax val (<=) "%1 <= %2" + syntax val (<) "%1 < %2" + syntax val (>=) "%1 >= %2" + syntax val (>) "%1 > %2" syntax val add_with_carry "add64_with_carry" syntax val add_double "add64_double" @@ -532,14 +532,14 @@ static struct __lsld64_result lsld64(uint64_t x, uint64_t cnt) syntax val add3 "add64_3" syntax val lsld "lsld64" - syntax val add_mod "(%1) + (%2)" - syntax val sub_mod "(%1) - (%2)" - syntax val mul_mod "(%1) * (%2)" + syntax val add_mod "%1 + %2" + syntax val sub_mod "%1 - %2" + syntax val mul_mod "%1 * %2" - syntax val lsl "(%1) << (%2)" - syntax val lsr "(%1) >> (%2)" + syntax val lsl "%1 << %2" + syntax val lsr "%1 >> %2" - syntax val is_msb_set "(%1) & 0x8000000000000000ULL" + syntax val is_msb_set "%1 & 0x8000000000000000ULL" syntax val count_leading_zeros "__builtin_clzll(%1)" @@ -570,16 +570,16 @@ module mach.c.C syntax val is_not_null "(%1) != NULL" syntax val null "NULL" - syntax val incr "%1+(%2)" + syntax val incr "%1 + %2" - syntax val get "*(%1)" - syntax val get_ofs "*(%1+(%2))" + syntax val get "*%1" + syntax val get_ofs "%1[%2]" - syntax val set "*(%1) = %2" - syntax val set_ofs "*(%1+(%2)) = %3" + syntax val set "*%1 = %2" + syntax val set_ofs "%1[%2] = %3" - syntax val incr_split "%1+(%2)" - syntax val decr_split "%1-(%2)" + syntax val incr_split "%1 + %2" + syntax val decr_split "%1 - %2" syntax val join "IGNORE2" syntax val join_r "IGNORE2" diff --git a/src/mlw/cprinter.ml b/src/mlw/cprinter.ml index f88ef2601452bdc6e7c7aa8c06e99a2e262b6189..a9f2b21cde59c77ef9acd3adb2d2ca796139ae9c 100644 --- a/src/mlw/cprinter.ml +++ b/src/mlw/cprinter.ml @@ -360,6 +360,20 @@ module C = struct else b | d,s -> d, s + +(* Operator precedence, needed to compute which parentheses can be removed *) + + let prec_unop = function + | Unot | Ustar | Uaddr | Upreincr | Upredecr -> 2 + | Upostincr | Upostdecr -> 1 + + let prec_binop = function + | Band -> 11 + | Bor -> 11 (* really 12, but this avoids Wparentheses *) + | Beq | Bne -> 7 + | Bassign -> 14 + | Blt | Ble | Bgt | Bge -> 6 + end type info = Pdriver.printer_args = private { @@ -419,7 +433,7 @@ module Print = struct (* should be handled in extract_stars *) | Tarray (ty, expr) -> fprintf fmt (protect_on paren "%a[%a]") - (print_ty ~paren:true) ty (print_expr ~paren:false) expr + (print_ty ~paren:true) ty (print_expr ~prec:1) expr | Tstruct (s,_) -> fprintf fmt "struct %s" s | Tunion _ -> raise (Unprinted "unions") | Tnamed id -> print_global_ident fmt id @@ -447,47 +461,57 @@ module Print = struct | Bgt -> fprintf fmt ">" | Bge -> fprintf fmt ">=" - and print_expr ~paren fmt = function + and print_expr ~prec fmt = function | Enothing -> Debug.dprintf debug_c_extraction "enothing"; () | Eunop(u,e) -> + let p = prec_unop u in if unop_postfix u - then fprintf fmt (protect_on paren "%a%a") - (print_expr ~paren:true) e print_unop u - else fprintf fmt (protect_on paren "%a%a") - print_unop u (print_expr ~paren:true) e + then fprintf fmt (protect_on (prec <= p) "%a%a") + (print_expr ~prec:p) e print_unop u + else fprintf fmt (protect_on (prec <= p) "%a%a") + print_unop u (print_expr ~prec:p) e | Ebinop(b,e1,e2) -> - fprintf fmt (protect_on paren "%a %a %a") - (print_expr ~paren:true) e1 print_binop b (print_expr ~paren:true) e2 + let p = prec_binop b in + fprintf fmt (protect_on (prec <= p) "%a %a %a") + (print_expr ~prec:p) e1 print_binop b (print_expr ~prec:p) e2 | Equestion(c,t,e) -> - fprintf fmt (protect_on paren "%a ? %a : %a") - (print_expr ~paren:true) c - (print_expr ~paren:true) t - (print_expr ~paren:true) e + fprintf fmt (protect_on (prec <= 13) "%a ? %a : %a") + (print_expr ~prec:13) c + (print_expr ~prec:13) t + (print_expr ~prec:13) e | Ecast(ty, e) -> - fprintf fmt (protect_on paren "(%a)%a") - (print_ty ~paren:false) ty (print_expr ~paren:true) e - | Ecall (Esyntax (s, _, _, [], _), l) -> (* function defined in the prelude *) - fprintf fmt (protect_on paren "%s(%a)") - s (print_list comma (print_expr ~paren:false)) l - | Ecall (e,l) -> fprintf fmt (protect_on paren "%a(%a)") - (print_expr ~paren:true) e (print_list comma (print_expr ~paren:false)) l + fprintf fmt (protect_on (prec <= 2) "(%a)%a") + (print_ty ~paren:false) ty (print_expr ~prec:2) e + | Ecall (Esyntax (s, _, _, [], _), l) -> + (* function defined in the prelude *) + fprintf fmt (protect_on (prec <= 1) "%s(%a)") + s (print_list comma (print_expr ~prec:15)) l + | Ecall (e,l) -> + fprintf fmt (protect_on (prec <= 1) "%a(%a)") + (print_expr ~prec:1) e (print_list comma (print_expr ~prec:15)) l | Econst c -> print_const fmt c | Evar id -> print_local_ident fmt id - | Elikely e -> fprintf fmt (protect_on paren "__builtin_expect(%a,1)") - (print_expr ~paren:true) e - | Eunlikely e -> fprintf fmt (protect_on paren "__builtin_expect(%a,0)") - (print_expr ~paren:true) e + | Elikely e -> fprintf fmt + (protect_on (prec <= 1) "__builtin_expect(%a,1)") + (print_expr ~prec:15) e + | Eunlikely e -> fprintf fmt + (protect_on (prec <= 1) "__builtin_expect(%a,0)") + (print_expr ~prec:15) e | Esize_expr e -> - fprintf fmt (protect_on paren "sizeof(%a)") (print_expr ~paren:false) e + fprintf fmt (protect_on (prec <= 2) "sizeof(%a)") (print_expr ~prec:15) e | Esize_type ty -> - fprintf fmt (protect_on paren "sizeof(%a)") (print_ty ~paren:false) ty - | Edot (e,s) -> fprintf fmt "%a.%s" (print_expr ~paren:true) e s - | Eindex _ | Earrow _ -> raise (Unprinted "struct/array access") - | Esyntax (s, t, args, lte,_) -> - gen_syntax_arguments_typed snd (fun _ -> args) - (if paren then ("("^s^")") else s) - (fun fmt (e,_t) -> print_expr ~paren:false fmt e) - (print_ty ~paren:false) (C.Enothing,t) fmt lte + fprintf fmt (protect_on (prec <= 2) "sizeof(%a)") + (print_ty ~paren:false) ty + | Edot (e,s) -> + fprintf fmt (protect_on (prec <= 1) "%a.%s") + (print_expr ~prec:1) e s + | Eindex _ | Earrow _ -> raise (Unprinted "struct/union access") + | Esyntax (s, t, args, lte, c) -> + (* no way to know precedence, so full parentheses*) + gen_syntax_arguments_typed snd (fun _ -> args) + (if prec <= 13 && not c then ("("^s^")") else s) + (fun fmt (e,_t) -> print_expr ~prec:1 fmt e) + (print_ty ~paren:false) (C.Enothing,t) fmt lte and print_const fmt = function | Cint s | Cfloat s | Cchar s | Cstring s -> fprintf fmt "%s" s @@ -499,11 +523,13 @@ module Print = struct match ie with | id, Enothing -> print_local_ident fmt id | id,e -> fprintf fmt "%a = %a" - print_local_ident id (print_expr ~paren:false) e + print_local_ident id (print_expr ~prec:(prec_binop Bassign)) e + + let print_expr_no_paren fmt expr = print_expr ~prec:max_int fmt expr let rec print_stmt ~braces fmt = function | Snop -> Debug.dprintf debug_c_extraction "snop"; () - | Sexpr e -> fprintf fmt "%a;" (print_expr ~paren:false) e; + | Sexpr e -> fprintf fmt "%a;" print_expr_no_paren e; | Sblock ([] ,s) when not braces -> (print_stmt ~braces:false) fmt s | Sblock b -> fprintf fmt "@[<hov>{@\n @[<hov>%a@]@\n}@]" print_body b @@ -511,23 +537,23 @@ module Print = struct (print_stmt ~braces:false) s1 (print_stmt ~braces:false) s2 | Sif(c,t,e) when is_nop e -> - fprintf fmt "if(%a)@\n%a" (print_expr ~paren:false) c + fprintf fmt "if(%a)@\n%a" print_expr_no_paren c (print_stmt ~braces:true) (Sblock([],t)) | Sif (c,t,e) -> fprintf fmt "if(%a)@\n%a@\nelse@\n%a" - (print_expr ~paren:false) c + print_expr_no_paren c (print_stmt ~braces:true) (Sblock([],t)) (print_stmt ~braces:true) (Sblock([],e)) | Swhile (e,b) -> fprintf fmt "while (%a)@;<1 2>%a" - (print_expr ~paren:false) e (print_stmt ~braces:true) (Sblock([],b)) + print_expr_no_paren e (print_stmt ~braces:true) (Sblock([],b)) | Sfor (einit, etest, eincr, s) -> fprintf fmt "for (%a; %a; %a)@;<1 2>%a" - (print_expr ~paren:false) einit - (print_expr ~paren:false) etest - (print_expr ~paren:false) eincr + print_expr_no_paren einit + print_expr_no_paren etest + print_expr_no_paren eincr (print_stmt ~braces:true) (Sblock([],s)) | Sbreak -> fprintf fmt "break;" | Sreturn Enothing -> fprintf fmt "return;" - | Sreturn e -> fprintf fmt "return %a;" (print_expr ~paren:true) e + | Sreturn e -> fprintf fmt "return %a;" print_expr_no_paren e and print_def fmt def = try match def with