You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/irgen/IRGenStmt.cpp

304 lines
10 KiB

#include "include/irgen/IRGen.h"
#include <stdexcept>
#include "SysYParser.h"
#include "include/ir/IR.h"
#include "include/utils/Log.h"
std::string IRGenImpl::NextBlockName(const std::string& prefix) {
return prefix + std::to_string(++block_index_);
}
void IRGenImpl::EmitCondBranch(SysYParser::CondContext& cond,
ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb) {
if (!cond.lOrExp()) {
throw std::runtime_error(FormatError("irgen", "非法条件表达式"));
}
EmitLOrBranch(*cond.lOrExp(), true_bb, false_bb);
}
void IRGenImpl::EmitLOrBranch(SysYParser::LOrExpContext& expr,
ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb) {
if (!expr.lOrExp()) {
if (!expr.lAndExp()) {
throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
}
EmitLAndBranch(*expr.lAndExp(), true_bb, false_bb);
return;
}
if (!expr.lAndExp()) {
throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
}
auto* rhs_bb = func_->CreateBlock(NextBlockName("lor.rhs."));
EmitLOrBranch(*expr.lOrExp(), true_bb, rhs_bb);
builder_.SetInsertPoint(rhs_bb);
EmitLAndBranch(*expr.lAndExp(), true_bb, false_bb);
}
void IRGenImpl::EmitLAndBranch(SysYParser::LAndExpContext& expr,
ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb) {
if (!expr.lAndExp()) {
if (!expr.eqExp()) {
throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
}
EmitEqBranch(*expr.eqExp(), true_bb, false_bb);
return;
}
if (!expr.eqExp()) {
throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
}
auto* rhs_bb = func_->CreateBlock(NextBlockName("land.rhs."));
EmitLAndBranch(*expr.lAndExp(), rhs_bb, false_bb);
builder_.SetInsertPoint(rhs_bb);
EmitEqBranch(*expr.eqExp(), true_bb, false_bb);
}
ir::Value* IRGenImpl::EvalRelValue(SysYParser::RelExpContext& expr) {
if (!expr.relExp()) {
if (!expr.addExp()) {
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
return std::any_cast<ir::Value*>(expr.addExp()->accept(this));
}
if (!expr.addExp()) {
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
auto* lhs = EvalRelValue(*expr.relExp());
auto* rhs = std::any_cast<ir::Value*>(expr.addExp()->accept(this));
ir::Opcode cmp = ir::Opcode::Lt;
if (expr.LtOp()) {
cmp = ir::Opcode::Lt;
} else if (expr.GtOp()) {
cmp = ir::Opcode::Gt;
} else if (expr.LeOp()) {
cmp = ir::Opcode::Le;
} else if (expr.GeOp()) {
cmp = ir::Opcode::Ge;
} else {
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
return EvalBinaryOrFold(cmp, lhs, rhs);
}
ir::Value* IRGenImpl::EvalEqValue(SysYParser::EqExpContext& expr) {
if (!expr.eqExp()) {
if (!expr.relExp()) {
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
return EvalRelValue(*expr.relExp());
}
if (!expr.relExp()) {
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
auto* lhs = EvalEqValue(*expr.eqExp());
auto* rhs = EvalRelValue(*expr.relExp());
ir::Opcode cmp = ir::Opcode::Eq;
if (expr.EqOp()) {
cmp = ir::Opcode::Eq;
} else if (expr.NeOp()) {
cmp = ir::Opcode::Ne;
} else {
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
return EvalBinaryOrFold(cmp, lhs, rhs);
}
void IRGenImpl::EmitEqBranch(SysYParser::EqExpContext& expr,
ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb) {
auto* cond_value = EvalEqValue(expr);
ir::Value* zero = cond_value->GetType()->IsFloat32()
? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
auto* cond = builder_.CreateBinary(ir::Opcode::Ne, cond_value, zero, module_.GetContext().NextTemp());
builder_.CreateCondBr(cond, true_bb, false_bb);
}
void IRGenImpl::EmitRelBranch(SysYParser::RelExpContext& expr,
ir::BasicBlock* true_bb,
ir::BasicBlock* false_bb) {
if (!expr.relExp()) {
if (!expr.addExp()) {
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
auto* value = std::any_cast<ir::Value*>(expr.addExp()->accept(this));
auto* zero = builder_.CreateConstInt(0);
auto* cond = builder_.CreateBinary(ir::Opcode::Ne, value, zero, module_.GetContext().NextTemp());
builder_.CreateCondBr(cond, true_bb, false_bb);
return;
}
if (!expr.addExp() || !expr.relExp()->addExp() || expr.relExp()->relExp()) {
throw std::runtime_error(
FormatError("irgen", "当前不支持链式关系比较表达式"));
}
auto* lhs = std::any_cast<ir::Value*>(expr.relExp()->addExp()->accept(this));
auto* rhs = std::any_cast<ir::Value*>(expr.addExp()->accept(this));
ir::Opcode cmp = ir::Opcode::Lt;
if (expr.LtOp()) {
cmp = ir::Opcode::Lt;
} else if (expr.GtOp()) {
cmp = ir::Opcode::Gt;
} else if (expr.LeOp()) {
cmp = ir::Opcode::Le;
} else if (expr.GeOp()) {
cmp = ir::Opcode::Ge;
} else {
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
auto* cond = builder_.CreateBinary(cmp, lhs, rhs, module_.GetContext().NextTemp());
builder_.CreateCondBr(cond, true_bb, false_bb);
}
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句"));
}
if (ctx->block()) {
return ctx->block()->accept(this);
}
if (ctx->lVal() && ctx->Assign()) {
if (!ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "赋值语句缺少右值表达式"));
}
bool is_array = false;
const auto extents = GetArrayExtentsForLVal(*ctx->lVal(), is_array);
if (is_array && ctx->lVal()->exp().size() < extents.size()) {
throw std::runtime_error(FormatError("irgen", "不能给数组对象赋值"));
}
auto* slot = GetLValAddress(*ctx->lVal());
ir::Value* rhs = EvalExpr(*ctx->exp());
auto target_ty = slot->GetType()->IsPtrFloat32() ? ir::Type::GetFloat32Type()
: ir::Type::GetInt32Type();
rhs = CastValueTo(rhs, target_ty);
builder_.CreateStore(rhs, slot);
return BlockFlow::Continue;
}
if (ctx->If()) {
if (!ctx->cond() || ctx->stmt().empty()) {
throw std::runtime_error(FormatError("irgen", "if 语句缺少必要组成部分"));
}
auto* then_bb = func_->CreateBlock(NextBlockName("if.then."));
ir::BasicBlock* else_bb = nullptr;
const bool has_else = ctx->Else() != nullptr;
if (has_else) {
else_bb = func_->CreateBlock(NextBlockName("if.else."));
}
auto* merge_bb = func_->CreateBlock(NextBlockName("if.end."));
EmitCondBranch(*ctx->cond(), then_bb, has_else ? else_bb : merge_bb);
builder_.SetInsertPoint(then_bb);
auto then_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (then_flow == BlockFlow::Continue && !builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateBr(merge_bb);
}
BlockFlow else_flow = BlockFlow::Continue;
if (has_else) {
builder_.SetInsertPoint(else_bb);
else_flow = std::any_cast<BlockFlow>(ctx->stmt(1)->accept(this));
if (else_flow == BlockFlow::Continue &&
!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateBr(merge_bb);
}
}
builder_.SetInsertPoint(merge_bb);
if (has_else && then_flow == BlockFlow::Terminated &&
else_flow == BlockFlow::Terminated) {
if (func_->GetType()->IsInt32()) {
builder_.CreateRet(builder_.CreateConstInt(0));
} else if (func_->GetType()->IsFloat32()) {
builder_.CreateRet(builder_.CreateConstFloat(0.0));
} else {
builder_.CreateRetVoid();
}
return BlockFlow::Terminated;
}
return BlockFlow::Continue;
}
if (ctx->While()) {
if (!ctx->cond() || ctx->stmt().empty()) {
throw std::runtime_error(FormatError("irgen", "while 语句缺少必要组成部分"));
}
auto* cond_bb = func_->CreateBlock(NextBlockName("while.cond."));
auto* body_bb = func_->CreateBlock(NextBlockName("while.body."));
auto* exit_bb = func_->CreateBlock(NextBlockName("while.end."));
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(cond_bb);
EmitCondBranch(*ctx->cond(), body_bb, exit_bb);
loop_stack_.push_back({cond_bb, exit_bb});
builder_.SetInsertPoint(body_bb);
auto body_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (body_flow == BlockFlow::Continue && !builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateBr(cond_bb);
}
loop_stack_.pop_back();
builder_.SetInsertPoint(exit_bb);
return BlockFlow::Continue;
}
if (ctx->Break()) {
if (loop_stack_.empty()) {
throw std::runtime_error(FormatError("irgen", "break 不在循环体内"));
}
builder_.CreateBr(loop_stack_.back().second);
return BlockFlow::Terminated;
}
if (ctx->Continue()) {
if (loop_stack_.empty()) {
throw std::runtime_error(FormatError("irgen", "continue 不在循环体内"));
}
builder_.CreateBr(loop_stack_.back().first);
return BlockFlow::Terminated;
}
if (ctx->Return()) {
if (func_->GetType()->IsVoid()) {
if (ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "void 函数不应返回值"));
}
builder_.CreateRetVoid();
return BlockFlow::Terminated;
}
if (!ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "return 缺少表达式"));
}
ir::Value* v = CastValueTo(EvalExpr(*ctx->exp()), func_->GetType());
builder_.CreateRet(v);
return BlockFlow::Terminated;
}
if (ctx->Semi() && !ctx->exp()) {
return BlockFlow::Continue;
}
if (ctx->exp() && ctx->Semi()) {
(void)EvalExpr(*ctx->exp());
return BlockFlow::Continue;
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
}