// Copyright (c) 2022 Marceline Cramer // SPDX-License-Identifier: GPL-3.0-or-later use crate::parse::ast; use cranelift::prelude::*; pub struct FunctionTranslator<'a> { pub int: types::Type, pub builder: FunctionBuilder<'a>, pub locals: Vec, pub variable_index: usize, } pub struct Local { pub name: String, pub var: Variable, pub mutable: bool, } impl<'a> FunctionTranslator<'a> { pub fn new(int: types::Type, builder: FunctionBuilder<'a>) -> Self { Self { int, builder, locals: Vec::new(), variable_index: 0, } } pub fn translate(mut self, fn_impl: &ast::FnImpl) { let signature = &fn_impl.def.signature; let return_info = if let Some(tail_expr) = &fn_impl.body.tail_expr { match signature.return_type { Some(t) => Some((tail_expr, t)), None => panic!("Function has tail expression but no return type"), } } else if let Some(_) = signature.return_type { panic!("Function has return type but no tail expression"); } else { None }; for ast::FnArg { type_name, .. } in signature.args.iter() { if type_name != &"i64" { unimplemented!("Non-i64 function arg types are unimplemented"); } self.builder .func .signature .params .push(AbiParam::new(self.int)); } if let Some((_, _return_type)) = return_info { self.builder .func .signature .returns .push(AbiParam::new(self.int)); } let entry_block = self.builder.create_block(); self.builder .append_block_params_for_function_params(entry_block); self.builder.switch_to_block(entry_block); self.builder.seal_block(entry_block); for (i, arg) in signature.args.iter().enumerate() { if let Some(name) = arg.name { let val = self.builder.block_params(entry_block)[i]; let var = self.add_local(self.int, name, false); self.builder.def_var(var, val); } } for stmt in fn_impl.body.statements.iter() { self.translate_statement(stmt); } if let Some((tail_expr, _)) = return_info { let return_value = self.translate_expr(tail_expr); self.builder.ins().return_(&[return_value]); } println!("{}", self.builder.func.display()); self.builder.finalize(); } pub fn translate_statement(&mut self, stmt: &ast::Statement) { use ast::Statement::*; match stmt { Expr(expr) => { self.translate_expr(expr); } While { test_expr, loop_body, } => self.translate_while_stmt(test_expr, loop_body), Let { var, mutable, expr } => { let val = self.translate_expr(expr); let var = self.add_local(self.int, var, *mutable); self.builder.def_var(var, val); } If(if_stmt) => self.translate_if_stmt(if_stmt), } } pub fn translate_expr(&mut self, expr: &ast::Expr) -> Value { use ast::Expr::*; match expr { If(if_expr) => self.translate_if_expr(if_expr), Local(name) => { let var = self .get_local(name) .expect(&format!("Unrecognized local {}", name)) .var; self.builder.use_var(var) } Literal(ast::Literal::DecimalInteger(literal)) => { // TODO parse integers while building ast so that codegen doesn't have to let val: i64 = literal.parse().unwrap(); self.builder.ins().iconst(self.int, val) } BinaryOp(op, terms) => self.translate_binary_op(op, &terms.0, &terms.1), // TODO the AST doesn't need this either Group(expr) => self.translate_expr(expr), _ => unimplemented!("Expression: {:#?}", expr), } } pub fn translate_if_expr(&mut self, if_expr: &ast::IfExpr) -> Value { let ast::IfExpr { test_expr, then_body, else_body, } = if_expr; let test_val = self.translate_expr(test_expr); let then_block = self.builder.create_block(); let else_block = self.builder.create_block(); let merge_block = self.builder.create_block(); self.builder.append_block_param(merge_block, self.int); self.builder.ins().brz(test_val, else_block, &[]); self.builder.ins().jump(then_block, &[]); let mut translate_branch = |block, body| { self.builder.switch_to_block(block); self.builder.seal_block(block); let val = self.translate_branch_body(body).unwrap(); self.builder.ins().jump(merge_block, &[val]); }; translate_branch(then_block, then_body); translate_branch(else_block, else_body); self.builder.switch_to_block(merge_block); self.builder.seal_block(merge_block); self.builder.block_params(merge_block)[0] } pub fn translate_if_stmt(&mut self, if_stmt: &ast::IfStmt) { let ast::IfStmt { test_expr, then_body, else_body, } = if_stmt; let test_val = self.translate_expr(test_expr); let then_block = self.builder.create_block(); let else_info = else_body.as_ref().map(|b| (b, self.builder.create_block())); let merge_block = self.builder.create_block(); let else_jump = else_info.map(|b| b.1).unwrap_or(merge_block); self.builder.ins().brz(test_val, else_jump, &[]); self.builder.ins().jump(then_block, &[]); let mut translate_branch = |block, body| { self.builder.switch_to_block(block); self.builder.seal_block(block); let val = self.translate_branch_body(body); assert_eq!(val, None, "If statement branch has tail expression"); self.builder.ins().jump(merge_block, &[]); }; translate_branch(then_block, then_body); if let Some((else_body, else_block)) = else_info { translate_branch(else_block, else_body); } self.builder.switch_to_block(merge_block); self.builder.seal_block(merge_block); } pub fn translate_while_stmt(&mut self, test_expr: &ast::Expr, loop_body: &ast::BranchBody) { let header_block = self.builder.create_block(); let body_block = self.builder.create_block(); let exit_block = self.builder.create_block(); self.builder.ins().jump(header_block, &[]); self.builder.switch_to_block(header_block); let test_val = self.translate_expr(test_expr); self.builder.ins().brz(test_val, exit_block, &[]); self.builder.ins().jump(body_block, &[]); self.builder.switch_to_block(body_block); self.builder.seal_block(body_block); let val = self.translate_branch_body(loop_body); assert_eq!(val, None, "While statement body has tail expression"); self.builder.ins().jump(header_block, &[]); self.builder.switch_to_block(exit_block); self.builder.seal_block(header_block); self.builder.seal_block(exit_block); } pub fn translate_branch_body(&mut self, branch_body: &ast::BranchBody) -> Option { let scope_size = self.locals.len(); for stmt in branch_body.statements.iter() { self.translate_statement(stmt); } let tail_val = if let Some(tail_expr) = &branch_body.tail_expr { Some(self.translate_expr(tail_expr)) } else { None }; self.locals.truncate(scope_size); tail_val } pub fn translate_binary_op( &mut self, op: &ast::BinaryOp, lhs: &ast::Expr, rhs: &ast::Expr, ) -> Value { let rhs = self.translate_expr(rhs); if *op == ast::BinaryOp::Assign { if let ast::Expr::Local(var) = lhs { let local = self .get_local(var) .expect(&format!("Unrecognized local {}", var)); if !local.mutable { panic!("Attempted to assign to immutable variable {}", var); } let var = local.var; self.builder.def_var(var, rhs); return rhs; } else { unimplemented!("Assign to non-local lhs"); } } let lhs = self.translate_expr(lhs); // TODO refactor ast to separate binary ops into types if let Some(val) = self.translate_bool_op(op, lhs, rhs) { val } else { use ast::BinaryOp::*; let ins = self.builder.ins(); match op { Add => ins.iadd(lhs, rhs), Sub => ins.isub(lhs, rhs), Mul => ins.imul(lhs, rhs), Div => ins.udiv(lhs, rhs), _ => unimplemented!("Binary operation: {:#?}", op), } } } pub fn translate_bool_op( &mut self, op: &ast::BinaryOp, lhs: Value, rhs: Value, ) -> Option { use ast::BinaryOp::*; if let Some(cmp) = match op { Eq => Some(IntCC::Equal), Neq => Some(IntCC::NotEqual), Less => Some(IntCC::SignedLessThan), LessEq => Some(IntCC::SignedLessThanOrEqual), Greater => Some(IntCC::SignedGreaterThan), GreaterEq => Some(IntCC::SignedGreaterThanOrEqual), _ => None, } { let val = self.builder.ins().icmp(cmp, lhs, rhs); Some(self.builder.ins().bint(self.int, val)) } else { None } } pub fn get_local(&self, name: &str) -> Option<&Local> { self.locals.iter().rev().find(|l| &l.name == name) } pub fn add_local(&mut self, var_type: types::Type, name: &str, mutable: bool) -> Variable { let name = name.to_string(); let var = Variable::new(self.variable_index); self.builder.declare_var(var, var_type); self.locals.push(Local { name, var, mutable }); self.variable_index += 1; var } }