Merge pull request '初步通过verify测试' (#5) from ptabmhn4l/nudt-compiler-cpp:fix/irgen into develop

develop
ptabmhn4l 2 weeks ago
commit 80c46cee7e

1
.gitignore vendored

@ -54,6 +54,7 @@ compile_commands.json
.fleet/
.vs/
*.code-workspace
CLAUDE.md
# CLion
cmake-build-debug/

@ -60,12 +60,14 @@ class Context {
~Context();
ConstantInt* GetConstInt(int v);
ConstantFloat* GetConstFloat(float v);
std::string NextTemp();
std::string NextTemp(); // 用于指令名(数字,连续)
std::string NextLabel(); // 用于块名(字母前缀,独立计数)
private:
std::unordered_map<int, std::unique_ptr<ConstantInt>> const_ints_;
std::unordered_map<float, std::unique_ptr<ConstantFloat>> const_floats_;
int temp_index_ = -1;
int label_index_ = -1;
};
// ─── Type ─────────────────────────────────────────────────────────────────────
@ -198,16 +200,28 @@ class GlobalValue : public User {
class GlobalVariable : public Value {
public:
GlobalVariable(std::string name, bool is_const, int init_val,
int num_elements = 1);
int num_elements = 1, bool is_array_decl = false,
bool is_float = false);
bool IsConst() const { return is_const_; }
bool IsFloat() const { return is_float_; }
int GetInitVal() const { return init_val_; }
float GetInitValF() const { return init_val_f_; }
int GetNumElements() const { return num_elements_; }
bool IsArray() const { return num_elements_ > 1; }
// GlobalVariable 的"指针类型"是 i32*,访问时使用 load/store
bool IsArray() const { return is_array_decl_ || num_elements_ > 1; }
void SetInitVals(std::vector<int> v) { init_vals_ = std::move(v); }
void SetInitValsF(std::vector<float> v) { init_vals_f_ = std::move(v); }
const std::vector<int>& GetInitVals() const { return init_vals_; }
const std::vector<float>& GetInitValsF() const { return init_vals_f_; }
bool HasInitVals() const { return !init_vals_.empty() || !init_vals_f_.empty(); }
private:
bool is_const_;
bool is_float_;
int init_val_;
float init_val_f_;
int num_elements_;
bool is_array_decl_;
std::vector<int> init_vals_;
std::vector<float> init_vals_f_;
};
// ─── Instruction ──────────────────────────────────────────────────────────────
@ -409,6 +423,8 @@ class Function : public Value {
Argument* GetArgument(size_t i) const;
size_t GetNumArgs() const { return args_.size(); }
bool IsVoidReturn() const { return type_->IsVoid(); }
// 将某个块移动到 blocks_ 列表末尾(用于确保块顺序正确)
void MoveBlockToEnd(BasicBlock* bb);
private:
BasicBlock* entry_ = nullptr;
@ -437,7 +453,9 @@ class Module {
const std::vector<std::unique_ptr<Function>>& GetFunctions() const;
GlobalVariable* CreateGlobalVariable(const std::string& name, bool is_const,
int init_val, int num_elements = 1);
int init_val, int num_elements = 1,
bool is_array_decl = false,
bool is_float = false);
GlobalVariable* GetGlobalVariable(const std::string& name) const;
const std::vector<std::unique_ptr<GlobalVariable>>& GetGlobalVariables() const;
@ -494,9 +512,12 @@ class IRBuilder {
AllocaInst* CreateAllocaI32(const std::string& name);
AllocaInst* CreateAllocaF32(const std::string& name);
AllocaInst* CreateAllocaArray(int num_elements, const std::string& name);
AllocaInst* CreateAllocaArrayF32(int num_elements, const std::string& name);
GepInst* CreateGep(Value* base_ptr, Value* index, const std::string& name);
LoadInst* CreateLoad(Value* ptr, const std::string& name);
StoreInst* CreateStore(Value* val, Value* ptr);
// 零初始化数组emit memset call
void CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod);
// 控制流
ReturnInst* CreateRet(Value* v);

@ -4,6 +4,9 @@ PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
COMPILER="$PROJECT_ROOT/build/bin/compiler"
TEST_CASE_DIR="$PROJECT_ROOT/test/test_case"
TEST_RESULT_DIR="$PROJECT_ROOT/test/test_result/ir"
VERIFY_SCRIPT="$PROJECT_ROOT/scripts/verify_ir.sh"
PARALLEL=${PARALLEL:-$(nproc)}
LOG_FILE="$TEST_RESULT_DIR/verify.log"
if [ ! -x "$COMPILER" ]; then
echo "错误:编译器不存在或不可执行: $COMPILER"
@ -12,47 +15,130 @@ if [ ! -x "$COMPILER" ]; then
fi
mkdir -p "$TEST_RESULT_DIR"
> "$LOG_FILE"
pass_count=0
fail_count=0
failed_cases=()
# ── 阶段1IR 生成(并行)────────────────────────────────────────────────────
echo "=== 阶段1IR 生成 ===" | tee -a "$LOG_FILE"
echo "" | tee -a "$LOG_FILE"
echo "=== 开始测试 IR 生成 ==="
echo ""
GEN_TMPDIR=$(mktemp -d)
while IFS= read -r test_file; do
gen_one() {
test_file="$1"
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
output_file="$TEST_RESULT_DIR/${relative_path%.sy}.ll"
mkdir -p "$(dirname "$output_file")"
echo -n "测试: $relative_path ... "
"$COMPILER" --emit-ir "$test_file" > "$output_file" 2>&1
exit_code=$?
# Use a per-case tmp file to avoid concurrent write issues
case_id=$(echo "$relative_path" | tr '/' '_')
if [ $exit_code -eq 0 ] && [ -s "$output_file" ] && ! grep -q '\[error\]' "$output_file"; then
echo "通过"
pass_count=$((pass_count + 1))
echo "通过: $relative_path" > "$GEN_TMPDIR/pass_${case_id}"
else
echo "失败"
fail_count=$((fail_count + 1))
echo "$relative_path" > "$GEN_TMPDIR/fail_${case_id}"
echo "失败: $relative_path" > "$GEN_TMPDIR/line_fail_${case_id}"
fi
}
export -f gen_one
export COMPILER TEST_CASE_DIR TEST_RESULT_DIR GEN_TMPDIR
find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort | \
xargs -P "$PARALLEL" -I{} bash -c 'gen_one "$@"' _ {}
# Collect results in sorted order
failed_cases=()
for f in $(find "$TEST_CASE_DIR" -name "*.sy" | sort); do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$f")
case_id=$(echo "$relative_path" | tr '/' '_')
if [ -f "$GEN_TMPDIR/pass_${case_id}" ]; then
cat "$GEN_TMPDIR/pass_${case_id}" | tee -a "$LOG_FILE"
elif [ -f "$GEN_TMPDIR/fail_${case_id}" ]; then
cat "$GEN_TMPDIR/line_fail_${case_id}" | tee -a "$LOG_FILE"
failed_cases+=("$relative_path")
echo " 错误信息已保存到: $output_file"
fi
done < <(find "$TEST_CASE_DIR" -name "*.sy" | sort)
done
pass_count=$(ls "$GEN_TMPDIR"/pass_* 2>/dev/null | wc -l)
fail_count=${#failed_cases[@]}
rm -rf "$GEN_TMPDIR"
echo "" | tee -a "$LOG_FILE"
echo "--- 生成完成: 通过 $pass_count / 失败 $fail_count ---" | tee -a "$LOG_FILE"
# ── 阶段2IR 运行验证(并行,需要 llc + clang──────────────────────────────
if ! command -v llc >/dev/null 2>&1 || ! command -v clang >/dev/null 2>&1; then
echo "" | tee -a "$LOG_FILE"
echo "=== 跳过阶段2未找到 llc 或 clang无法运行 IR ===" | tee -a "$LOG_FILE"
else
echo "" | tee -a "$LOG_FILE"
echo "=== 阶段2IR 运行验证 ===" | tee -a "$LOG_FILE"
echo "" | tee -a "$LOG_FILE"
VRF_TMPDIR=$(mktemp -d)
verify_one() {
test_file="$1"
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$test_file")
relative_dir=$(dirname "$relative_path")
out_dir="$TEST_RESULT_DIR/$relative_dir"
stem=$(basename "${test_file%.sy}")
case_log="$out_dir/$stem.verify.log"
case_id=$(echo "$relative_path" | tr '/' '_')
if bash "$VERIFY_SCRIPT" "$test_file" "$out_dir" --run > "$case_log" 2>&1; then
echo "通过: $relative_path" > "$VRF_TMPDIR/pass_${case_id}"
else
extra=$(grep -E '(退出码|输出不匹配|错误)' "$case_log" | head -3 | sed 's/^/ /' || true)
{ echo "失败: $relative_path"; [ -n "$extra" ] && echo "$extra"; } > "$VRF_TMPDIR/fail_${case_id}"
echo "$relative_path" > "$VRF_TMPDIR/failname_${case_id}"
fi
}
export -f verify_one
export TEST_CASE_DIR TEST_RESULT_DIR VERIFY_SCRIPT VRF_TMPDIR
find "$TEST_CASE_DIR" -name "*.sy" -not -path '*/*performance*/*' | sort | \
xargs -P "$PARALLEL" -I{} bash -c 'verify_one "$@"' _ {}
# Collect results in sorted order
verify_failed_cases=()
for f in $(find "$TEST_CASE_DIR" -name "*.sy" | sort); do
relative_path=$(realpath --relative-to="$TEST_CASE_DIR" "$f")
case_id=$(echo "$relative_path" | tr '/' '_')
if [ -f "$VRF_TMPDIR/pass_${case_id}" ]; then
cat "$VRF_TMPDIR/pass_${case_id}" | tee -a "$LOG_FILE"
elif [ -f "$VRF_TMPDIR/fail_${case_id}" ]; then
cat "$VRF_TMPDIR/fail_${case_id}" | tee -a "$LOG_FILE"
verify_failed_cases+=("$relative_path")
fi
done
verify_pass=$(ls "$VRF_TMPDIR"/pass_* 2>/dev/null | wc -l)
verify_fail=${#verify_failed_cases[@]}
rm -rf "$VRF_TMPDIR"
echo "" | tee -a "$LOG_FILE"
echo "--- 验证完成: 通过 $verify_pass / 失败 $verify_fail ---" | tee -a "$LOG_FILE"
if [ ${#verify_failed_cases[@]} -gt 0 ]; then
echo "" | tee -a "$LOG_FILE"
echo "=== 验证失败的用例 ===" | tee -a "$LOG_FILE"
for f in "${verify_failed_cases[@]}"; do
[ -n "$f" ] && echo " - $f" | tee -a "$LOG_FILE"
done
fi
fi
echo ""
echo "=== 测试完成 ==="
echo "通过: $pass_count"
echo "失败: $fail_count"
echo "结果保存在: $TEST_RESULT_DIR"
# ── 汇总 ─────────────────────────────────────────────────────────────────────
echo "" | tee -a "$LOG_FILE"
echo "=== 测试完成 ===" | tee -a "$LOG_FILE"
echo "IR生成 通过: $pass_count 失败: $fail_count" | tee -a "$LOG_FILE"
echo "结果保存在: $TEST_RESULT_DIR" | tee -a "$LOG_FILE"
echo "日志保存在: $LOG_FILE" | tee -a "$LOG_FILE"
if [ ${#failed_cases[@]} -gt 0 ]; then
echo ""
echo "=== 失败的用例 ==="
echo "" | tee -a "$LOG_FILE"
echo "=== IR生成失败的用例 ===" | tee -a "$LOG_FILE"
for f in "${failed_cases[@]}"; do
echo " - $f"
[ -n "$f" ] && echo " - $f" | tee -a "$LOG_FILE"
done
exit 1
fi

@ -3,6 +3,8 @@
set -euo pipefail
PROJECT_ROOT=$(cd "$(dirname "$0")/.." ; pwd)
if [[ $# -lt 1 || $# -gt 3 ]]; then
echo "用法: $0 <input.sy> [output_dir] [--run]" >&2
exit 1
@ -31,7 +33,7 @@ if [[ ! -f "$input" ]]; then
exit 1
fi
compiler="./build/bin/compiler"
compiler="$PROJECT_ROOT/build/bin/compiler"
if [[ ! -x "$compiler" ]]; then
echo "未找到编译器: $compiler ,请先构建(如: mkdir -p build && cd build && cmake .. && make -j" >&2
exit 1
@ -60,13 +62,13 @@ if [[ "$run_exec" == true ]]; then
stdout_file="$out_dir/$stem.stdout"
actual_file="$out_dir/$stem.actual.out"
llc -filetype=obj "$out_file" -o "$obj"
clang "$obj" -o "$exe"
clang "$obj" "$PROJECT_ROOT/sylib/sylib.c" -o "$exe" -lm
echo "运行 $exe ..."
set +e
if [[ -f "$stdin_file" ]]; then
"$exe" < "$stdin_file" > "$stdout_file"
(ulimit -s unlimited; "$exe" < "$stdin_file") > "$stdout_file"
else
"$exe" > "$stdout_file"
(ulimit -s unlimited; "$exe") > "$stdout_file"
fi
status=$?
set -e

@ -29,4 +29,10 @@ std::string Context::NextTemp() {
return oss.str();
}
std::string Context::NextLabel() {
std::ostringstream oss;
oss << "L" << ++label_index_;
return oss.str();
}
} // namespace ir

@ -17,6 +17,17 @@ BasicBlock* Function::CreateBlock(const std::string& name) {
return ptr;
}
void Function::MoveBlockToEnd(BasicBlock* bb) {
for (size_t i = 0; i < blocks_.size(); ++i) {
if (blocks_[i].get() == bb) {
auto tmp = std::move(blocks_[i]);
blocks_.erase(blocks_.begin() + i);
blocks_.push_back(std::move(tmp));
return;
}
}
}
BasicBlock* Function::GetEntry() { return entry_; }
const BasicBlock* Function::GetEntry() const { return entry_; }
const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {

@ -136,6 +136,15 @@ AllocaInst* IRBuilder::CreateAllocaArray(int num_elements,
num_elements, name);
}
AllocaInst* IRBuilder::CreateAllocaArrayF32(int num_elements,
const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
return insert_block_->Append<AllocaInst>(Type::GetPtrFloat32Type(),
num_elements, name);
}
GepInst* IRBuilder::CreateGep(Value* base_ptr, Value* index,
const std::string& name) {
if (!insert_block_) {
@ -237,4 +246,20 @@ FPToSIInst* IRBuilder::CreateFPToSI(Value* val, const std::string& name) {
return insert_block_->Append<FPToSIInst>(val, name);
}
void IRBuilder::CreateMemsetZero(Value* ptr, int num_elements, Context& ctx, Module& mod) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
// declare memset if not already declared
if (!mod.HasExternalDecl("memset")) {
mod.DeclareExternalFunc("memset", Type::GetVoidType(),
{Type::GetPtrInt32Type(), Type::GetInt32Type(), Type::GetInt32Type()});
}
int byte_count = num_elements * 4;
insert_block_->Append<CallInst>(
std::string("memset"), Type::GetVoidType(),
std::vector<Value*>{ptr, ctx.GetConstInt(0), ctx.GetConstInt(byte_count)},
std::string(""));
}
} // namespace ir

@ -1,6 +1,9 @@
#include "ir/IR.h"
#include <cstdint>
#include <cstring>
#include <ostream>
#include <sstream>
#include <stdexcept>
#include <string>
@ -50,7 +53,13 @@ static std::string ValStr(const Value* v) {
return std::to_string(ci->GetValue());
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return std::to_string(cf->GetValue());
// LLVM IR 要求 float 常量用 64 位十六进制表示double 精度)
double d = static_cast<double>(cf->GetValue());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "0x" << std::hex << std::uppercase << bits;
return oss.str();
}
// BasicBlock: 打印为 label %name
if (dynamic_cast<const BasicBlock*>(v)) {
@ -59,10 +68,11 @@ static std::string ValStr(const Value* v) {
// GlobalVariable: 打印为 @name
if (auto* gv = dynamic_cast<const GlobalVariable*>(v)) {
if (gv->IsArray()) {
// 数组全局变量的指针getelementptr [N x i32], [N x i32]* @name, i32 0, i32 0
return "getelementptr ([" + std::to_string(gv->GetNumElements()) +
" x i32], [" + std::to_string(gv->GetNumElements()) +
" x i32]* @" + gv->GetName() + ", i32 0, i32 0)";
// 数组全局变量的指针getelementptr [N x T], [N x T]* @name, i32 0, i32 0
const char* elem_ty = gv->IsFloat() ? "float" : "i32";
return std::string("getelementptr ([") + std::to_string(gv->GetNumElements()) +
" x " + elem_ty + "], [" + std::to_string(gv->GetNumElements()) +
" x " + elem_ty + "]* @" + gv->GetName() + ", i32 0, i32 0)";
}
return "@" + v->GetName();
}
@ -76,8 +86,12 @@ static std::string TypeVal(const Value* v) {
std::to_string(ci->GetValue());
}
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
return std::string(TypeToStr(*cf->GetType())) + " " +
std::to_string(cf->GetValue());
double d = static_cast<double>(cf->GetValue());
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "float 0x" << std::hex << std::uppercase << bits;
return oss.str();
}
return std::string(TypeToStr(*v->GetType())) + " " + ValStr(v);
}
@ -86,13 +100,34 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
// 1. 全局变量/常量
for (const auto& g : module.GetGlobalVariables()) {
if (g->IsArray()) {
// 全局数组zeroinitializer
if (g->IsConst()) {
os << "@" << g->GetName() << " = constant [" << g->GetNumElements()
<< " x i32] zeroinitializer\n";
const char* linkage = g->IsConst() ? "constant" : "global";
const char* elem_ty = g->IsFloat() ? "float" : "i32";
os << "@" << g->GetName() << " = " << linkage
<< " [" << g->GetNumElements() << " x " << elem_ty << "] ";
if (g->HasInitVals()) {
os << "[";
if (g->IsFloat()) {
const auto& vals = g->GetInitValsF();
for (int i = 0; i < g->GetNumElements(); ++i) {
if (i > 0) os << ", ";
float fv = (i < (int)vals.size()) ? vals[i] : 0.0f;
double d = static_cast<double>(fv);
uint64_t bits;
std::memcpy(&bits, &d, sizeof(bits));
std::ostringstream oss;
oss << "float 0x" << std::hex << std::uppercase << bits;
os << oss.str();
}
} else {
const auto& vals = g->GetInitVals();
for (int i = 0; i < g->GetNumElements(); ++i) {
if (i > 0) os << ", ";
os << "i32 " << (i < (int)vals.size() ? vals[i] : 0);
}
}
os << "]\n";
} else {
os << "@" << g->GetName() << " = global [" << g->GetNumElements()
<< " x i32] zeroinitializer\n";
os << "zeroinitializer\n";
}
} else {
if (g->IsConst()) {
@ -209,8 +244,11 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
}
case Opcode::Gep: {
auto* gep = static_cast<const GepInst*>(inst);
bool is_float_ptr = gep->GetBasePtr()->GetType()->IsPtrFloat32();
const char* elem_type = is_float_ptr ? "float" : "i32";
const char* ptr_type = is_float_ptr ? "float*" : "i32*";
os << " %" << gep->GetName()
<< " = getelementptr i32, i32* "
<< " = getelementptr " << elem_type << ", " << ptr_type << " "
<< ValStr(gep->GetBasePtr()) << ", i32 "
<< ValStr(gep->GetIndex()) << "\n";
break;

@ -224,7 +224,11 @@ AllocaInst::AllocaInst(std::shared_ptr<Type> ptr_ty, int num_elements,
// ─── GepInst ──────────────────────────────────────────────────────────────────
GepInst::GepInst(Value* base_ptr, Value* index, std::string name)
: Instruction(Opcode::Gep, Type::GetPtrInt32Type(), std::move(name)) {
: Instruction(Opcode::Gep,
(base_ptr && base_ptr->GetType()->IsPtrFloat32())
? Type::GetPtrFloat32Type()
: Type::GetPtrInt32Type(),
std::move(name)) {
if (!base_ptr || !index) {
throw std::runtime_error(FormatError("ir", "GepInst 缺少操作数"));
}
@ -265,10 +269,15 @@ GlobalValue::GlobalValue(std::shared_ptr<Type> ty, std::string name)
// ─── GlobalVariable ────────────────────────────────────────────────────────────
GlobalVariable::GlobalVariable(std::string name, bool is_const, int init_val,
int num_elements)
: Value(Type::GetPtrInt32Type(), std::move(name)),
int num_elements, bool is_array_decl,
bool is_float)
: Value(is_float ? Type::GetPtrFloat32Type() : Type::GetPtrInt32Type(),
std::move(name)),
is_const_(is_const),
is_float_(is_float),
init_val_(init_val),
num_elements_(num_elements) {}
init_val_f_(0.0f),
num_elements_(num_elements),
is_array_decl_(is_array_decl) {}
} // namespace ir

@ -28,9 +28,10 @@ const std::vector<std::unique_ptr<Function>>& Module::GetFunctions() const {
// ─── 全局变量管理 ─────────────────────────────────────────────────────────────
GlobalVariable* Module::CreateGlobalVariable(const std::string& name,
bool is_const, int init_val,
int num_elements) {
int num_elements, bool is_array_decl,
bool is_float) {
globals_.push_back(
std::make_unique<GlobalVariable>(name, is_const, init_val, num_elements));
std::make_unique<GlobalVariable>(name, is_const, init_val, num_elements, is_array_decl, is_float));
GlobalVariable* g = globals_.back().get();
global_map_[name] = g;
return g;

@ -117,33 +117,101 @@ std::any IRGenImpl::visitConstDecl(SysYParser::ConstDeclContext* ctx) {
for (int d : dims) total *= (d > 0 ? d : 1);
if (in_global_scope_) {
auto* gv = module_.CreateGlobalVariable(name, true, 0, total);
auto* gv = module_.CreateGlobalVariable(name, true, 0, total, true);
global_storage_map_[constDef] = gv;
global_array_dims_[constDef] = dims;
// 计算初始值并存入全局变量
if (constDef->constInitVal()) {
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
std::vector<int> flat(total, 0);
std::function<void(SysYParser::ConstInitValContext*, int, int)> fill;
fill = [&](SysYParser::ConstInitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->constExp()) { flat[pos] = EvalConstExprInt(iv->constExp()); return; }
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->constInitVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->constExp()) { flat[cur++] = EvalConstExprInt(sub->constExp()); }
else { int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos; fill(sub,a,sub_stride); cur=a+sub_stride; }
}
};
int cur = 0;
if (constDef->constInitVal()->constExp()) {
flat[0] = EvalConstExprInt(constDef->constInitVal()->constExp());
} else {
for (auto* sub : constDef->constInitVal()->constInitVal()) {
if (cur >= total) break;
if (sub->constExp()) { flat[cur++] = EvalConstExprInt(sub->constExp()); }
else { int a = ((cur+top_stride-1)/top_stride)*top_stride; fill(sub,a,top_stride); cur=a+top_stride; }
}
}
gv->SetInitVals(flat);
}
} else {
auto* slot = builder_.CreateAllocaArray(total, name);
storage_map_[constDef] = slot;
array_dims_[constDef] = dims;
// 扁平化初始化
// 按 C 语义扁平化初始化(子列表对齐到维度边界)
if (constDef->constInitVal()) {
std::vector<int> flat;
flat.reserve(total);
std::function<void(SysYParser::ConstInitValContext*)> flatten =
[&](SysYParser::ConstInitValContext* iv) {
if (!iv) return;
if (iv->constExp()) {
flat.push_back(EvalConstExprInt(iv->constExp()));
} else {
for (auto* sub : iv->constInitVal()) flatten(sub);
}
};
flatten(constDef->constInitVal());
std::vector<int> flat(total, 0);
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
std::function<void(SysYParser::ConstInitValContext*, int, int)> fill;
fill = [&](SysYParser::ConstInitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->constExp()) {
flat[pos] = EvalConstExprInt(iv->constExp());
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->constInitVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->constExp()) {
flat[cur] = EvalConstExprInt(sub->constExp());
cur++;
} else {
int aligned = ((cur - pos + sub_stride - 1) / sub_stride) * sub_stride + pos;
fill(sub, aligned, sub_stride);
cur = aligned + sub_stride;
}
}
};
int cur = 0;
if (constDef->constInitVal()->constExp()) {
flat[0] = EvalConstExprInt(constDef->constInitVal()->constExp());
} else {
for (auto* sub : constDef->constInitVal()->constInitVal()) {
if (cur >= total) break;
if (sub->constExp()) {
flat[cur] = EvalConstExprInt(sub->constExp());
cur++;
} else {
int aligned = ((cur + top_stride - 1) / top_stride) * top_stride;
fill(sub, aligned, top_stride);
cur = aligned + top_stride;
}
}
}
for (int i = 0; i < total; ++i) {
int v = (i < (int)flat.size()) ? flat[i] : 0;
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(v), ptr);
builder_.CreateStore(builder_.CreateConstInt(flat[i]), ptr);
}
}
}
@ -194,9 +262,97 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
} else {
int total = 1;
for (int d : dims) total *= (d > 0 ? d : 1);
auto* gv = module_.CreateGlobalVariable(name, false, 0, total);
auto* gv = module_.CreateGlobalVariable(name, false, 0, total, true, is_float);
global_storage_map_[ctx] = gv;
global_array_dims_[ctx] = dims;
// 计算初始值
if (ctx->initVal()) {
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0];
if (is_float) {
std::vector<float> flat(total, 0.0f);
std::function<void(SysYParser::InitValContext*, int, int)> fill_f;
fill_f = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
try { flat[pos] = static_cast<float>(sem::EvaluateExp(*iv->exp()->addExp()).float_val); } catch (...) {}
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<float>(sem::EvaluateExp(*sub->exp()->addExp()).float_val); } catch (...) {}
cur++;
} else {
int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos;
fill_f(sub, a, sub_stride); cur = a + sub_stride;
}
}
};
int cur = 0;
if (ctx->initVal()->exp()) {
try { flat[0] = static_cast<float>(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).float_val); } catch (...) {}
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<float>(sem::EvaluateExp(*sub->exp()->addExp()).float_val); } catch (...) {}
cur++;
} else {
int a = ((cur+top_stride-1)/top_stride)*top_stride;
fill_f(sub, a, top_stride); cur = a + top_stride;
}
}
}
gv->SetInitValsF(flat);
} else {
std::vector<int> flat(total, 0);
std::function<void(SysYParser::InitValContext*, int, int)> fill;
fill = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
try { flat[pos] = static_cast<int>(sem::EvaluateExp(*iv->exp()->addExp()).int_val); } catch (...) {}
return;
}
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k)
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<int>(sem::EvaluateExp(*sub->exp()->addExp()).int_val); } catch (...) {}
cur++;
} else {
int a = ((cur-pos+sub_stride-1)/sub_stride)*sub_stride+pos;
fill(sub, a, sub_stride); cur = a + sub_stride;
}
}
};
int cur = 0;
if (ctx->initVal()->exp()) {
try { flat[0] = static_cast<int>(sem::EvaluateExp(*ctx->initVal()->exp()->addExp()).int_val); } catch (...) {}
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
try { flat[cur] = static_cast<int>(sem::EvaluateExp(*sub->exp()->addExp()).int_val); } catch (...) {}
cur++;
} else {
int a = ((cur+top_stride-1)/top_stride)*top_stride;
fill(sub, a, top_stride); cur = a + top_stride;
}
}
}
gv->SetInitVals(flat);
}
}
}
} else {
if (storage_map_.count(ctx)) {
@ -211,6 +367,14 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
ir::Value* init;
if (ctx->initVal() && ctx->initVal()->exp()) {
init = EvalExpr(*ctx->initVal()->exp());
// Coerce init value to slot type
if (!is_float && init->IsFloat32()) {
init = ToInt(init);
} else if (is_float && init->IsInt32()) {
init = ToFloat(init);
} else if (!is_float && init->IsInt1()) {
init = ToI32(init);
}
} else {
init = is_float ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
@ -219,40 +383,95 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
} else {
int total = 1;
for (int d : dims) total *= (d > 0 ? d : 1);
auto* slot = builder_.CreateAllocaArray(total, name);
auto* slot = is_float ? builder_.CreateAllocaArrayF32(total, module_.GetContext().NextTemp())
: builder_.CreateAllocaArray(total, module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
array_dims_[ctx] = dims;
ir::Value* zero_init = is_float ? static_cast<ir::Value*>(builder_.CreateConstFloat(0.0f))
: static_cast<ir::Value*>(builder_.CreateConstInt(0));
if (ctx->initVal()) {
// 收集扁平化初始值
std::vector<ir::Value*> flat;
flat.reserve(total);
std::function<void(SysYParser::InitValContext*)> flatten =
[&](SysYParser::InitValContext* iv) {
if (!iv) return;
if (iv->exp()) {
flat.push_back(EvalExpr(*iv->exp()));
} else {
for (auto* sub : iv->initVal()) flatten(sub);
}
};
flatten(ctx->initVal());
for (int i = 0; i < total; ++i) {
ir::Value* v = (i < (int)flat.size()) ? flat[i]
: builder_.CreateConstInt(0);
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(v, ptr);
// 按 C 语义扁平化初始值:子列表对齐到对应维度边界
std::vector<ir::Value*> flat(total, zero_init);
// 计算各维度的 stridestride[i] = dims[i]*dims[i+1]*...*dims[n-1]
// 但我们需要「子列表对应第几维的 stride」
// 顶层stride = total / dims[0](即每行的元素数)
// 递归时 stride 继续除以当前维度大小
std::vector<int> strides(dims.size(), 1);
for (int i = (int)dims.size() - 2; i >= 0; --i)
strides[i] = strides[i + 1] * dims[i + 1];
int top_stride = strides[0]; // 每个顶层子列表占用的元素数
// fill(iv, pos, stride):将 iv 的内容填入 flat[pos..pos+stride)
// stride 表示当前层子列表对应的元素个数
std::function<void(SysYParser::InitValContext*, int, int)> fill;
fill = [&](SysYParser::InitValContext* iv, int pos, int stride) {
if (!iv || pos >= total) return;
if (iv->exp()) {
flat[pos] = EvalExpr(*iv->exp());
return;
}
// 子列表内的 stride = stride / (当前层首维大小)
// 找到对应的 strides 层stride == strides[k] → 子stride = strides[k+1]
int sub_stride = 1;
for (int k = 0; k < (int)strides.size() - 1; ++k) {
if (strides[k] == stride) { sub_stride = strides[k + 1]; break; }
}
int cur = pos;
for (auto* sub : iv->initVal()) {
if (cur >= pos + stride || cur >= total) break;
if (sub->exp()) {
flat[cur] = EvalExpr(*sub->exp());
cur++;
} else {
// 对齐到 sub_stride 边界
int aligned = ((cur - pos + sub_stride - 1) / sub_stride) * sub_stride + pos;
fill(sub, aligned, sub_stride);
cur = aligned + sub_stride;
}
}
};
// 顶层扫描
int cur = 0;
if (ctx->initVal()->exp()) {
flat[0] = EvalExpr(*ctx->initVal()->exp());
} else {
for (auto* sub : ctx->initVal()->initVal()) {
if (cur >= total) break;
if (sub->exp()) {
flat[cur] = EvalExpr(*sub->exp());
cur++;
} else {
// 对齐到 top_stride 边界
int aligned = ((cur + top_stride - 1) / top_stride) * top_stride;
fill(sub, aligned, top_stride);
cur = aligned + top_stride;
}
}
}
} else {
// 零初始化
// 先 memset 归零,再只写入非零元素
builder_.CreateMemsetZero(slot, total, module_.GetContext(), module_);
for (int i = 0; i < total; ++i) {
bool is_zero = false;
if (auto* ci = dynamic_cast<ir::ConstantInt*>(flat[i])) {
is_zero = (ci->GetValue() == 0);
} else if (auto* cf = dynamic_cast<ir::ConstantFloat*>(flat[i])) {
is_zero = (cf->GetValue() == 0.0f);
}
if (is_zero) continue;
auto* ptr = builder_.CreateGep(
slot, builder_.CreateConstInt(i),
module_.GetContext().NextTemp());
builder_.CreateStore(builder_.CreateConstInt(0), ptr);
builder_.CreateStore(flat[i], ptr);
}
} else {
// 零初始化:用 memset 归零
builder_.CreateMemsetZero(slot, total, module_.GetContext(), module_);
(void)zero_init;
}
}
}

@ -10,10 +10,15 @@
// ─── 辅助 ─────────────────────────────────────────────────────────────────────
// 把 i32 值转成 i1icmp ne i32 v, 0
// 把 i32/float 值转成 i1
ir::Value* IRGenImpl::ToI1(ir::Value* v) {
if (!v) throw std::runtime_error(FormatError("irgen", "ToI1: null value"));
if (v->IsInt1()) return v;
if (v->IsFloat32()) {
return builder_.CreateFCmp(ir::FCmpPredicate::ONE, v,
builder_.CreateConstFloat(0.0f),
module_.GetContext().NextTemp());
}
return builder_.CreateICmp(ir::ICmpPredicate::NE, v,
builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
@ -87,7 +92,13 @@ void IRGenImpl::EnsureExternalDecl(const std::string& name) {
} else if (name == "getch") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(), {});
} else if (name == "getfloat") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {}); // 近似
module_.DeclareExternalFunc(name, ir::Type::GetFloat32Type(), {});
} else if (name == "getarray") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(),
{ir::Type::GetPtrInt32Type()});
} else if (name == "getfarray") {
module_.DeclareExternalFunc(name, ir::Type::GetInt32Type(),
{ir::Type::GetPtrFloat32Type()});
} else if (name == "putint") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
@ -95,10 +106,16 @@ void IRGenImpl::EnsureExternalDecl(const std::string& name) {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
} else if (name == "putfloat") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(), {});
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetFloat32Type()});
} else if (name == "putarray") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
{ir::Type::GetInt32Type(),
ir::Type::GetPtrInt32Type()});
} else if (name == "putfarray") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type(),
ir::Type::GetPtrFloat32Type()});
} else if (name == "starttime" || name == "stoptime") {
module_.DeclareExternalFunc(name, ir::Type::GetVoidType(),
{ir::Type::GetInt32Type()});
@ -227,13 +244,113 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
std::vector<ir::Value*> args;
if (ctx->funcRParams()) {
for (auto* exp : ctx->funcRParams()->exp()) {
args.push_back(EvalExpr(*exp));
// 检查是否是数组变量(无索引的 lVar若是则传指针而非 load
ir::Value* arg = nullptr;
auto* add = exp->addExp();
if (add && add->mulExp().size() == 1) {
auto* mul = add->mulExp(0);
if (mul && mul->unaryExp().size() == 1) {
auto* unary = mul->unaryExp(0);
if (unary && !unary->unaryOp() && unary->primaryExp()) {
auto* primary = unary->primaryExp();
if (primary && primary->lVar() && primary->lVar()->exp().empty()) {
auto* lvar = primary->lVar();
auto* decl = sema_.ResolveVarUse(lvar->Ident());
if (decl) {
// 检查是否是数组参数storage_map_ 里存的是指针)
auto it = storage_map_.find(decl);
if (it != storage_map_.end()) {
auto* val = it->second;
if (val && (val->IsPtrInt32() || val->IsPtrFloat32())) {
// 检查是否是 Argument数组参数直接传指针
if (dynamic_cast<ir::Argument*>(val)) {
arg = val;
} else if (array_dims_.count(decl)) {
// 本地数组(含 dims 记录):传首元素地址
arg = builder_.CreateGep(val, builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
}
}
// 检查全局数组
if (!arg) {
auto git = global_storage_map_.find(decl);
if (git != global_storage_map_.end()) {
auto* gv = dynamic_cast<ir::GlobalVariable*>(git->second);
if (gv && gv->IsArray()) {
arg = builder_.CreateGep(git->second,
builder_.CreateConstInt(0),
module_.GetContext().NextTemp());
}
}
}
}
}
}
}
}
// Also handle partially-indexed multi-dim arrays: arr[i] where arr is
// int arr[M][N] should pass &arr[i*N] as i32*, not load arr[i] as i32.
if (!arg) {
auto* add2 = exp->addExp();
if (add2 && add2->mulExp().size() == 1) {
auto* mul2 = add2->mulExp(0);
if (mul2 && mul2->unaryExp().size() == 1) {
auto* unary2 = mul2->unaryExp(0);
if (unary2 && !unary2->unaryOp() && unary2->primaryExp()) {
auto* primary2 = unary2->primaryExp();
if (primary2 && primary2->lVar() && !primary2->lVar()->exp().empty()) {
auto* lvar2 = primary2->lVar();
auto* decl2 = sema_.ResolveVarUse(lvar2->Ident());
if (decl2) {
std::vector<int> dims2;
ir::Value* base2 = nullptr;
auto it2 = array_dims_.find(decl2);
if (it2 != array_dims_.end()) {
dims2 = it2->second;
auto sit = storage_map_.find(decl2);
if (sit != storage_map_.end()) base2 = sit->second;
} else {
auto git2 = global_array_dims_.find(decl2);
if (git2 != global_array_dims_.end()) {
dims2 = git2->second;
auto gsit = global_storage_map_.find(decl2);
if (gsit != global_storage_map_.end()) base2 = gsit->second;
}
}
// Partially indexed: fewer indices than dimensions -> pass pointer
bool is_param = !dims2.empty() && dims2[0] == -1;
size_t effective_dims = is_param ? dims2.size() - 1 : dims2.size();
if (base2 && !dims2.empty() &&
lvar2->exp().size() < effective_dims + (is_param ? 1 : 0)) {
arg = EvalLVarAddr(lvar2);
}
}
}
}
}
}
}
if (!arg) arg = EvalExpr(*exp);
args.push_back(arg);
}
}
// 模块内已知函数?
ir::Function* callee = module_.GetFunction(callee_name);
if (callee) {
// Coerce args to match parameter types
for (size_t i = 0; i < args.size() && i < callee->GetNumArgs(); ++i) {
auto* param = callee->GetArgument(i);
if (!param || !args[i]) continue;
if (param->IsInt32() && args[i]->IsFloat32()) {
args[i] = ToInt(args[i]);
} else if (param->IsFloat32() && args[i]->IsInt32()) {
args[i] = ToFloat(args[i]);
} else if (param->IsInt32() && args[i]->IsInt1()) {
args[i] = ToI32(args[i]);
}
}
std::string ret_name =
callee->IsVoidReturn() ? "" : module_.GetContext().NextTemp();
auto* call =
@ -246,15 +363,28 @@ std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
// 外部函数
EnsureExternalDecl(callee_name);
// 获取返回类型
// 获取返回类型和参数类型
std::shared_ptr<ir::Type> ret_type = ir::Type::GetInt32Type();
std::vector<std::shared_ptr<ir::Type>> param_types;
for (const auto& decl : module_.GetExternalDecls()) {
if (decl.name == callee_name) {
ret_type = decl.ret_type;
param_types = decl.param_types;
break;
}
}
bool is_void = ret_type->IsVoid();
// Coerce args to match external function parameter types
for (size_t i = 0; i < args.size() && i < param_types.size(); ++i) {
if (!args[i]) continue;
if (param_types[i]->IsInt32() && args[i]->IsFloat32()) {
args[i] = ToInt(args[i]);
} else if (param_types[i]->IsFloat32() && args[i]->IsInt32()) {
args[i] = ToFloat(args[i]);
} else if (param_types[i]->IsInt32() && args[i]->IsInt1()) {
args[i] = ToI32(args[i]);
}
}
std::string ret_name = is_void ? "" : module_.GetContext().NextTemp();
auto* call = builder_.CreateCallExternal(callee_name, ret_type,
std::move(args), ret_name);
@ -331,40 +461,26 @@ ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) {
throw std::runtime_error(FormatError("irgen", "数组索引维度过多"));
}
ir::Value* offset = builder_.CreateConstInt(0);
ir::Value* offset = nullptr;
if (is_array_param) {
// 数组参数dims[0]=-1, dims[1..n]是已知维度
// 索引indices[0]对应第一维indices[1]对应第二维...
for (size_t i = 0; i < indices.size(); ++i) {
ir::Value* idx = EvalExpr(*indices[i]);
if (i == 0) {
// 第一维stride = dims[1] * dims[2] * ... (如果有的话)
int stride = 1;
for (size_t j = 1; j < dims.size(); ++j) {
stride *= dims[j];
}
if (stride > 1) {
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
} else {
offset = builder_.CreateAdd(offset, idx,
module_.GetContext().NextTemp());
}
int stride = 1;
size_t start = (i == 0) ? 1 : i + 1;
for (size_t j = start; j < dims.size(); ++j) stride *= dims[j];
ir::Value* term;
if (stride == 1) {
term = idx;
} else {
// 后续维度
int stride = 1;
for (size_t j = i + 1; j < dims.size(); ++j) {
stride *= dims[j];
}
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
module_.GetContext().NextTemp());
term = builder_.CreateMul(idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
}
if (!offset) {
offset = term;
} else {
offset = builder_.CreateAdd(offset, term, module_.GetContext().NextTemp());
}
}
} else {
@ -374,15 +490,24 @@ ir::Value* IRGenImpl::EvalLVarAddr(SysYParser::LVarContext* ctx) {
stride = (i == (int)dims.size() - 1) ? 1 : stride * dims[i + 1];
if (i < (int)indices.size()) {
ir::Value* idx = EvalExpr(*indices[i]);
ir::Value* scaled = builder_.CreateMul(
idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
offset = builder_.CreateAdd(offset, scaled,
ir::Value* term;
if (stride == 1) {
term = idx;
} else {
term = builder_.CreateMul(idx, builder_.CreateConstInt(stride),
module_.GetContext().NextTemp());
}
if (!offset) {
offset = term;
} else {
offset = builder_.CreateAdd(offset, term, module_.GetContext().NextTemp());
}
}
}
}
if (!offset) offset = builder_.CreateConstInt(0);
return builder_.CreateGep(base, offset, module_.GetContext().NextTemp());
}
@ -486,8 +611,8 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
ir::Value* res_ext = ToI32(result);
builder_.CreateStore(res_ext, res_slot);
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".or.end");
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".or.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".or.end");
builder_.CreateCondBr(result, end_bb, rhs_bb);
builder_.SetInsertPoint(rhs_bb);
@ -498,6 +623,7 @@ std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
builder_.CreateBr(end_bb);
}
func_->MoveBlockToEnd(end_bb);
builder_.SetInsertPoint(end_bb);
result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp()));
}
@ -523,8 +649,8 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
ir::Value* res_ext = ToI32(result);
builder_.CreateStore(res_ext, res_slot);
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".and.end");
ir::BasicBlock* rhs_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".and.rhs");
ir::BasicBlock* end_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".and.end");
builder_.CreateCondBr(result, rhs_bb, end_bb);
builder_.SetInsertPoint(rhs_bb);
@ -535,6 +661,7 @@ std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
builder_.CreateBr(end_bb);
}
func_->MoveBlockToEnd(end_bb);
builder_.SetInsertPoint(end_bb);
result = ToI1(builder_.CreateLoad(res_slot, module_.GetContext().NextTemp()));
}

