// Copyright (c) 2022 Marceline Cramer // SPDX-License-Identifier: GPL-3.0-or-later use crate::parse::ast; use cranelift::codegen::{self, Context}; use cranelift::frontend::FunctionBuilderContext; use cranelift::prelude::*; use cranelift_jit::{JITBuilder, JITModule}; use cranelift_module::{DataContext, Linkage, Module}; pub struct Engine { builder_context: FunctionBuilderContext, ctx: Context, data_ctx: DataContext, module: JITModule, } impl Engine { pub fn new() -> Self { let builder = JITBuilder::new(cranelift_module::default_libcall_names()); let module = JITModule::new(builder); Self { builder_context: FunctionBuilderContext::new(), ctx: module.make_context(), data_ctx: DataContext::new(), module, } } pub fn compile(&mut self, fn_impl: &ast::FnImpl) -> Result<*const u8, String> { self.translate(fn_impl)?; let name = "dummy_function"; let id = self .module .declare_function(name, Linkage::Export, &self.ctx.func.signature) .map_err(|e| e.to_string())?; let mut trap_sink = codegen::binemit::NullTrapSink {}; let mut stack_map_sink = codegen::binemit::NullStackMapSink {}; self.module .define_function(id, &mut self.ctx, &mut trap_sink, &mut stack_map_sink) .map_err(|e| e.to_string())?; self.module.clear_context(&mut self.ctx); self.module.finalize_definitions(); let code = self.module.get_finalized_function(id); Ok(code) } pub fn translate(&mut self, fn_impl: &ast::FnImpl) -> Result<(), String> { let int = self.module.target_config().pointer_type(); let builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context); let trans = super::translate::FunctionTranslator::new(int, builder); trans.translate(fn_impl); Ok(()) } } #[cfg(test)] mod tests { use super::*; fn jit_compile(source: &str) -> *const u8 { use crate::parse::lexer::Lexer; use crate::parse::rd::RecursiveDescent; println!("JIT-compiling source: {}", source); let lexer = Lexer::new(source); let mut rd = RecursiveDescent::new(lexer); let def = rd.build_fn(false, None); if let ast::Definition::Function { implementation, .. } = def { Engine::new().compile(&implementation).unwrap() } else { panic!("Failed to parse function from test source"); } } fn jit_fn(source: &str) -> fn(I) -> O { let code_ptr = jit_compile(source); unsafe { std::mem::transmute::<_, fn(I) -> O>(code_ptr) } } #[test] fn simple_math() { let source = "simple_math() i64 { (1 + 2) * ((4 + 5) / 3) }"; let code_fn = jit_fn::<(), i64>(source); assert_eq!(code_fn(()), 9); } #[test] fn simple_args() { let source = "simple_args(i64 a, i64 b) i64 { a * (a + b) }"; let code_fn = jit_fn::<(i64, i64), i64>(source); assert_eq!(code_fn((2, 3)), 10); } #[test] fn let_statement() { let source = "let_stmt(i64 a, i64 b) i64 { let c = a + b; b * c }"; let code_fn = jit_fn::<(i64, i64), i64>(source); assert_eq!(code_fn((2, 3)), 15); } #[test] fn let_shadowing() { let source = r#" let_shadowing(i64 a, i64 b) i64 { let c = a + b; let c = (c * b); let c = (c + a); let c = (c * a); c }"#; let code_fn = jit_fn::<(i64, i64), i64>(source); assert_eq!(code_fn((2, 3)), 34); } #[test] fn let_mutable() { let source = r#" let_mutable(i64 a, i64 b) i64 { let mut c = a + b; c = (c * b); c = (c + a); c = (c * a); c }"#; let code_fn = jit_fn::<(i64, i64), i64>(source); assert_eq!(code_fn((2, 3)), 34); } }