Shrink: Compile pass with IRGen fixed

实现合并
Shrink
Shrink 2 weeks ago
parent 97d5ec1d48
commit c33d36e040

@ -29,13 +29,22 @@ class IRGenImpl final : public SysYBaseVisitor {
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitVarDecl(SysYParser::VarDeclContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override;
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override;
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override;
std::any visitVarExp(SysYParser::VarExpContext* ctx) override;
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override;
std::any visitExp(SysYParser::ExpContext* ctx) override;
std::any visitCond(SysYParser::CondContext* ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override;
std::any visitNumber(SysYParser::NumberContext* ctx) override;
std::any visitLValue(SysYParser::LValueContext* ctx) override;
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override;
std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
std::any visitRelExp(SysYParser::RelExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override;
private:
enum class BlockFlow {
@ -43,8 +52,16 @@ class IRGenImpl final : public SysYBaseVisitor {
Terminated,
};
struct LoopTargets {
ir::BasicBlock* continue_target;
ir::BasicBlock* break_target;
};
BlockFlow VisitBlockItemResult(SysYParser::BlockItemContext& item);
ir::Value* EvalExpr(SysYParser::ExpContext& expr);
ir::Value* EvalCond(SysYParser::CondContext& cond);
ir::Value* ToBoolValue(ir::Value* v);
std::string NextBlockName();
ir::Module& module_;
const SemanticContext& sema_;
@ -52,6 +69,8 @@ class IRGenImpl final : public SysYBaseVisitor {
ir::IRBuilder builder_;
// 名称绑定由 Sema 负责IRGen 只维护“声明 -> 存储槽位”的代码生成状态。
std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_;
std::unordered_map<std::string, ir::Value*> named_storage_;
std::vector<LoopTargets> loop_stack_;
};
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,

@ -60,17 +60,29 @@ std::any IRGenImpl::visitBlockItem(SysYParser::BlockItemContext* ctx) {
// - const、数组、全局变量等不同声明形态
// - 更丰富的类型系统。
std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
}
if (!ctx->varDecl()) {
// 当前先忽略 constDecl 与其它声明形态。
return {};
}
return ctx->varDecl()->accept(this);
}
std::any IRGenImpl::visitVarDecl(SysYParser::VarDeclContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
}
if (!ctx->btype() || !ctx->btype()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明"));
}
auto* var_def = ctx->varDef();
if (!var_def) {
throw std::runtime_error(FormatError("irgen", "非法变量声明"));
for (auto* var_def : ctx->varDef()) {
if (!var_def) {
throw std::runtime_error(FormatError("irgen", "非法变量声明"));
}
var_def->accept(this);
}
var_def->accept(this);
return {};
}
@ -83,15 +95,16 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量定义"));
}
if (!ctx->lValue()) {
if (!ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
}
GetLValueName(*ctx->lValue());
const std::string name = ctx->ID()->getText();
if (storage_map_.find(ctx) != storage_map_.end()) {
throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
}
auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
named_storage_[name] = slot;
ir::Value* init = nullptr;
if (auto* init_value = ctx->initValue()) {

@ -24,21 +24,62 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
return std::any_cast<ir::Value*>(expr.accept(this));
}
ir::Value* IRGenImpl::EvalCond(SysYParser::CondContext& cond) {
return std::any_cast<ir::Value*>(cond.accept(this));
}
std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "非法括号表达式"));
ir::Value* IRGenImpl::ToBoolValue(ir::Value* v) {
if (!v) {
throw std::runtime_error(FormatError("irgen", "条件值为空"));
}
return EvalExpr(*ctx->exp());
auto* zero = builder_.CreateConstInt(0);
return builder_.CreateCmp(ir::CmpOp::Ne, v, zero, module_.GetContext().NextTemp());
}
std::string IRGenImpl::NextBlockName() {
std::string temp = module_.GetContext().NextTemp();
if (!temp.empty() && temp.front() == '%') {
return "bb" + temp.substr(1);
}
return "bb" + temp;
}
std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
std::any IRGenImpl::visitExp(SysYParser::ExpContext* ctx) {
if (!ctx || !ctx->addExp()) {
throw std::runtime_error(FormatError("irgen", "非法表达式"));
}
return ctx->addExp()->accept(this);
}
std::any IRGenImpl::visitCond(SysYParser::CondContext* ctx) {
if (!ctx || !ctx->lOrExp()) {
throw std::runtime_error(FormatError("irgen", "非法条件表达式"));
}
return ctx->lOrExp()->accept(this);
}
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法基本表达式"));
}
if (ctx->exp()) {
return EvalExpr(*ctx->exp());
}
if (ctx->number()) {
return ctx->number()->accept(this);
}
if (ctx->lValue()) {
return ctx->lValue()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "不支持的基本表达式"));
}
std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
if (!ctx || !ctx->ILITERAL()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量"));
}
return static_cast<ir::Value*>(
builder_.CreateConstInt(std::stoi(ctx->number()->getText())));
builder_.CreateConstInt(std::stoi(ctx->getText())));
}
// 变量使用的处理流程:
@ -47,34 +88,192 @@ std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
// 3. 最后生成 load把内存中的值读出来。
//
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。
std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) {
if (!ctx || !ctx->var() || !ctx->var()->ID()) {
std::any IRGenImpl::visitLValue(SysYParser::LValueContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
}
auto* decl = sema_.ResolveVarUse(ctx->var());
if (!decl) {
throw std::runtime_error(
FormatError("irgen",
"变量使用缺少语义绑定: " + ctx->var()->ID()->getText()));
}
auto it = storage_map_.find(decl);
if (it == storage_map_.end()) {
const std::string name = ctx->ID()->getText();
auto it = named_storage_.find(name);
if (it == named_storage_.end()) {
throw std::runtime_error(
FormatError("irgen",
"变量声明缺少存储槽位: " + ctx->var()->ID()->getText()));
FormatError("irgen", "变量声明缺少存储槽位: " + name));
}
return static_cast<ir::Value*>(
builder_.CreateLoad(it->second, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法一元表达式"));
}
if (ctx->primaryExp()) {
return ctx->primaryExp()->accept(this);
}
if (ctx->unaryOp() && ctx->unaryExp()) {
ir::Value* v = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
if (ctx->unaryOp()->SUB()) {
auto* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(builder_.CreateSub(
zero, v, module_.GetContext().NextTemp()));
}
if (ctx->unaryOp()->ADD()) {
return v;
}
throw std::runtime_error(FormatError("irgen", "当前不支持逻辑非运算"));
}
throw std::runtime_error(FormatError("irgen", "当前不支持函数调用表达式"));
}
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
}
if (ctx->mulExp()) {
if (!ctx->unaryExp()) {
throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
if (ctx->MUL()) {
return static_cast<ir::Value*>(
builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->DIV()) {
return static_cast<ir::Value*>(
builder_.CreateDiv(lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->MOD()) {
return static_cast<ir::Value*>(
builder_.CreateMod(lhs, rhs, module_.GetContext().NextTemp()));
}
throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
}
if (ctx->unaryExp()) {
return ctx->unaryExp()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
}
std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
return static_cast<ir::Value*>(
builder_.CreateBinary(ir::Opcode::Add, lhs, rhs,
module_.GetContext().NextTemp()));
if (ctx->addExp()) {
if (!ctx->mulExp()) {
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
if (ctx->ADD()) {
return static_cast<ir::Value*>(
builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->SUB()) {
return static_cast<ir::Value*>(
builder_.CreateSub(lhs, rhs, module_.GetContext().NextTemp()));
}
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
if (ctx->mulExp()) {
return ctx->mulExp()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
if (ctx->relExp()) {
if (!ctx->addExp()) {
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->relExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
if (ctx->LT()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Lt, lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->LE()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Le, lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->GT()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Gt, lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->GE()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Ge, lhs, rhs, module_.GetContext().NextTemp()));
}
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
if (ctx->addExp()) {
return ctx->addExp()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
}
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
if (ctx->eqExp()) {
if (!ctx->relExp()) {
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->eqExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->relExp()->accept(this));
if (ctx->EQ()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Eq, lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->NE()) {
return static_cast<ir::Value*>(builder_.CreateCmp(
ir::CmpOp::Ne, lhs, rhs, module_.GetContext().NextTemp()));
}
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
if (ctx->relExp()) {
return ctx->relExp()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
}
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
}
if (ctx->lAndExp()) {
if (!ctx->eqExp()) {
throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
}
auto* lhs = ToBoolValue(std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this)));
auto* rhs = ToBoolValue(std::any_cast<ir::Value*>(ctx->eqExp()->accept(this)));
return static_cast<ir::Value*>(
builder_.CreateMul(lhs, rhs, module_.GetContext().NextTemp()));
}
if (ctx->eqExp()) {
return ToBoolValue(std::any_cast<ir::Value*>(ctx->eqExp()->accept(this)));
}
throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
}
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
}
if (ctx->lOrExp()) {
if (!ctx->lAndExp()) {
throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
}
auto* lhs = ToBoolValue(std::any_cast<ir::Value*>(ctx->lOrExp()->accept(this)));
auto* rhs = ToBoolValue(std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this)));
auto* sum = builder_.CreateAdd(lhs, rhs, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(ToBoolValue(sum));
}
if (ctx->lAndExp()) {
return ToBoolValue(std::any_cast<ir::Value*>(ctx->lAndExp()->accept(this)));
}
throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
}

@ -38,11 +38,14 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
}
auto* func = ctx->funcDef();
if (!func) {
if (ctx->funcDef().empty()) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
func->accept(this);
for (auto* func : ctx->funcDef()) {
if (func) {
func->accept(this);
}
}
return {};
}
@ -79,6 +82,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
func_ = module_.CreateFunction(ctx->ID()->getText(), ir::Type::GetInt32Type());
builder_.SetInsertPoint(func_->GetEntry());
storage_map_.clear();
named_storage_.clear();
ctx->blockStmt()->accept(this);
// 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。

@ -19,9 +19,101 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句"));
}
if (ctx->lValue() && ctx->ASSIGN() && ctx->exp()) {
if (!ctx->lValue()->ID()) {
throw std::runtime_error(FormatError("irgen", "赋值语句左值非法"));
}
const std::string name = ctx->lValue()->ID()->getText();
auto slot_it = named_storage_.find(name);
if (slot_it == named_storage_.end()) {
throw std::runtime_error(FormatError("irgen", "赋值目标未声明: " + name));
}
ir::Value* rhs = EvalExpr(*ctx->exp());
builder_.CreateStore(rhs, slot_it->second);
return BlockFlow::Continue;
}
if (ctx->blockStmt()) {
ctx->blockStmt()->accept(this);
return builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()
? BlockFlow::Terminated
: BlockFlow::Continue;
}
if (ctx->IF()) {
if (!ctx->cond() || ctx->stmt().empty()) {
throw std::runtime_error(FormatError("irgen", "if 语句不完整"));
}
auto* then_bb = func_->CreateBlock(NextBlockName());
auto* merge_bb = func_->CreateBlock(NextBlockName());
auto* else_bb = ctx->ELSE() ? func_->CreateBlock(NextBlockName()) : merge_bb;
ir::Value* cond = ToBoolValue(EvalCond(*ctx->cond()));
builder_.CreateCondBr(cond, then_bb, else_bb);
builder_.SetInsertPoint(then_bb);
auto then_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (then_flow != BlockFlow::Terminated) {
builder_.CreateBr(merge_bb);
}
if (ctx->ELSE()) {
builder_.SetInsertPoint(else_bb);
auto else_flow = std::any_cast<BlockFlow>(ctx->stmt(1)->accept(this));
if (else_flow != BlockFlow::Terminated) {
builder_.CreateBr(merge_bb);
}
}
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
if (ctx->WHILE()) {
if (!ctx->cond() || ctx->stmt().empty()) {
throw std::runtime_error(FormatError("irgen", "while 语句不完整"));
}
auto* cond_bb = func_->CreateBlock(NextBlockName());
auto* body_bb = func_->CreateBlock(NextBlockName());
auto* exit_bb = func_->CreateBlock(NextBlockName());
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(cond_bb);
ir::Value* cond = ToBoolValue(EvalCond(*ctx->cond()));
builder_.CreateCondBr(cond, body_bb, exit_bb);
loop_stack_.push_back({cond_bb, exit_bb});
builder_.SetInsertPoint(body_bb);
auto body_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (body_flow != BlockFlow::Terminated) {
builder_.CreateBr(cond_bb);
}
loop_stack_.pop_back();
builder_.SetInsertPoint(exit_bb);
return BlockFlow::Continue;
}
if (ctx->BREAK()) {
if (loop_stack_.empty()) {
throw std::runtime_error(FormatError("irgen", "break 不在循环中"));
}
builder_.CreateBr(loop_stack_.back().break_target);
return BlockFlow::Terminated;
}
if (ctx->CONTINUE()) {
if (loop_stack_.empty()) {
throw std::runtime_error(FormatError("irgen", "continue 不在循环中"));
}
builder_.CreateBr(loop_stack_.back().continue_target);
return BlockFlow::Terminated;
}
if (ctx->returnStmt()) {
return ctx->returnStmt()->accept(this);
}
if (ctx->exp()) {
EvalExpr(*ctx->exp());
return BlockFlow::Continue;
}
if (ctx->SEMICOLON()) {
return BlockFlow::Continue;
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
}

Loading…
Cancel
Save