175 lines
5.2 KiB
Rust
175 lines
5.2 KiB
Rust
// Copyright (c) 2022 Marceline Cramer
|
|
// SPDX-License-Identifier: GPL-3.0-or-later
|
|
|
|
use crate::parse::ast;
|
|
use cranelift::codegen::{self, Context};
|
|
use cranelift::frontend::FunctionBuilderContext;
|
|
use cranelift::prelude::*;
|
|
use cranelift_jit::{JITBuilder, JITModule};
|
|
use cranelift_module::{DataContext, Linkage, Module};
|
|
|
|
pub struct Engine {
|
|
builder_context: FunctionBuilderContext,
|
|
ctx: Context,
|
|
data_ctx: DataContext,
|
|
module: JITModule,
|
|
}
|
|
|
|
impl Engine {
|
|
pub fn new() -> Self {
|
|
let builder = JITBuilder::new(cranelift_module::default_libcall_names());
|
|
let module = JITModule::new(builder);
|
|
Self {
|
|
builder_context: FunctionBuilderContext::new(),
|
|
ctx: module.make_context(),
|
|
data_ctx: DataContext::new(),
|
|
module,
|
|
}
|
|
}
|
|
|
|
pub fn compile(&mut self, fn_impl: &ast::FnImpl) -> Result<*const u8, String> {
|
|
self.translate(fn_impl)?;
|
|
|
|
let name = "dummy_function";
|
|
let id = self
|
|
.module
|
|
.declare_function(name, Linkage::Export, &self.ctx.func.signature)
|
|
.map_err(|e| e.to_string())?;
|
|
|
|
let mut trap_sink = codegen::binemit::NullTrapSink {};
|
|
let mut stack_map_sink = codegen::binemit::NullStackMapSink {};
|
|
self.module
|
|
.define_function(id, &mut self.ctx, &mut trap_sink, &mut stack_map_sink)
|
|
.map_err(|e| e.to_string())?;
|
|
|
|
self.module.clear_context(&mut self.ctx);
|
|
self.module.finalize_definitions();
|
|
|
|
let code = self.module.get_finalized_function(id);
|
|
Ok(code)
|
|
}
|
|
|
|
pub fn translate(&mut self, fn_impl: &ast::FnImpl) -> Result<(), String> {
|
|
let int = self.module.target_config().pointer_type();
|
|
let builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
|
|
let trans = super::translate::FunctionTranslator::new(int, builder);
|
|
trans.translate(fn_impl);
|
|
Ok(())
|
|
}
|
|
}
|
|
|
|
#[cfg(test)]
|
|
mod tests {
|
|
use super::*;
|
|
|
|
fn jit_compile(source: &str) -> *const u8 {
|
|
use crate::parse::lexer::Lexer;
|
|
use crate::parse::rd::RecursiveDescent;
|
|
println!("JIT-compiling source: {}", source);
|
|
let lexer = Lexer::new(source);
|
|
let mut rd = RecursiveDescent::new(lexer);
|
|
let def = rd.build_fn(false, None);
|
|
if let ast::Definition::Function { implementation, .. } = def {
|
|
Engine::new().compile(&implementation).unwrap()
|
|
} else {
|
|
panic!("Failed to parse function from test source");
|
|
}
|
|
}
|
|
|
|
fn jit_fn<I, O>(source: &str) -> fn(I) -> O {
|
|
let code_ptr = jit_compile(source);
|
|
unsafe { std::mem::transmute::<_, fn(I) -> O>(code_ptr) }
|
|
}
|
|
|
|
#[test]
|
|
fn simple_math() {
|
|
let source = "simple_math() i64 { (1 + 2) * ((4 + 5) / 3) }";
|
|
let code_fn = jit_fn::<(), i64>(source);
|
|
assert_eq!(code_fn(()), 9);
|
|
}
|
|
|
|
#[test]
|
|
fn simple_args() {
|
|
let source = "simple_args(i64 a, i64 b) i64 { a * (a + b) }";
|
|
let code_fn = jit_fn::<(i64, i64), i64>(source);
|
|
assert_eq!(code_fn((2, 3)), 10);
|
|
}
|
|
|
|
#[test]
|
|
fn let_statement() {
|
|
let source = "let_stmt(i64 a, i64 b) i64 { let c = a + b; b * c }";
|
|
let code_fn = jit_fn::<(i64, i64), i64>(source);
|
|
assert_eq!(code_fn((2, 3)), 15);
|
|
}
|
|
|
|
#[test]
|
|
fn let_shadowing() {
|
|
let source = r#"
|
|
let_shadowing(i64 a, i64 b) i64 {
|
|
let c = a + b;
|
|
let c = (c * b);
|
|
let c = (c + a);
|
|
let c = (c * a);
|
|
c
|
|
}"#;
|
|
let code_fn = jit_fn::<(i64, i64), i64>(source);
|
|
assert_eq!(code_fn((2, 3)), 34);
|
|
}
|
|
|
|
#[test]
|
|
fn let_mutable() {
|
|
let source = r#"
|
|
let_mutable(i64 a, i64 b) i64 {
|
|
let mut c = a + b;
|
|
c = (c * b);
|
|
c = (c + a);
|
|
c = (c * a);
|
|
c
|
|
}"#;
|
|
let code_fn = jit_fn::<(i64, i64), i64>(source);
|
|
assert_eq!(code_fn((2, 3)), 34);
|
|
}
|
|
|
|
#[test]
|
|
fn if_expr() {
|
|
let source = r#"
|
|
min(i64 a, i64 b) i64 {
|
|
if a > b {
|
|
b
|
|
} else {
|
|
a
|
|
}
|
|
}"#;
|
|
let code_fn = jit_fn::<(i64, i64), i64>(source);
|
|
assert_eq!(code_fn((2, 4)), 2);
|
|
assert_eq!(code_fn((5, 3)), 3);
|
|
}
|
|
|
|
#[test]
|
|
fn if_scopes() {
|
|
// TODO this is overcomplicated for a test but i had fun writing it
|
|
let source = r#"
|
|
if_scopes(i64 a, i64 b) i64 {
|
|
let mut c = (a + b); // Test 1: 6, Test 2: 10
|
|
let d = (a * b); // Test 1: 8, Test 2: 21
|
|
|
|
let e = if a < b {
|
|
// Test 1 only
|
|
let d = (b - a); // 2
|
|
c = (c * d); // 12
|
|
d
|
|
} else {
|
|
// Test 2 only
|
|
let d = (a - b); // 4
|
|
c = (c * d); // 40
|
|
d
|
|
};
|
|
|
|
c + d + e // Test 1: 22, Test 2: 65
|
|
}"#;
|
|
let code_fn = jit_fn::<(i64, i64), i64>(source);
|
|
assert_eq!(code_fn((2, 4)), 22); // Test 1
|
|
assert_eq!(code_fn((7, 3)), 65); // Test 2
|
|
}
|
|
}
|