@ -25,6 +25,14 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (ctx->Return()) {
if (ctx->exp()) {
ir::Value* v = EvalExpr(*ctx->exp());
// Coerce return value to function return type
if (func_->GetType()->IsInt32() && v->IsFloat32()) {
v = ToInt(v);
} else if (func_->GetType()->IsFloat32() && v->IsInt32()) {
v = ToFloat(v);
} else if (func_->GetType()->IsInt32() && v->IsInt1()) {
v = ToI32(v);
}
builder_.CreateRet(v);
} else {
builder_.CreateRetVoid();
@ -54,6 +62,14 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (ctx->lVar() && ctx->Assign()) {
ir::Value* rhs = EvalExpr(*ctx->exp());
ir::Value* addr = EvalLVarAddr(ctx->lVar());
// Coerce rhs to match slot type
if (addr->IsPtrInt32() && rhs->IsFloat32()) {
rhs = ToInt(rhs);
} else if (addr->IsPtrFloat32() && rhs->IsInt32()) {
rhs = ToFloat(rhs);
} else if (addr->IsPtrInt32() && rhs->IsInt1()) {
rhs = ToI32(rhs);
}
builder_.CreateStore(rhs, addr);
return BlockFlow::Continue;
}
@ -74,32 +90,47 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
}
auto stmts = ctx->stmt();
// Step 1: evaluate condition (may create short-circuit blocks with lower
// SSA numbers — must happen before any branch-target blocks are created).
ir::Value* cond_val = EvalCond(*ctx->cond());
// Step 2: create then_bb now (its label number will be >= all short-circuit
// block numbers allocated during EvalCond).
ir::BasicBlock* then_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".if.then");
module_.GetContext().NextLabel() + ".if.then");
// Step 3: create else_bb/merge_bb as placeholders. They will be moved to
// the end of the block list after their predecessors are filled in, so the
// block ordering in the output will be correct even though their label
// numbers are allocated here (before then-body sub-blocks are created).
ir::BasicBlock* else_bb = nullptr;
ir::BasicBlock* merge_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".if.end");
ir::BasicBlock* merge_bb = nullptr;
// 求值条件(可能创建短路求值块)
ir::Value* cond_val = EvalCond(*ctx->cond());
if (stmts.size() >= 2) {
else_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.else");
merge_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.end");
} else {
merge_bb = func_->CreateBlock(module_.GetContext().NextLabel() + ".if.end");
}
// 检查当前块是否已终结(短路求值可能导致)
// Check if current block already terminated (short-circuit may do this)
if (builder_.GetInsertBlock() && builder_.GetInsertBlock()->HasTerminator()) {
// 条件求值已经终结了当前块,无法继续
// 这种情况下我们需要在merge_bb继续
func_->MoveBlockToEnd(then_bb);
if (else_bb) func_->MoveBlockToEnd(else_bb);
func_->MoveBlockToEnd(merge_bb);
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
if (stmts.size() >= 2) {
// if-else
else_bb = func_->CreateBlock(module_.GetContext().NextTemp() + ".if.else");
builder_.CreateCondBr(cond_val, then_bb, else_bb);
} else {
builder_.CreateCondBr(cond_val, then_bb, merge_bb);
}
// then 分支
// then 分支 — visit body (may create many sub-blocks with higher numbers)
func_->MoveBlockToEnd(then_bb);
builder_.SetInsertPoint(then_bb);
auto then_flow = VisitStmt(*stmts[0]);
if (then_flow != BlockFlow::Terminated) {
@ -108,6 +139,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
// else 分支
if (else_bb) {
func_->MoveBlockToEnd(else_bb);
builder_.SetInsertPoint(else_bb);
auto else_flow = VisitStmt(*stmts[1]);
if (else_flow != BlockFlow::Terminated) {
@ -115,6 +147,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
}
}
func_->MoveBlockToEnd(merge_bb);
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
@ -124,28 +157,32 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
if (!ctx->cond()) {
throw std::runtime_error(FormatError("irgen", "while 缺少条件"));
}
ir::BasicBlock* cond_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.cond");
ir::BasicBlock* body_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.body");
ir::BasicBlock* after_bb = func_->CreateBlock(
module_.GetContext().NextTemp() + ".while.end");
module_.GetContext().NextLabel() + ".while.cond");
// 跳转到条件块
if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateBr(cond_bb);
}
// 条件块
// EvalCond MUST come before creating body_bb/after_bb so that
// short-circuit blocks get lower SSA numbers than the loop body blocks.
builder_.SetInsertPoint(cond_bb);
ir::Value* cond_val = EvalCond(*ctx->cond());
ir::BasicBlock* body_bb = func_->CreateBlock(
module_.GetContext().NextLabel() + ".while.body");
ir::BasicBlock* after_bb = func_->CreateBlock(
module_.GetContext().NextLabel() + ".while.end");
// 检查条件求值后是否已终结
if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateCondBr(cond_val, body_bb, after_bb);
}
// 循环体(压入循环栈)
func_->MoveBlockToEnd(body_bb);
loop_stack_.push_back({cond_bb, after_bb});
builder_.SetInsertPoint(body_bb);
auto stmts = ctx->stmt();
@ -159,6 +196,7 @@ IRGenImpl::BlockFlow IRGenImpl::VisitStmt(SysYParser::StmtContext& s) {
}
loop_stack_.pop_back();
func_->MoveBlockToEnd(after_bb);
builder_.SetInsertPoint(after_bb);
return BlockFlow::Continue;
}

