297 lines
9.3 KiB
Rust
297 lines
9.3 KiB
Rust
// Copyright (c) 2022 Marceline Cramer
|
|
// SPDX-License-Identifier: GPL-3.0-or-later
|
|
|
|
use crate::parse::ast;
|
|
use cranelift::prelude::*;
|
|
|
|
pub struct FunctionTranslator<'a> {
|
|
pub int: types::Type,
|
|
pub builder: FunctionBuilder<'a>,
|
|
pub locals: Vec<Local>,
|
|
pub variable_index: usize,
|
|
}
|
|
|
|
pub struct Local {
|
|
pub name: String,
|
|
pub var: Variable,
|
|
pub mutable: bool,
|
|
}
|
|
|
|
impl<'a> FunctionTranslator<'a> {
|
|
pub fn new(int: types::Type, builder: FunctionBuilder<'a>) -> Self {
|
|
Self {
|
|
int,
|
|
builder,
|
|
locals: Vec::new(),
|
|
variable_index: 0,
|
|
}
|
|
}
|
|
|
|
pub fn translate(mut self, fn_impl: &ast::FnImpl) {
|
|
let signature = &fn_impl.def.signature;
|
|
|
|
let return_info = if let Some(tail_expr) = &fn_impl.body.tail_expr {
|
|
match signature.return_type {
|
|
Some(t) => Some((tail_expr, t)),
|
|
None => panic!("Function has tail expression but no return type"),
|
|
}
|
|
} else if let Some(_) = signature.return_type {
|
|
panic!("Function has return type but no tail expression");
|
|
} else {
|
|
None
|
|
};
|
|
|
|
for ast::FnArg { type_name, .. } in signature.args.iter() {
|
|
if type_name != &"i64" {
|
|
unimplemented!("Non-i64 function arg types are unimplemented");
|
|
}
|
|
|
|
self.builder
|
|
.func
|
|
.signature
|
|
.params
|
|
.push(AbiParam::new(self.int));
|
|
}
|
|
|
|
if let Some((_, _return_type)) = return_info {
|
|
self.builder
|
|
.func
|
|
.signature
|
|
.returns
|
|
.push(AbiParam::new(self.int));
|
|
}
|
|
|
|
let entry_block = self.builder.create_block();
|
|
self.builder
|
|
.append_block_params_for_function_params(entry_block);
|
|
self.builder.switch_to_block(entry_block);
|
|
self.builder.seal_block(entry_block);
|
|
|
|
for (i, arg) in signature.args.iter().enumerate() {
|
|
if let Some(name) = arg.name {
|
|
let val = self.builder.block_params(entry_block)[i];
|
|
let var = self.add_local(self.int, name, false);
|
|
self.builder.def_var(var, val);
|
|
}
|
|
}
|
|
|
|
for stmt in fn_impl.body.statements.iter() {
|
|
self.translate_statement(stmt);
|
|
}
|
|
|
|
if let Some((tail_expr, _)) = return_info {
|
|
let return_value = self.translate_expr(tail_expr);
|
|
self.builder.ins().return_(&[return_value]);
|
|
}
|
|
|
|
println!("{}", self.builder.func.display());
|
|
self.builder.finalize();
|
|
}
|
|
|
|
pub fn translate_statement(&mut self, stmt: &ast::Statement) {
|
|
use ast::Statement::*;
|
|
match stmt {
|
|
Expr(expr) => {
|
|
self.translate_expr(expr);
|
|
}
|
|
Let { var, mutable, expr } => {
|
|
let val = self.translate_expr(expr);
|
|
let var = self.add_local(self.int, var, *mutable);
|
|
self.builder.def_var(var, val);
|
|
}
|
|
If(if_stmt) => self.translate_if_stmt(if_stmt),
|
|
_ => unimplemented!(),
|
|
}
|
|
}
|
|
|
|
pub fn translate_expr(&mut self, expr: &ast::Expr) -> Value {
|
|
use ast::Expr::*;
|
|
match expr {
|
|
If(if_expr) => self.translate_if_expr(if_expr),
|
|
Local(name) => {
|
|
let var = self
|
|
.get_local(name)
|
|
.expect(&format!("Unrecognized local {}", name))
|
|
.var;
|
|
self.builder.use_var(var)
|
|
}
|
|
Literal(ast::Literal::DecimalInteger(literal)) => {
|
|
// TODO parse integers while building ast so that codegen doesn't have to
|
|
let val: i64 = literal.parse().unwrap();
|
|
self.builder.ins().iconst(self.int, val)
|
|
}
|
|
BinaryOp(op, terms) => self.translate_binary_op(op, &terms.0, &terms.1),
|
|
// TODO the AST doesn't need this either
|
|
Group(expr) => self.translate_expr(expr),
|
|
_ => unimplemented!("Expression: {:#?}", expr),
|
|
}
|
|
}
|
|
|
|
pub fn translate_if_expr(&mut self, if_expr: &ast::IfExpr) -> Value {
|
|
let ast::IfExpr {
|
|
test_expr,
|
|
then_body,
|
|
else_body,
|
|
} = if_expr;
|
|
|
|
let test_val = self.translate_expr(test_expr);
|
|
|
|
let then_block = self.builder.create_block();
|
|
let else_block = self.builder.create_block();
|
|
let merge_block = self.builder.create_block();
|
|
|
|
self.builder.append_block_param(merge_block, self.int);
|
|
|
|
self.builder.ins().brz(test_val, else_block, &[]);
|
|
self.builder.ins().jump(then_block, &[]);
|
|
|
|
let mut translate_branch = |block, body| {
|
|
self.builder.switch_to_block(block);
|
|
self.builder.seal_block(block);
|
|
let val = self.translate_branch_body(body).unwrap();
|
|
self.builder.ins().jump(merge_block, &[val]);
|
|
};
|
|
|
|
translate_branch(then_block, then_body);
|
|
translate_branch(else_block, else_body);
|
|
|
|
self.builder.switch_to_block(merge_block);
|
|
self.builder.seal_block(merge_block);
|
|
|
|
self.builder.block_params(merge_block)[0]
|
|
}
|
|
|
|
pub fn translate_if_stmt(&mut self, if_stmt: &ast::IfStmt) {
|
|
let ast::IfStmt {
|
|
test_expr,
|
|
then_body,
|
|
else_body,
|
|
} = if_stmt;
|
|
|
|
let test_val = self.translate_expr(test_expr);
|
|
|
|
let then_block = self.builder.create_block();
|
|
let else_info = else_body.as_ref().map(|b| (b, self.builder.create_block()));
|
|
let merge_block = self.builder.create_block();
|
|
|
|
let else_jump = else_info.map(|b| b.1).unwrap_or(merge_block);
|
|
self.builder.ins().brz(test_val, else_jump, &[]);
|
|
self.builder.ins().jump(then_block, &[]);
|
|
|
|
let mut translate_branch = |block, body| {
|
|
self.builder.switch_to_block(block);
|
|
self.builder.seal_block(block);
|
|
let val = self.translate_branch_body(body);
|
|
assert_eq!(val, None, "If statement branch has tail expression");
|
|
self.builder.ins().jump(merge_block, &[]);
|
|
};
|
|
|
|
translate_branch(then_block, then_body);
|
|
|
|
if let Some((else_body, else_block)) = else_info {
|
|
translate_branch(else_block, else_body);
|
|
}
|
|
|
|
self.builder.switch_to_block(merge_block);
|
|
self.builder.seal_block(merge_block);
|
|
}
|
|
|
|
pub fn translate_branch_body(&mut self, branch_body: &ast::BranchBody) -> Option<Value> {
|
|
let scope_size = self.locals.len();
|
|
|
|
for stmt in branch_body.statements.iter() {
|
|
self.translate_statement(stmt);
|
|
}
|
|
|
|
let tail_val = if let Some(tail_expr) = &branch_body.tail_expr {
|
|
Some(self.translate_expr(tail_expr))
|
|
} else {
|
|
None
|
|
};
|
|
|
|
self.locals.truncate(scope_size);
|
|
|
|
tail_val
|
|
}
|
|
|
|
pub fn translate_binary_op(
|
|
&mut self,
|
|
op: &ast::BinaryOp,
|
|
lhs: &ast::Expr,
|
|
rhs: &ast::Expr,
|
|
) -> Value {
|
|
let rhs = self.translate_expr(rhs);
|
|
|
|
if *op == ast::BinaryOp::Assign {
|
|
if let ast::Expr::Local(var) = lhs {
|
|
let local = self
|
|
.get_local(var)
|
|
.expect(&format!("Unrecognized local {}", var));
|
|
|
|
if !local.mutable {
|
|
panic!("Attempted to assign to immutable variable {}", var);
|
|
}
|
|
|
|
let var = local.var;
|
|
self.builder.def_var(var, rhs);
|
|
return rhs;
|
|
} else {
|
|
unimplemented!("Assign to non-local lhs");
|
|
}
|
|
}
|
|
|
|
let lhs = self.translate_expr(lhs);
|
|
|
|
// TODO refactor ast to separate binary ops into types
|
|
if let Some(val) = self.translate_bool_op(op, lhs, rhs) {
|
|
val
|
|
} else {
|
|
use ast::BinaryOp::*;
|
|
let ins = self.builder.ins();
|
|
match op {
|
|
Add => ins.iadd(lhs, rhs),
|
|
Sub => ins.isub(lhs, rhs),
|
|
Mul => ins.imul(lhs, rhs),
|
|
Div => ins.udiv(lhs, rhs),
|
|
_ => unimplemented!("Binary operation: {:#?}", op),
|
|
}
|
|
}
|
|
}
|
|
|
|
pub fn translate_bool_op(
|
|
&mut self,
|
|
op: &ast::BinaryOp,
|
|
lhs: Value,
|
|
rhs: Value,
|
|
) -> Option<Value> {
|
|
use ast::BinaryOp::*;
|
|
if let Some(cmp) = match op {
|
|
Eq => Some(IntCC::Equal),
|
|
Neq => Some(IntCC::NotEqual),
|
|
Less => Some(IntCC::SignedLessThan),
|
|
LessEq => Some(IntCC::SignedLessThanOrEqual),
|
|
Greater => Some(IntCC::SignedGreaterThan),
|
|
GreaterEq => Some(IntCC::SignedGreaterThanOrEqual),
|
|
_ => None,
|
|
} {
|
|
let val = self.builder.ins().icmp(cmp, lhs, rhs);
|
|
Some(self.builder.ins().bint(self.int, val))
|
|
} else {
|
|
None
|
|
}
|
|
}
|
|
|
|
pub fn get_local(&self, name: &str) -> Option<&Local> {
|
|
self.locals.iter().rev().find(|l| &l.name == name)
|
|
}
|
|
|
|
pub fn add_local(&mut self, var_type: types::Type, name: &str, mutable: bool) -> Variable {
|
|
let name = name.to_string();
|
|
let var = Variable::new(self.variable_index);
|
|
self.builder.declare_var(var, var_type);
|
|
self.locals.push(Local { name, var, mutable });
|
|
self.variable_index += 1;
|
|
var
|
|
}
|
|
}
|