sprite-rs/src/jit/engine.rs

276 lines
8.1 KiB
Rust

// Copyright (c) 2022 Marceline Cramer
// SPDX-License-Identifier: GPL-3.0-or-later
use crate::parse::ast;
use cranelift::codegen::{self, Context};
use cranelift::frontend::FunctionBuilderContext;
use cranelift::prelude::*;
use cranelift_jit::{JITBuilder, JITModule};
use cranelift_module::{DataContext, Linkage, Module};
pub struct Engine {
builder_context: FunctionBuilderContext,
ctx: Context,
data_ctx: DataContext,
module: JITModule,
}
impl Engine {
pub fn new() -> Self {
let builder = JITBuilder::new(cranelift_module::default_libcall_names());
let module = JITModule::new(builder);
Self {
builder_context: FunctionBuilderContext::new(),
ctx: module.make_context(),
data_ctx: DataContext::new(),
module,
}
}
pub fn compile(&mut self, fn_impl: &ast::FnImpl) -> Result<*const u8, String> {
self.translate(fn_impl)?;
let name = "dummy_function";
let id = self
.module
.declare_function(name, Linkage::Export, &self.ctx.func.signature)
.map_err(|e| e.to_string())?;
let mut trap_sink = codegen::binemit::NullTrapSink {};
let mut stack_map_sink = codegen::binemit::NullStackMapSink {};
self.module
.define_function(id, &mut self.ctx, &mut trap_sink, &mut stack_map_sink)
.map_err(|e| e.to_string())?;
self.module.clear_context(&mut self.ctx);
self.module.finalize_definitions();
let code = self.module.get_finalized_function(id);
Ok(code)
}
pub fn translate(&mut self, fn_impl: &ast::FnImpl) -> Result<(), String> {
let int = self.module.target_config().pointer_type();
let builder = FunctionBuilder::new(&mut self.ctx.func, &mut self.builder_context);
let trans = super::translate::FunctionTranslator::new(int, builder);
trans.translate(fn_impl);
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::parse::lexer::Lexer;
use crate::parse::rd::RecursiveDescent;
use crate::parse::BuildAst;
use std::collections::HashMap;
// TODO this will all be made obsolete with more module abstractions
type CompiledFns = HashMap<String, *const u8>;
fn jit_compile_ast(source: &str) -> CompiledFns {
println!("JIT-compiling source: {}", source);
let ast = RecursiveDescent::build_source_ast(source);
let mut engine = Engine::new();
let mut fns = CompiledFns::new();
for def in ast.defs.iter() {
if let ast::Definition::Function { implementation, .. } = def {
let key = implementation.def.name.to_string();
let val = engine.compile(&implementation).unwrap();
fns.insert(key, val);
}
}
fns
}
fn get_compiled_fn<I, O>(fns: &CompiledFns, name: &str) -> fn(I) -> O {
let code_ptr = fns.get(name).unwrap();
unsafe { std::mem::transmute::<_, fn(I) -> O>(code_ptr) }
}
fn jit_compile_fn(source: &str) -> *const u8 {
println!("JIT-compiling source: {}", source);
let lexer = Lexer::new(source);
let mut rd = RecursiveDescent::new(lexer);
let def = rd.build_fn(false, None);
if let ast::Definition::Function { implementation, .. } = def {
Engine::new().compile(&implementation).unwrap()
} else {
panic!("Failed to parse function from test source");
}
}
fn jit_fn<I, O>(source: &str) -> fn(I) -> O {
let code_ptr = jit_compile_fn(source);
unsafe { std::mem::transmute::<_, fn(I) -> O>(code_ptr) }
}
#[test]
fn i64_math() {
let source = "simple_math() i64 { (1 + 2) * ((4 + 5) / 3) }";
let code_fn = jit_fn::<(), i64>(source);
assert_eq!(code_fn(()), 9);
}
#[test]
fn simple_args() {
let source = "simple_args(i64 a, i64 b) i64 { a * (a + b) }";
let code_fn = jit_fn::<(i64, i64), i64>(source);
assert_eq!(code_fn((2, 3)), 10);
}
#[test]
fn let_statement() {
let source = "let_stmt(i64 a, i64 b) i64 { let c = a + b; b * c }";
let code_fn = jit_fn::<(i64, i64), i64>(source);
assert_eq!(code_fn((2, 3)), 15);
}
#[test]
fn let_shadowing() {
let source = r#"
let_shadowing(i64 a, i64 b) i64 {
let c = a + b;
let c = (c * b);
let c = (c + a);
let c = (c * a);
c
}"#;
let code_fn = jit_fn::<(i64, i64), i64>(source);
assert_eq!(code_fn((2, 3)), 34);
}
#[test]
fn let_mutable() {
let source = r#"
let_mutable(i64 a, i64 b) i64 {
let mut c = a + b;
c = (c * b);
c = (c + a);
c = (c * a);
c
}"#;
let code_fn = jit_fn::<(i64, i64), i64>(source);
assert_eq!(code_fn((2, 3)), 34);
}
#[test]
fn if_expr() {
let source = r#"
min(i64 a, i64 b) i64 {
if a > b {
b
} else {
a
}
}"#;
let code_fn = jit_fn::<(i64, i64), i64>(source);
assert_eq!(code_fn((2, 4)), 2);
assert_eq!(code_fn((5, 3)), 3);
}
#[test]
fn if_scopes() {
// TODO this is overcomplicated for a test but i had fun writing it
let source = r#"
if_scopes(i64 a, i64 b) i64 {
let mut c = (a + b); // Test 1: 6, Test 2: 10
let d = (a * b); // Test 1: 8, Test 2: 21
let e = if a < b {
// Test 1 only
let d = (b - a); // 2
c = (c * d); // 12
d
} else {
// Test 2 only
let d = (a - b); // 4
c = (c * d); // 40
d
};
c + d + e // Test 1: 22, Test 2: 65
}"#;
let code_fn = jit_fn::<(i64, i64), i64>(source);
assert_eq!(code_fn((2, 4)), 22); // Test 1
assert_eq!(code_fn((7, 3)), 65); // Test 2
}
#[test]
fn if_statement() {
let source = r#"
if_statement(i64 a, i64 b) i64 {
let mut c = a + b;
if a < b {
c = (c * a);
} else {
c = (c * b);
}
c
}"#;
let code_fn = jit_fn::<(i64, i64), i64>(source);
assert_eq!(code_fn((2, 4)), 12);
assert_eq!(code_fn((7, 3)), 30);
}
#[test]
fn while_statement() {
let source = r#"
iterative_fibonacci(i64 iterations) i64 {
let mut a = 0;
let mut b = 1;
let mut i = 0;
while i < iterations {
let c = a + b;
a = b;
b = c;
i = (i + 1);
}
b
}"#;
let code_fn = jit_fn::<i64, i64>(source);
assert_eq!(code_fn(10), 89);
}
#[test]
fn mandelbrot_scalar() {
let source = include_str!("../test/mandelbrot_scalar.fae");
let fns = jit_compile_ast(source);
type Inputs = (f32, f32, f32, f32, u64);
let mandelbrot = get_compiled_fn::<Inputs, u64>(&fns, "mandelbrot");
// display size in characters
let w = 40;
let h = 40;
// left, top, right, and bottom input bounds
let l = -1.0;
let t = 1.0;
let r = 1.0;
let b = -1.0;
// c input to mandelbrot
let ca = 0.0;
let cb = 0.0;
// max iterations
let max_iter = 100;
for y in 0..h {
let y = (y as f32) / (t - b) + b;
for x in 0..w {
let x = (x as f32) / (r - l) + l;
let iters = mandelbrot((x, y, ca, cb, max_iter));
print!("{}", iters % 10);
}
println!("");
}
}
}