sprite-rs/src/jit/translate.rs

297 lines
9.3 KiB
Rust

// Copyright (c) 2022 Marceline Cramer
// SPDX-License-Identifier: GPL-3.0-or-later
use crate::parse::ast;
use cranelift::prelude::*;
pub struct FunctionTranslator<'a> {
pub int: types::Type,
pub builder: FunctionBuilder<'a>,
pub locals: Vec<Local>,
pub variable_index: usize,
}
pub struct Local {
pub name: String,
pub var: Variable,
pub mutable: bool,
}
impl<'a> FunctionTranslator<'a> {
pub fn new(int: types::Type, builder: FunctionBuilder<'a>) -> Self {
Self {
int,
builder,
locals: Vec::new(),
variable_index: 0,
}
}
pub fn translate(mut self, fn_impl: &ast::FnImpl) {
let signature = &fn_impl.def.signature;
let return_info = if let Some(tail_expr) = &fn_impl.body.tail_expr {
match signature.return_type {
Some(t) => Some((tail_expr, t)),
None => panic!("Function has tail expression but no return type"),
}
} else if let Some(_) = signature.return_type {
panic!("Function has return type but no tail expression");
} else {
None
};
for ast::FnArg { type_name, .. } in signature.args.iter() {
if type_name != &"i64" {
unimplemented!("Non-i64 function arg types are unimplemented");
}
self.builder
.func
.signature
.params
.push(AbiParam::new(self.int));
}
if let Some((_, _return_type)) = return_info {
self.builder
.func
.signature
.returns
.push(AbiParam::new(self.int));
}
let entry_block = self.builder.create_block();
self.builder
.append_block_params_for_function_params(entry_block);
self.builder.switch_to_block(entry_block);
self.builder.seal_block(entry_block);
for (i, arg) in signature.args.iter().enumerate() {
if let Some(name) = arg.name {
let val = self.builder.block_params(entry_block)[i];
let var = self.add_local(self.int, name, false);
self.builder.def_var(var, val);
}
}
for stmt in fn_impl.body.statements.iter() {
self.translate_statement(stmt);
}
if let Some((tail_expr, _)) = return_info {
let return_value = self.translate_expr(tail_expr);
self.builder.ins().return_(&[return_value]);
}
println!("{}", self.builder.func.display());
self.builder.finalize();
}
pub fn translate_statement(&mut self, stmt: &ast::Statement) {
use ast::Statement::*;
match stmt {
Expr(expr) => {
self.translate_expr(expr);
}
Let { var, mutable, expr } => {
let val = self.translate_expr(expr);
let var = self.add_local(self.int, var, *mutable);
self.builder.def_var(var, val);
}
If(if_stmt) => self.translate_if_stmt(if_stmt),
_ => unimplemented!(),
}
}
pub fn translate_expr(&mut self, expr: &ast::Expr) -> Value {
use ast::Expr::*;
match expr {
If(if_expr) => self.translate_if_expr(if_expr),
Local(name) => {
let var = self
.get_local(name)
.expect(&format!("Unrecognized local {}", name))
.var;
self.builder.use_var(var)
}
Literal(ast::Literal::DecimalInteger(literal)) => {
// TODO parse integers while building ast so that codegen doesn't have to
let val: i64 = literal.parse().unwrap();
self.builder.ins().iconst(self.int, val)
}
BinaryOp(op, terms) => self.translate_binary_op(op, &terms.0, &terms.1),
// TODO the AST doesn't need this either
Group(expr) => self.translate_expr(expr),
_ => unimplemented!("Expression: {:#?}", expr),
}
}
pub fn translate_if_expr(&mut self, if_expr: &ast::IfExpr) -> Value {
let ast::IfExpr {
test_expr,
then_body,
else_body,
} = if_expr;
let test_val = self.translate_expr(test_expr);
let then_block = self.builder.create_block();
let else_block = self.builder.create_block();
let merge_block = self.builder.create_block();
self.builder.append_block_param(merge_block, self.int);
self.builder.ins().brz(test_val, else_block, &[]);
self.builder.ins().jump(then_block, &[]);
let mut translate_branch = |block, body| {
self.builder.switch_to_block(block);
self.builder.seal_block(block);
let val = self.translate_branch_body(body).unwrap();
self.builder.ins().jump(merge_block, &[val]);
};
translate_branch(then_block, then_body);
translate_branch(else_block, else_body);
self.builder.switch_to_block(merge_block);
self.builder.seal_block(merge_block);
self.builder.block_params(merge_block)[0]
}
pub fn translate_if_stmt(&mut self, if_stmt: &ast::IfStmt) {
let ast::IfStmt {
test_expr,
then_body,
else_body,
} = if_stmt;
let test_val = self.translate_expr(test_expr);
let then_block = self.builder.create_block();
let else_info = else_body.as_ref().map(|b| (b, self.builder.create_block()));
let merge_block = self.builder.create_block();
let else_jump = else_info.map(|b| b.1).unwrap_or(merge_block);
self.builder.ins().brz(test_val, else_jump, &[]);
self.builder.ins().jump(then_block, &[]);
let mut translate_branch = |block, body| {
self.builder.switch_to_block(block);
self.builder.seal_block(block);
let val = self.translate_branch_body(body);
assert_eq!(val, None, "If statement branch has tail expression");
self.builder.ins().jump(merge_block, &[]);
};
translate_branch(then_block, then_body);
if let Some((else_body, else_block)) = else_info {
translate_branch(else_block, else_body);
}
self.builder.switch_to_block(merge_block);
self.builder.seal_block(merge_block);
}
pub fn translate_branch_body(&mut self, branch_body: &ast::BranchBody) -> Option<Value> {
let scope_size = self.locals.len();
for stmt in branch_body.statements.iter() {
self.translate_statement(stmt);
}
let tail_val = if let Some(tail_expr) = &branch_body.tail_expr {
Some(self.translate_expr(tail_expr))
} else {
None
};
self.locals.truncate(scope_size);
tail_val
}
pub fn translate_binary_op(
&mut self,
op: &ast::BinaryOp,
lhs: &ast::Expr,
rhs: &ast::Expr,
) -> Value {
let rhs = self.translate_expr(rhs);
if *op == ast::BinaryOp::Assign {
if let ast::Expr::Local(var) = lhs {
let local = self
.get_local(var)
.expect(&format!("Unrecognized local {}", var));
if !local.mutable {
panic!("Attempted to assign to immutable variable {}", var);
}
let var = local.var;
self.builder.def_var(var, rhs);
return rhs;
} else {
unimplemented!("Assign to non-local lhs");
}
}
let lhs = self.translate_expr(lhs);
// TODO refactor ast to separate binary ops into types
if let Some(val) = self.translate_bool_op(op, lhs, rhs) {
val
} else {
use ast::BinaryOp::*;
let ins = self.builder.ins();
match op {
Add => ins.iadd(lhs, rhs),
Sub => ins.isub(lhs, rhs),
Mul => ins.imul(lhs, rhs),
Div => ins.udiv(lhs, rhs),
_ => unimplemented!("Binary operation: {:#?}", op),
}
}
}
pub fn translate_bool_op(
&mut self,
op: &ast::BinaryOp,
lhs: Value,
rhs: Value,
) -> Option<Value> {
use ast::BinaryOp::*;
if let Some(cmp) = match op {
Eq => Some(IntCC::Equal),
Neq => Some(IntCC::NotEqual),
Less => Some(IntCC::SignedLessThan),
LessEq => Some(IntCC::SignedLessThanOrEqual),
Greater => Some(IntCC::SignedGreaterThan),
GreaterEq => Some(IntCC::SignedGreaterThanOrEqual),
_ => None,
} {
let val = self.builder.ins().icmp(cmp, lhs, rhs);
Some(self.builder.ins().bint(self.int, val))
} else {
None
}
}
pub fn get_local(&self, name: &str) -> Option<&Local> {
self.locals.iter().rev().find(|l| &l.name == name)
}
pub fn add_local(&mut self, var_type: types::Type, name: &str, mutable: bool) -> Variable {
let name = name.to_string();
let var = Variable::new(self.variable_index);
self.builder.declare_var(var, var_type);
self.locals.push(Local { name, var, mutable });
self.variable_index += 1;
var
}
}