diff --git a/src/jit/engine.rs b/src/jit/engine.rs index 6af7ebf..bbd8441 100644 --- a/src/jit/engine.rs +++ b/src/jit/engine.rs @@ -61,10 +61,37 @@ impl Engine { #[cfg(test)] mod tests { use super::*; + use crate::parse::lexer::Lexer; + use crate::parse::rd::RecursiveDescent; + use crate::parse::BuildAst; + use std::collections::HashMap; - fn jit_compile(source: &str) -> *const u8 { - use crate::parse::lexer::Lexer; - use crate::parse::rd::RecursiveDescent; + // TODO this will all be made obsolete with more module abstractions + type CompiledFns = HashMap; + + fn jit_compile_ast(source: &str) -> CompiledFns { + println!("JIT-compiling source: {}", source); + let ast = RecursiveDescent::build_source_ast(source); + let mut engine = Engine::new(); + let mut fns = CompiledFns::new(); + + for def in ast.defs.iter() { + if let ast::Definition::Function { implementation, .. } = def { + let key = implementation.def.name.to_string(); + let val = engine.compile(&implementation).unwrap(); + fns.insert(key, val); + } + } + + fns + } + + fn get_compiled_fn(fns: &CompiledFns, name: &str) -> fn(I) -> O { + let code_ptr = fns.get(name).unwrap(); + unsafe { std::mem::transmute::<_, fn(I) -> O>(code_ptr) } + } + + fn jit_compile_fn(source: &str) -> *const u8 { println!("JIT-compiling source: {}", source); let lexer = Lexer::new(source); let mut rd = RecursiveDescent::new(lexer); @@ -77,7 +104,7 @@ mod tests { } fn jit_fn(source: &str) -> fn(I) -> O { - let code_ptr = jit_compile(source); + let code_ptr = jit_compile_fn(source); unsafe { std::mem::transmute::<_, fn(I) -> O>(code_ptr) } } @@ -210,4 +237,39 @@ mod tests { let code_fn = jit_fn::(source); assert_eq!(code_fn(10), 89); } + + #[test] + fn mandelbrot_scalar() { + let source = include_str!("../test/mandelbrot_scalar.fae"); + let fns = jit_compile_ast(source); + type Inputs = (f32, f32, f32, f32, u64); + let mandelbrot = get_compiled_fn::(&fns, "mandelbrot"); + + // display size in characters + let w = 40; + let h = 40; + + // left, top, right, and bottom input bounds + let l = -1.0; + let t = 1.0; + let r = 1.0; + let b = -1.0; + + // c input to mandelbrot + let ca = 0.0; + let cb = 0.0; + + // max iterations + let max_iter = 100; + + for y in 0..h { + let y = (y as f32) / (t - b) + b; + for x in 0..w { + let x = (x as f32) / (r - l) + l; + let iters = mandelbrot((x, y, ca, cb, max_iter)); + print!("{}", iters % 10); + } + println!(""); + } + } } diff --git a/src/parse/mod.rs b/src/parse/mod.rs index 348f600..641d824 100644 --- a/src/parse/mod.rs +++ b/src/parse/mod.rs @@ -37,6 +37,7 @@ mod tests { parse_test!(interface, $parser, "../test/interface.fae"); parse_test!(example, $parser, "../test/example.fae"); parse_test!(clock, $parser, "../test/clock.fae"); + parse_test!(mandelbrot_scalar, $parser, "../test/mandelbrot_scalar.fae"); } }; } diff --git a/src/test/mandelbrot_scalar.fae b/src/test/mandelbrot_scalar.fae new file mode 100644 index 0000000..77b635b --- /dev/null +++ b/src/test/mandelbrot_scalar.fae @@ -0,0 +1,18 @@ +// TODO remove parentheses once order of operations is implemented +// implements the calculation of the Mandelbrot set using scalars + +fn mandelbrot(f32 za, f32 zb, f32 ca, f32 cb, u64 max_iters) u64 { + // TODO figure out mutable function arguments + let mut za = za; + let mut zb = zb; + + let mut i = 0; + while (i < max_iters) and (((za * za) + (zb * zb)) < 4.0) { + let new_za = (za * za) - (zb * zb) + ca; + zb = ((2.0 * za * zb) + cb); + za = new_za; + i = (i + 1); + } + + i +}