forked from NUDT-compiler/nudt-compiler-cpp
You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
304 lines
10 KiB
304 lines
10 KiB
#include "include/irgen/IRGen.h"
|
|
|
|
#include <stdexcept>
|
|
|
|
#include "SysYParser.h"
|
|
#include "include/ir/IR.h"
|
|
#include "include/utils/Log.h"
|
|
|
|
std::string IRGenImpl::NextBlockName(const std::string& prefix) {
|
|
return prefix + std::to_string(++block_index_);
|
|
}
|
|
|
|
void IRGenImpl::EmitCondBranch(SysYParser::CondContext& cond,
|
|
ir::BasicBlock* true_bb,
|
|
ir::BasicBlock* false_bb) {
|
|
if (!cond.lOrExp()) {
|
|
throw std::runtime_error(FormatError("irgen", "非法条件表达式"));
|
|
}
|
|
EmitLOrBranch(*cond.lOrExp(), true_bb, false_bb);
|
|
}
|
|
|
|
void IRGenImpl::EmitLOrBranch(SysYParser::LOrExpContext& expr,
|
|
ir::BasicBlock* true_bb,
|
|
ir::BasicBlock* false_bb) {
|
|
if (!expr.lOrExp()) {
|
|
if (!expr.lAndExp()) {
|
|
throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
|
|
}
|
|
EmitLAndBranch(*expr.lAndExp(), true_bb, false_bb);
|
|
return;
|
|
}
|
|
if (!expr.lAndExp()) {
|
|
throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
|
|
}
|
|
|
|
auto* rhs_bb = func_->CreateBlock(NextBlockName("lor.rhs."));
|
|
EmitLOrBranch(*expr.lOrExp(), true_bb, rhs_bb);
|
|
builder_.SetInsertPoint(rhs_bb);
|
|
EmitLAndBranch(*expr.lAndExp(), true_bb, false_bb);
|
|
}
|
|
|
|
void IRGenImpl::EmitLAndBranch(SysYParser::LAndExpContext& expr,
|
|
ir::BasicBlock* true_bb,
|
|
ir::BasicBlock* false_bb) {
|
|
if (!expr.lAndExp()) {
|
|
if (!expr.eqExp()) {
|
|
throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
|
|
}
|
|
EmitEqBranch(*expr.eqExp(), true_bb, false_bb);
|
|
return;
|
|
}
|
|
if (!expr.eqExp()) {
|
|
throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
|
|
}
|
|
|
|
auto* rhs_bb = func_->CreateBlock(NextBlockName("land.rhs."));
|
|
EmitLAndBranch(*expr.lAndExp(), rhs_bb, false_bb);
|
|
builder_.SetInsertPoint(rhs_bb);
|
|
EmitEqBranch(*expr.eqExp(), true_bb, false_bb);
|
|
}
|
|
|
|
ir::Value* IRGenImpl::EvalRelValue(SysYParser::RelExpContext& expr) {
|
|
if (!expr.relExp()) {
|
|
if (!expr.addExp()) {
|
|
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
|
|
}
|
|
return std::any_cast<ir::Value*>(expr.addExp()->accept(this));
|
|
}
|
|
if (!expr.addExp()) {
|
|
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
|
|
}
|
|
auto* lhs = EvalRelValue(*expr.relExp());
|
|
auto* rhs = std::any_cast<ir::Value*>(expr.addExp()->accept(this));
|
|
|
|
ir::Opcode cmp = ir::Opcode::Lt;
|
|
if (expr.LtOp()) {
|
|
cmp = ir::Opcode::Lt;
|
|
} else if (expr.GtOp()) {
|
|
cmp = ir::Opcode::Gt;
|
|
} else if (expr.LeOp()) {
|
|
cmp = ir::Opcode::Le;
|
|
} else if (expr.GeOp()) {
|
|
cmp = ir::Opcode::Ge;
|
|
} else {
|
|
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
|
|
}
|
|
return EvalBinaryOrFold(cmp, lhs, rhs);
|
|
}
|
|
|
|
ir::Value* IRGenImpl::EvalEqValue(SysYParser::EqExpContext& expr) {
|
|
if (!expr.eqExp()) {
|
|
if (!expr.relExp()) {
|
|
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
|
|
}
|
|
return EvalRelValue(*expr.relExp());
|
|
}
|
|
if (!expr.relExp()) {
|
|
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
|
|
}
|
|
auto* lhs = EvalEqValue(*expr.eqExp());
|
|
auto* rhs = EvalRelValue(*expr.relExp());
|
|
|
|
ir::Opcode cmp = ir::Opcode::Eq;
|
|
if (expr.EqOp()) {
|
|
cmp = ir::Opcode::Eq;
|
|
} else if (expr.NeOp()) {
|
|
cmp = ir::Opcode::Ne;
|
|
} else {
|
|
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
|
|
}
|
|
return EvalBinaryOrFold(cmp, lhs, rhs);
|
|
}
|
|
|
|
void IRGenImpl::EmitEqBranch(SysYParser::EqExpContext& expr,
|
|
ir::BasicBlock* true_bb,
|
|
ir::BasicBlock* false_bb) {
|
|
auto* cond_value = EvalEqValue(expr);
|
|
ir::Value* zero = cond_value->GetType()->IsFloat32()
|
|
? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0))
|
|
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
|
|
auto* cond = builder_.CreateBinary(ir::Opcode::Ne, cond_value, zero, module_.GetContext().NextTemp());
|
|
builder_.CreateCondBr(cond, true_bb, false_bb);
|
|
}
|
|
|
|
void IRGenImpl::EmitRelBranch(SysYParser::RelExpContext& expr,
|
|
ir::BasicBlock* true_bb,
|
|
ir::BasicBlock* false_bb) {
|
|
if (!expr.relExp()) {
|
|
if (!expr.addExp()) {
|
|
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
|
|
}
|
|
auto* value = std::any_cast<ir::Value*>(expr.addExp()->accept(this));
|
|
auto* zero = builder_.CreateConstInt(0);
|
|
auto* cond = builder_.CreateBinary(ir::Opcode::Ne, value, zero, module_.GetContext().NextTemp());
|
|
builder_.CreateCondBr(cond, true_bb, false_bb);
|
|
return;
|
|
}
|
|
|
|
if (!expr.addExp() || !expr.relExp()->addExp() || expr.relExp()->relExp()) {
|
|
throw std::runtime_error(
|
|
FormatError("irgen", "当前不支持链式关系比较表达式"));
|
|
}
|
|
auto* lhs = std::any_cast<ir::Value*>(expr.relExp()->addExp()->accept(this));
|
|
auto* rhs = std::any_cast<ir::Value*>(expr.addExp()->accept(this));
|
|
|
|
ir::Opcode cmp = ir::Opcode::Lt;
|
|
if (expr.LtOp()) {
|
|
cmp = ir::Opcode::Lt;
|
|
} else if (expr.GtOp()) {
|
|
cmp = ir::Opcode::Gt;
|
|
} else if (expr.LeOp()) {
|
|
cmp = ir::Opcode::Le;
|
|
} else if (expr.GeOp()) {
|
|
cmp = ir::Opcode::Ge;
|
|
} else {
|
|
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
|
|
}
|
|
|
|
auto* cond = builder_.CreateBinary(cmp, lhs, rhs, module_.GetContext().NextTemp());
|
|
builder_.CreateCondBr(cond, true_bb, false_bb);
|
|
}
|
|
|
|
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
|
|
if (!ctx) {
|
|
throw std::runtime_error(FormatError("irgen", "缺少语句"));
|
|
}
|
|
|
|
if (ctx->block()) {
|
|
return ctx->block()->accept(this);
|
|
}
|
|
|
|
if (ctx->lVal() && ctx->Assign()) {
|
|
if (!ctx->exp()) {
|
|
throw std::runtime_error(FormatError("irgen", "赋值语句缺少右值表达式"));
|
|
}
|
|
bool is_array = false;
|
|
const auto extents = GetArrayExtentsForLVal(*ctx->lVal(), is_array);
|
|
if (is_array && ctx->lVal()->exp().size() < extents.size()) {
|
|
throw std::runtime_error(FormatError("irgen", "不能给数组对象赋值"));
|
|
}
|
|
auto* slot = GetLValAddress(*ctx->lVal());
|
|
ir::Value* rhs = EvalExpr(*ctx->exp());
|
|
auto target_ty = slot->GetType()->IsPtrFloat32() ? ir::Type::GetFloat32Type()
|
|
: ir::Type::GetInt32Type();
|
|
rhs = CastValueTo(rhs, target_ty);
|
|
builder_.CreateStore(rhs, slot);
|
|
return BlockFlow::Continue;
|
|
}
|
|
|
|
if (ctx->If()) {
|
|
if (!ctx->cond() || ctx->stmt().empty()) {
|
|
throw std::runtime_error(FormatError("irgen", "if 语句缺少必要组成部分"));
|
|
}
|
|
auto* then_bb = func_->CreateBlock(NextBlockName("if.then."));
|
|
ir::BasicBlock* else_bb = nullptr;
|
|
const bool has_else = ctx->Else() != nullptr;
|
|
if (has_else) {
|
|
else_bb = func_->CreateBlock(NextBlockName("if.else."));
|
|
}
|
|
auto* merge_bb = func_->CreateBlock(NextBlockName("if.end."));
|
|
|
|
EmitCondBranch(*ctx->cond(), then_bb, has_else ? else_bb : merge_bb);
|
|
|
|
builder_.SetInsertPoint(then_bb);
|
|
auto then_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
|
|
if (then_flow == BlockFlow::Continue && !builder_.GetInsertBlock()->HasTerminator()) {
|
|
builder_.CreateBr(merge_bb);
|
|
}
|
|
|
|
BlockFlow else_flow = BlockFlow::Continue;
|
|
if (has_else) {
|
|
builder_.SetInsertPoint(else_bb);
|
|
else_flow = std::any_cast<BlockFlow>(ctx->stmt(1)->accept(this));
|
|
if (else_flow == BlockFlow::Continue &&
|
|
!builder_.GetInsertBlock()->HasTerminator()) {
|
|
builder_.CreateBr(merge_bb);
|
|
}
|
|
}
|
|
|
|
builder_.SetInsertPoint(merge_bb);
|
|
if (has_else && then_flow == BlockFlow::Terminated &&
|
|
else_flow == BlockFlow::Terminated) {
|
|
if (func_->GetType()->IsInt32()) {
|
|
builder_.CreateRet(builder_.CreateConstInt(0));
|
|
} else if (func_->GetType()->IsFloat32()) {
|
|
builder_.CreateRet(builder_.CreateConstFloat(0.0));
|
|
} else {
|
|
builder_.CreateRetVoid();
|
|
}
|
|
return BlockFlow::Terminated;
|
|
}
|
|
return BlockFlow::Continue;
|
|
}
|
|
|
|
if (ctx->While()) {
|
|
if (!ctx->cond() || ctx->stmt().empty()) {
|
|
throw std::runtime_error(FormatError("irgen", "while 语句缺少必要组成部分"));
|
|
}
|
|
|
|
auto* cond_bb = func_->CreateBlock(NextBlockName("while.cond."));
|
|
auto* body_bb = func_->CreateBlock(NextBlockName("while.body."));
|
|
auto* exit_bb = func_->CreateBlock(NextBlockName("while.end."));
|
|
|
|
builder_.CreateBr(cond_bb);
|
|
|
|
builder_.SetInsertPoint(cond_bb);
|
|
EmitCondBranch(*ctx->cond(), body_bb, exit_bb);
|
|
|
|
loop_stack_.push_back({cond_bb, exit_bb});
|
|
builder_.SetInsertPoint(body_bb);
|
|
auto body_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
|
|
if (body_flow == BlockFlow::Continue && !builder_.GetInsertBlock()->HasTerminator()) {
|
|
builder_.CreateBr(cond_bb);
|
|
}
|
|
loop_stack_.pop_back();
|
|
|
|
builder_.SetInsertPoint(exit_bb);
|
|
return BlockFlow::Continue;
|
|
}
|
|
|
|
if (ctx->Break()) {
|
|
if (loop_stack_.empty()) {
|
|
throw std::runtime_error(FormatError("irgen", "break 不在循环体内"));
|
|
}
|
|
builder_.CreateBr(loop_stack_.back().second);
|
|
return BlockFlow::Terminated;
|
|
}
|
|
|
|
if (ctx->Continue()) {
|
|
if (loop_stack_.empty()) {
|
|
throw std::runtime_error(FormatError("irgen", "continue 不在循环体内"));
|
|
}
|
|
builder_.CreateBr(loop_stack_.back().first);
|
|
return BlockFlow::Terminated;
|
|
}
|
|
|
|
if (ctx->Return()) {
|
|
if (func_->GetType()->IsVoid()) {
|
|
if (ctx->exp()) {
|
|
throw std::runtime_error(FormatError("irgen", "void 函数不应返回值"));
|
|
}
|
|
builder_.CreateRetVoid();
|
|
return BlockFlow::Terminated;
|
|
}
|
|
if (!ctx->exp()) {
|
|
throw std::runtime_error(FormatError("irgen", "return 缺少表达式"));
|
|
}
|
|
ir::Value* v = CastValueTo(EvalExpr(*ctx->exp()), func_->GetType());
|
|
builder_.CreateRet(v);
|
|
return BlockFlow::Terminated;
|
|
}
|
|
|
|
if (ctx->Semi() && !ctx->exp()) {
|
|
return BlockFlow::Continue;
|
|
}
|
|
|
|
if (ctx->exp() && ctx->Semi()) {
|
|
(void)EvalExpr(*ctx->exp());
|
|
return BlockFlow::Continue;
|
|
}
|
|
|
|
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
|
|
}
|