You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

197 lines
6.9 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include "irgen/IRGen.h"
#include <stdexcept>
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
// 语句生成当前只实现了最小子集。
// 目前支持:
// - return <exp>;
// - 赋值语句lVal = exp;
//
// 还未支持:
// - if / while 等控制流
// - 空语句、块语句嵌套分发之外的更多语句形态
std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句"));
}
if (ctx->lVal() && ctx->ASSIGN()) {
if (!ctx->exp()) {
throw std::runtime_error(FormatError("irgen", "赋值语句缺少表达式"));
}
ir::Value* rhs = EvalExpr(*ctx->exp());
auto* lval_ctx = ctx->lVal();
if (!lval_ctx || !lval_ctx->ID()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持普通整型变量赋值"));
}
const auto* decl = sema_.ResolveObjectUse(lval_ctx);
if (!decl) {
throw std::runtime_error(
FormatError("irgen", "变量使用缺少语义绑定:" + lval_ctx->ID()->getText()));
}
std::string var_name = lval_ctx->ID()->getText();
ir::Value* slot = FindStorage(var_name);
if (!slot) {
throw std::runtime_error(
FormatError("irgen", "变量声明缺少存储槽位:" + var_name));
}
ir::Value* ptr = slot;
auto ptr_ty = ptr->GetType();
bool is_param = false;
// If it's a pointer to a pointer (function parameter case), load the pointer value first
if (ptr_ty->IsPointer() && ptr_ty->GetPointedType()->IsPointer()) {
ptr = builder_.CreateLoad(ptr, module_.GetContext().NextTemp());
is_param = true;
}
if (ptr->IsArgument()) is_param = true;
if (!lval_ctx->exp().empty()) {
std::vector<ir::Value*> indices;
if (ptr->GetType()->IsPointer() && ptr->GetType()->GetPointedType()->IsArray()) {
if (!is_param) {
indices.push_back(builder_.CreateConstInt(0));
}
}
for (auto* exp_ctx : lval_ctx->exp()) {
indices.push_back(EvalExpr(*exp_ctx));
}
auto res_ptr_ty = GetGEPResultType(ptr, indices);
ptr = builder_.CreateGEP(res_ptr_ty, ptr, indices, module_.GetContext().NextTemp());
}
// Implicit conversion for assignment
if ((ptr->GetType()->IsPtrFloat() || (ptr->GetType()->IsArray() && ptr->GetType()->GetElementType()->IsFloat())) && !rhs->GetType()->IsFloat()) {
rhs = builder_.CreateSIToFP(rhs, module_.GetContext().NextTemp());
} else if ((ptr->GetType()->IsPtrInt32() || (ptr->GetType()->IsArray() && ptr->GetType()->GetElementType()->IsInt32())) && rhs->GetType()->IsFloat()) {
rhs = builder_.CreateFPToSI(rhs, module_.GetContext().NextTemp());
}
builder_.CreateStore(rhs, ptr);
return BlockFlow::Continue;
}
if (ctx->IF()) {
ir::Value* cond_val = std::any_cast<ir::Value*>(ctx->cond()->accept(this));
// cond_val must be i1, if it's not we need to check if it's != 0
if (cond_val->GetType()->IsInt32()) {
ir::Value* zero = builder_.CreateConstInt(0);
cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp());
} else if (cond_val->GetType()->IsFloat()) {
ir::Value* zero = module_.GetContext().GetConstFloat(0.0f);
cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp());
}
ir::BasicBlock* then_bb = func_->CreateBlock(NextBlockName("if_then"));
ir::BasicBlock* else_bb = ctx->ELSE() ? func_->CreateBlock(NextBlockName("if_else")) : nullptr;
ir::BasicBlock* merge_bb = func_->CreateBlock(NextBlockName("if_merge"));
builder_.CreateCondBr(cond_val, then_bb, else_bb ? else_bb : merge_bb);
builder_.SetInsertPoint(then_bb);
auto then_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (then_flow == BlockFlow::Continue) {
builder_.CreateBr(merge_bb);
}
if (ctx->ELSE()) {
builder_.SetInsertPoint(else_bb);
auto else_flow = std::any_cast<BlockFlow>(ctx->stmt(1)->accept(this));
if (else_flow == BlockFlow::Continue) {
builder_.CreateBr(merge_bb);
}
}
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
if (ctx->WHILE()) {
ir::BasicBlock* cond_bb = func_->CreateBlock(NextBlockName("while_cond"));
ir::BasicBlock* body_bb = func_->CreateBlock(NextBlockName("while_body"));
ir::BasicBlock* exit_bb = func_->CreateBlock(NextBlockName("while_exit"));
builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(cond_bb);
ir::Value* cond_val = std::any_cast<ir::Value*>(ctx->cond()->accept(this));
if (cond_val->GetType()->IsInt32()) {
ir::Value* zero = builder_.CreateConstInt(0);
cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp());
} else if (cond_val->GetType()->IsFloat()) {
ir::Value* zero = module_.GetContext().GetConstFloat(0.0f);
cond_val = builder_.CreateCmp(ir::CmpOp::Ne, cond_val, zero, module_.GetContext().NextTemp());
}
builder_.CreateCondBr(cond_val, body_bb, exit_bb);
builder_.SetInsertPoint(body_bb);
ir::BasicBlock* old_cond = current_loop_cond_bb_;
ir::BasicBlock* old_exit = current_loop_exit_bb_;
current_loop_cond_bb_ = cond_bb;
current_loop_exit_bb_ = exit_bb;
auto body_flow = std::any_cast<BlockFlow>(ctx->stmt(0)->accept(this));
if (body_flow == BlockFlow::Continue) {
builder_.CreateBr(cond_bb);
}
current_loop_cond_bb_ = old_cond;
current_loop_exit_bb_ = old_exit;
builder_.SetInsertPoint(exit_bb);
return BlockFlow::Continue;
}
if (ctx->BREAK()) {
if (!current_loop_exit_bb_) {
throw std::runtime_error(FormatError("irgen", "break 必须在循环内"));
}
builder_.CreateBr(current_loop_exit_bb_);
return BlockFlow::Terminated;
}
if (ctx->CONTINUE()) {
if (!current_loop_cond_bb_) {
throw std::runtime_error(FormatError("irgen", "continue 必须在循环内"));
}
builder_.CreateBr(current_loop_cond_bb_);
return BlockFlow::Terminated;
}
if (ctx->RETURN()) {
if (ctx->exp()) {
ir::Value* v = EvalExpr(*ctx->exp());
// Handle return type conversion if necessary
if (func_->GetType()->IsFloat() && !v->GetType()->IsFloat()) {
v = builder_.CreateSIToFP(v, module_.GetContext().NextTemp());
} else if (func_->GetType()->IsInt32() && v->GetType()->IsFloat()) {
v = builder_.CreateFPToSI(v, module_.GetContext().NextTemp());
}
builder_.CreateRet(v);
} else {
builder_.CreateRet(nullptr); // nullptr for void ret
}
return BlockFlow::Terminated;
}
if (ctx->block()) {
return ctx->block()->accept(this);
}
if (ctx->exp()) {
EvalExpr(*ctx->exp());
return BlockFlow::Continue;
}
if (ctx->SEMICOLON()) {
return BlockFlow::Continue;
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
}