“IRGen部分实现:支持更多二元运算符(Sub, Mul, Div, Mod)”

dyz
lc 1 week ago committed by olivame
parent d4516f2289
commit 3366d20f9e

@ -301,6 +301,8 @@ class IRBuilder {
BinaryInst* CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name);
BinaryInst* CreateAdd(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateSub(Value* lhs, Value* rhs, const std::string& name);
BinaryInst* CreateMul(Value* lhs, Value* rhs, const std::string& name);
AllocaInst* CreateAllocaI32(const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);

@ -26,16 +26,16 @@ class IRGenImpl final : public SysYBaseVisitor {
std::any visitCompUnit(SysYParser::CompUnitContext* ctx) override;
std::any visitFuncDef(SysYParser::FuncDefContext* ctx) override;
std::any visitBlockStmt(SysYParser::BlockStmtContext* ctx) override;
std::any visitBlock(SysYParser::BlockContext* ctx) override;
std::any visitBlockItem(SysYParser::BlockItemContext* ctx) override;
std::any visitDecl(SysYParser::DeclContext* ctx) override;
std::any visitStmt(SysYParser::StmtContext* ctx) override;
std::any visitVarDef(SysYParser::VarDefContext* ctx) override;
std::any visitReturnStmt(SysYParser::ReturnStmtContext* ctx) override;
std::any visitParenExp(SysYParser::ParenExpContext* ctx) override;
std::any visitNumberExp(SysYParser::NumberExpContext* ctx) override;
std::any visitVarExp(SysYParser::VarExpContext* ctx) override;
std::any visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) override;
std::any visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) override;
std::any visitNumber(SysYParser::NumberContext* ctx) override;
std::any visitLVal(SysYParser::LValContext* ctx) override;
std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override;
private:
enum class BlockFlow {
@ -50,8 +50,8 @@ class IRGenImpl final : public SysYBaseVisitor {
const SemanticContext& sema_;
ir::Function* func_;
ir::IRBuilder builder_;
// 名称绑定由 Sema 负责IRGen 只维护“声明 -> 存储槽位”的代码生成状态。
std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_;
// 名称绑定由 Sema 负责IRGen 只维护"变量名 -> 存储槽位"的代码生成状态。
std::unordered_map<std::string, ir::Value*> storage_map_;
};
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,

@ -32,6 +32,10 @@ static const char* OpcodeToString(Opcode op) {
return "sub";
case Opcode::Mul:
return "mul";
case Opcode::Div:
return "sdiv";
case Opcode::Mod:
return "srem";
case Opcode::Alloca:
return "alloca";
case Opcode::Load:
@ -65,7 +69,9 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
switch (inst->GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul: {
case Opcode::Mul:
case Opcode::Div:
case Opcode::Mod: {
auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " "

@ -61,8 +61,9 @@ void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, std::string name)
: Instruction(op, std::move(ty), std::move(name)) {
if (op != Opcode::Add) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add"));
if (op != Opcode::Add && op != Opcode::Sub && op != Opcode::Mul &&
op != Opcode::Div && op != Opcode::Mod) {
throw std::runtime_error(FormatError("ir", "BinaryInst 不支持的操作码"));
}
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数"));

@ -6,18 +6,7 @@
#include "ir/IR.h"
#include "utils/Log.h"
namespace {
std::string GetLValueName(SysYParser::LValueContext& lvalue) {
if (!lvalue.ID()) {
throw std::runtime_error(FormatError("irgen", "非法左值"));
}
return lvalue.ID()->getText();
}
} // namespace
std::any IRGenImpl::visitBlockStmt(SysYParser::BlockStmtContext* ctx) {
std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句块"));
}
@ -63,14 +52,20 @@ std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量声明"));
}
if (!ctx->btype() || !ctx->btype()->INT()) {
// 当前语法中 decl 包含 constDecl 或 varDecl这里只支持 varDecl
auto* var_decl = ctx->varDecl();
if (!var_decl) {
throw std::runtime_error(FormatError("irgen", "当前仅支持变量声明"));
}
if (!var_decl->bType() || !var_decl->bType()->INT()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int 变量声明"));
}
auto* var_def = ctx->varDef();
if (!var_def) {
throw std::runtime_error(FormatError("irgen", "非法变量声明"));
// 遍历所有 varDef
for (auto* var_def : var_decl->varDef()) {
if (var_def) {
var_def->accept(this);
}
}
var_def->accept(this);
return {};
}
@ -83,22 +78,26 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少变量定义"));
}
if (!ctx->lValue()) {
if (!ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "变量声明缺少名称"));
}
GetLValueName(*ctx->lValue());
if (storage_map_.find(ctx) != storage_map_.end()) {
// 暂不支持数组声明constIndex
if (!ctx->constIndex().empty()) {
throw std::runtime_error(FormatError("irgen", "暂不支持数组声明"));
}
std::string var_name = ctx->ID()->getText();
if (storage_map_.find(var_name) != storage_map_.end()) {
throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
}
auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
storage_map_[var_name] = slot;
ir::Value* init = nullptr;
if (auto* init_value = ctx->initValue()) {
if (!init_value->exp()) {
if (auto* init_val = ctx->initVal()) {
if (!init_val->exp()) {
throw std::runtime_error(FormatError("irgen", "当前不支持聚合初始化"));
}
init = EvalExpr(*init_value->exp());
init = EvalExpr(*init_val->exp());
} else {
init = builder_.CreateConstInt(0);
}

@ -25,20 +25,51 @@ ir::Value* IRGenImpl::EvalExpr(SysYParser::ExpContext& expr) {
}
std::any IRGenImpl::visitParenExp(SysYParser::ParenExpContext* ctx) {
if (!ctx || !ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "非法括号表达式"));
std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法基本表达式"));
}
// 处理括号表达式LPAREN exp RPAREN
if (ctx->exp()) {
return EvalExpr(*ctx->exp());
}
// 处理 lVal变量使用- 交给 visitLVal 处理
if (ctx->lVal()) {
// 直接在这里处理变量读取,避免 accept 调用可能导致的问题
auto* lval_ctx = ctx->lVal();
if (!lval_ctx || !lval_ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
}
const auto* decl = sema_.ResolveObjectUse(lval_ctx);
if (!decl) {
throw std::runtime_error(
FormatError("irgen",
"变量使用缺少语义绑定:" + lval_ctx->ID()->getText()));
}
std::string var_name = lval_ctx->ID()->getText();
auto it = storage_map_.find(var_name);
if (it == storage_map_.end()) {
throw std::runtime_error(
FormatError("irgen",
"变量声明缺少存储槽位:" + var_name));
}
return static_cast<ir::Value*>(
builder_.CreateLoad(it->second, module_.GetContext().NextTemp()));
}
return EvalExpr(*ctx->exp());
// 处理 number
if (ctx->number()) {
return ctx->number()->accept(this);
}
throw std::runtime_error(FormatError("irgen", "不支持的基本表达式类型"));
}
std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
if (!ctx || !ctx->number() || !ctx->number()->ILITERAL()) {
std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
if (!ctx || !ctx->intConst()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量"));
}
return static_cast<ir::Value*>(
builder_.CreateConstInt(std::stoi(ctx->number()->getText())));
builder_.CreateConstInt(std::stoi(ctx->intConst()->getText())));
}
// 变量使用的处理流程:
@ -47,33 +78,46 @@ std::any IRGenImpl::visitNumberExp(SysYParser::NumberExpContext* ctx) {
// 3. 最后生成 load把内存中的值读出来。
//
// 因此当前 IRGen 自己不再做名字查找,而是直接消费 Sema 的绑定结果。
std::any IRGenImpl::visitVarExp(SysYParser::VarExpContext* ctx) {
if (!ctx || !ctx->var() || !ctx->var()->ID()) {
std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
if (!ctx || !ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量"));
}
auto* decl = sema_.ResolveVarUse(ctx->var());
const auto* decl = sema_.ResolveObjectUse(ctx);
if (!decl) {
throw std::runtime_error(
FormatError("irgen",
"变量使用缺少语义绑定: " + ctx->var()->ID()->getText()));
"变量使用缺少语义绑定" + ctx->ID()->getText()));
}
auto it = storage_map_.find(decl);
// 使用变量名查找存储槽位
std::string var_name = ctx->ID()->getText();
auto it = storage_map_.find(var_name);
if (it == storage_map_.end()) {
throw std::runtime_error(
FormatError("irgen",
"变量声明缺少存储槽位: " + ctx->var()->ID()->getText()));
"变量声明缺少存储槽位" + var_name));
}
return static_cast<ir::Value*>(
builder_.CreateLoad(it->second, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitAdditiveExp(SysYParser::AdditiveExpContext* ctx) {
if (!ctx || !ctx->exp(0) || !ctx->exp(1)) {
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "非法加减法表达式"));
}
ir::Value* lhs = EvalExpr(*ctx->exp(0));
ir::Value* rhs = EvalExpr(*ctx->exp(1));
// 如果是 mulExp 直接返回addExp : mulExp
if (ctx->mulExp() && ctx->addExp() == nullptr) {
return ctx->mulExp()->accept(this);
}
// 处理 addExp op mulExp 的递归形式
if (!ctx->addExp() || !ctx->mulExp()) {
throw std::runtime_error(FormatError("irgen", "非法加减法表达式结构"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->addExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
ir::Opcode op = ir::Opcode::Add;
if (ctx->ADD()) {
@ -93,18 +137,18 @@ std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
throw std::runtime_error(FormatError("irgen", "非法乘除法表达式"));
}
// 如果是 unaryExp 直接返回
if (ctx->unaryExp()) {
// 如果是 unaryExp 直接返回mulExp : unaryExp
if (ctx->unaryExp() && ctx->mulExp() == nullptr) {
return ctx->unaryExp()->accept(this);
}
// 处理 MulExp op unaryExp 的递归形式
if (!ctx->exp(0) || !ctx->unaryExp(0)) {
// 处理 mulExp op unaryExp 的递归形式
if (!ctx->mulExp() || !ctx->unaryExp()) {
throw std::runtime_error(FormatError("irgen", "非法乘除法表达式结构"));
}
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->exp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->unaryExp(0)->accept(this));
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->mulExp()->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
ir::Opcode op = ir::Opcode::Mul;
if (ctx->MUL()) {

@ -29,7 +29,7 @@ IRGenImpl::IRGenImpl(ir::Module& module, const SemanticContext& sema)
// 编译单元的 IR 生成当前只实现了最小功能:
// - Module 已在 GenerateIR 中创建,这里只负责继续生成其中的内容;
// - 当前会读取编译单元中的函数定义,并交给 visitFuncDef 生成函数 IR
// - 当前会读取编译单元中的 topLevelItem找到 funcDef 后生成函数 IR
//
// 当前还没有实现:
// - 多个函数定义的遍历与生成;
@ -38,12 +38,15 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少编译单元"));
}
auto* func = ctx->funcDef();
if (!func) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
// 遍历所有 topLevelItem找到 funcDef
for (auto* item : ctx->topLevelItem()) {
if (item && item->funcDef()) {
item->funcDef()->accept(this);
// 当前只支持单个函数,找到第一个后就返回
return {};
}
}
func->accept(this);
return {};
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
// 函数 IR 生成当前实现了:
@ -61,12 +64,11 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
// - 入口块中的参数初始化逻辑。
// ...
// 因此这里目前只支持最小的“无参 int 函数”生成。
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
if (!ctx->blockStmt()) {
if (!ctx->block()) {
throw std::runtime_error(FormatError("irgen", "函数体为空"));
}
if (!ctx->ID()) {
@ -80,7 +82,7 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
builder_.SetInsertPoint(func_->GetEntry());
storage_map_.clear();
ctx->blockStmt()->accept(this);
ctx->block()->accept(this);
// 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。
VerifyFunctionStructure(*func_);
return {};

@ -19,21 +19,14 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句"));
}
if (ctx->returnStmt()) {
return ctx->returnStmt()->accept(this);
// 检查是否是 return 语句RETURN exp? SEMICOLON
if (ctx->RETURN()) {
if (!ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "return 缺少表达式"));
}
ir::Value* v = EvalExpr(*ctx->exp());
builder_.CreateRet(v);
return BlockFlow::Terminated;
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
}
std::any IRGenImpl::visitReturnStmt(SysYParser::ReturnStmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少 return 语句"));
}
if (!ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "return 缺少表达式"));
}
ir::Value* v = EvalExpr(*ctx->exp());
builder_.CreateRet(v);
return BlockFlow::Terminated;
}

@ -434,9 +434,11 @@ class SemaVisitor final : public SysYBaseVisitor {
if (!ctx) {
ThrowSemaError(ctx, "非法乘法表达式");
}
if (ctx->unaryExp()) {
// 如果是 mulExp : unaryExp 形式(没有 MUL/DIV/MOD token直接处理 unaryExp
if (!ctx->MUL() && !ctx->DIV() && !ctx->MOD()) {
return EvalExpr(*ctx->unaryExp());
}
// 否则是 mulExp MUL/DIV/MOD unaryExp 形式
ExprInfo lhs = EvalExpr(*ctx->mulExp());
ExprInfo rhs = EvalExpr(*ctx->unaryExp());
return EvalArithmetic(*ctx, lhs, rhs, ctx->MUL() ? '*' : (ctx->DIV() ? '/' : '%'));
@ -446,9 +448,11 @@ class SemaVisitor final : public SysYBaseVisitor {
if (!ctx) {
ThrowSemaError(ctx, "非法加法表达式");
}
if (ctx->mulExp()) {
// 如果是 addExp : mulExp 形式(没有 ADD/SUB token直接处理 mulExp
if (!ctx->ADD() && !ctx->SUB()) {
return EvalExpr(*ctx->mulExp());
}
// 否则是 addExp ADD/SUB mulExp 形式
ExprInfo lhs = EvalExpr(*ctx->addExp());
ExprInfo rhs = EvalExpr(*ctx->mulExp());
return EvalArithmetic(*ctx, lhs, rhs, ctx->ADD() ? '+' : '-');
@ -544,7 +548,7 @@ class SemaVisitor final : public SysYBaseVisitor {
const std::string name = ctx.ID()->getText();
const ObjectBinding* symbol = symbols_.Lookup(name);
if (!symbol) {
ThrowSemaError(&ctx, "使用了未声明的标识符: " + name);
ThrowSemaError(&ctx, "使用了未声明的标识符" + name);
}
sema_.BindObjectUse(&ctx, *symbol);

Loading…
Cancel
Save