From 058ac57a47247be449a1e30fdd826146d5aa6163 Mon Sep 17 00:00:00 2001 From: zjx Date: Wed, 8 Apr 2026 19:36:39 +0800 Subject: [PATCH] =?UTF-8?q?=E8=AF=AD=E6=B3=95=E9=80=9A=E8=BF=87?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- include/ir/IR.h | 2 + include/irgen/IRGen.h | 17 ++- scripts/run_verify_all.sh | 60 +++++++++ src/ir/Context.cpp | 2 +- src/ir/Function.cpp | 8 +- src/ir/IRBuilder.cpp | 31 ++++- src/ir/IRPrinter.cpp | 145 +++++++++++++++++---- src/ir/Instruction.cpp | 8 +- src/irgen/IRGenDecl.cpp | 205 +++++++++++++++++++++++++++++- src/irgen/IRGenExp.cpp | 256 ++++++++++++++++++++++++++++++++++++-- src/irgen/IRGenFunc.cpp | 159 ++++++++++++++++++++++- src/irgen/IRGenStmt.cpp | 158 ++++++++++++++++++++++- 12 files changed, 994 insertions(+), 57 deletions(-) create mode 100755 scripts/run_verify_all.sh diff --git a/include/ir/IR.h b/include/ir/IR.h index f5194ed..8957452 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -457,6 +457,8 @@ class Function : public Value { BasicBlock* entry_ = nullptr; std::vector> blocks_; std::vector params_; // 参数值(通常是 Argument 类型,后续可定义) + // Owned parameter storage to keep argument Values alive + std::vector> owned_params_; std::shared_ptr func_type_; // 缓存函数类型 }; diff --git a/include/irgen/IRGen.h b/include/irgen/IRGen.h index abd03b7..45883f9 100644 --- a/include/irgen/IRGen.h +++ b/include/irgen/IRGen.h @@ -35,6 +35,13 @@ class IRGenImpl final : public SysYBaseVisitor { std::any visitNumber(SysYParser::NumberContext* ctx) override; std::any visitLVal(SysYParser::LValContext* ctx) override; std::any visitAddExp(SysYParser::AddExpContext* ctx) override; + std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override; + std::any visitMulExp(SysYParser::MulExpContext* ctx) override; + std::any visitRelExp(SysYParser::RelExpContext* ctx) override; + std::any visitEqExp(SysYParser::EqExpContext* ctx) override; + std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override; + std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override; + std::any visitConstDef(SysYParser::ConstDefContext* ctx) override; private: enum class BlockFlow { @@ -50,7 +57,15 @@ class IRGenImpl final : public SysYBaseVisitor { ir::Function* func_; ir::IRBuilder builder_; // 名称绑定由 Sema 负责;IRGen 只维护“声明 -> 存储槽位”的代码生成状态。 - std::unordered_map storage_map_; + std::unordered_map storage_map_; + // 额外增加按名称的快速映射,以防有时无法直接通过声明节点指针匹配。 + std::unordered_map name_map_; + // 常量名称到整数值的快速映射(供数组维度解析使用) + std::unordered_map const_values_; + // 当前正在处理的声明基础类型(由 visitDecl 设置,visitVarDef/visitConstDef 使用) + std::string current_btype_; + std::vector break_targets_; + std::vector continue_targets_; }; std::unique_ptr GenerateIR(SysYParser::CompUnitContext& tree, diff --git a/scripts/run_verify_all.sh b/scripts/run_verify_all.sh new file mode 100755 index 0000000..1e81233 --- /dev/null +++ b/scripts/run_verify_all.sh @@ -0,0 +1,60 @@ +#!/usr/bin/env bash +set -euo pipefail + +ROOT=$(cd "$(dirname "$0")/.." && pwd) +FUNC_DIR="$ROOT/test/test_case/functional" +OUT_BASE="$ROOT/test/test_result/function/ir" +LOG_DIR="$ROOT/test/test_result/function/ir_logs" +VERIFY="$ROOT/scripts/verify_ir.sh" + +mkdir -p "$OUT_BASE" +mkdir -p "$LOG_DIR" + +if [ ! -x "$VERIFY" ]; then + echo "verify script not executable, trying to run with bash: $VERIFY" +fi + +files=("$FUNC_DIR"/*.sy) +if [ ${#files[@]} -eq 0 ]; then + echo "No .sy files found in $FUNC_DIR" >&2 + exit 1 +fi + +total=0 +pass=0 +fail=0 +failed_list=() + +for f in "${files[@]}"; do + ((total++)) + name=$(basename "$f") + echo "=== Test: $name ===" + log="$LOG_DIR/${name%.sy}.log" + set +e + bash "$VERIFY" "$f" "$OUT_BASE" --run >"$log" 2>&1 + rc=$? + set -e + if [ $rc -eq 0 ]; then + echo "PASS: $name" + ((pass++)) + else + echo "FAIL: $name (log: $log)" + failed_list+=("$name") + ((fail++)) + fi +done + +echo +echo "Summary: total=$total pass=$pass fail=$fail" +if [ $fail -ne 0 ]; then + echo "Failed tests:"; for t in "${failed_list[@]}"; do echo " - $t"; done + echo "Tail of failure logs (last 200 lines each):" + for t in "${failed_list[@]}"; do + logfile="$LOG_DIR/${t%.sy}.log" + echo + echo "--- $t ---" + tail -n 200 "$logfile" || true + done +fi + +exit $fail diff --git a/src/ir/Context.cpp b/src/ir/Context.cpp index 4f075d6..78770eb 100644 --- a/src/ir/Context.cpp +++ b/src/ir/Context.cpp @@ -25,7 +25,7 @@ ConstantFloat* Context::GetConstFloat(float v) { std::string Context::NextTemp() { std::ostringstream oss; - oss << "%" << ++temp_index_; + oss << ++temp_index_; return oss.str(); } diff --git a/src/ir/Function.cpp b/src/ir/Function.cpp index 312abfc..c53c9bd 100644 --- a/src/ir/Function.cpp +++ b/src/ir/Function.cpp @@ -9,8 +9,14 @@ Function::Function(std::string name, std::shared_ptr ret_type, std::vector> param_types) : Value(std::move(ret_type), std::move(name)) { func_type_ = std::static_pointer_cast( - Type::GetFunctionType(GetType(), std::move(param_types))); + Type::GetFunctionType(GetType(), param_types)); entry_ = CreateBlock("entry"); + // Create arguments + for (size_t i = 0; i < param_types.size(); ++i) { + owned_params_.push_back(std::make_unique(param_types[i], "arg" + std::to_string(i))); + params_.push_back(owned_params_.back().get()); + // Note: arguments are owned in owned_params_ to ensure lifetime + } } BasicBlock* Function::CreateBlock(const std::string& name) { diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index 4fb364c..3b08e49 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -21,6 +21,10 @@ ConstantInt* IRBuilder::CreateConstInt(int v) { return ctx_.GetConstInt(v); } +ConstantFloat* IRBuilder::CreateConstFloat(float v) { + return ctx_.GetConstFloat(v); +} + BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs, const std::string& name) { if (!insert_block_) { @@ -42,6 +46,7 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs, return CreateBinary(Opcode::Add, lhs, rhs, name); } + AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) { if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); @@ -161,8 +166,30 @@ GetElementPtrInst* IRBuilder::CreateGEP(Value* ptr, std::vector indices, if (!insert_block_) { throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); } - // 计算结果类型(简化:假设为指针类型) - auto result_ty = Type::GetPointerType(ptr->GetType()); + // 计算结果类型:根据传入的 indices 逐步从 pointee 类型走到目标元素类型。 + // 例如 ptr 是指向数组的指针,GEP 使用一个索引应返回指向数组元素的指针。 + std::shared_ptr current; + if (ptr->GetType() && ptr->GetType()->IsPointer()) { + const PointerType* pty = static_cast(ptr->GetType().get()); + current = pty->GetPointeeType(); + } else { + current = ptr->GetType(); + } + // 根据每个索引推进类型层次:数组 -> 元素类型,指针 -> 指向类型 + for (size_t i = 0; i < indices.size(); ++i) { + if (!current) break; + if (current->IsArray()) { + const ArrayType* aty = static_cast(current.get()); + current = aty->GetElementType(); + } else if (current->IsPointer()) { + const PointerType* ppty = static_cast(current.get()); + current = ppty->GetPointeeType(); + } else { + // 非数组/指针类型,无法继续下钻,保持当前类型 + break; + } + } + auto result_ty = Type::GetPointerType(current); return insert_block_->Append(result_ty, ptr, indices, name); } diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 5df779b..4f77c7d 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -12,7 +12,30 @@ namespace ir { -static const char* TypeToString(const Type& ty) { +static std::string PredicateToString(CmpInst::Predicate pred, bool is_float) { + if (is_float) { + switch (pred) { + case CmpInst::EQ: return "oeq"; + case CmpInst::NE: return "one"; + case CmpInst::LT: return "olt"; + case CmpInst::LE: return "ole"; + case CmpInst::GT: return "ogt"; + case CmpInst::GE: return "oge"; + } + } else { + switch (pred) { + case CmpInst::EQ: return "eq"; + case CmpInst::NE: return "ne"; + case CmpInst::LT: return "slt"; + case CmpInst::LE: return "sle"; + case CmpInst::GT: return "sgt"; + case CmpInst::GE: return "sge"; + } + } + return "unknown"; +} + +static std::string TypeToString(const Type& ty) { switch (ty.GetKind()) { case Type::Kind::Void: return "void"; @@ -20,10 +43,14 @@ static const char* TypeToString(const Type& ty) { return "i32"; case Type::Kind::Float32: return "float"; - case Type::Kind::Pointer: - return "i32*"; // 目前仅支持 i32* 指针打印 - case Type::Kind::Array: - return "[array]"; + case Type::Kind::Pointer: { + const PointerType* p = static_cast(&ty); + return TypeToString(*p->GetPointeeType()) + "*"; + } + case Type::Kind::Array: { + const ArrayType* a = static_cast(&ty); + return std::string("[") + std::to_string(a->GetSize()) + " x " + TypeToString(*a->GetElementType()) + "]"; + } case Type::Kind::Function: return "[function]"; case Type::Kind::Label: @@ -60,17 +87,65 @@ static const char* OpcodeToString(Opcode op) { return "?"; } +static std::string ConstantValueToString(const ConstantValue* cv); + static std::string ValueToString(const Value* v) { + if (!v) return ""; if (auto* ci = dynamic_cast(v)) { return std::to_string(ci->GetValue()); } - return v ? v->GetName() : ""; + if (auto* cf = dynamic_cast(v)) { + // simple float literal + return std::to_string(cf->GetValue()); + } + if (auto* ca = dynamic_cast(v)) { + return ConstantValueToString(ca); + } + // fallback to name for instructions/alloca/vars — prefix with '%' + return std::string("%") + v->GetName(); +} + +static std::string ConstantValueToString(const ConstantValue* cv) { + if (!cv) return ""; + if (auto* ci = dynamic_cast(cv)) return std::to_string(ci->GetValue()); + if (auto* cf = dynamic_cast(cv)) { + std::string s = std::to_string(cf->GetValue()); + size_t dot = s.find('.'); + if (dot != std::string::npos) { + size_t e = s.find('e'); + if (e == std::string::npos) e = s.size(); + while (e > dot + 1 && s[e-1] == '0') e--; + if (e == dot + 1) s = s.substr(0, dot + 1) + "0"; + else s = s.substr(0, e); + } + return s; + } + if (auto* ca = dynamic_cast(cv)) { + // format: [ , , ... ] + const auto& elems = ca->GetElements(); + std::string out = "["; + for (size_t i = 0; i < elems.size(); ++i) { + if (i) out += ", "; + // each element should be printed with its type and value + auto* e = elems[i]; + std::string etype = TypeToString(*e->GetType()); + out += etype + " " + ConstantValueToString(e); + } + out += "]"; + return out; + } + return ""; } void IRPrinter::Print(const Module& module, std::ostream& os) { for (const auto& func : module.GetFunctions()) { - os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() - << "() {\n"; + os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() << "("; + const auto& params = func->GetParams(); + for (size_t i = 0; i < params.size(); ++i) { + if (i) os << ", "; + os << TypeToString(*params[i]->GetType()) << " " << ValueToString(params[i]); + } + os << ") {\n"; for (const auto& bb : func->GetBlocks()) { if (!bb) { continue; @@ -91,18 +166,31 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { case Opcode::LShr: case Opcode::AShr: { auto* bin = static_cast(inst); - os << " " << bin->GetName() << " = " - << OpcodeToString(bin->GetOpcode()) << " " - << TypeToString(*bin->GetLhs()->GetType()) << " " - << ValueToString(bin->GetLhs()) << ", " - << ValueToString(bin->GetRhs()) << "\n"; + // choose opcode name: integer ops use e.g. 'add','sub', float ops use 'fadd','fsub', etc. + std::string op_name = OpcodeToString(bin->GetOpcode()); + bool is_float = bin->GetLhs()->GetType()->IsFloat32(); + if (is_float) { + switch (bin->GetOpcode()) { + case Opcode::Add: op_name = "fadd"; break; + case Opcode::Sub: op_name = "fsub"; break; + case Opcode::Mul: op_name = "fmul"; break; + case Opcode::Div: op_name = "fdiv"; break; + case Opcode::Mod: op_name = "frem"; break; + default: break; + } + } + os << " %" << bin->GetName() << " = " + << op_name << " " + << TypeToString(*bin->GetLhs()->GetType()) << " " + << ValueToString(bin->GetLhs()) << ", " + << ValueToString(bin->GetRhs()) << "\n"; break; } case Opcode::ICmp: case Opcode::FCmp: { auto* cmp = static_cast(inst); - os << " " << cmp->GetName() << " = " - << OpcodeToString(cmp->GetOpcode()) << " eq " + os << " %" << cmp->GetName() << " = " + << OpcodeToString(cmp->GetOpcode()) << " " << PredicateToString(cmp->GetPredicate(), cmp->GetOpcode() == Opcode::FCmp) << " " << TypeToString(*cmp->GetLhs()->GetType()) << " " << ValueToString(cmp->GetLhs()) << ", " << ValueToString(cmp->GetRhs()) << "\n"; @@ -110,16 +198,16 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { } case Opcode::Alloca: { auto* alloca = static_cast(inst); - os << " " << alloca->GetName() << " = alloca " - << TypeToString(*static_cast(alloca->GetType().get())->GetPointeeType()) << "\n"; + os << " %" << alloca->GetName() << " = alloca " + << TypeToString(*static_cast(alloca->GetType().get())->GetPointeeType()) << "\n"; break; } case Opcode::Load: { auto* load = static_cast(inst); - os << " " << load->GetName() << " = load " - << TypeToString(*load->GetType()) << ", " - << TypeToString(*load->GetPtr()->GetType()) << " " - << ValueToString(load->GetPtr()) << "\n"; + os << " %" << load->GetName() << " = load " + << TypeToString(*load->GetType()) << ", " + << TypeToString(*load->GetPtr()->GetType()) << " " + << ValueToString(load->GetPtr()) << "\n"; break; } case Opcode::Store: { @@ -152,7 +240,7 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { auto* call = static_cast(inst); os << " "; if (!call->GetName().empty()) { - os << call->GetName() << " = "; + os << "%" << call->GetName() << " = "; } os << "call " << TypeToString(*call->GetCallee()->GetType()) << " @" << call->GetCallee()->GetName() << "("; @@ -166,9 +254,12 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { } case Opcode::GEP: { auto* gep = static_cast(inst); - os << " " << gep->GetName() << " = getelementptr " - << TypeToString(*gep->GetPtr()->GetType()) << " " - << ValueToString(gep->GetPtr()); + os << " %" << gep->GetName() << " = getelementptr "; + // Print element type first, then the pointer type and pointer value + const auto ptrType = gep->GetPtr()->GetType(); + const PointerType* pty = static_cast(ptrType.get()); + os << TypeToString(*pty->GetPointeeType()) << ", " + << TypeToString(*ptrType) << " " << ValueToString(gep->GetPtr()); for (auto* idx : gep->GetIndices()) { os << ", " << TypeToString(*idx->GetType()) << " " << ValueToString(idx); } @@ -177,8 +268,8 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { } case Opcode::Phi: { auto* phi = static_cast(inst); - os << " " << phi->GetName() << " = phi " - << TypeToString(*phi->GetType()); + os << " %" << phi->GetName() << " = phi " + << TypeToString(*phi->GetType()); for (const auto& incoming : phi->GetIncomings()) { os << " [ " << ValueToString(incoming.first) << ", %" << incoming.second->GetName() << " ]"; diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 236abd9..4bfc173 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -63,8 +63,8 @@ void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; } BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, Value* rhs, std::string name) : Instruction(op, std::move(ty), std::move(name)) { - if (op != Opcode::Add) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add")); + if (op != Opcode::Add && op != Opcode::Sub && op != Opcode::Mul && op != Opcode::Div && op != Opcode::Mod) { + throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持算术操作")); } if (!lhs || !rhs) { throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数")); @@ -76,8 +76,8 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr ty, Value* lhs, type_->GetKind() != lhs->GetType()->GetKind()) { throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配")); } - if (!type_->IsInt32()) { - throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32")); + if (!type_->IsInt32() && !type_->IsFloat32()) { + throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32 和 float")); } AddOperand(lhs); AddOperand(rhs); diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index be7a356..4325dbe 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -5,6 +5,7 @@ #include "SysYParser.h" #include "ir/IR.h" #include "utils/Log.h" +#include // helper functions removed; VarDef uses Ident() directly per current grammar. @@ -22,6 +23,52 @@ std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) { return {}; } +std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) { + if (!ctx) throw std::runtime_error(FormatError("irgen", "缺少常量定义")); + if (!ctx->Ident()) throw std::runtime_error(FormatError("irgen", "常量声明缺少名称")); + if (!ctx->constInitVal()) throw std::runtime_error(FormatError("irgen", "常量必须初始化")); + if (storage_map_.find(ctx) != storage_map_.end()) throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); + + auto* slot = (current_btype_ == "float") ? + static_cast(builder_.CreateAllocaFloat(module_.GetContext().NextTemp())) : + static_cast(builder_.CreateAllocaI32(module_.GetContext().NextTemp())); + storage_map_[ctx] = slot; + name_map_[ctx->Ident()->getText()] = slot; + + // Try to evaluate a scalar const initializer + ir::ConstantValue* cinit = nullptr; + try { + auto* initval = ctx->constInitVal(); + if (initval && initval->constExp() && initval->constExp()->addExp()) { + if (current_btype_ == "float") { + auto* add = initval->constExp()->addExp(); + float fv = std::stof(add->getText()); + cinit = module_.GetContext().GetConstFloat(fv); + } else { + auto* add = initval->constExp()->addExp(); + int iv = std::stoi(add->getText()); + cinit = module_.GetContext().GetConstInt(iv); + } + } + } catch(...) { + // fallback: try evaluate via visitor + try { + auto* add = ctx->constInitVal()->constExp()->addExp(); + ir::Value* v = std::any_cast(add->accept(this)); + if (auto* cv = dynamic_cast(v)) cinit = cv; + } catch(...) {} + } + if (cinit) builder_.CreateStore(cinit, slot); + else builder_.CreateStore((current_btype_=="float"? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) : (ir::Value*)module_.GetContext().GetConstInt(0)), slot); + // record simple integer consts for dimension evaluation + try { + if (auto* ci = dynamic_cast(cinit)) { + const_values_[ctx->Ident()->getText()] = ci->GetValue(); + } + } catch(...) {} + return {}; +} + IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult( SysYParser::BlockItemContext& item) { return std::any_cast(item.accept(this)); @@ -63,6 +110,15 @@ std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) { } return {}; } + if (ctx->constDecl()) { + auto* cdecl = ctx->constDecl(); + if (!cdecl->bType()) throw std::runtime_error(FormatError("irgen", "缺少常量基类型")); + current_btype_ = cdecl->bType()->getText(); + if (current_btype_ != "int" && current_btype_ != "float") throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int/float 常量声明")); + for (auto* const_def : cdecl->constDef()) if (const_def) const_def->accept(this); + current_btype_.clear(); + return {}; + } throw std::runtime_error(FormatError("irgen", "暂不支持的声明类型")); } @@ -81,9 +137,156 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { if (storage_map_.find(ctx) != storage_map_.end()) { throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位")); } + // check if this is an array declaration (has constExp dimensions) + if (!ctx->constExp().empty()) { + // parse dims + std::vector dims; + for (auto* ce : ctx->constExp()) { + try { + int v = 0; + auto anyv = sema_.GetConstVal(ce); + if (anyv.has_value()) { + if (anyv.type() == typeid(int)) v = std::any_cast(anyv); + else if (anyv.type() == typeid(long)) v = (int)std::any_cast(anyv); + else throw std::runtime_error("not-const-int"); + } else { + // try simple patterns like NUM or IDENT+NUM or NUM+IDENT + std::string s = ce->addExp()->getText(); + s.erase(std::remove_if(s.begin(), s.end(), ::isspace), s.end()); + auto pos = s.find('+'); + if (pos == std::string::npos) { + // plain number or identifier + try { v = std::stoi(s); } + catch(...) { + // try lookup identifier in recorded consts or symbol table + auto it = const_values_.find(s); + if (it != const_values_.end()) v = (int)it->second; + else { + VarInfo vi; void* declctx = nullptr; + if (sema_.GetSymbolTable().LookupVar(s, vi, declctx) && vi.const_val.has_value()) { + if (vi.const_val.type() == typeid(int)) v = std::any_cast(vi.const_val); + else if (vi.const_val.type() == typeid(long)) v = (int)std::any_cast(vi.const_val); + else throw std::runtime_error("not-const-int"); + } else throw std::runtime_error("not-const-int"); + } + } + } else { + // form A+B where A or B may be ident or number + std::string L = s.substr(0, pos); + std::string R = s.substr(pos + 1); + int lv = 0, rv = 0; bool ok = false; + // try left + try { lv = std::stoi(L); ok = true; } catch(...) { + auto it = const_values_.find(L); + if (it != const_values_.end()) { lv = (int)it->second; ok = true; } + else { + VarInfo vi; void* declctx = nullptr; + if (sema_.GetSymbolTable().LookupVar(L, vi, declctx) && vi.const_val.has_value()) { + if (vi.const_val.type() == typeid(int)) lv = std::any_cast(vi.const_val); + else if (vi.const_val.type() == typeid(long)) lv = (int)std::any_cast(vi.const_val); + ok = true; + } + } + } + // try right + try { rv = std::stoi(R); ok = ok && true; } catch(...) { + auto it2 = const_values_.find(R); + if (it2 != const_values_.end()) { rv = (int)it2->second; ok = ok && true; } + else { + VarInfo vi2; void* declctx2 = nullptr; + if (sema_.GetSymbolTable().LookupVar(R, vi2, declctx2) && vi2.const_val.has_value()) { + if (vi2.const_val.type() == typeid(int)) rv = std::any_cast(vi2.const_val); + else if (vi2.const_val.type() == typeid(long)) rv = (int)std::any_cast(vi2.const_val); + ok = ok && true; + } else ok = false; + } + } + if (!ok) throw std::runtime_error("not-const-int"); + v = lv + rv; + } + } + dims.push_back(v); + } catch (...) { + throw std::runtime_error(FormatError("irgen", "数组维度必须为常量整数")); + } + } + + std::shared_ptr elemTy = (current_btype_ == "float") ? ir::Type::GetFloat32Type() : ir::Type::GetInt32Type(); + + // build nested array type + std::function(size_t)> makeArrayType = [&](size_t level) -> std::shared_ptr { + if (level + 1 >= dims.size()) return ir::Type::GetArrayType(elemTy, dims[level]); + auto sub = makeArrayType(level + 1); + return ir::Type::GetArrayType(sub, dims[level]); + }; + auto fullArrayTy = makeArrayType(0); + auto arr_ptr_ty = ir::Type::GetPointerType(fullArrayTy); + auto* array_slot = builder_.CreateAlloca(arr_ptr_ty, module_.GetContext().NextTemp()); + storage_map_[ctx] = array_slot; + name_map_[ctx->Ident()->getText()] = array_slot; + + // compute spans and total scalar slots + int nlevels = (int)dims.size(); + std::vector span(nlevels); + int total = 1; + for (int i = nlevels - 1; i >= 0; --i) { + if (i == nlevels - 1) span[i] = 1; + else span[i] = span[i + 1] * dims[i + 1]; + total *= dims[i]; + } + + ir::Value* zero = elemTy->IsFloat32() ? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) : (ir::Value*)module_.GetContext().GetConstInt(0); + std::vector slots(total, zero); + + // process initializer (if any) into linear slots + if (auto* init_value = ctx->initVal()) { + std::function process_group; + process_group = [&](SysYParser::InitValContext* init, int level, int& pos) { + if (level >= nlevels) return; + int sub_span = span[level]; + int elems = dims[level]; + if (!init) { pos += elems * sub_span; return; } + for (auto* child : init->initVal()) { + if (pos >= total) break; + if (!child) { pos += 1; continue; } + if (!child->initVal().empty()) { + int subpos = pos; + int inner = subpos; + process_group(child, level + 1, inner); + pos = subpos + sub_span; + } else if (child->exp()) { + try { ir::Value* v = EvalExpr(*child->exp()); if (pos < total) slots[pos] = v; } catch(...) {} + pos += 1; + } else { pos += 1; } + } + }; + int pos0 = 0; + process_group(init_value, 0, pos0); + } + + // emit stores for each scalar slot in row-major order + for (int idx = 0; idx < total; ++idx) { + std::vector indices; + int rem = idx; + for (int L = 0; L < nlevels; ++L) { + int ind = rem / span[L]; + indices.push_back(ind % dims[L]); + rem = rem % span[L]; + } + std::vector gep_inds; + gep_inds.push_back(module_.GetContext().GetConstInt(0)); + for (int v : indices) gep_inds.push_back(module_.GetContext().GetConstInt(v)); + while (gep_inds.size() < (size_t)(1 + nlevels)) gep_inds.push_back(module_.GetContext().GetConstInt(0)); + auto* gep = builder_.CreateGEP(array_slot, gep_inds, module_.GetContext().NextTemp()); + builder_.CreateStore(slots[idx], gep); + } + + return {}; + } + + // scalar variable auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp()); storage_map_[ctx] = slot; - ir::Value* init = nullptr; if (auto* init_value = ctx->initVal()) { if (!init_value->exp()) { diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index 4a189d9..a6f89ce 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -35,11 +35,14 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) { std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) { - if (!ctx || !ctx->IntConst()) { - throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量")); + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法数字字面量")); + if (ctx->IntConst()) { + return static_cast(builder_.CreateConstInt(std::stoi(ctx->getText()))); } - return static_cast( - builder_.CreateConstInt(std::stoi(ctx->getText()))); + if (ctx->FloatConst()) { + return static_cast(builder_.CreateConstFloat(std::stof(ctx->getText()))); + } + throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量或浮点字面量")); } // 变量使用的处理流程: @@ -55,11 +58,59 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { // find storage by matching declaration node stored in Sema context // Sema stores types/decl contexts in IRGenContext maps; here we search storage_map_ by name std::string name = ctx->Ident()->getText(); + // 优先使用按名称的快速映射 + auto nit = name_map_.find(name); + if (nit != name_map_.end()) { + // 支持下标访问:若有索引表达式列表,则生成 GEP + load + if (ctx->exp().size() > 0) { + std::vector indices; + // 首个索引用于穿过数组对象 + indices.push_back(builder_.CreateConstInt(0)); + for (auto* e : ctx->exp()) { + indices.push_back(EvalExpr(*e)); + } + auto* gep = builder_.CreateGEP(nit->second, indices, module_.GetContext().NextTemp()); + return static_cast(builder_.CreateLoad(gep, module_.GetContext().NextTemp())); + } + // 如果映射到的是常量,直接返回常量值;否则按原来行为从槽位 load + if (nit->second->IsConstant()) return nit->second; + return static_cast(builder_.CreateLoad(nit->second, module_.GetContext().NextTemp())); + } for (auto& kv : storage_map_) { - // kv.first is VarDefContext*, try to get Ident text - if (kv.first && kv.first->Ident() && kv.first->Ident()->getText() == name) { - return static_cast( - builder_.CreateLoad(kv.second, module_.GetContext().NextTemp())); + if (!kv.first) continue; + if (auto* vdef = dynamic_cast(kv.first)) { + if (vdef->Ident() && vdef->Ident()->getText() == name) { + if (ctx->exp().size() > 0) { + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e)); + auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp()); + return static_cast(builder_.CreateLoad(gep, module_.GetContext().NextTemp())); + } + return static_cast(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp())); + } + } else if (auto* fparam = dynamic_cast(kv.first)) { + if (fparam->Ident() && fparam->Ident()->getText() == name) { + if (ctx->exp().size() > 0) { + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e)); + auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp()); + return static_cast(builder_.CreateLoad(gep, module_.GetContext().NextTemp())); + } + return static_cast(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp())); + } + } else if (auto* cdef = dynamic_cast(kv.first)) { + if (cdef->Ident() && cdef->Ident()->getText() == name) { + if (ctx->exp().size() > 0) { + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e)); + auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp()); + return static_cast(builder_.CreateLoad(gep, module_.GetContext().NextTemp())); + } + return static_cast(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp())); + } } } throw std::runtime_error(FormatError("irgen", "变量声明缺少存储槽位: " + name)); @@ -68,11 +119,190 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) { std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) { if (!ctx) throw std::runtime_error(FormatError("irgen", "非法加法表达式")); - // left-associative: evaluate first two mulExp as a simple binary add + try { + // left-associative: fold across all mulExp operands if (ctx->mulExp().size() == 1) return ctx->mulExp(0)->accept(this); - ir::Value* lhs = std::any_cast(ctx->mulExp(0)->accept(this)); - ir::Value* rhs = std::any_cast(ctx->mulExp(1)->accept(this)); + ir::Value* cur = std::any_cast(ctx->mulExp(0)->accept(this)); + // extract operator sequence from text (in-order '+' or '-') + std::string text = ctx->getText(); + std::vector ops; + for (char c : text) if (c == '+' || c == '-') ops.push_back(c); + for (size_t i = 1; i < ctx->mulExp().size(); ++i) { + ir::Value* rhs = std::any_cast(ctx->mulExp(i)->accept(this)); + char opch = (i - 1 < ops.size()) ? ops[i - 1] : '+'; + ir::Opcode op = (opch == '-') ? ir::Opcode::Sub : ir::Opcode::Add; + cur = builder_.CreateBinary(op, cur, rhs, module_.GetContext().NextTemp()); + } + return static_cast(cur); + } catch (const std::exception& e) { + LogInfo(std::string("[irgen] exception in visitAddExp text=") + ctx->getText() + ", err=" + e.what(), std::cerr); + throw; + } +} + +std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) { + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法乘法表达式")); + if (ctx->unaryExp().size() == 1) return ctx->unaryExp(0)->accept(this); + ir::Value* cur = std::any_cast(ctx->unaryExp(0)->accept(this)); + // extract operator sequence for '*', '/', '%' + std::string text = ctx->getText(); + std::vector ops; + for (char c : text) if (c == '*' || c == '/' || c == '%') ops.push_back(c); + for (size_t i = 1; i < ctx->unaryExp().size(); ++i) { + ir::Value* rhs = std::any_cast(ctx->unaryExp(i)->accept(this)); + char opch = (i - 1 < ops.size()) ? ops[i - 1] : '*'; + ir::Opcode op = ir::Opcode::Mul; + if (opch == '/') op = ir::Opcode::Div; + else if (opch == '%') op = ir::Opcode::Mod; + cur = builder_.CreateBinary(op, cur, rhs, module_.GetContext().NextTemp()); + } + return static_cast(cur); +} + +std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法一元表达式")); + if (ctx->primaryExp()) return ctx->primaryExp()->accept(this); + // function call: Ident '(' funcRParams? ')' + if (ctx->Ident() && ctx->getText().find("(") != std::string::npos) { + std::string fname = ctx->Ident()->getText(); + std::vector args; + if (ctx->funcRParams()) { + for (auto* e : ctx->funcRParams()->exp()) { + args.push_back(EvalExpr(*e)); + } + } + // find existing function or create an external declaration (assume int return) + ir::Function* callee = nullptr; + for (auto &fup : module_.GetFunctions()) { + if (fup && fup->GetName() == fname) { callee = fup.get(); break; } + } + if (!callee) { + std::vector> param_types; + for (auto* a : args) { + if (a && a->IsFloat32()) param_types.push_back(ir::Type::GetFloat32Type()); + else param_types.push_back(ir::Type::GetInt32Type()); + } + callee = module_.CreateFunction(fname, ir::Type::GetInt32Type(), param_types); + } + return static_cast(builder_.CreateCall(callee, args, module_.GetContext().NextTemp())); + } + if (ctx->unaryExp()) { + ir::Value* val = std::any_cast(ctx->unaryExp()->accept(this)); + if (ctx->unaryOp() && ctx->unaryOp()->getText() == "+") return static_cast(val); + else if (ctx->unaryOp() && ctx->unaryOp()->getText() == "-") { + // 负号:0 - val,区分整型/浮点 + if (val->IsFloat32()) { + ir::Value* zero = builder_.CreateConstFloat(0.0f); + return static_cast(builder_.CreateSub(zero, val, module_.GetContext().NextTemp())); + } else { + ir::Value* zero = builder_.CreateConstInt(0); + return static_cast(builder_.CreateSub(zero, val, module_.GetContext().NextTemp())); + } + } + if (ctx->unaryOp() && ctx->unaryOp()->getText() == "!") { + // logical not: produce int 1 if val == 0, else 0 + if (val->IsFloat32()) { + ir::Value* zerof = builder_.CreateConstFloat(0.0f); + return static_cast(builder_.CreateFCmp(ir::CmpInst::EQ, val, zerof, module_.GetContext().NextTemp())); + } else { + ir::Value* zero = builder_.CreateConstInt(0); + return static_cast(builder_.CreateICmp(ir::CmpInst::EQ, val, zero, module_.GetContext().NextTemp())); + } + } + } + throw std::runtime_error(FormatError("irgen", "不支持的一元运算")); +} + +std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) { + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法关系表达式")); + if (ctx->addExp().size() == 1) return ctx->addExp(0)->accept(this); + ir::Value* lhs = std::any_cast(ctx->addExp(0)->accept(this)); + ir::Value* rhs = std::any_cast(ctx->addExp(1)->accept(this)); + // 类型提升 + if (lhs->IsFloat32() && rhs->IsInt32()) { + if (auto* ci = dynamic_cast(rhs)) { + rhs = builder_.CreateConstFloat(static_cast(ci->GetValue())); + } else { + throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换")); + } + } else if (rhs->IsFloat32() && lhs->IsInt32()) { + if (auto* ci = dynamic_cast(lhs)) { + lhs = builder_.CreateConstFloat(static_cast(ci->GetValue())); + } else { + throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换")); + } + } + ir::CmpInst::Predicate pred = ir::CmpInst::EQ; + std::string text = ctx->getText(); + if (text.find("<=") != std::string::npos) pred = ir::CmpInst::LE; + else if (text.find(">=") != std::string::npos) pred = ir::CmpInst::GE; + else if (text.find("<") != std::string::npos) pred = ir::CmpInst::LT; + else if (text.find(">") != std::string::npos) pred = ir::CmpInst::GT; + if (lhs->IsFloat32() || rhs->IsFloat32()) { + return static_cast( + builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp())); + } return static_cast( - builder_.CreateBinary(ir::Opcode::Add, lhs, rhs, - module_.GetContext().NextTemp())); + builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp())); } + +std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) { + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法相等表达式")); + if (ctx->relExp().size() == 1) return ctx->relExp(0)->accept(this); + ir::Value* lhs = std::any_cast(ctx->relExp(0)->accept(this)); + ir::Value* rhs = std::any_cast(ctx->relExp(1)->accept(this)); + // 类型提升 + if (lhs->IsFloat32() && rhs->IsInt32()) { + if (auto* ci = dynamic_cast(rhs)) { + rhs = builder_.CreateConstFloat(static_cast(ci->GetValue())); + } else { + throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换")); + } + } else if (rhs->IsFloat32() && lhs->IsInt32()) { + if (auto* ci = dynamic_cast(lhs)) { + lhs = builder_.CreateConstFloat(static_cast(ci->GetValue())); + } else { + throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换")); + } + } + ir::CmpInst::Predicate pred = ir::CmpInst::EQ; + std::string text = ctx->getText(); + if (text.find("==") != std::string::npos) pred = ir::CmpInst::EQ; + else if (text.find("!=") != std::string::npos) pred = ir::CmpInst::NE; + if (lhs->IsFloat32() || rhs->IsFloat32()) { + return static_cast( + builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp())); + } + return static_cast( + builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp())); +} + +std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式")); + if (ctx->eqExp().size() == 1) return ctx->eqExp(0)->accept(this); + // For simplicity, treat as int (0 or 1) + ir::Value* lhs = std::any_cast(ctx->eqExp(0)->accept(this)); + ir::Value* rhs = std::any_cast(ctx->eqExp(1)->accept(this)); + // lhs && rhs : (lhs != 0) && (rhs != 0) + ir::Value* zero = builder_.CreateConstInt(0); + ir::Value* lhs_ne = builder_.CreateICmp(ir::CmpInst::NE, lhs, zero, module_.GetContext().NextTemp()); + ir::Value* rhs_ne = builder_.CreateICmp(ir::CmpInst::NE, rhs, zero, module_.GetContext().NextTemp()); + return static_cast( + builder_.CreateMul(lhs_ne, rhs_ne, module_.GetContext().NextTemp())); +} + +std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { + if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式")); + if (ctx->lAndExp().size() == 1) return ctx->lAndExp(0)->accept(this); + ir::Value* lhs = std::any_cast(ctx->lAndExp(0)->accept(this)); + ir::Value* rhs = std::any_cast(ctx->lAndExp(1)->accept(this)); + // lhs || rhs : (lhs != 0) || (rhs != 0) + ir::Value* zero = builder_.CreateConstInt(0); + ir::Value* lhs_ne = builder_.CreateICmp(ir::CmpInst::NE, lhs, zero, module_.GetContext().NextTemp()); + ir::Value* rhs_ne = builder_.CreateICmp(ir::CmpInst::NE, rhs, zero, module_.GetContext().NextTemp()); + ir::Value* or_val = builder_.CreateAdd(lhs_ne, rhs_ne, module_.GetContext().NextTemp()); + ir::Value* one = builder_.CreateConstInt(1); + return static_cast( + builder_.CreateICmp(ir::CmpInst::GE, or_val, one, module_.GetContext().NextTemp())); +} + diff --git a/src/irgen/IRGenFunc.cpp b/src/irgen/IRGenFunc.cpp index 8fb2764..64db25b 100644 --- a/src/irgen/IRGenFunc.cpp +++ b/src/irgen/IRGenFunc.cpp @@ -5,6 +5,7 @@ #include "SysYParser.h" #include "ir/IR.h" #include "utils/Log.h" +#include namespace { @@ -42,7 +43,119 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { if (ctx->funcDef().empty()) { throw std::runtime_error(FormatError("irgen", "缺少函数定义")); } - ctx->funcDef(0)->accept(this); + // 先处理顶层声明(仅支持简单的 const int 初始化作为全局常量) + for (auto* decl : ctx->decl()) { + if (!decl) continue; + if (decl->constDecl()) { + auto* cdecl = decl->constDecl(); + for (auto* cdef : cdecl->constDef()) { + if (!cdef || !cdef->Ident() || !cdef->constInitVal() || !cdef->constInitVal()->constExp()) continue; + // 仅支持形如: const int a = 10; 的简单常量初始化(字面量) + auto* add = cdef->constInitVal()->constExp()->addExp(); + if (!add) continue; + try { + int v = std::stoi(add->getText()); + auto* cval = module_.GetContext().GetConstInt(v); + name_map_[cdef->Ident()->getText()] = cval; + } catch (...) { + // 无法解析则跳过,全局复杂常量暂不支持 + } + } + } + // 支持简单的全局变量声明(数组或标量),初始化为零 + if (decl->varDecl()) { + auto* vdecl = decl->varDecl(); + if (!vdecl->bType()) continue; + std::string btype = vdecl->bType()->getText(); + for (auto* vdef : vdecl->varDef()) { + if (!vdef) continue; + LogInfo(std::string("[irgen] global varDef text=") + vdef->getText() + std::string(" ident=") + (vdef->Ident() ? vdef->Ident()->getText() : std::string("")) + std::string(" dims=") + std::to_string((int)vdef->constExp().size()), std::cerr); + if (!vdef || !vdef->Ident()) continue; + std::string name = vdef->Ident()->getText(); + // array globals + if (!vdef->constExp().empty()) { + std::vector dims; + bool ok = true; + for (auto* ce : vdef->constExp()) { + try { + int val = 0; + auto anyv = sema_.GetConstVal(ce); + if (anyv.has_value()) { + if (anyv.type() == typeid(int)) val = std::any_cast(anyv); + else if (anyv.type() == typeid(long)) val = (int)std::any_cast(anyv); + else throw std::runtime_error("not-const-int"); + } else { + // try literal parse + try { + val = std::stoi(ce->addExp()->getText()); + } catch (...) { + // try lookup in name_map_ for previously created const + std::string t = ce->addExp()->getText(); + auto it = name_map_.find(t); + if (it != name_map_.end() && it->second && it->second->IsConstant()) { + if (auto* ci = dynamic_cast(it->second)) { + val = ci->GetValue(); + } else { + ok = false; break; + } + } else { + ok = false; break; + } + } + } + dims.push_back(val); + } catch (...) { ok = false; break; } + } + if (!ok) continue; + // build zero constant array similar to visitVarDef + std::function&, size_t, std::shared_ptr)> buildZero; + buildZero = [&](const std::vector& ds, size_t idx, std::shared_ptr elemTy) -> ir::ConstantValue* { + if (idx >= ds.size()) return nullptr; + std::vector elems; + if (idx + 1 == ds.size()) { + for (int i = 0; i < ds[idx]; ++i) { + if (elemTy->IsFloat32()) elems.push_back(module_.GetContext().GetConstFloat(0.0f)); + else elems.push_back(module_.GetContext().GetConstInt(0)); + } + } else { + for (int i = 0; i < ds[idx]; ++i) { + ir::ConstantValue* sub = buildZero(ds, idx + 1, elemTy); + if (sub) elems.push_back(sub); + else elems.push_back(module_.GetContext().GetConstInt(0)); + } + } + std::function(size_t)> makeArrayType = [&](size_t level) -> std::shared_ptr { + if (level + 1 >= ds.size()) return ir::Type::GetArrayType(elemTy, ds[level]); + auto sub = makeArrayType(level + 1); + return ir::Type::GetArrayType(sub, ds[level]); + }; + auto at_real = makeArrayType(idx); + return new ir::ConstantArray(at_real, elems); + }; + std::shared_ptr elemTy = (btype == "float") ? ir::Type::GetFloat32Type() : ir::Type::GetInt32Type(); + ir::ConstantValue* zero = buildZero(dims, 0, elemTy); + auto gvty = ir::Type::GetPointerType(zero ? zero->GetType() : ir::Type::GetPointerType(elemTy)); + ir::GlobalValue* gv = module_.CreateGlobalVariable(name, gvty, zero); + name_map_[name] = gv; + LogInfo(std::string("[irgen] created global ") + name, std::cerr); + } else { + // scalar global + std::shared_ptr elemTy = (btype == "float") ? ir::Type::GetFloat32Type() : ir::Type::GetInt32Type(); + ir::ConstantValue* init = nullptr; + if (btype == "float") init = module_.GetContext().GetConstFloat(0.0f); + else init = module_.GetContext().GetConstInt(0); + ir::GlobalValue* gv = module_.CreateGlobalVariable(name, ir::Type::GetPointerType(elemTy), init); + name_map_[name] = gv; + LogInfo(std::string("[irgen] created global ") + name, std::cerr); + } + } + } + } + + // 生成编译单元中所有函数定义(之前只生成第一个函数) + for (size_t i = 0; i < ctx->funcDef().size(); ++i) { + if (ctx->funcDef(i)) ctx->funcDef(i)->accept(this); + } return {}; } @@ -61,6 +174,7 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) { // - 入口块中的参数初始化逻辑。 // ... +// 因此这里目前只支持最小的“无参 int 函数”生成。 // 因此这里目前只支持最小的“无参 int 函数”生成。 std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { if (!ctx) { @@ -72,15 +186,50 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) { if (!ctx->Ident()) { throw std::runtime_error(FormatError("irgen", "缺少函数名")); } - if (!ctx->funcType() || ctx->funcType()->getText() != "int") { - throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数")); + if (!ctx->funcType()) { + throw std::runtime_error(FormatError("irgen", "缺少函数返回类型")); + } + std::shared_ptr ret_type; + if (ctx->funcType()->getText() == "int") ret_type = ir::Type::GetInt32Type(); + else if (ctx->funcType()->getText() == "float") ret_type = ir::Type::GetFloat32Type(); + else if (ctx->funcType()->getText() == "void") ret_type = ir::Type::GetVoidType(); + else throw std::runtime_error(FormatError("irgen", "当前仅支持 int/float/void 函数")); + + std::vector> param_types; + if (ctx->funcFParams()) { + for (auto* param : ctx->funcFParams()->funcFParam()) { + if (param->bType()->getText() == "int") param_types.push_back(ir::Type::GetInt32Type()); + else if (param->bType()->getText() == "float") param_types.push_back(ir::Type::GetFloat32Type()); + else throw std::runtime_error(FormatError("irgen", "当前仅支持 int/float 参数")); + } } - func_ = module_.CreateFunction(ctx->Ident()->getText(), ir::Type::GetInt32Type()); + func_ = module_.CreateFunction(ctx->Ident()->getText(), ret_type, param_types); builder_.SetInsertPoint(func_->GetEntry()); - storage_map_.clear(); + + // Allocate storage for parameters + if (ctx->funcFParams()) { + int idx = 0; + for (auto* param : ctx->funcFParams()->funcFParam()) { + std::string param_name = param->Ident()->getText(); + ir::AllocaInst* alloca = nullptr; + if (param->bType()->getText() == "float") alloca = builder_.CreateAllocaFloat(param_name); + else alloca = builder_.CreateAllocaI32(param_name); + storage_map_[param] = alloca; + name_map_[param_name] = alloca; + // Store the argument value + auto* arg = func_->GetParams()[idx]; + builder_.CreateStore(arg, alloca); + idx++; + } + } + ctx->block()->accept(this); + // 如果函数体末尾没有显式终结(如 void 函数没有 return),补一个隐式 return + if (!builder_.GetInsertBlock()->HasTerminator()) { + builder_.CreateRet(nullptr); + } // 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。 VerifyFunctionStructure(*func_); return {}; diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index eb44f9a..9f1cb56 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -19,8 +19,10 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { if (!ctx) { throw std::runtime_error(FormatError("irgen", "缺少语句")); } - // stmt can be many forms; handle return specifically - if (!ctx->getText().empty() && ctx->getText().find("return") != std::string::npos) { + std::string text = ctx->getText(); + LogInfo("[irgen] visitStmt text='" + text + "' break_size=" + std::to_string(break_targets_.size()) + " cont_size=" + std::to_string(continue_targets_.size()), std::cerr); + // return + if (ctx->getStart()->getText() == "return") { if (ctx->exp()) { ir::Value* v = EvalExpr(*ctx->exp()); builder_.CreateRet(v); @@ -29,6 +31,158 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) { } return BlockFlow::Terminated; } + + // assignment: lVal '=' exp + if (ctx->lVal() && text.find("=") != std::string::npos) { + ir::Value* val = EvalExpr(*ctx->exp()); + std::string name = ctx->lVal()->Ident()->getText(); + // 优先检查按名称的快速映射(支持全局变量) + auto nit = name_map_.find(name); + if (nit != name_map_.end()) { + // 支持带索引的赋值 + if (ctx->lVal()->exp().size() > 0) { + std::vector indices; + indices.push_back(builder_.CreateConstInt(0)); + for (auto* e : ctx->lVal()->exp()) indices.push_back(EvalExpr(*e)); + auto* gep = builder_.CreateGEP(nit->second, indices, module_.GetContext().NextTemp()); + builder_.CreateStore(val, gep); + return BlockFlow::Continue; + } + builder_.CreateStore(val, nit->second); + return BlockFlow::Continue; + } + for (auto& kv : storage_map_) { + if (!kv.first) continue; + if (auto* vdef = dynamic_cast(kv.first)) { + if (vdef->Ident() && vdef->Ident()->getText() == name) { + builder_.CreateStore(val, kv.second); + return BlockFlow::Continue; + } + } else if (auto* fparam = dynamic_cast(kv.first)) { + if (fparam->Ident() && fparam->Ident()->getText() == name) { + builder_.CreateStore(val, kv.second); + return BlockFlow::Continue; + } + } else if (auto* cdef = dynamic_cast(kv.first)) { + if (cdef->Ident() && cdef->Ident()->getText() == name) { + builder_.CreateStore(val, kv.second); + return BlockFlow::Continue; + } + } + } + throw std::runtime_error(FormatError("irgen", "变量未声明: " + name)); + } + + // if + if (ctx->getStart()->getText() == "if" && ctx->cond()) { + ir::Value* condv = std::any_cast(ctx->cond()->lOrExp()->accept(this)); + ir::BasicBlock* then_bb = func_->CreateBlock("if.then"); + ir::BasicBlock* else_bb = (ctx->stmt().size() > 1) ? func_->CreateBlock("if.else") : nullptr; + ir::BasicBlock* merge_bb = func_->CreateBlock("if.merge"); + + if (else_bb) builder_.CreateCondBr(condv, then_bb, else_bb); + else builder_.CreateCondBr(condv, then_bb, merge_bb); + + // then + builder_.SetInsertPoint(then_bb); + ctx->stmt(0)->accept(this); + if (!builder_.GetInsertBlock()->HasTerminator()) builder_.CreateBr(merge_bb); + + // else + if (else_bb) { + builder_.SetInsertPoint(else_bb); + ctx->stmt(1)->accept(this); + if (!builder_.GetInsertBlock()->HasTerminator()) builder_.CreateBr(merge_bb); + } + + builder_.SetInsertPoint(merge_bb); + return BlockFlow::Continue; + } + + // while + if (ctx->getStart()->getText() == "while" && ctx->cond()) { + ir::BasicBlock* cond_bb = func_->CreateBlock("while.cond"); + ir::BasicBlock* body_bb = func_->CreateBlock("while.body"); + ir::BasicBlock* after_bb = func_->CreateBlock("while.after"); + + builder_.CreateBr(cond_bb); + // cond + builder_.SetInsertPoint(cond_bb); + ir::Value* condv = std::any_cast(ctx->cond()->lOrExp()->accept(this)); + builder_.CreateCondBr(condv, body_bb, after_bb); + + // body + builder_.SetInsertPoint(body_bb); + LogInfo("[irgen] while body about to push targets, before sizes: break=" + std::to_string(break_targets_.size()) + ", cont=" + std::to_string(continue_targets_.size()), std::cerr); + break_targets_.push_back(after_bb); + continue_targets_.push_back(cond_bb); + LogInfo("[irgen] after push: break_targets size=" + std::to_string(break_targets_.size()) + ", continue_targets size=" + std::to_string(continue_targets_.size()), std::cerr); + ctx->stmt(0)->accept(this); + LogInfo("[irgen] before pop: break_targets size=" + std::to_string(break_targets_.size()) + ", continue_targets size=" + std::to_string(continue_targets_.size()), std::cerr); + continue_targets_.pop_back(); + break_targets_.pop_back(); + LogInfo("[irgen] after pop: break_targets size=" + std::to_string(break_targets_.size()) + ", continue_targets size=" + std::to_string(continue_targets_.size()), std::cerr); + if (!builder_.GetInsertBlock()->HasTerminator()) builder_.CreateBr(cond_bb); + + builder_.SetInsertPoint(after_bb); + return BlockFlow::Continue; + } + + // break + if (ctx->getStart()->getText() == "break") { + if (break_targets_.empty()) { + // fallback: 尝试通过函数块名找目标(不依赖 sema),兼容因栈丢失导致的情况 + ir::BasicBlock* fallback = nullptr; + for (auto &bb_up : func_->GetBlocks()) { + auto *bb = bb_up.get(); + if (!bb) continue; + if (bb->GetName().find("while.after") != std::string::npos) fallback = bb; + } + if (fallback) { + LogInfo("[irgen] emit break (fallback), target=" + fallback->GetName(), std::cerr); + builder_.CreateBr(fallback); + return BlockFlow::Terminated; + } + throw std::runtime_error(FormatError("irgen", "break 不在循环内")); + } + LogInfo("[irgen] emit break, break_targets size=" + std::to_string(break_targets_.size()), std::cerr); + builder_.CreateBr(break_targets_.back()); + return BlockFlow::Terminated; + } + + // continue + if (ctx->getStart()->getText() == "continue") { + if (continue_targets_.empty()) { + ir::BasicBlock* fallback = nullptr; + for (auto &bb_up : func_->GetBlocks()) { + auto *bb = bb_up.get(); + if (!bb) continue; + if (bb->GetName().find("while.cond") != std::string::npos) fallback = bb; + } + if (fallback) { + LogInfo("[irgen] emit continue (fallback), target=" + fallback->GetName(), std::cerr); + builder_.CreateBr(fallback); + return BlockFlow::Terminated; + } + throw std::runtime_error(FormatError("irgen", "continue 不在循环内")); + } + LogInfo("[irgen] emit continue, continue_targets size=" + std::to_string(continue_targets_.size()), std::cerr); + builder_.CreateBr(continue_targets_.back()); + return BlockFlow::Terminated; + } + + // block + if (ctx->block()) { + return ctx->block()->accept(this); + } + + // expression statement + if (ctx->exp()) { + EvalExpr(*ctx->exp()); + return BlockFlow::Continue; + } + throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型")); } +