Skip to content
Snippets Groups Projects
Select Git revision
  • 573f631465b1dbaf3fe9b85844ab15b11e8fbc9d
  • master default protected
  • stacked_borrows
  • generics_and_traits
  • crane_lift
  • spans
  • rust_syntax
  • type_check
  • expression
  • fallible
10 results

jit.rs

Blame
  • Forked from Per Lindgren / D7050E_2020
    Source project has a limited visibility.
    jit.rs 18.97 KiB
    use std::collections::HashMap;
    
    use cranelift::prelude::*;
    use cranelift_module::{DataContext, Linkage, Module};
    use cranelift_simplejit::{SimpleJITBackend, SimpleJITBuilder};
    // use frontend::*;
    use std::slice;
    
    use crate::ast::*;
    
    /// The basic JIT class.
    pub struct JIT {
        /// The function builder context, which is reused across multiple
        /// FunctionBuilder instances.
        builder_context: FunctionBuilderContext,
    
        /// The main Cranelift context, which holds the state for codegen. Cranelift
        /// separates this from `Module` to allow for parallel compilation, with a
        /// context per thread, though this isn't in the simple demo here.
        ctx: codegen::Context,
    
        /// The data context, which is to data objects what `ctx` is to functions.
        data_ctx: DataContext,
    
        /// The module, with the simplejit backend, which manages the JIT'd
        /// functions.
        module: Module<SimpleJITBackend>,
    }
    
    impl JIT {
        /// Create a new `JIT` instance.
        pub fn new() -> Self {
            let builder = SimpleJITBuilder::new(cranelift_module::default_libcall_names());
            let module = Module::new(builder);
            Self {
                builder_context: FunctionBuilderContext::new(),
                ctx: module.make_context(),
                data_ctx: DataContext::new(),
                module,
            }
        }
    
        pub fn compile(&mut self, p: &Program) -> Result<*const u8, String> {
            // First, parse the string, producing AST nodes.
            let f = &p.fn_decls[0];
    
            // Translate the AST nodes into Cranelift IR.
            self.translate(f).map_err(|e| e.to_string())?;
    
            // Next, declare the function to simplejit. Functions must be declared
            // before they can be called, or defined.
            let id = self
                .module
                .declare_function(f.id.as_str(), Linkage::Export, &self.ctx.func.signature)
                .map_err(|e| e.to_string())?;
    
            // Define the function to simplejit. This finishes compilation, although
            // there may be outstanding relocations to perform. Currently, simplejit
            // cannot finish relocations until all functions to be called are
            // defined. For this toy demo for now, we'll just finalize the function
            // below.
            self.module
                .define_function(id, &mut self.ctx, &mut codegen::binemit::NullTrapSink {})
                .map_err(|e| e.to_string())?;
    
            // Now that compilation is finished, we can clear out the context state.
            self.module.clear_context(&mut self.ctx);
    
            // Finalize the functions which we just defined, which resolves any
            // outstanding relocations (patching in addresses, now that they're
            // available).
            self.module.finalize_definitions();
    
            // We can now retrieve a pointer to the machine code.
            let code = self.module.get_finalized_function(id);
    
            Ok(code)
        }
    
        //     /// Create a zero-initialized data section.
        //     pub fn create_data(&mut self, name: &str, contents: Vec<u8>) -> Result<&[u8], String> {
        //         // The steps here are analogous to `compile`, except that data is much
        //         // simpler than functions.
        //         self.data_ctx.define(contents.into_boxed_slice());
        //         let id = self
        //             .module
        //             .declare_data(name, Linkage::Export, true, false, None)
        //             .map_err(|e| e.to_string())?;
    
        //         self.module
        //             .define_data(id, &self.data_ctx)
        //             .map_err(|e| e.to_string())?;
        //         self.data_ctx.clear();
        //         self.module.finalize_definitions();
        //         let buffer = self.module.get_finalized_data(id);
        //         // TODO: Can we move the unsafe into cranelift?
        //         Ok(unsafe { slice::from_raw_parts(buffer.0, buffer.1) })
        //     }
    
        // Translate from toy-language AST nodes into Cranelift IR.
        fn translate(
            &mut self,
            f: &FnDecl,
            // params: Vec<String>,
            // the_return: String,
            // stmts: Vec<Expr>,
        ) -> Result<(), String> {
            // Our toy language currently only supports I64 values, though Cranelift
            // supports other types.
    
            // Per: ok let i32 be an I64 for now, does not really matter
            let int = self.module.target_config().pointer_type();
    
            println!("set params");
            for _p in &f.params.0 {
                println!("id {:?}", _p.id);
                self.ctx.func.signature.params.push(AbiParam::new(int));
            }
    
            println!("set return value");
            // Our toy language currently only supports one return value, though
            // Cranelift is designed to support more.
            self.ctx.func.signature.returns.push(AbiParam::new(int));
    
            // Create the builder to build a function.
            let mut builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
    
            // Create the entry block, to start emitting code in.
            let entry_block = builder.create_block();
    
            // Since this is the entry block, add block parameters corresponding to
            // the function's parameters.
            builder.append_block_params_for_function_params(entry_block);
    
            // Tell the builder to emit code in this block.
            builder.switch_to_block(entry_block);
    
            // And, tell the builder that this block will have no further
            // predecessors. Since it's the entry block, it won't have any
            // predecessors.
            builder.seal_block(entry_block);
    
            // The toy language allows variables to be declared implicitly.
            // Walk the AST and declare all implicitly-declared variables.
    
            // Per: Not sure how we deal with introducing variables on the fly
            // let variables =
            //     declare_variables(int, &mut builder, &params, &the_return, &stmts, entry_block);
            let variables = HashMap::<Id, Variable>::new();
    
            // Now translate the statements of the function body.
            let mut trans = FunctionTranslator {
                int,
                builder,
                variables,
                module: &mut self.module,
            };
            for stmt in &f.body.stmts {
                trans.translate_stmt(stmt);
            }
    
            // Set up the return variable of the function. Above, we declared a
            // variable to hold the return value. Here, we just do a use of that
            // variable.
            let return_variable = trans.variables.get(&the_return).unwrap();
            let return_value = trans.builder.use_var(*return_variable);
    
            //         // Emit the return instruction.
            //         trans.builder.ins().return_(&[return_value]);
    
            //         // Tell the builder we're done with this function.
            //         trans.builder.finalize();
            Ok(())
        }
    }
    /// A collection of state used for translating from toy-language AST nodes
    /// into Cranelift IR.
    struct FunctionTranslator<'a> {
        int: types::Type,
        builder: FunctionBuilder<'a>,
        variables: HashMap<String, Variable>,
        module: &'a mut Module<SimpleJITBackend>,
    }
    
    impl<'a> FunctionTranslator<'a> {
        /// When you write out instructions in Cranelift, you get back `Value`s. You
        /// can then use these references in other instructions.
        fn translate_expr(&mut self, expr: Expr) -> Value {
            match expr {
                Expr::Num(literal) => self.builder.ins().iconst(self.int, i64::from(literal)),
                _ => unimplemented!(),
            }
        }
    
        fn translate_stmt(&mut self, stmt: &Stmt) -> Value {
            match stmt {
                Stmt::Assign(l, r) => {
                    let new_value = self.translate_expr(r);
                    let variable = self.variables.get("res").unwrap();
                    self.builder.def_var(*variable, new_value);
                    new_value
                }
                _ => unimplemented!(),
            }
        }
    }
    
    //             Expr::Add(lhs, rhs) => {
    //                 let lhs = self.translate_expr(*lhs);
    //                 let rhs = self.translate_expr(*rhs);
    //                 self.builder.ins().iadd(lhs, rhs)
    //             }
    
    //             Expr::Sub(lhs, rhs) => {
    //                 let lhs = self.translate_expr(*lhs);
    //                 let rhs = self.translate_expr(*rhs);
    //                 self.builder.ins().isub(lhs, rhs)
    //             }
    
    //             Expr::Mul(lhs, rhs) => {
    //                 let lhs = self.translate_expr(*lhs);
    //                 let rhs = self.translate_expr(*rhs);
    //                 self.builder.ins().imul(lhs, rhs)
    //             }
    
    //             Expr::Div(lhs, rhs) => {
    //                 let lhs = self.translate_expr(*lhs);
    //                 let rhs = self.translate_expr(*rhs);
    //                 self.builder.ins().udiv(lhs, rhs)
    //             }
    
    //             Expr::Eq(lhs, rhs) => {
    //                 let lhs = self.translate_expr(*lhs);
    //                 let rhs = self.translate_expr(*rhs);
    //                 let c = self.builder.ins().icmp(IntCC::Equal, lhs, rhs);
    //                 self.builder.ins().bint(self.int, c)
    //             }
    
    //             Expr::Ne(lhs, rhs) => {
    //                 let lhs = self.translate_expr(*lhs);
    //                 let rhs = self.translate_expr(*rhs);
    //                 let c = self.builder.ins().icmp(IntCC::NotEqual, lhs, rhs);
    //                 self.builder.ins().bint(self.int, c)
    //             }
    
    //             Expr::Lt(lhs, rhs) => {
    //                 let lhs = self.translate_expr(*lhs);
    //                 let rhs = self.translate_expr(*rhs);
    //                 let c = self.builder.ins().icmp(IntCC::SignedLessThan, lhs, rhs);
    //                 self.builder.ins().bint(self.int, c)
    //             }
    
    //             Expr::Le(lhs, rhs) => {
    //                 let lhs = self.translate_expr(*lhs);
    //                 let rhs = self.translate_expr(*rhs);
    //                 let c = self
    //                     .builder
    //                     .ins()
    //                     .icmp(IntCC::SignedLessThanOrEqual, lhs, rhs);
    //                 self.builder.ins().bint(self.int, c)
    //             }
    
    //             Expr::Gt(lhs, rhs) => {
    //                 let lhs = self.translate_expr(*lhs);
    //                 let rhs = self.translate_expr(*rhs);
    //                 let c = self.builder.ins().icmp(IntCC::SignedGreaterThan, lhs, rhs);
    //                 self.builder.ins().bint(self.int, c)
    //             }
    
    //             Expr::Ge(lhs, rhs) => {
    //                 let lhs = self.translate_expr(*lhs);
    //                 let rhs = self.translate_expr(*rhs);
    //                 let c = self
    //                     .builder
    //                     .ins()
    //                     .icmp(IntCC::SignedGreaterThanOrEqual, lhs, rhs);
    //                 self.builder.ins().bint(self.int, c)
    //             }
    
    //             Expr::Call(name, args) => self.translate_call(name, args),
    
    //             Expr::GlobalDataAddr(name) => self.translate_global_data_addr(name),
    
    //             Expr::Identifier(name) => {
    //                 // `use_var` is used to read the value of a variable.
    //                 let variable = self.variables.get(&name).expect("variable not defined");
    //                 self.builder.use_var(*variable)
    //             }
    
    //             Expr::Assign(name, expr) => {
    //                 // `def_var` is used to write the value of a variable. Note that
    //                 // variables can have multiple definitions. Cranelift will
    //                 // convert them into SSA form for itself automatically.
    //                 let new_value = self.translate_expr(*expr);
    //                 let variable = self.variables.get(&name).unwrap();
    //                 self.builder.def_var(*variable, new_value);
    //                 new_value
    //             }
    
    //             Expr::IfElse(condition, then_body, else_body) => {
    //                 let condition_value = self.translate_expr(*condition);
    
    //                 let then_block = self.builder.create_block();
    //                 let else_block = self.builder.create_block();
    //                 let merge_block = self.builder.create_block();
    
    //                 // If-else constructs in the toy language have a return value.
    //                 // In traditional SSA form, this would produce a PHI between
    //                 // the then and else bodies. Cranelift uses block parameters,
    //                 // so set up a parameter in the merge block, and we'll pass
    //                 // the return values to it from the branches.
    //                 self.builder.append_block_param(merge_block, self.int);
    
    //                 // Test the if condition and conditionally branch.
    //                 self.builder.ins().brz(condition_value, else_block, &[]);
    //                 // Fall through to then block.
    //                 self.builder.ins().jump(then_block, &[]);
    
    //                 self.builder.switch_to_block(then_block);
    //                 self.builder.seal_block(then_block);
    //                 let mut then_return = self.builder.ins().iconst(self.int, 0);
    //                 for expr in then_body {
    //                     then_return = self.translate_expr(expr);
    //                 }
    
    //                 // Jump to the merge block, passing it the block return value.
    //                 self.builder.ins().jump(merge_block, &[then_return]);
    
    //                 self.builder.switch_to_block(else_block);
    //                 self.builder.seal_block(else_block);
    //                 let mut else_return = self.builder.ins().iconst(self.int, 0);
    //                 for expr in else_body {
    //                     else_return = self.translate_expr(expr);
    //                 }
    
    //                 // Jump to the merge block, passing it the block return value.
    //                 self.builder.ins().jump(merge_block, &[else_return]);
    
    //                 // Switch to the merge block for subsequent statements.
    //                 self.builder.switch_to_block(merge_block);
    
    //                 // We've now seen all the predecessors of the merge block.
    //                 self.builder.seal_block(merge_block);
    
    //                 // Read the value of the if-else by reading the merge block
    //                 // parameter.
    //                 let phi = self.builder.block_params(merge_block)[0];
    
    //                 phi
    //             }
    
    //             Expr::WhileLoop(condition, loop_body) => {
    //                 let header_block = self.builder.create_block();
    //                 let body_block = self.builder.create_block();
    //                 let exit_block = self.builder.create_block();
    
    //                 self.builder.ins().jump(header_block, &[]);
    //                 self.builder.switch_to_block(header_block);
    
    //                 let condition_value = self.translate_expr(*condition);
    //                 self.builder.ins().brz(condition_value, exit_block, &[]);
    //                 self.builder.ins().jump(body_block, &[]);
    
    //                 self.builder.switch_to_block(body_block);
    //                 self.builder.seal_block(body_block);
    
    //                 for expr in loop_body {
    //                     self.translate_expr(expr);
    //                 }
    //                 self.builder.ins().jump(header_block, &[]);
    
    //                 self.builder.switch_to_block(exit_block);
    
    //                 // We've reached the bottom of the loop, so there will be no
    //                 // more backedges to the header to exits to the bottom.
    //                 self.builder.seal_block(header_block);
    //                 self.builder.seal_block(exit_block);
    
    //                 // Just return 0 for now.
    //                 self.builder.ins().iconst(self.int, 0)
    //             }
    //         }
    
    //     fn translate_call(&mut self, name: String, args: Vec<Expr>) -> Value {
    //         let mut sig = self.module.make_signature();
    
    //         // Add a parameter for each argument.
    //         for _arg in &args {
    //             sig.params.push(AbiParam::new(self.int));
    //         }
    
    //         // For simplicity for now, just make all calls return a single I64.
    //         sig.returns.push(AbiParam::new(self.int));
    
    //         // TODO: Streamline the API here?
    //         let callee = self
    //             .module
    //             .declare_function(&name, Linkage::Import, &sig)
    //             .expect("problem declaring function");
    //         let local_callee = self
    //             .module
    //             .declare_func_in_func(callee, &mut self.builder.func);
    
    //         let mut arg_values = Vec::new();
    //         for arg in args {
    //             arg_values.push(self.translate_expr(arg))
    //         }
    //         let call = self.builder.ins().call(local_callee, &arg_values);
    //         self.builder.inst_results(call)[0]
    //     }
    
    //     fn translate_global_data_addr(&mut self, name: String) -> Value {
    //         let sym = self
    //             .module
    //             .declare_data(&name, Linkage::Export, true, false, None)
    //             .expect("problem declaring data object");
    //         let local_id = self
    //             .module
    //             .declare_data_in_func(sym, &mut self.builder.func);
    
    //         let pointer = self.module.target_config().pointer_type();
    //         self.builder.ins().symbol_value(pointer, local_id)
    //     }
    
    fn declare_variables(
        int: types::Type,
        builder: &mut FunctionBuilder,
        params: &[String],
        the_return: &str,
        stmts: &[Expr],
        entry_block: Block,
    ) -> HashMap<String, Variable> {
        let mut variables = HashMap::new();
        let mut index = 0;
    
        // for 
        for (i, name) in params.iter().enumerate() {
            
            let val = builder.block_params(entry_block)[i];
            let var = declare_variable(int, builder, &mut variables, &mut index, name);
            builder.def_var(var, val);
        }
        let zero = builder.ins().iconst(int, 0);
        let return_variable = declare_variable(int, builder, &mut variables, &mut index, the_return);
        builder.def_var(return_variable, zero);
        for expr in stmts {
            declare_variables_in_stmt(int, builder, &mut variables, &mut index, expr);
        }
    
        variables
    }
    
    /// Recursively descend through the AST, translating all implicit
    /// variable declarations.
    fn declare_variables_in_stmt(
        int: types::Type,
        builder: &mut FunctionBuilder,
        variables: &mut HashMap<String, Variable>,
        index: &mut usize,
        expr: &Expr,
    ) {
        match *expr {
            // Expr::Assign(ref name, _) => {
            //     declare_variable(int, builder, variables, index, name);
            // }
            // Expr::IfElse(ref _condition, ref then_body, ref else_body) => {
            //     for stmt in then_body {
            //         declare_variables_in_stmt(int, builder, variables, index, &stmt);
            //     }
            //     for stmt in else_body {
            //         declare_variables_in_stmt(int, builder, variables, index, &stmt);
            //     }
            // }
            // Expr::WhileLoop(ref _condition, ref loop_body) => {
            //     for stmt in loop_body {
            //         declare_variables_in_stmt(int, builder, variables, index, &stmt);
            //     }
            // }
            _ => (),
        }
    }
    
    /// Declare a single variable declaration.
    fn declare_variable(
        int: types::Type,
        builder: &mut FunctionBuilder,
        variables: &mut HashMap<String, Variable>,
        index: &mut usize,
        name: &str,
    ) -> Variable {
        let var = Variable::new(*index);
        if !variables.contains_key(name) {
            variables.insert(name.into(), var);
            builder.declare_var(var, int);
            *index += 1;
        }
        var
    }