diff --git a/.gitignore b/.gitignore index f680eb7..de15aa3 100644 --- a/.gitignore +++ b/.gitignore @@ -54,6 +54,7 @@ compile_commands.json .fleet/ .vs/ *.code-workspace +CLAUDE.md # CLion cmake-build-debug/ diff --git a/include/ir/IR.h b/include/ir/IR.h index 0315e28..c49b931 100644 --- a/include/ir/IR.h +++ b/include/ir/IR.h @@ -60,12 +60,14 @@ class Context { ~Context(); ConstantInt* GetConstInt(int v); ConstantFloat* GetConstFloat(float v); - std::string NextTemp(); + std::string NextTemp(); // 用于指令名(数字,连续) + std::string NextLabel(); // 用于块名(字母前缀,独立计数) private: std::unordered_map> const_ints_; std::unordered_map> const_floats_; int temp_index_ = -1; + int label_index_ = -1; }; // ─── Type ───────────────────────────────────────────────────────────────────── @@ -198,16 +200,28 @@ class GlobalValue : public User { class GlobalVariable : public Value { public: GlobalVariable(std::string name, bool is_const, int init_val, - int num_elements = 1); + int num_elements = 1, bool is_array_decl = false, + bool is_float = false); bool IsConst() const { return is_const_; } + bool IsFloat() const { return is_float_; } int GetInitVal() const { return init_val_; } + float GetInitValF() const { return init_val_f_; } int GetNumElements() const { return num_elements_; } - bool IsArray() const { return num_elements_ > 1; } - // GlobalVariable 的"指针类型"是 i32*,访问时使用 load/store + bool IsArray() const { return is_array_decl_ || num_elements_ > 1; } + void SetInitVals(std::vector v) { init_vals_ = std::move(v); } + void SetInitValsF(std::vector v) { init_vals_f_ = std::move(v); } + const std::vector& GetInitVals() const { return init_vals_; } + const std::vector& GetInitValsF() const { return init_vals_f_; } + bool HasInitVals() const { return !init_vals_.empty() || !init_vals_f_.empty(); } private: bool is_const_; + bool is_float_; int init_val_; + float init_val_f_; int num_elements_; + bool is_array_decl_; + std::vector init_vals_; + std::vector init_vals_f_; }; // ─── Instruction ────────────────────────────────────────────────────────────── @@ -409,6 +423,8 @@ class Function : public Value { Argument* GetArgument(size_t i) const; size_t GetNumArgs() const { return args_.size(); } bool IsVoidReturn() const { return type_->IsVoid(); } + // 将某个块移动到 blocks_ 列表末尾(用于确保块顺序正确) + void MoveBlockToEnd(BasicBlock* bb); private: BasicBlock* entry_ = nullptr; @@ -437,7 +453,9 @@ class Module { const std::vector>& GetFunctions() const; GlobalVariable* CreateGlobalVariable(const std::string& name, bool is_const, - int init_val, int num_elements = 1); + int init_val, int num_elements = 1, + bool is_array_decl = false, + bool is_float = false); GlobalVariable* GetGlobalVariable(const std::string& name) const; const std::vector>& GetGlobalVariables() const; @@ -494,9 +512,12 @@ class IRBuilder { AllocaInst* CreateAllocaI32(const std::string& name); AllocaInst* CreateAllocaF32(const std::string& name); AllocaInst* CreateAllocaArray(int num_elements, const std::string& name); + AllocaInst* CreateAllocaArrayF32(int num_elements, const std::string& name); GepInst* CreateGep(Value* base_ptr, Value* index, const std::string& name); LoadInst* CreateLoad(Value* ptr, const std::string& name); StoreInst* CreateStore(Value* val, Value* ptr); + // 零初始化数组(emit memset call) + void CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod); // 控制流 ReturnInst* CreateRet(Value* v); diff --git a/scripts/run_ir_test.sh b/scripts/run_ir_test.sh index 0baaa0d..f6b21af 100755 --- a/scripts/run_ir_test.sh +++ b/scripts/run_ir_test.sh @@ -4,6 +4,9 @@ PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd) COMPILER="$PROJECT_ROOT/build/bin/compiler" TEST_CASE_DIR="$PROJECT_ROOT/test/test_case" TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/ir" +VERIFY_SCRIPT="$PROJECT_ROOT/scripts/verify_ir.sh" +PARALLEL=${PARALLEL:-$(nproc)} +LOG_FILE="$TEST_RESULT_DIR/verify.log" if [ ! -x "$COMPILER" ]; then echo "错误:编译器不存在或不可执行: $COMPILER" @@ -12,47 +15,130 @@ if [ ! -x "$COMPILER" ]; then fi mkdir -p "$TEST_RESULT_DIR" +> "$LOG_FILE" -pass_count=0 -fail_count=0 -failed_cases=() +# ── 阶段1:IR 生成(并行)──────────────────────────────────────────────────── +echo "=== 阶段1:IR 生成 ===" | tee -a "$LOG_FILE" +echo "" | tee -a "$LOG_FILE" -echo "=== 开始测试 IR 生成 ===" -echo "" +GEN_TMPDIR=$(mktemp -d) -while IFS= read -r test_file; do +gen_one() { + test_file="$1" relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file") output_file="$TEST_RESULT_DIR/${relative_path%.sy}.ll" - mkdir -p "$(dirname "$output_file")" - - echo -n "测试: $relative_path ... " - "$COMPILER" --emit-ir "$test_file" > "$output_file" 2>&1 exit_code=$? - + # Use a per-case tmp file to avoid concurrent write issues + case_id=$(echo "$relative_path" | tr '/' '_') if [ $exit_code -eq 0 ] && [ -s "$output_file" ] && ! grep -q '\[error\]' "$output_file"; then - echo "通过" - pass_count=$((pass_count + 1)) + echo "通过: $relative_path" > "$GEN_TMPDIR/pass_${case_id}" else - echo "失败" - fail_count=$((fail_count + 1)) + echo "$relative_path" > "$GEN_TMPDIR/fail_${case_id}" + echo "失败: $relative_path" > "$GEN_TMPDIR/line_fail_${case_id}" + fi +} +export -f gen_one +export COMPILER TEST_CASE_DIR TEST_RESULT_DIR GEN_TMPDIR + +find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort | \ + xargs -P "$PARALLEL" -I{} bash -c 'gen_one "$@"' _ {} + +# Collect results in sorted order +failed_cases=() +for f in $(find "$TEST_CASE_DIR" -name "*.sy" | sort); do + relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$f") + case_id=$(echo "$relative_path" | tr '/' '_') + if [ -f "$GEN_TMPDIR/pass_${case_id}" ]; then + cat "$GEN_TMPDIR/pass_${case_id}" | tee -a "$LOG_FILE" + elif [ -f "$GEN_TMPDIR/fail_${case_id}" ]; then + cat "$GEN_TMPDIR/line_fail_${case_id}" | tee -a "$LOG_FILE" failed_cases+=("$relative_path") - echo " 错误信息已保存到: $output_file" fi -done < <(find "$TEST_CASE_DIR" -name "*.sy" | sort) +done + +pass_count=$(ls "$GEN_TMPDIR"/pass_* 2>/dev/null | wc -l) +fail_count=${#failed_cases[@]} +rm -rf "$GEN_TMPDIR" + +echo "" | tee -a "$LOG_FILE" +echo "--- 生成完成: 通过 $pass_count / 失败 $fail_count ---" | tee -a "$LOG_FILE" + +# ── 阶段2:IR 运行验证(并行,需要 llc + clang)────────────────────────────── +if ! command -v llc >/dev/null 2>&1 || ! command -v clang >/dev/null 2>&1; then + echo "" | tee -a "$LOG_FILE" + echo "=== 跳过阶段2:未找到 llc 或 clang,无法运行 IR ===" | tee -a "$LOG_FILE" +else + echo "" | tee -a "$LOG_FILE" + echo "=== 阶段2:IR 运行验证 ===" | tee -a "$LOG_FILE" + echo "" | tee -a "$LOG_FILE" + + VRF_TMPDIR=$(mktemp -d) + + verify_one() { + test_file="$1" + relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file") + relative_dir=$(dirname "$relative_path") + out_dir="$TEST_RESULT_DIR/$relative_dir" + stem=$(basename "${test_file%.sy}") + case_log="$out_dir/$stem.verify.log" + case_id=$(echo "$relative_path" | tr '/' '_') + if bash "$VERIFY_SCRIPT" "$test_file" "$out_dir" --run > "$case_log" 2>&1; then + echo "通过: $relative_path" > "$VRF_TMPDIR/pass_${case_id}" + else + extra=$(grep -E '(退出码|输出不匹配|错误)' "$case_log" | head -3 | sed 's/^/ /' || true) + { echo "失败: $relative_path"; [ -n "$extra" ] && echo "$extra"; } > "$VRF_TMPDIR/fail_${case_id}" + echo "$relative_path" > "$VRF_TMPDIR/failname_${case_id}" + fi + } + export -f verify_one + export TEST_CASE_DIR TEST_RESULT_DIR VERIFY_SCRIPT VRF_TMPDIR + + find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort | \ + xargs -P "$PARALLEL" -I{} bash -c 'verify_one "$@"' _ {} + + # Collect results in sorted order + verify_failed_cases=() + for f in $(find "$TEST_CASE_DIR" -name "*.sy" | sort); do + relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$f") + case_id=$(echo "$relative_path" | tr '/' '_') + if [ -f "$VRF_TMPDIR/pass_${case_id}" ]; then + cat "$VRF_TMPDIR/pass_${case_id}" | tee -a "$LOG_FILE" + elif [ -f "$VRF_TMPDIR/fail_${case_id}" ]; then + cat "$VRF_TMPDIR/fail_${case_id}" | tee -a "$LOG_FILE" + verify_failed_cases+=("$relative_path") + fi + done + + verify_pass=$(ls "$VRF_TMPDIR"/pass_* 2>/dev/null | wc -l) + verify_fail=${#verify_failed_cases[@]} + rm -rf "$VRF_TMPDIR" + + echo "" | tee -a "$LOG_FILE" + echo "--- 验证完成: 通过 $verify_pass / 失败 $verify_fail ---" | tee -a "$LOG_FILE" + + if [ ${#verify_failed_cases[@]} -gt 0 ]; then + echo "" | tee -a "$LOG_FILE" + echo "=== 验证失败的用例 ===" | tee -a "$LOG_FILE" + for f in "${verify_failed_cases[@]}"; do + [ -n "$f" ] && echo " - $f" | tee -a "$LOG_FILE" + done + fi +fi -echo "" -echo "=== 测试完成 ===" -echo "通过: $pass_count" -echo "失败: $fail_count" -echo "结果保存在: $TEST_RESULT_DIR" +# ── 汇总 ───────────────────────────────────────────────────────────────────── +echo "" | tee -a "$LOG_FILE" +echo "=== 测试完成 ===" | tee -a "$LOG_FILE" +echo "IR生成 通过: $pass_count 失败: $fail_count" | tee -a "$LOG_FILE" +echo "结果保存在: $TEST_RESULT_DIR" | tee -a "$LOG_FILE" +echo "日志保存在: $LOG_FILE" | tee -a "$LOG_FILE" if [ ${#failed_cases[@]} -gt 0 ]; then - echo "" - echo "=== 失败的用例 ===" + echo "" | tee -a "$LOG_FILE" + echo "=== IR生成失败的用例 ===" | tee -a "$LOG_FILE" for f in "${failed_cases[@]}"; do - echo " - $f" + [ -n "$f" ] && echo " - $f" | tee -a "$LOG_FILE" done exit 1 fi diff --git a/scripts/verify_ir.sh b/scripts/verify_ir.sh index 049b725..ce543ac 100755 --- a/scripts/verify_ir.sh +++ b/scripts/verify_ir.sh @@ -3,6 +3,8 @@ set -euo pipefail +PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd) + if [[ $# -lt 1 || $# -gt 3 ]]; then echo "用法: $0 [output_dir] [--run]" >&2 exit 1 @@ -31,7 +33,7 @@ if [[ ! -f "$input" ]]; then exit 1 fi -compiler="./build/bin/compiler" +compiler="$PROJECT_ROOT/build/bin/compiler" if [[ ! -x "$compiler" ]]; then echo "未找到编译器: $compiler ,请先构建(如: mkdir -p build && cd build && cmake .. && make -j)" >&2 exit 1 @@ -60,13 +62,13 @@ if [[ "$run_exec" == true ]]; then stdout_file="$out_dir/$stem.stdout" actual_file="$out_dir/$stem.actual.out" llc -filetype=obj "$out_file" -o "$obj" - clang "$obj" -o "$exe" + clang "$obj" "$PROJECT_ROOT/sylib/sylib.c" -o "$exe" -lm echo "运行 $exe ..." set +e if [[ -f "$stdin_file" ]]; then - "$exe" < "$stdin_file" > "$stdout_file" + (ulimit -s unlimited; "$exe" < "$stdin_file") > "$stdout_file" else - "$exe" > "$stdout_file" + (ulimit -s unlimited; "$exe") > "$stdout_file" fi status=$? set -e diff --git a/src/ir/Context.cpp b/src/ir/Context.cpp index a6a6ac5..dea8bca 100644 --- a/src/ir/Context.cpp +++ b/src/ir/Context.cpp @@ -29,4 +29,10 @@ std::string Context::NextTemp() { return oss.str(); } +std::string Context::NextLabel() { + std::ostringstream oss; + oss << "L" << ++label_index_; + return oss.str(); +} + } // namespace ir diff --git a/src/ir/Function.cpp b/src/ir/Function.cpp index 4cc7067..804229c 100644 --- a/src/ir/Function.cpp +++ b/src/ir/Function.cpp @@ -17,6 +17,17 @@ BasicBlock* Function::CreateBlock(const std::string& name) { return ptr; } +void Function::MoveBlockToEnd(BasicBlock* bb) { + for (size_t i = 0; i < blocks_.size(); ++i) { + if (blocks_[i].get() == bb) { + auto tmp = std::move(blocks_[i]); + blocks_.erase(blocks_.begin() + i); + blocks_.push_back(std::move(tmp)); + return; + } + } +} + BasicBlock* Function::GetEntry() { return entry_; } const BasicBlock* Function::GetEntry() const { return entry_; } const std::vector>& Function::GetBlocks() const { diff --git a/src/ir/IRBuilder.cpp b/src/ir/IRBuilder.cpp index c11568f..49060c0 100644 --- a/src/ir/IRBuilder.cpp +++ b/src/ir/IRBuilder.cpp @@ -136,6 +136,15 @@ AllocaInst* IRBuilder::CreateAllocaArray(int num_elements, num_elements, name); } +AllocaInst* IRBuilder::CreateAllocaArrayF32(int num_elements, + const std::string& name) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + return insert_block_->Append(Type::GetPtrFloat32Type(), + num_elements, name); +} + GepInst* IRBuilder::CreateGep(Value* base_ptr, Value* index, const std::string& name) { if (!insert_block_) { @@ -237,4 +246,20 @@ FPToSIInst* IRBuilder::CreateFPToSI(Value* val, const std::string& name) { return insert_block_->Append(val, name); } +void IRBuilder::CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod) { + if (!insert_block_) { + throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点")); + } + // declare memset if not already declared + if (!mod.HasExternalDecl("memset")) { + mod.DeclareExternalFunc("memset", Type::GetVoidType(), + {Type::GetPtrInt32Type(), Type::GetInt32Type(), Type::GetInt32Type()}); + } + int byte_count = num_elements * 4; + insert_block_->Append( + std::string("memset"), Type::GetVoidType(), + std::vector{ptr, ctx.GetConstInt(0), ctx.GetConstInt(byte_count)}, + std::string("")); +} + } // namespace ir diff --git a/src/ir/IRPrinter.cpp b/src/ir/IRPrinter.cpp index 6e0a827..8867dfc 100644 --- a/src/ir/IRPrinter.cpp +++ b/src/ir/IRPrinter.cpp @@ -1,6 +1,9 @@ #include "ir/IR.h" +#include +#include #include +#include #include #include @@ -50,7 +53,13 @@ static std::string ValStr(const Value* v) { return std::to_string(ci->GetValue()); } if (auto* cf = dynamic_cast(v)) { - return std::to_string(cf->GetValue()); + // LLVM IR 要求 float 常量用 64 位十六进制表示(double 精度) + double d = static_cast(cf->GetValue()); + uint64_t bits; + std::memcpy(&bits, &d, sizeof(bits)); + std::ostringstream oss; + oss << "0x" << std::hex << std::uppercase << bits; + return oss.str(); } // BasicBlock: 打印为 label %name if (dynamic_cast(v)) { @@ -59,10 +68,11 @@ static std::string ValStr(const Value* v) { // GlobalVariable: 打印为 @name if (auto* gv = dynamic_cast(v)) { if (gv->IsArray()) { - // 数组全局变量的指针:getelementptr [N x i32], [N x i32]* @name, i32 0, i32 0 - return "getelementptr ([" + std::to_string(gv->GetNumElements()) + - " x i32], [" + std::to_string(gv->GetNumElements()) + - " x i32]* @" + gv->GetName() + ", i32 0, i32 0)"; + // 数组全局变量的指针:getelementptr [N x T], [N x T]* @name, i32 0, i32 0 + const char* elem_ty = gv->IsFloat() ? "float" : "i32"; + return std::string("getelementptr ([") + std::to_string(gv->GetNumElements()) + + " x " + elem_ty + "], [" + std::to_string(gv->GetNumElements()) + + " x " + elem_ty + "]* @" + gv->GetName() + ", i32 0, i32 0)"; } return "@" + v->GetName(); } @@ -76,8 +86,12 @@ static std::string TypeVal(const Value* v) { std::to_string(ci->GetValue()); } if (auto* cf = dynamic_cast(v)) { - return std::string(TypeToStr(*cf->GetType())) + " " + - std::to_string(cf->GetValue()); + double d = static_cast(cf->GetValue()); + uint64_t bits; + std::memcpy(&bits, &d, sizeof(bits)); + std::ostringstream oss; + oss << "float 0x" << std::hex << std::uppercase << bits; + return oss.str(); } return std::string(TypeToStr(*v->GetType())) + " " + ValStr(v); } @@ -86,13 +100,34 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { // 1. 全局变量/常量 for (const auto& g : module.GetGlobalVariables()) { if (g->IsArray()) { - // 全局数组:zeroinitializer - if (g->IsConst()) { - os << "@" << g->GetName() << " = constant [" << g->GetNumElements() - << " x i32] zeroinitializer\n"; + const char* linkage = g->IsConst() ? "constant" : "global"; + const char* elem_ty = g->IsFloat() ? "float" : "i32"; + os << "@" << g->GetName() << " = " << linkage + << " [" << g->GetNumElements() << " x " << elem_ty << "] "; + if (g->HasInitVals()) { + os << "["; + if (g->IsFloat()) { + const auto& vals = g->GetInitValsF(); + for (int i = 0; i < g->GetNumElements(); ++i) { + if (i > 0) os << ", "; + float fv = (i < (int)vals.size()) ? vals[i] : 0.0f; + double d = static_cast(fv); + uint64_t bits; + std::memcpy(&bits, &d, sizeof(bits)); + std::ostringstream oss; + oss << "float 0x" << std::hex << std::uppercase << bits; + os << oss.str(); + } + } else { + const auto& vals = g->GetInitVals(); + for (int i = 0; i < g->GetNumElements(); ++i) { + if (i > 0) os << ", "; + os << "i32 " << (i < (int)vals.size() ? vals[i] : 0); + } + } + os << "]\n"; } else { - os << "@" << g->GetName() << " = global [" << g->GetNumElements() - << " x i32] zeroinitializer\n"; + os << "zeroinitializer\n"; } } else { if (g->IsConst()) { @@ -209,8 +244,11 @@ void IRPrinter::Print(const Module& module, std::ostream& os) { } case Opcode::Gep: { auto* gep = static_cast(inst); + bool is_float_ptr = gep->GetBasePtr()->GetType()->IsPtrFloat32(); + const char* elem_type = is_float_ptr ? "float" : "i32"; + const char* ptr_type = is_float_ptr ? "float*" : "i32*"; os << " %" << gep->GetName() - << " = getelementptr i32, i32* " + << " = getelementptr " << elem_type << ", " << ptr_type << " " << ValStr(gep->GetBasePtr()) << ", i32 " << ValStr(gep->GetIndex()) << "\n"; break; diff --git a/src/ir/Instruction.cpp b/src/ir/Instruction.cpp index 7830c0e..8475546 100644 --- a/src/ir/Instruction.cpp +++ b/src/ir/Instruction.cpp @@ -224,7 +224,11 @@ AllocaInst::AllocaInst(std::shared_ptr ptr_ty, int num_elements, // ─── GepInst ────────────────────────────────────────────────────────────────── GepInst::GepInst(Value* base_ptr, Value* index, std::string name) - : Instruction(Opcode::Gep, Type::GetPtrInt32Type(), std::move(name)) { + : Instruction(Opcode::Gep, + (base_ptr && base_ptr->GetType()->IsPtrFloat32()) + ? Type::GetPtrFloat32Type() + : Type::GetPtrInt32Type(), + std::move(name)) { if (!base_ptr || !index) { throw std::runtime_error(FormatError("ir", "GepInst 缺少操作数")); } @@ -265,10 +269,15 @@ GlobalValue::GlobalValue(std::shared_ptr ty, std::string name) // ─── GlobalVariable ──────────────────────────────────────────────────────────── GlobalVariable::GlobalVariable(std::string name, bool is_const, int init_val, - int num_elements) - : Value(Type::GetPtrInt32Type(), std::move(name)), + int num_elements, bool is_array_decl, + bool is_float) + : Value(is_float ? Type::GetPtrFloat32Type() : Type::GetPtrInt32Type(), + std::move(name)), is_const_(is_const), + is_float_(is_float), init_val_(init_val), - num_elements_(num_elements) {} + init_val_f_(0.0f), + num_elements_(num_elements), + is_array_decl_(is_array_decl) {} } // namespace ir diff --git a/src/ir/Module.cpp b/src/ir/Module.cpp index 5d46d90..76ea77b 100644 --- a/src/ir/Module.cpp +++ b/src/ir/Module.cpp @@ -28,9 +28,10 @@ const std::vector>& Module::GetFunctions() const { // ─── 全局变量管理 ───────────────────────────────────────────────────────────── GlobalVariable* Module::CreateGlobalVariable(const std::string& name, bool is_const, int init_val, - int num_elements) { + int num_elements, bool is_array_decl, + bool is_float) { globals_.push_back( - std::make_unique(name, is_const, init_val, num_elements)); + std::make_unique(name, is_const, init_val, num_elements, is_array_decl, is_float)); GlobalVariable* g = globals_.back().get(); global_map_[name] = g; return g; diff --git a/src/irgen/IRGenDecl.cpp b/src/irgen/IRGenDecl.cpp index f0c6a50..b69d29a 100644 --- a/src/irgen/IRGenDecl.cpp +++ b/src/irgen/IRGenDecl.cpp @@ -117,33 +117,101 @@ std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) { for (int d : dims) total *= (d > 0 ? d : 1); if (in_global_scope_) { - auto* gv = module_.CreateGlobalVariable(name, true, 0, total); + auto* gv = module_.CreateGlobalVariable(name, true, 0, total, true); global_storage_map_[constDef] = gv; global_array_dims_[constDef] = dims; + // 计算初始值并存入全局变量 + if (constDef->constInitVal()) { + std::vector strides(dims.size(), 1); + for (int i = (int)dims.size() - 2; i >= 0; --i) + strides[i] = strides[i + 1] * dims[i + 1]; + int top_stride = strides[0]; + std::vector flat(total, 0); + std::function fill; + fill = [&](SysYParser::ConstInitValContext* iv, int pos, int stride) { + if (!iv || pos >= total) return; + if (iv->constExp()) { flat[pos] = EvalConstExprInt(iv->constExp()); return; } + int sub_stride = 1; + for (int k = 0; k < (int)strides.size() - 1; ++k) + if (strides[k] == stride) { sub_stride = strides[k + 1]; break; } + int cur = pos; + for (auto* sub : iv->constInitVal()) { + if (cur >= pos + stride || cur >= total) break; + if (sub->constExp()) { flat[cur++] = EvalConstExprInt(sub->constExp()); } + else { int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos; fill(sub,a,sub_stride); cur=a+sub_stride; } + } + }; + int cur = 0; + if (constDef->constInitVal()->constExp()) { + flat[0] = EvalConstExprInt(constDef->constInitVal()->constExp()); + } else { + for (auto* sub : constDef->constInitVal()->constInitVal()) { + if (cur >= total) break; + if (sub->constExp()) { flat[cur++] = EvalConstExprInt(sub->constExp()); } + else { int a = ((cur+top_stride-1)/top_stride)*top_stride; fill(sub,a,top_stride); cur=a+top_stride; } + } + } + gv->SetInitVals(flat); + } } else { auto* slot = builder_.CreateAllocaArray(total, name); storage_map_[constDef] = slot; array_dims_[constDef] = dims; - // 扁平化初始化 + // 按 C 语义扁平化初始化(子列表对齐到维度边界) if (constDef->constInitVal()) { - std::vector flat; - flat.reserve(total); - std::function flatten = - [&](SysYParser::ConstInitValContext* iv) { - if (!iv) return; - if (iv->constExp()) { - flat.push_back(EvalConstExprInt(iv->constExp())); - } else { - for (auto* sub : iv->constInitVal()) flatten(sub); - } - }; - flatten(constDef->constInitVal()); + std::vector flat(total, 0); + + std::vector strides(dims.size(), 1); + for (int i = (int)dims.size() - 2; i >= 0; --i) + strides[i] = strides[i + 1] * dims[i + 1]; + int top_stride = strides[0]; + + std::function fill; + fill = [&](SysYParser::ConstInitValContext* iv, int pos, int stride) { + if (!iv || pos >= total) return; + if (iv->constExp()) { + flat[pos] = EvalConstExprInt(iv->constExp()); + return; + } + int sub_stride = 1; + for (int k = 0; k < (int)strides.size() - 1; ++k) + if (strides[k] == stride) { sub_stride = strides[k + 1]; break; } + int cur = pos; + for (auto* sub : iv->constInitVal()) { + if (cur >= pos + stride || cur >= total) break; + if (sub->constExp()) { + flat[cur] = EvalConstExprInt(sub->constExp()); + cur++; + } else { + int aligned = ((cur - pos + sub_stride - 1) / sub_stride) * sub_stride + pos; + fill(sub, aligned, sub_stride); + cur = aligned + sub_stride; + } + } + }; + + int cur = 0; + if (constDef->constInitVal()->constExp()) { + flat[0] = EvalConstExprInt(constDef->constInitVal()->constExp()); + } else { + for (auto* sub : constDef->constInitVal()->constInitVal()) { + if (cur >= total) break; + if (sub->constExp()) { + flat[cur] = EvalConstExprInt(sub->constExp()); + cur++; + } else { + int aligned = ((cur + top_stride - 1) / top_stride) * top_stride; + fill(sub, aligned, top_stride); + cur = aligned + top_stride; + } + } + } + for (int i = 0; i < total; ++i) { - int v = (i < (int)flat.size()) ? flat[i] : 0; auto* ptr = builder_.CreateGep( slot, builder_.CreateConstInt(i), module_.GetContext().NextTemp()); - builder_.CreateStore(builder_.CreateConstInt(v), ptr); + builder_.CreateStore(builder_.CreateConstInt(flat[i]), ptr); } } } @@ -194,9 +262,97 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { } else { int total = 1; for (int d : dims) total *= (d > 0 ? d : 1); - auto* gv = module_.CreateGlobalVariable(name, false, 0, total); + auto* gv = module_.CreateGlobalVariable(name, false, 0, total, true, is_float); global_storage_map_[ctx] = gv; global_array_dims_[ctx] = dims; + // 计算初始值 + if (ctx->initVal()) { + std::vector strides(dims.size(), 1); + for (int i = (int)dims.size() - 2; i >= 0; --i) + strides[i] = strides[i + 1] * dims[i + 1]; + int top_stride = strides[0]; + if (is_float) { + std::vector flat(total, 0.0f); + std::function fill_f; + fill_f = [&](SysYParser::InitValContext* iv, int pos, int stride) { + if (!iv || pos >= total) return; + if (iv->exp()) { + try { flat[pos] = static_cast(sem::EvaluateExp(*iv->exp()->addExp()).float_val); } catch (...) {} + return; + } + int sub_stride = 1; + for (int k = 0; k < (int)strides.size() - 1; ++k) + if (strides[k] == stride) { sub_stride = strides[k + 1]; break; } + int cur = pos; + for (auto* sub : iv->initVal()) { + if (cur >= pos + stride || cur >= total) break; + if (sub->exp()) { + try { flat[cur] = static_cast(sem::EvaluateExp(*sub->exp()->addExp()).float_val); } catch (...) {} + cur++; + } else { + int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos; + fill_f(sub, a, sub_stride); cur = a + sub_stride; + } + } + }; + int cur = 0; + if (ctx->initVal()->exp()) { + try { flat[0] = static_cast(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).float_val); } catch (...) {} + } else { + for (auto* sub : ctx->initVal()->initVal()) { + if (cur >= total) break; + if (sub->exp()) { + try { flat[cur] = static_cast(sem::EvaluateExp(*sub->exp()->addExp()).float_val); } catch (...) {} + cur++; + } else { + int a = ((cur+top_stride-1)/top_stride)*top_stride; + fill_f(sub, a, top_stride); cur = a + top_stride; + } + } + } + gv->SetInitValsF(flat); + } else { + std::vector flat(total, 0); + std::function fill; + fill = [&](SysYParser::InitValContext* iv, int pos, int stride) { + if (!iv || pos >= total) return; + if (iv->exp()) { + try { flat[pos] = static_cast(sem::EvaluateExp(*iv->exp()->addExp()).int_val); } catch (...) {} + return; + } + int sub_stride = 1; + for (int k = 0; k < (int)strides.size() - 1; ++k) + if (strides[k] == stride) { sub_stride = strides[k + 1]; break; } + int cur = pos; + for (auto* sub : iv->initVal()) { + if (cur >= pos + stride || cur >= total) break; + if (sub->exp()) { + try { flat[cur] = static_cast(sem::EvaluateExp(*sub->exp()->addExp()).int_val); } catch (...) {} + cur++; + } else { + int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos; + fill(sub, a, sub_stride); cur = a + sub_stride; + } + } + }; + int cur = 0; + if (ctx->initVal()->exp()) { + try { flat[0] = static_cast(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).int_val); } catch (...) {} + } else { + for (auto* sub : ctx->initVal()->initVal()) { + if (cur >= total) break; + if (sub->exp()) { + try { flat[cur] = static_cast(sem::EvaluateExp(*sub->exp()->addExp()).int_val); } catch (...) {} + cur++; + } else { + int a = ((cur+top_stride-1)/top_stride)*top_stride; + fill(sub, a, top_stride); cur = a + top_stride; + } + } + } + gv->SetInitVals(flat); + } + } } } else { if (storage_map_.count(ctx)) { @@ -211,6 +367,14 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { ir::Value* init; if (ctx->initVal() && ctx->initVal()->exp()) { init = EvalExpr(*ctx->initVal()->exp()); + // Coerce init value to slot type + if (!is_float && init->IsFloat32()) { + init = ToInt(init); + } else if (is_float && init->IsInt32()) { + init = ToFloat(init); + } else if (!is_float && init->IsInt1()) { + init = ToI32(init); + } } else { init = is_float ? static_cast(builder_.CreateConstFloat(0.0f)) : static_cast(builder_.CreateConstInt(0)); @@ -219,40 +383,95 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) { } else { int total = 1; for (int d : dims) total *= (d > 0 ? d : 1); - auto* slot = builder_.CreateAllocaArray(total, name); + auto* slot = is_float ? builder_.CreateAllocaArrayF32(total, module_.GetContext().NextTemp()) + : builder_.CreateAllocaArray(total, module_.GetContext().NextTemp()); storage_map_[ctx] = slot; array_dims_[ctx] = dims; + ir::Value* zero_init = is_float ? static_cast(builder_.CreateConstFloat(0.0f)) + : static_cast(builder_.CreateConstInt(0)); + if (ctx->initVal()) { - // 收集扁平化初始值 - std::vector flat; - flat.reserve(total); - std::function flatten = - [&](SysYParser::InitValContext* iv) { - if (!iv) return; - if (iv->exp()) { - flat.push_back(EvalExpr(*iv->exp())); - } else { - for (auto* sub : iv->initVal()) flatten(sub); - } - }; - flatten(ctx->initVal()); - for (int i = 0; i < total; ++i) { - ir::Value* v = (i < (int)flat.size()) ? flat[i] - : builder_.CreateConstInt(0); - auto* ptr = builder_.CreateGep( - slot, builder_.CreateConstInt(i), - module_.GetContext().NextTemp()); - builder_.CreateStore(v, ptr); + // 按 C 语义扁平化初始值:子列表对齐到对应维度边界 + std::vector flat(total, zero_init); + + // 计算各维度的 stride:stride[i] = dims[i]*dims[i+1]*...*dims[n-1] + // 但我们需要「子列表对应第几维的 stride」 + // 顶层:stride = total / dims[0](即每行的元素数) + // 递归时 stride 继续除以当前维度大小 + std::vector strides(dims.size(), 1); + for (int i = (int)dims.size() - 2; i >= 0; --i) + strides[i] = strides[i + 1] * dims[i + 1]; + int top_stride = strides[0]; // 每个顶层子列表占用的元素数 + + // fill(iv, pos, stride):将 iv 的内容填入 flat[pos..pos+stride) + // stride 表示当前层子列表对应的元素个数 + std::function fill; + fill = [&](SysYParser::InitValContext* iv, int pos, int stride) { + if (!iv || pos >= total) return; + if (iv->exp()) { + flat[pos] = EvalExpr(*iv->exp()); + return; + } + // 子列表内的 stride = stride / (当前层首维大小) + // 找到对应的 strides 层:stride == strides[k] → 子stride = strides[k+1] + int sub_stride = 1; + for (int k = 0; k < (int)strides.size() - 1; ++k) { + if (strides[k] == stride) { sub_stride = strides[k + 1]; break; } + } + int cur = pos; + for (auto* sub : iv->initVal()) { + if (cur >= pos + stride || cur >= total) break; + if (sub->exp()) { + flat[cur] = EvalExpr(*sub->exp()); + cur++; + } else { + // 对齐到 sub_stride 边界 + int aligned = ((cur - pos + sub_stride - 1) / sub_stride) * sub_stride + pos; + fill(sub, aligned, sub_stride); + cur = aligned + sub_stride; + } + } + }; + + // 顶层扫描 + int cur = 0; + if (ctx->initVal()->exp()) { + flat[0] = EvalExpr(*ctx->initVal()->exp()); + } else { + for (auto* sub : ctx->initVal()->initVal()) { + if (cur >= total) break; + if (sub->exp()) { + flat[cur] = EvalExpr(*sub->exp()); + cur++; + } else { + // 对齐到 top_stride 边界 + int aligned = ((cur + top_stride - 1) / top_stride) * top_stride; + fill(sub, aligned, top_stride); + cur = aligned + top_stride; + } + } } - } else { - // 零初始化 + + // 先 memset 归零,再只写入非零元素 + builder_.CreateMemsetZero(slot, total, module_.GetContext(), module_); for (int i = 0; i < total; ++i) { + bool is_zero = false; + if (auto* ci = dynamic_cast(flat[i])) { + is_zero = (ci->GetValue() == 0); + } else if (auto* cf = dynamic_cast(flat[i])) { + is_zero = (cf->GetValue() == 0.0f); + } + if (is_zero) continue; auto* ptr = builder_.CreateGep( slot, builder_.CreateConstInt(i), module_.GetContext().NextTemp()); - builder_.CreateStore(builder_.CreateConstInt(0), ptr); + builder_.CreateStore(flat[i], ptr); } + } else { + // 零初始化:用 memset 归零 + builder_.CreateMemsetZero(slot, total, module_.GetContext(), module_); + (void)zero_init; } } } diff --git a/src/irgen/IRGenExp.cpp b/src/irgen/IRGenExp.cpp index 39a04f3..224e606 100644 --- a/src/irgen/IRGenExp.cpp +++ b/src/irgen/IRGenExp.cpp @@ -10,10 +10,15 @@ // ─── 辅助 ───────────────────────────────────────────────────────────────────── -// 把 i32 值转成 i1(icmp ne i32 v, 0) +// 把 i32/float 值转成 i1 ir::Value* IRGenImpl::ToI1(ir::Value* v) { if (!v) throw std::runtime_error(FormatError("irgen", "ToI1: null value")); if (v->IsInt1()) return v; + if (v->IsFloat32()) { + return builder_.CreateFCmp(ir::FCmpPredicate::ONE, v, + builder_.CreateConstFloat(0.0f), + module_.GetContext().NextTemp()); + } return builder_.CreateICmp(ir::ICmpPredicate::NE, v, builder_.CreateConstInt(0), module_.GetContext().NextTemp()); @@ -87,7 +92,13 @@ void IRGenImpl::EnsureExternalDecl(const std::string& name) { } else if (name == "getch") { module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {}); } else if (name == "getfloat") { - module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); // 近似 + module_.DeclareExternalFunc(name, ir::Type::GetFloat32Type(), {}); + } else if (name == "getarray") { + module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), + {ir::Type::GetPtrInt32Type()}); + } else if (name == "getfarray") { + module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), + {ir::Type::GetPtrFloat32Type()}); } else if (name == "putint") { module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {ir::Type::GetInt32Type()}); @@ -95,10 +106,16 @@ void IRGenImpl::EnsureExternalDecl(const std::string& name) { module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {ir::Type::GetInt32Type()}); } else if (name == "putfloat") { - module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); + module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), + {ir::Type::GetFloat32Type()}); } else if (name == "putarray") { module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), - {ir::Type::GetInt32Type()}); + {ir::Type::GetInt32Type(), + ir::Type::GetPtrInt32Type()}); + } else if (name == "putfarray") { + module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), + {ir::Type::GetInt32Type(), + ir::Type::GetPtrFloat32Type()}); } else if (name == "starttime" || name == "stoptime") { module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {ir::Type::GetInt32Type()}); @@ -227,13 +244,113 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { std::vector args; if (ctx->funcRParams()) { for (auto* exp : ctx->funcRParams()->exp()) { - args.push_back(EvalExpr(*exp)); + // 检查是否是数组变量(无索引的 lVar),若是则传指针而非 load + ir::Value* arg = nullptr; + auto* add = exp->addExp(); + if (add && add->mulExp().size() == 1) { + auto* mul = add->mulExp(0); + if (mul && mul->unaryExp().size() == 1) { + auto* unary = mul->unaryExp(0); + if (unary && !unary->unaryOp() && unary->primaryExp()) { + auto* primary = unary->primaryExp(); + if (primary && primary->lVar() && primary->lVar()->exp().empty()) { + auto* lvar = primary->lVar(); + auto* decl = sema_.ResolveVarUse(lvar->Ident()); + if (decl) { + // 检查是否是数组参数(storage_map_ 里存的是指针) + auto it = storage_map_.find(decl); + if (it != storage_map_.end()) { + auto* val = it->second; + if (val && (val->IsPtrInt32() || val->IsPtrFloat32())) { + // 检查是否是 Argument(数组参数,直接传指针) + if (dynamic_cast(val)) { + arg = val; + } else if (array_dims_.count(decl)) { + // 本地数组(含 dims 记录):传首元素地址 + arg = builder_.CreateGep(val, builder_.CreateConstInt(0), + module_.GetContext().NextTemp()); + } + } + } + // 检查全局数组 + if (!arg) { + auto git = global_storage_map_.find(decl); + if (git != global_storage_map_.end()) { + auto* gv = dynamic_cast(git->second); + if (gv && gv->IsArray()) { + arg = builder_.CreateGep(git->second, + builder_.CreateConstInt(0), + module_.GetContext().NextTemp()); + } + } + } + } + } + } + } + } + // Also handle partially-indexed multi-dim arrays: arr[i] where arr is + // int arr[M][N] should pass &arr[i*N] as i32*, not load arr[i] as i32. + if (!arg) { + auto* add2 = exp->addExp(); + if (add2 && add2->mulExp().size() == 1) { + auto* mul2 = add2->mulExp(0); + if (mul2 && mul2->unaryExp().size() == 1) { + auto* unary2 = mul2->unaryExp(0); + if (unary2 && !unary2->unaryOp() && unary2->primaryExp()) { + auto* primary2 = unary2->primaryExp(); + if (primary2 && primary2->lVar() && !primary2->lVar()->exp().empty()) { + auto* lvar2 = primary2->lVar(); + auto* decl2 = sema_.ResolveVarUse(lvar2->Ident()); + if (decl2) { + std::vector dims2; + ir::Value* base2 = nullptr; + auto it2 = array_dims_.find(decl2); + if (it2 != array_dims_.end()) { + dims2 = it2->second; + auto sit = storage_map_.find(decl2); + if (sit != storage_map_.end()) base2 = sit->second; + } else { + auto git2 = global_array_dims_.find(decl2); + if (git2 != global_array_dims_.end()) { + dims2 = git2->second; + auto gsit = global_storage_map_.find(decl2); + if (gsit != global_storage_map_.end()) base2 = gsit->second; + } + } + // Partially indexed: fewer indices than dimensions -> pass pointer + bool is_param = !dims2.empty() && dims2[0] == -1; + size_t effective_dims = is_param ? dims2.size() - 1 : dims2.size(); + if (base2 && !dims2.empty() && + lvar2->exp().size() < effective_dims + (is_param ? 1 : 0)) { + arg = EvalLVarAddr(lvar2); + } + } + } + } + } + } + } + if (!arg) arg = EvalExpr(*exp); + args.push_back(arg); } } // 模块内已知函数? ir::Function* callee = module_.GetFunction(callee_name); if (callee) { + // Coerce args to match parameter types + for (size_t i = 0; i < args.size() && i < callee->GetNumArgs(); ++i) { + auto* param = callee->GetArgument(i); + if (!param || !args[i]) continue; + if (param->IsInt32() && args[i]->IsFloat32()) { + args[i] = ToInt(args[i]); + } else if (param->IsFloat32() && args[i]->IsInt32()) { + args[i] = ToFloat(args[i]); + } else if (param->IsInt32() && args[i]->IsInt1()) { + args[i] = ToI32(args[i]); + } + } std::string ret_name = callee->IsVoidReturn() ? "" : module_.GetContext().NextTemp(); auto* call = @@ -246,15 +363,28 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) { // 外部函数 EnsureExternalDecl(callee_name); - // 获取返回类型 + // 获取返回类型和参数类型 std::shared_ptr ret_type = ir::Type::GetInt32Type(); + std::vector> param_types; for (const auto& decl : module_.GetExternalDecls()) { if (decl.name == callee_name) { ret_type = decl.ret_type; + param_types = decl.param_types; break; } } bool is_void = ret_type->IsVoid(); + // Coerce args to match external function parameter types + for (size_t i = 0; i < args.size() && i < param_types.size(); ++i) { + if (!args[i]) continue; + if (param_types[i]->IsInt32() && args[i]->IsFloat32()) { + args[i] = ToInt(args[i]); + } else if (param_types[i]->IsFloat32() && args[i]->IsInt32()) { + args[i] = ToFloat(args[i]); + } else if (param_types[i]->IsInt32() && args[i]->IsInt1()) { + args[i] = ToI32(args[i]); + } + } std::string ret_name = is_void ? "" : module_.GetContext().NextTemp(); auto* call = builder_.CreateCallExternal(callee_name, ret_type, std::move(args), ret_name); @@ -331,40 +461,26 @@ ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) { throw std::runtime_error(FormatError("irgen", "数组索引维度过多")); } - ir::Value* offset = builder_.CreateConstInt(0); + ir::Value* offset = nullptr; if (is_array_param) { // 数组参数:dims[0]=-1, dims[1..n]是已知维度 - // 索引:indices[0]对应第一维,indices[1]对应第二维... for (size_t i = 0; i < indices.size(); ++i) { ir::Value* idx = EvalExpr(*indices[i]); - if (i == 0) { - // 第一维:stride = dims[1] * dims[2] * ... (如果有的话) - int stride = 1; - for (size_t j = 1; j < dims.size(); ++j) { - stride *= dims[j]; - } - if (stride > 1) { - ir::Value* scaled = builder_.CreateMul( - idx, builder_.CreateConstInt(stride), - module_.GetContext().NextTemp()); - offset = builder_.CreateAdd(offset, scaled, - module_.GetContext().NextTemp()); - } else { - offset = builder_.CreateAdd(offset, idx, - module_.GetContext().NextTemp()); - } + int stride = 1; + size_t start = (i == 0) ? 1 : i + 1; + for (size_t j = start; j < dims.size(); ++j) stride *= dims[j]; + ir::Value* term; + if (stride == 1) { + term = idx; } else { - // 后续维度 - int stride = 1; - for (size_t j = i + 1; j < dims.size(); ++j) { - stride *= dims[j]; - } - ir::Value* scaled = builder_.CreateMul( - idx, builder_.CreateConstInt(stride), - module_.GetContext().NextTemp()); - offset = builder_.CreateAdd(offset, scaled, - module_.GetContext().NextTemp()); + term = builder_.CreateMul(idx, builder_.CreateConstInt(stride), + module_.GetContext().NextTemp()); + } + if (!offset) { + offset = term; + } else { + offset = builder_.CreateAdd(offset, term, module_.GetContext().NextTemp()); } } } else { @@ -374,15 +490,24 @@ ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) { stride = (i == (int)dims.size() - 1) ? 1 : stride * dims[i + 1]; if (i < (int)indices.size()) { ir::Value* idx = EvalExpr(*indices[i]); - ir::Value* scaled = builder_.CreateMul( - idx, builder_.CreateConstInt(stride), - module_.GetContext().NextTemp()); - offset = builder_.CreateAdd(offset, scaled, + ir::Value* term; + if (stride == 1) { + term = idx; + } else { + term = builder_.CreateMul(idx, builder_.CreateConstInt(stride), module_.GetContext().NextTemp()); + } + if (!offset) { + offset = term; + } else { + offset = builder_.CreateAdd(offset, term, module_.GetContext().NextTemp()); + } } } } + if (!offset) offset = builder_.CreateConstInt(0); + return builder_.CreateGep(base, offset, module_.GetContext().NextTemp()); } @@ -486,8 +611,8 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { ir::Value* res_ext = ToI32(result); builder_.CreateStore(res_ext, res_slot); - ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.rhs"); - ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.end"); + ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".or.rhs"); + ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".or.end"); builder_.CreateCondBr(result, end_bb, rhs_bb); builder_.SetInsertPoint(rhs_bb); @@ -498,6 +623,7 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) { builder_.CreateBr(end_bb); } + func_->MoveBlockToEnd(end_bb); builder_.SetInsertPoint(end_bb); result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp())); } @@ -523,8 +649,8 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { ir::Value* res_ext = ToI32(result); builder_.CreateStore(res_ext, res_slot); - ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.rhs"); - ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.end"); + ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".and.rhs"); + ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".and.end"); builder_.CreateCondBr(result, rhs_bb, end_bb); builder_.SetInsertPoint(rhs_bb); @@ -535,6 +661,7 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) { builder_.CreateBr(end_bb); } + func_->MoveBlockToEnd(end_bb); builder_.SetInsertPoint(end_bb); result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp())); } diff --git a/src/irgen/IRGenStmt.cpp b/src/irgen/IRGenStmt.cpp index 75eb78a..965826d 100644 --- a/src/irgen/IRGenStmt.cpp +++ b/src/irgen/IRGenStmt.cpp @@ -25,6 +25,14 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) { if (ctx->Return()) { if (ctx->exp()) { ir::Value* v = EvalExpr(*ctx->exp()); + // Coerce return value to function return type + if (func_->GetType()->IsInt32() && v->IsFloat32()) { + v = ToInt(v); + } else if (func_->GetType()->IsFloat32() && v->IsInt32()) { + v = ToFloat(v); + } else if (func_->GetType()->IsInt32() && v->IsInt1()) { + v = ToI32(v); + } builder_.CreateRet(v); } else { builder_.CreateRetVoid(); @@ -54,6 +62,14 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) { if (ctx->lVar() && ctx->Assign()) { ir::Value* rhs = EvalExpr(*ctx->exp()); ir::Value* addr = EvalLVarAddr(ctx->lVar()); + // Coerce rhs to match slot type + if (addr->IsPtrInt32() && rhs->IsFloat32()) { + rhs = ToInt(rhs); + } else if (addr->IsPtrFloat32() && rhs->IsInt32()) { + rhs = ToFloat(rhs); + } else if (addr->IsPtrInt32() && rhs->IsInt1()) { + rhs = ToI32(rhs); + } builder_.CreateStore(rhs, addr); return BlockFlow::Continue; } @@ -74,32 +90,47 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) { } auto stmts = ctx->stmt(); + + // Step 1: evaluate condition (may create short-circuit blocks with lower + // SSA numbers — must happen before any branch-target blocks are created). + ir::Value* cond_val = EvalCond(*ctx->cond()); + + // Step 2: create then_bb now (its label number will be >= all short-circuit + // block numbers allocated during EvalCond). ir::BasicBlock* then_bb = func_->CreateBlock( - module_.GetContext().NextTemp() + ".if.then"); + module_.GetContext().NextLabel() + ".if.then"); + + // Step 3: create else_bb/merge_bb as placeholders. They will be moved to + // the end of the block list after their predecessors are filled in, so the + // block ordering in the output will be correct even though their label + // numbers are allocated here (before then-body sub-blocks are created). ir::BasicBlock* else_bb = nullptr; - ir::BasicBlock* merge_bb = func_->CreateBlock( - module_.GetContext().NextTemp() + ".if.end"); + ir::BasicBlock* merge_bb = nullptr; - // 求值条件(可能创建短路求值块) - ir::Value* cond_val = EvalCond(*ctx->cond()); + if (stmts.size() >= 2) { + else_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.else"); + merge_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.end"); + } else { + merge_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.end"); + } - // 检查当前块是否已终结(短路求值可能导致) + // Check if current block already terminated (short-circuit may do this) if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) { - // 条件求值已经终结了当前块,无法继续 - // 这种情况下,我们需要在merge_bb继续 + func_->MoveBlockToEnd(then_bb); + if (else_bb) func_->MoveBlockToEnd(else_bb); + func_->MoveBlockToEnd(merge_bb); builder_.SetInsertPoint(merge_bb); return BlockFlow::Continue; } if (stmts.size() >= 2) { - // if-else - else_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".if.else"); builder_.CreateCondBr(cond_val, then_bb, else_bb); } else { builder_.CreateCondBr(cond_val, then_bb, merge_bb); } - // then 分支 + // then 分支 — visit body (may create many sub-blocks with higher numbers) + func_->MoveBlockToEnd(then_bb); builder_.SetInsertPoint(then_bb); auto then_flow = VisitStmt(*stmts[0]); if (then_flow != BlockFlow::Terminated) { @@ -108,6 +139,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) { // else 分支 if (else_bb) { + func_->MoveBlockToEnd(else_bb); builder_.SetInsertPoint(else_bb); auto else_flow = VisitStmt(*stmts[1]); if (else_flow != BlockFlow::Terminated) { @@ -115,6 +147,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) { } } + func_->MoveBlockToEnd(merge_bb); builder_.SetInsertPoint(merge_bb); return BlockFlow::Continue; } @@ -124,28 +157,32 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) { if (!ctx->cond()) { throw std::runtime_error(FormatError("irgen", "while 缺少条件")); } + ir::BasicBlock* cond_bb = func_->CreateBlock( - module_.GetContext().NextTemp() + ".while.cond"); - ir::BasicBlock* body_bb = func_->CreateBlock( - module_.GetContext().NextTemp() + ".while.body"); - ir::BasicBlock* after_bb = func_->CreateBlock( - module_.GetContext().NextTemp() + ".while.end"); + module_.GetContext().NextLabel() + ".while.cond"); // 跳转到条件块 if (!builder_.GetInsertBlock()->HasTerminator()) { builder_.CreateBr(cond_bb); } - // 条件块 + // EvalCond MUST come before creating body_bb/after_bb so that + // short-circuit blocks get lower SSA numbers than the loop body blocks. builder_.SetInsertPoint(cond_bb); ir::Value* cond_val = EvalCond(*ctx->cond()); + ir::BasicBlock* body_bb = func_->CreateBlock( + module_.GetContext().NextLabel() + ".while.body"); + ir::BasicBlock* after_bb = func_->CreateBlock( + module_.GetContext().NextLabel() + ".while.end"); + // 检查条件求值后是否已终结 if (!builder_.GetInsertBlock()->HasTerminator()) { builder_.CreateCondBr(cond_val, body_bb, after_bb); } // 循环体(压入循环栈) + func_->MoveBlockToEnd(body_bb); loop_stack_.push_back({cond_bb, after_bb}); builder_.SetInsertPoint(body_bb); auto stmts = ctx->stmt(); @@ -159,6 +196,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) { } loop_stack_.pop_back(); + func_->MoveBlockToEnd(after_bb); builder_.SetInsertPoint(after_bb); return BlockFlow::Continue; } diff --git a/src/sem/func.cpp b/src/sem/func.cpp index 8da07f8..3053065 100644 --- a/src/sem/func.cpp +++ b/src/sem/func.cpp @@ -1,5 +1,6 @@ #include "sem/func.h" +#include #include #include @@ -7,6 +8,12 @@ namespace sem { +// Truncate double to float32 precision (mimics C float arithmetic) +static double ToFloat32(double v) { + float f = static_cast(v); + return static_cast(f); +} + // 编译时求值常量表达式 ConstValue EvaluateConstExp(SysYParser::ConstExpContext& ctx) { return EvaluateExp(*ctx.addExp()); @@ -73,14 +80,65 @@ ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) { if (ctx.exp()) { return EvaluateExp(*ctx.exp()->addExp()); } else if (ctx.lVar()) { - // 处理变量引用(必须是已定义的常量) + // 处理变量引用:向上遍历 AST 找到对应的常量定义并求值 auto* ident = ctx.lVar()->Ident(); if (!ident) { throw std::runtime_error(FormatError("sema", "非法变量引用")); } std::string name = ident->getText(); - // 这里简化处理,实际应该在符号表中查找常量 - // 暂时假设常量已经在前面被处理过 + // 向上遍历 AST 找到作用域内的 constDef + antlr4::ParserRuleContext* scope = + dynamic_cast(ctx.lVar()->parent); + while (scope) { + // 检查当前作用域中的所有 constDecl + for (auto* tree_child : scope->children) { + auto* child = dynamic_cast(tree_child); + if (!child) continue; + auto* block_item = dynamic_cast(child); + if (block_item && block_item->decl()) { + auto* decl = block_item->decl(); + if (decl->constDecl()) { + for (auto* def : decl->constDecl()->constDef()) { + if (def->Ident() && def->Ident()->getText() == name) { + if (def->constInitVal() && def->constInitVal()->constExp()) { + ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp()); + bool decl_is_int = decl->constDecl()->bType() && + decl->constDecl()->bType()->Int(); + if (decl_is_int) { + cv.is_int = true; + cv.int_val = static_cast(static_cast(cv.float_val)); + cv.float_val = static_cast(cv.int_val); + } + return cv; + } + } + } + } + } + // compUnit 级别的 constDecl + auto* decl = dynamic_cast(child); + if (decl && decl->constDecl()) { + for (auto* def : decl->constDecl()->constDef()) { + if (def->Ident() && def->Ident()->getText() == name) { + if (def->constInitVal() && def->constInitVal()->constExp()) { + ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp()); + // If declared as int, truncate to integer + bool decl_is_int = decl->constDecl()->bType() && + decl->constDecl()->bType()->Int(); + if (decl_is_int) { + cv.is_int = true; + cv.int_val = static_cast(static_cast(cv.float_val)); + cv.float_val = static_cast(cv.int_val); + } + return cv; + } + } + } + } + } + scope = dynamic_cast(scope->parent); + } + // 未找到常量定义,返回 0 ConstValue val; val.is_int = true; val.int_val = 0; @@ -94,11 +152,11 @@ ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) { ConstValue val; if (int_const) { val.is_int = true; - val.int_val = std::stoll(int_const->getText()); + val.int_val = std::stoll(int_const->getText(), nullptr, 0); val.float_val = static_cast(val.int_val); } else if (float_const) { val.is_int = false; - val.float_val = std::stod(float_const->getText()); + val.float_val = ToFloat32(std::stod(float_const->getText())); val.int_val = static_cast(val.float_val); } else { throw std::runtime_error(FormatError("sema", "非法数字字面量")); @@ -127,8 +185,9 @@ ConstValue AddValues(const ConstValue& lhs, const ConstValue& rhs) { result.float_val = static_cast(result.int_val); } else { result.is_int = false; - result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) + - (rhs.is_int ? rhs.int_val : rhs.float_val); + double l = lhs.is_int ? lhs.int_val : lhs.float_val; + double r = rhs.is_int ? rhs.int_val : rhs.float_val; + result.float_val = ToFloat32(l + r); result.int_val = static_cast(result.float_val); } return result; @@ -143,8 +202,9 @@ ConstValue SubValues(const ConstValue& lhs, const ConstValue& rhs) { result.float_val = static_cast(result.int_val); } else { result.is_int = false; - result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) - - (rhs.is_int ? rhs.int_val : rhs.float_val); + double l = lhs.is_int ? lhs.int_val : lhs.float_val; + double r = rhs.is_int ? rhs.int_val : rhs.float_val; + result.float_val = ToFloat32(l - r); result.int_val = static_cast(result.float_val); } return result; @@ -159,8 +219,9 @@ ConstValue MulValues(const ConstValue& lhs, const ConstValue& rhs) { result.float_val = static_cast(result.int_val); } else { result.is_int = false; - result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) * - (rhs.is_int ? rhs.int_val : rhs.float_val); + double l = lhs.is_int ? lhs.int_val : lhs.float_val; + double r = rhs.is_int ? rhs.int_val : rhs.float_val; + result.float_val = ToFloat32(l * r); result.int_val = static_cast(result.float_val); } return result; @@ -175,8 +236,9 @@ ConstValue DivValues(const ConstValue& lhs, const ConstValue& rhs) { result.float_val = static_cast(result.int_val); } else { result.is_int = false; - result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) / - (rhs.is_int ? rhs.int_val : rhs.float_val); + double l = lhs.is_int ? lhs.int_val : lhs.float_val; + double r = rhs.is_int ? rhs.int_val : rhs.float_val; + result.float_val = ToFloat32(l / r); result.int_val = static_cast(result.float_val); } return result; diff --git a/sylib/sylib.c b/sylib/sylib.c index 7f26d0b..20ddb7d 100644 --- a/sylib/sylib.c +++ b/sylib/sylib.c @@ -2,3 +2,41 @@ // - 按实验/评测规范提供 I/O 等函数实现 // - 与编译器生成的目标代码链接,支撑运行时行为 +#include +#include"sylib.h" +/* Input & output functions */ +int getint(){int t; scanf("%d",&t); return t; } +int getch(){char c; scanf("%c",&c); return (int)c; } +int getarray(int a[]){ + int n; + scanf("%d",&n); + for(int i=0;i +#include +#include +/* Input & output functions */ +int getint(),getch(),getarray(int a[]); +void putint(int a),putch(int a),putarray(int n,int a[]); +float getfloat(); +void putfloat(float a); +int getfarray(float a[]); +void putfarray(int n,float a[]); +/* Timing functions */ +void starttime(); +void stoptime(); +#endif