@ -1,5 +1,6 @@
#include "sem/func.h"
#include <cstring>
#include <stdexcept>
#include <string>
@ -7,6 +8,12 @@
namespace sem {
// Truncate double to float32 precision (mimics C float arithmetic)
static double ToFloat32(double v) {
float f = static_cast<float>(v);
return static_cast<double>(f);
}
// 编译时求值常量表达式
ConstValue EvaluateConstExp(SysYParser::ConstExpContext& ctx) {
return EvaluateExp(*ctx.addExp());
@ -73,14 +80,65 @@ ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
if (ctx.exp()) {
return EvaluateExp(*ctx.exp()->addExp());
} else if (ctx.lVar()) {
// 处理变量引用(必须是已定义的常量)
// 处理变量引用:向上遍历 AST 找到对应的常量定义并求值
auto* ident = ctx.lVar()->Ident();
if (!ident) {
throw std::runtime_error(FormatError("sema", "非法变量引用"));
}
std::string name = ident->getText();
// 这里简化处理,实际应该在符号表中查找常量
// 暂时假设常量已经在前面被处理过
// 向上遍历 AST 找到作用域内的 constDef
antlr4::ParserRuleContext* scope =
dynamic_cast<antlr4::ParserRuleContext*>(ctx.lVar()->parent);
while (scope) {
// 检查当前作用域中的所有 constDecl
for (auto* tree_child : scope->children) {
auto* child = dynamic_cast<antlr4::ParserRuleContext*>(tree_child);
if (!child) continue;
auto* block_item = dynamic_cast<SysYParser::BlockItemContext*>(child);
if (block_item && block_item->decl()) {
auto* decl = block_item->decl();
if (decl->constDecl()) {
for (auto* def : decl->constDecl()->constDef()) {
if (def->Ident() && def->Ident()->getText() == name) {
if (def->constInitVal() && def->constInitVal()->constExp()) {
ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp());
bool decl_is_int = decl->constDecl()->bType() &&
decl->constDecl()->bType()->Int();
if (decl_is_int) {
cv.is_int = true;
cv.int_val = static_cast<long long>(static_cast<int>(cv.float_val));
cv.float_val = static_cast<double>(cv.int_val);
}
return cv;
}
}
}
}
}
// compUnit 级别的 constDecl
auto* decl = dynamic_cast<SysYParser::DeclContext*>(child);
if (decl && decl->constDecl()) {
for (auto* def : decl->constDecl()->constDef()) {
if (def->Ident() && def->Ident()->getText() == name) {
if (def->constInitVal() && def->constInitVal()->constExp()) {
ConstValue cv = EvaluateConstExp(*def->constInitVal()->constExp());
// If declared as int, truncate to integer
bool decl_is_int = decl->constDecl()->bType() &&
decl->constDecl()->bType()->Int();
if (decl_is_int) {
cv.is_int = true;
cv.int_val = static_cast<long long>(static_cast<int>(cv.float_val));
cv.float_val = static_cast<double>(cv.int_val);
}
return cv;
}
}
}
}
}
scope = dynamic_cast<antlr4::ParserRuleContext*>(scope->parent);
}
// 未找到常量定义,返回 0
ConstValue val;
val.is_int = true;
val.int_val = 0;
@ -94,11 +152,11 @@ ConstValue EvaluatePrimaryExp(SysYParser::PrimaryExpContext& ctx) {
ConstValue val;
if (int_const) {
val.is_int = true;
val.int_val = std::stoll(int_const->getText());
val.int_val = std::stoll(int_const->getText(), nullptr, 0);
val.float_val = static_cast<double>(val.int_val);
} else if (float_const) {
val.is_int = false;
val.float_val = std::stod(float_const->getText());
val.float_val = ToFloat32(std::stod(float_const->getText()));
val.int_val = static_cast<long long>(val.float_val);
} else {
throw std::runtime_error(FormatError("sema", "非法数字字面量"));
@ -127,8 +185,9 @@ ConstValue AddValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) +
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l + r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;
@ -143,8 +202,9 @@ ConstValue SubValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) -
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l - r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;
@ -159,8 +219,9 @@ ConstValue MulValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) *
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l * r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;
@ -175,8 +236,9 @@ ConstValue DivValues(const ConstValue& lhs, const ConstValue& rhs) {
result.float_val = static_cast<double>(result.int_val);
} else {
result.is_int = false;
result.float_val = (lhs.is_int ? lhs.int_val : lhs.float_val) /
(rhs.is_int ? rhs.int_val : rhs.float_val);
double l = lhs.is_int ? lhs.int_val : lhs.float_val;
double r = rhs.is_int ? rhs.int_val : rhs.float_val;
result.float_val = ToFloat32(l / r);
result.int_val = static_cast<long long>(result.float_val);
}
return result;

@ -2,3 +2,41 @@
// - 按实验/评测规范提供 I/O 等函数实现
// - 与编译器生成的目标代码链接,支撑运行时行为
#include<stdio.h>
#include"sylib.h"
/* Input & output functions */
int getint(){int t; scanf("%d",&t); return t; }
int getch(){char c; scanf("%c",&c); return (int)c; }
int getarray(int a[]){
int n;
scanf("%d",&n);
for(int i=0;i<n;i++)scanf("%d",&a[i]);
return n;
}
void putint(int a){ printf("%d",a);}
void putch(int a){ printf("%c",a); }
void putarray(int n,int a[]){
printf("%d:",n);
for(int i=0;i<n;i++)printf(" %d",a[i]);
printf("\n");
}
float getfloat(){ float f; scanf("%f",&f); return f; }
void putfloat(float a){ printf("%a",a); }
int getfarray(float a[]){
int n; scanf("%d",&n);
for(int i=0;i<n;i++) scanf("%a",&a[i]);
return n;
}
void putfarray(int n,float a[]){
printf("%d:",n);
for(int i=0;i<n;i++) printf(" %a",a[i]);
printf("\n");
}
static struct timeval _sysy_start, _sysy_end;
void starttime(){ gettimeofday(&_sysy_start, NULL); }
void stoptime(){
gettimeofday(&_sysy_end, NULL);
long us = (_sysy_end.tv_sec - _sysy_start.tv_sec) * 1000000L
+ (_sysy_end.tv_usec - _sysy_start.tv_usec);
fprintf(stderr, "Timer: %ldus\n", us);
}

@ -2,3 +2,20 @@
// - 声明运行库函数原型(供编译器生成 call 或链接阶段引用)
// - 与 sylib.c 配套,按规范逐步补齐声明
#ifndef __SYLIB_H_
#define __SYLIB_H_
#include<stdio.h>
#include<stdarg.h>
#include<sys/time.h>
/* Input & output functions */
int getint(),getch(),getarray(int a[]);
void putint(int a),putch(int a),putarray(int n,int a[]);
float getfloat();
void putfloat(float a);
int getfarray(float a[]);
void putfarray(int n,float a[]);
/* Timing functions */
void starttime();
void stoptime();
#endif

Loading…
Cancel
Save