Skip to content
Snippets Groups Projects
Commit c3b9c04a authored by nilfit's avatar nilfit
Browse files

add partial application and splitting into multiple calls

Rust is more sensitive to how arguments are applied, so
mlw `f a b` could mean Rust `f(a, b)`, `f(a)(b)` or `f()(a, b)`, etc.
parent 473991fc
No related branches found
No related tags found
No related merge requests found
...@@ -6,11 +6,6 @@ theory BuiltIn ...@@ -6,11 +6,6 @@ theory BuiltIn
syntax predicate (=) "%1 == %2" syntax predicate (=) "%1 == %2"
end end
module HighOrd
(* TODO this does not work when there are multiple @ *)
syntax val ( @ ) "%1(%2)"
end
theory Bool theory Bool
syntax type bool "bool" syntax type bool "bool"
syntax function True "true" syntax function True "true"
......
...@@ -23,8 +23,6 @@ module Rust = struct ...@@ -23,8 +23,6 @@ module Rust = struct
| Tapp of ident * ty list * Slt.t (* lifetimes last *) | Tapp of ident * ty list * Slt.t (* lifetimes last *)
| Tsyn of string * ty list (* syntax *) | Tsyn of string * ty list (* syntax *)
| Tfn of ty list * ty (* arg types, result type*) | Tfn of ty list * ty (* arg types, result type*)
(* | Tstruct *)
(* | Tsyntax *)
type pat = type pat =
| Pwild (* _ *) | Pwild (* _ *)
...@@ -47,20 +45,21 @@ module Rust = struct ...@@ -47,20 +45,21 @@ module Rust = struct
and expr = and expr =
| Econst of string option * BigInt.t | Econst of string option * BigInt.t
| Evar of pvsymbol | Evar of ident
| Efield of expr * ident | Efield of expr * ident
| Eassign of pvsymbol * rsymbol * pvsymbol | Eassign of pvsymbol * rsymbol * pvsymbol
| Etup of expr list | Etup of expr list
| Eenum of rsymbol * expr list | Eenum of rsymbol * expr list
| Estruct of ident * rsymbol list * expr list (* Rs { r1: e1, r2: e2 } *) | Estruct of ident * rsymbol list * expr list (* Rs { r1: e1, r2: e2 } *)
| Eblock of expr list | Eblock of expr list
| Ecall of rsymbol * expr list | Ecall of expr * expr list (* e(e1, e2, ...)*)
| Ers of rsymbol (* for using a function name as an expression *)
| Esyntax of string * expr list | Esyntax of string * expr list
(* optional syntax, r(e1, ...) *) (* optional syntax, r(e1, ...) *)
| Elet of pat * expr * expr (* let p = e1; e2 *) | Elet of pat * expr * expr (* let p = e1; e2 *)
| Ematch of expr * branch list | Ematch of expr * branch list
| Eif of expr * expr * expr | Eif of expr * expr * expr
| Eclosure of var list * expr | Eclosure of ident list * expr
| Efn of fn * expr | Efn of fn * expr
(* fn rs(vl) -> t {e1} e2 *) (* fn rs(vl) -> t {e1} e2 *)
| Ewhile of expr * expr | Ewhile of expr * expr
...@@ -109,6 +108,7 @@ module Rust = struct ...@@ -109,6 +108,7 @@ module Rust = struct
| [Eblock el'] -> clean_expr (Eblock el') | [Eblock el'] -> clean_expr (Eblock el')
| _ -> Eblock (clean_list el)) | _ -> Eblock (clean_list el))
| Ecall (rs, el) -> Ecall (rs, (clean_list el)) | Ecall (rs, el) -> Ecall (rs, (clean_list el))
| Ers _ -> e
| Esyntax (s, el) -> Esyntax (s, (clean_list el)) | Esyntax (s, el) -> Esyntax (s, (clean_list el))
| Elet (p, e1, e2) -> Elet (p, clean_expr e1, clean_expr e2) | Elet (p, e1, e2) -> Elet (p, clean_expr e1, clean_expr e2)
| Efield (e, id) -> Efield (clean_expr e, id) | Efield (e, id) -> Efield (clean_expr e, id)
...@@ -157,6 +157,8 @@ module Translate = struct ...@@ -157,6 +157,8 @@ module Translate = struct
let boxes = { box_fields = Sid.empty; box_enum = Sbe.empty } let boxes = { box_fields = Sid.empty; box_enum = Sbe.empty }
let rs_nargs = Hrs.create 42
(* TODO make sure we don't extract any ghosts *) (* TODO make sure we don't extract any ghosts *)
let is_id_func = id_equal ts_func.ts_name let is_id_func = id_equal ts_func.ts_name
...@@ -190,6 +192,8 @@ module Translate = struct ...@@ -190,6 +192,8 @@ module Translate = struct
(id, translate_ty info ty) (id, translate_ty info ty)
let translate_vars info = List.map (translate_var info) let translate_vars info = List.map (translate_var info)
let translate_vars_closure info vl = List.map (fun (id, _) -> id) (translate_vars info vl)
let get_its_defn info rs = let get_its_defn info rs =
match Mid.find_opt rs.rs_name info.mo_known_map with match Mid.find_opt rs.rs_name info.mo_known_map with
| Some {pd_node = PDtype itdl} -> | Some {pd_node = PDtype itdl} ->
...@@ -267,8 +271,50 @@ module Translate = struct ...@@ -267,8 +271,50 @@ module Translate = struct
exception MissingSyntaxLiteral of string exception MissingSyntaxLiteral of string
(* requires that all rsymbols that will be used have been added *)
(* Produces a list [i_1, i_2, .. i_n] where i_1 indicates how many arguments
* e can be applied to. After applying those arguments, i_2 indicates how many
* arguments the result can be applied to. *)
let rec find_arg_counts (e:expr) =
match e.e_node with
| Eapp (rs, el) ->
let rec consume k nargs =
match nargs with
| [] -> raise (TODO "Eapp with rs that takes too few args")
| n :: nargs when k < n -> n - k :: nargs
| n :: nargs when k = n -> nargs
| n :: nargs when k > n -> consume (k - n) nargs
| _ :: _ -> assert false in
(try
let nargs = Hrs.find rs_nargs rs in
consume (List.length el) nargs
with Not_found -> []) (* rs is not an extracted symbol *)
(* TODO mutually recursive functions can get Not_found *)
| Efun (args, e) -> List.length args :: find_arg_counts e
| Elet (_, e) -> find_arg_counts e
| Eif (_, e1, e2) ->
let l1 = find_arg_counts e1 in
let l2 = find_arg_counts e2 in
assert (l1 = l2);
l1
| Ematch (_, bl, []) ->
(match List.map (fun (_, e) -> find_arg_counts e) bl with
| [] -> assert false
| argl :: rest ->
List.iter (fun l -> assert (l = argl)) rest;
argl)
| Ematch (_, _, _) -> raise (TODO "arg_counts match with exceptions")
| Eblock el -> find_arg_counts (List.hd (List.rev el))
| Eexn (_, _, e) -> find_arg_counts e
| Evar _ (* TODO can this ever be of function type? *)
| Econst _ | Eassign _ | Ewhile (_, _) | Efor (_, _, _, _, _)
| Eraise (_, _) | Eignore _ | Eabsurd | Ehole -> []
| Eany _ -> assert false
let box_expr (e:Rust.expr) = Rust.Esyntax ("Box::new(%1)", [e]) let box_expr (e:Rust.expr) = Rust.Esyntax ("Box::new(%1)", [e])
let pv_name pv = pv.pv_vs.vs_name
let rec translate_expr (info:info) (e:expr) : Rust.expr = let rec translate_expr (info:info) (e:expr) : Rust.expr =
Rust.clean_expr (match e.e_node with Rust.clean_expr (match e.e_node with
| Econst c -> | Econst c ->
...@@ -282,8 +328,19 @@ module Translate = struct ...@@ -282,8 +328,19 @@ module Translate = struct
(* there is no default, driver must make a decision about how to represent (* there is no default, driver must make a decision about how to represent
literals *) literals *)
(* TODO show the path to the ident e.g. mach.int.Int64.int64 instead of int64 *) (* TODO show the path to the ident e.g. mach.int.Int64.int64 instead of int64 *)
| Evar pvs -> Rust.Evar pvs | Evar pvs -> Rust.Evar pvs.pv_vs.vs_name
| Eapp (rs, el) -> | Eapp (rs, el) ->
let isfuncapp = id_equal rs.rs_name fs_func_app.ls_name in (* a @ b *)
if isfuncapp then
(match el with
| {e_node = Eapp (rs, el)} as e' :: el2 ->
(* note that the ity won't be accurate *)
let e'' = {e' with e_node = Eapp (rs, el @ el2)} in
translate_expr info e''
| e1 :: el' -> Rust.Ecall (translate_expr info e1, List.map (translate_expr info) el')
| _ -> raise (TODO ("infix @ with " ^
string_of_int (List.length el) ^ " args")))
else (
let el = List.map (translate_expr info) el in let el = List.map (translate_expr info) el in
let isfield = let isfield =
match rs.rs_field with match rs.rs_field with
...@@ -320,12 +377,57 @@ module Translate = struct ...@@ -320,12 +377,57 @@ module Translate = struct
let id = rs_type_ident rs in let id = rs_type_ident rs in
Rust.Estruct (id, rsl, el) Rust.Estruct (id, rsl, el)
) )
| None, None, _ -> Rust.Ecall (rs, el) | None, None, _ ->
(* TODO partial application *) (* use rs_nargs to know when and where to split into multiple calls *)
(* TODO application like x()(3)(4,5) *) let nargs = (try Hrs.find rs_nargs rs
) with Not_found -> raise (
TODO ("rs_nargs not found " ^ rs.rs_name.id_string))) in
(* the second result signals that the last application is partial
* and how many args are not applied *)
let rec split_args el nargs : Rust.expr list list * int option =
match nargs with
| n :: nargs' ->
if List.length el < n then
([el], Some (n - List.length el))
else
let args =
try Lists.prefix n el
with Invalid_argument s -> raise (TODO
(rs.rs_name.id_string ^ " partial application (" ^
string_of_int n ^ "/" ^ string_of_int (List.length el) ^
"; " ^ s)) in
let el' =
try Lists.chop n el
(* TODO remove try since short el is already handled *)
with Invalid_argument s -> raise (TODO
("partial application; " ^ s)) in
let splits, partial = split_args el' nargs' in
(args :: splits, partial)
| [] ->
assert (List.length el = 0);
[], None in
let nest_ecall acc args = Rust.Ecall (acc, args) in
let rec nest_partial acc al n =
(* the last arg list in al is partially applied *)
match al with
| [] -> acc
| args :: [] ->
let fresh_idents = List.init n (fun _ -> id_register (id_fresh "a")) in
let fresh_evars = List.map (fun id -> Rust.Evar id) fresh_idents in
Rust.Eclosure (fresh_idents, Rust.Ecall (acc, args @ fresh_evars))
| args :: al' ->
let acc' = nest_ecall acc args in
nest_partial acc' al' n
in
let splits, partial = split_args el nargs in
match partial with
| Some n -> nest_partial (Rust.Ers rs) splits n
| None -> List.fold_left nest_ecall (Rust.Ers rs) splits
))
| Efun (argl, e) -> | Efun (argl, e) ->
Rust.Eclosure (translate_vars info argl, translate_expr info e) Rust.Eclosure (translate_vars_closure info argl, translate_expr info e)
| Elet (Lvar (pv, e_var), e) -> | Elet (Lvar (pv, e_var), e) ->
(* If the type has some mutable component, make the variable mutable. (* If the type has some mutable component, make the variable mutable.
* This will result in variables being unnecessarily marked mutable. *) * This will result in variables being unnecessarily marked mutable. *)
...@@ -335,8 +437,10 @@ module Translate = struct ...@@ -335,8 +437,10 @@ module Translate = struct
Rust.Elet (pat, translate_expr info e_var, translate_expr info e) Rust.Elet (pat, translate_expr info e_var, translate_expr info e)
| Elet (Lsym (rs, ty, vars, e_body), e_after) -> | Elet (Lsym (rs, ty, vars, e_body), e_after) ->
let er = translate_expr info e_body in let er = translate_expr info e_body in
let vr = translate_vars info vars in let vr = translate_vars_closure info vars in
let e_after = translate_expr info e_after in let e_after = translate_expr info e_after in
let arg_counts = List.length vars :: find_arg_counts e_body in
Hrs.add rs_nargs rs arg_counts;
let pat = Rust.Pvar (rs.rs_name, false) in (* TODO mutability *) let pat = Rust.Pvar (rs.rs_name, false) in (* TODO mutability *)
Rust.Elet(pat, Rust.Eclosure (vr, er), e_after) Rust.Elet(pat, Rust.Eclosure (vr, er), e_after)
| Elet (Lany (_, _, _), _) -> raise (TODO "Elet(Lany)") | Elet (Lany (_, _, _), _) -> raise (TODO "Elet(Lany)")
...@@ -364,12 +468,13 @@ module Translate = struct ...@@ -364,12 +468,13 @@ module Translate = struct
| Ewhile (test, body) -> | Ewhile (test, body) ->
Rust.Ewhile (translate_expr info test, Rust.Eblock [(translate_expr info body)]) Rust.Ewhile (translate_expr info test, Rust.Eblock [(translate_expr info body)])
| Efor (pv1, pv2, dir, pv3, e) -> | Efor (pv1, pv2, dir, pv3, e) ->
(* TODO special case for bigint *)
let b = translate_expr info e in let b = translate_expr info e in
let r = (match dir with let r = (match dir with
| To -> Rust.Erange_inc (Rust.Evar pv2, Rust.Evar pv3) | To -> Rust.Erange_inc (Rust.Evar (pv_name pv2), Rust.Evar (pv_name pv3))
| DownTo -> | DownTo ->
(* (pv3..pv2).rev() *) (* (pv3..pv2).rev() *)
let range = Rust.Erange_inc (Rust.Evar pv3, Rust.Evar pv2) in let range = Rust.Erange_inc (Rust.Evar (pv_name pv3), Rust.Evar (pv_name pv2)) in
Rust.Esyntax("(%1).rev()", [range]) Rust.Esyntax("(%1).rev()", [range])
) in ) in
Rust.Efor (pv1, r, Rust.Eblock [b]) Rust.Efor (pv1, r, Rust.Eblock [b])
...@@ -411,7 +516,8 @@ module Translate = struct ...@@ -411,7 +516,8 @@ module Translate = struct
(List.map discover_lts_tvs_expr el) in (List.map discover_lts_tvs_expr el) in
(* let map_and_union el = union_stv_list (List.map discover_lts_tvs_expr el) in *) (* let map_and_union el = union_stv_list (List.map discover_lts_tvs_expr el) in *)
(match e with (match e with
| Rust.Econst _ | Rust.Evar _ | Rust.Eunreachable -> Slt.empty, Stv.empty | Rust.Econst _ | Rust.Evar _ | Rust.Eunreachable | Rust.Ers _
-> Slt.empty, Stv.empty
| Rust.Ecall (_, el) | Rust.Esyntax (_, el) | Rust.Eblock el -> | Rust.Ecall (_, el) | Rust.Esyntax (_, el) | Rust.Eblock el ->
map_and_union el map_and_union el
| Rust.Elet (_, e1, e2) -> map_and_union [e1;e2] | Rust.Elet (_, e1, e2) -> map_and_union [e1;e2]
...@@ -489,6 +595,9 @@ module Translate = struct ...@@ -489,6 +595,9 @@ module Translate = struct
let e = translate_expr info r.rec_exp in let e = translate_expr info r.rec_exp in
let slt, stv = discover_lts_tvs_def (Rust.Dfn (rs, vars, result_ty, e, r.rec_svar, let slt, stv = discover_lts_tvs_def (Rust.Dfn (rs, vars, result_ty, e, r.rec_svar,
Slt.empty)) in Slt.empty)) in
(* TODO make sure all related rdefs have been discovered *)
let arg_counts = List.length vars :: find_arg_counts r.rec_exp in
Hrs.add rs_nargs rs arg_counts;
Rust.Dfn (rs, vars, result_ty, e, stv, slt) Rust.Dfn (rs, vars, result_ty, e, stv, slt)
let box_ty (t:Rust.ty) = Rust.Tsyn ("Box<%1>", [t]) let box_ty (t:Rust.ty) = Rust.Tsyn ("Box<%1>", [t])
...@@ -638,6 +747,8 @@ module Translate = struct ...@@ -638,6 +747,8 @@ module Translate = struct
let er = translate_expr info e in let er = translate_expr info e in
let tr = translate_ty info ty in let tr = translate_ty info ty in
let vr = translate_vars info vars in let vr = translate_vars info vars in
let arg_counts = List.length vars :: find_arg_counts e in
Hrs.add rs_nargs rs arg_counts;
(* TODO can the tvs be found from rs_cty? *) (* TODO can the tvs be found from rs_cty? *)
let slt, stv = discover_lts_tvs_def (Rust.Dfn (rs, vr, tr, er, let slt, stv = discover_lts_tvs_def (Rust.Dfn (rs, vr, tr, er,
Stv.empty, Slt.empty)) in Stv.empty, Slt.empty)) in
...@@ -711,6 +822,7 @@ module Print = struct ...@@ -711,6 +822,7 @@ module Print = struct
create_ident_printer rust_keywords ~sanitizer:lsanitize create_ident_printer rust_keywords ~sanitizer:lsanitize
let forget_id id = forget_id iprinter id let forget_id id = forget_id iprinter id
let forget_ids = List.iter forget_id
let forget_var (id, _) = forget_id id let forget_var (id, _) = forget_id id
let forget_vars = List.iter forget_var let forget_vars = List.iter forget_var
...@@ -860,17 +972,13 @@ module Print = struct ...@@ -860,17 +972,13 @@ module Print = struct
let print_arg info fmt (id, ty) = let print_arg info fmt (id, ty) =
fprintf fmt "%a:@ %a" (print_lident info) id (print_ty info) ty fprintf fmt "%a:@ %a" (print_lident info) id (print_ty info) ty
(* By not printing the arg types, they will be inferred on the Rust side *)
let print_arg_closure info fmt (id, _) =
print_lident info fmt id
let rec print_expr info fmt (e:expr) = let rec print_expr info fmt (e:expr) =
match e with match e with
| Econst (s_opt, n) -> | Econst (s_opt, n) ->
(match s_opt with (match s_opt with
| Some s -> syntax_arguments s print_constant fmt [n] | Some s -> syntax_arguments s print_constant fmt [n]
| None -> raise (TODO "const no syntax")) | None -> raise (TODO "const no syntax"))
| Evar pv -> print_pv info fmt pv | Evar id -> print_lident info fmt id
| Efield (e, id) -> | Efield (e, id) ->
fprintf fmt "%a.%a" (print_expr info) e (print_lident info) id fprintf fmt "%a.%a" (print_expr info) e (print_lident info) id
| Eassign (rho, rs, pv) -> | Eassign (rho, rs, pv) ->
...@@ -889,9 +997,10 @@ module Print = struct ...@@ -889,9 +997,10 @@ module Print = struct
(print_list2 comma colon (print_rs info) (print_expr info)) (fl, el) (print_list2 comma colon (print_rs info) (print_expr info)) (fl, el)
| Eblock el -> | Eblock el ->
fprintf fmt "@[<hov 2>{@\n%a@]@\n}" (print_list semi (print_expr info)) el fprintf fmt "@[<hov 2>{@\n%a@]@\n}" (print_list semi (print_expr info)) el
| Ecall (rs, el) -> | Ecall (e, el) ->
fprintf fmt "%a(@[%a@])" (print_rs info) rs fprintf fmt "%a(@[%a@])" (print_expr info) e
(print_list comma (print_expr info)) el (print_list comma (print_expr info)) el
| Ers rs -> print_rs info fmt rs
| Esyntax (s, el) -> syntax_arguments s (print_expr info) fmt el | Esyntax (s, el) -> syntax_arguments s (print_expr info) fmt el
| Elet (pat, e, e_after) -> | Elet (pat, e, e_after) ->
fprintf fmt "let@ %a@ =@ {@[%a@]};@\n%a" (print_pat info) pat fprintf fmt "let@ %a@ =@ {@[%a@]};@\n%a" (print_pat info) pat
...@@ -920,13 +1029,12 @@ module Print = struct ...@@ -920,13 +1029,12 @@ module Print = struct
fprintf fmt "for %a in %a@ %a" (print_pv info) p fprintf fmt "for %a in %a@ %a" (print_pv info) p
(print_expr info) r (print_expr info) b (print_expr info) r (print_expr info) b
| Erange_inc (lo, hi) -> | Erange_inc (lo, hi) ->
(* TODO special case for bigint *)
fprintf fmt "%a..%a" (print_expr info) lo (print_expr info) hi fprintf fmt "%a..%a" (print_expr info) lo (print_expr info) hi
| Eclosure (args, e) -> | Eclosure (args, e) ->
fprintf fmt "@[<hov 2>|%a| {@\n%a@]@\n}\n" fprintf fmt "@[<hov 2>(move |%a| {@\n%a@]@\n})\n"
(print_list comma (print_arg_closure info)) args (print_list comma (print_lident info)) args
(print_expr info) e; (print_expr info) e;
forget_vars args forget_ids args
| Efn (fn, e) -> | Efn (fn, e) ->
fprintf fmt "%a@\n%a" (print_fn info) fn (print_expr info) e fprintf fmt "%a@\n%a" (print_fn info) fn (print_expr info) e
| Eunreachable -> fprintf fmt "unreachable!()" | Eunreachable -> fprintf fmt "unreachable!()"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment