语法通过

zjx 1 day ago
parent 52c4b75a9e
commit 058ac57a47

@ -457,6 +457,8 @@ class Function : public Value {
BasicBlock* entry_ = nullptr;
std::vector<std::unique_ptr<BasicBlock>> blocks_;
std::vector<Value*> params_; // 参数值(通常是 Argument 类型,后续可定义)
// Owned parameter storage to keep argument Values alive
std::vector<std::unique_ptr<Value>> owned_params_;
std::shared_ptr<FunctionType> func_type_; // 缓存函数类型
};

@ -35,6 +35,13 @@ class IRGenImpl final : public SysYBaseVisitor {
std::any visitNumber(SysYParser::NumberContext* ctx) override;
std::any visitLVal(SysYParser::LValContext* ctx) override;
std::any visitAddExp(SysYParser::AddExpContext* ctx) override;
std::any visitUnaryExp(SysYParser::UnaryExpContext* ctx) override;
std::any visitMulExp(SysYParser::MulExpContext* ctx) override;
std::any visitRelExp(SysYParser::RelExpContext* ctx) override;
std::any visitEqExp(SysYParser::EqExpContext* ctx) override;
std::any visitLAndExp(SysYParser::LAndExpContext* ctx) override;
std::any visitLOrExp(SysYParser::LOrExpContext* ctx) override;
std::any visitConstDef(SysYParser::ConstDefContext* ctx) override;
private:
enum class BlockFlow {
@ -50,7 +57,15 @@ class IRGenImpl final : public SysYBaseVisitor {
ir::Function* func_;
ir::IRBuilder builder_;
// 名称绑定由 Sema 负责IRGen 只维护“声明 -> 存储槽位”的代码生成状态。
std::unordered_map<SysYParser::VarDefContext*, ir::Value*> storage_map_;
std::unordered_map<antlr4::ParserRuleContext*, ir::Value*> storage_map_;
// 额外增加按名称的快速映射,以防有时无法直接通过声明节点指针匹配。
std::unordered_map<std::string, ir::Value*> name_map_;
// 常量名称到整数值的快速映射(供数组维度解析使用)
std::unordered_map<std::string, long> const_values_;
// 当前正在处理的声明基础类型(由 visitDecl 设置visitVarDef/visitConstDef 使用)
std::string current_btype_;
std::vector<ir::BasicBlock*> break_targets_;
std::vector<ir::BasicBlock*> continue_targets_;
};
std::unique_ptr<ir::Module> GenerateIR(SysYParser::CompUnitContext& tree,

@ -0,0 +1,60 @@
#!/usr/bin/env bash
set -euo pipefail
ROOT=$(cd "$(dirname "$0")/.." && pwd)
FUNC_DIR="$ROOT/test/test_case/functional"
OUT_BASE="$ROOT/test/test_result/function/ir"
LOG_DIR="$ROOT/test/test_result/function/ir_logs"
VERIFY="$ROOT/scripts/verify_ir.sh"
mkdir -p "$OUT_BASE"
mkdir -p "$LOG_DIR"
if [ ! -x "$VERIFY" ]; then
echo "verify script not executable, trying to run with bash: $VERIFY"
fi
files=("$FUNC_DIR"/*.sy)
if [ ${#files[@]} -eq 0 ]; then
echo "No .sy files found in $FUNC_DIR" >&2
exit 1
fi
total=0
pass=0
fail=0
failed_list=()
for f in "${files[@]}"; do
((total++))
name=$(basename "$f")
echo "=== Test: $name ==="
log="$LOG_DIR/${name%.sy}.log"
set +e
bash "$VERIFY" "$f" "$OUT_BASE" --run >"$log" 2>&1
rc=$?
set -e
if [ $rc -eq 0 ]; then
echo "PASS: $name"
((pass++))
else
echo "FAIL: $name (log: $log)"
failed_list+=("$name")
((fail++))
fi
done
echo
echo "Summary: total=$total pass=$pass fail=$fail"
if [ $fail -ne 0 ]; then
echo "Failed tests:"; for t in "${failed_list[@]}"; do echo " - $t"; done
echo "Tail of failure logs (last 200 lines each):"
for t in "${failed_list[@]}"; do
logfile="$LOG_DIR/${t%.sy}.log"
echo
echo "--- $t ---"
tail -n 200 "$logfile" || true
done
fi
exit $fail

@ -25,7 +25,7 @@ ConstantFloat* Context::GetConstFloat(float v) {
std::string Context::NextTemp() {
std::ostringstream oss;
oss << "%" << ++temp_index_;
oss << ++temp_index_;
return oss.str();
}

@ -9,8 +9,14 @@ Function::Function(std::string name, std::shared_ptr<Type> ret_type,
std::vector<std::shared_ptr<Type>> param_types)
: Value(std::move(ret_type), std::move(name)) {
func_type_ = std::static_pointer_cast<FunctionType>(
Type::GetFunctionType(GetType(), std::move(param_types)));
Type::GetFunctionType(GetType(), param_types));
entry_ = CreateBlock("entry");
// Create arguments
for (size_t i = 0; i < param_types.size(); ++i) {
owned_params_.push_back(std::make_unique<Value>(param_types[i], "arg" + std::to_string(i)));
params_.push_back(owned_params_.back().get());
// Note: arguments are owned in owned_params_ to ensure lifetime
}
}
BasicBlock* Function::CreateBlock(const std::string& name) {

@ -21,6 +21,10 @@ ConstantInt* IRBuilder::CreateConstInt(int v) {
return ctx_.GetConstInt(v);
}
ConstantFloat* IRBuilder::CreateConstFloat(float v) {
return ctx_.GetConstFloat(v);
}
BinaryInst* IRBuilder::CreateBinary(Opcode op, Value* lhs, Value* rhs,
const std::string& name) {
if (!insert_block_) {
@ -42,6 +46,7 @@ BinaryInst* IRBuilder::CreateAdd(Value* lhs, Value* rhs,
return CreateBinary(Opcode::Add, lhs, rhs, name);
}
AllocaInst* IRBuilder::CreateAllocaI32(const std::string& name) {
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
@ -161,8 +166,30 @@ GetElementPtrInst* IRBuilder::CreateGEP(Value* ptr, std::vector<Value*> indices,
if (!insert_block_) {
throw std::runtime_error(FormatError("ir", "IRBuilder 未设置插入点"));
}
// 计算结果类型(简化:假设为指针类型)
auto result_ty = Type::GetPointerType(ptr->GetType());
// 计算结果类型:根据传入的 indices 逐步从 pointee 类型走到目标元素类型。
// 例如 ptr 是指向数组的指针GEP 使用一个索引应返回指向数组元素的指针。
std::shared_ptr<Type> current;
if (ptr->GetType() && ptr->GetType()->IsPointer()) {
const PointerType* pty = static_cast<const PointerType*>(ptr->GetType().get());
current = pty->GetPointeeType();
} else {
current = ptr->GetType();
}
// 根据每个索引推进类型层次:数组 -> 元素类型,指针 -> 指向类型
for (size_t i = 0; i < indices.size(); ++i) {
if (!current) break;
if (current->IsArray()) {
const ArrayType* aty = static_cast<const ArrayType*>(current.get());
current = aty->GetElementType();
} else if (current->IsPointer()) {
const PointerType* ppty = static_cast<const PointerType*>(current.get());
current = ppty->GetPointeeType();
} else {
// 非数组/指针类型,无法继续下钻,保持当前类型
break;
}
}
auto result_ty = Type::GetPointerType(current);
return insert_block_->Append<GetElementPtrInst>(result_ty, ptr, indices, name);
}

@ -12,7 +12,30 @@
namespace ir {
static const char* TypeToString(const Type& ty) {
static std::string PredicateToString(CmpInst::Predicate pred, bool is_float) {
if (is_float) {
switch (pred) {
case CmpInst::EQ: return "oeq";
case CmpInst::NE: return "one";
case CmpInst::LT: return "olt";
case CmpInst::LE: return "ole";
case CmpInst::GT: return "ogt";
case CmpInst::GE: return "oge";
}
} else {
switch (pred) {
case CmpInst::EQ: return "eq";
case CmpInst::NE: return "ne";
case CmpInst::LT: return "slt";
case CmpInst::LE: return "sle";
case CmpInst::GT: return "sgt";
case CmpInst::GE: return "sge";
}
}
return "unknown";
}
static std::string TypeToString(const Type& ty) {
switch (ty.GetKind()) {
case Type::Kind::Void:
return "void";
@ -20,10 +43,14 @@ static const char* TypeToString(const Type& ty) {
return "i32";
case Type::Kind::Float32:
return "float";
case Type::Kind::Pointer:
return "i32*"; // 目前仅支持 i32* 指针打印
case Type::Kind::Array:
return "[array]";
case Type::Kind::Pointer: {
const PointerType* p = static_cast<const PointerType*>(&ty);
return TypeToString(*p->GetPointeeType()) + "*";
}
case Type::Kind::Array: {
const ArrayType* a = static_cast<const ArrayType*>(&ty);
return std::string("[") + std::to_string(a->GetSize()) + " x " + TypeToString(*a->GetElementType()) + "]";
}
case Type::Kind::Function:
return "[function]";
case Type::Kind::Label:
@ -60,17 +87,65 @@ static const char* OpcodeToString(Opcode op) {
return "?";
}
static std::string ConstantValueToString(const ConstantValue* cv);
static std::string ValueToString(const Value* v) {
if (!v) return "<null>";
if (auto* ci = dynamic_cast<const ConstantInt*>(v)) {
return std::to_string(ci->GetValue());
}
return v ? v->GetName() : "<null>";
if (auto* cf = dynamic_cast<const ConstantFloat*>(v)) {
// simple float literal
return std::to_string(cf->GetValue());
}
if (auto* ca = dynamic_cast<const ConstantArray*>(v)) {
return ConstantValueToString(ca);
}
// fallback to name for instructions/alloca/vars — prefix with '%'
return std::string("%") + v->GetName();
}
static std::string ConstantValueToString(const ConstantValue* cv) {
if (!cv) return "<null-const>";
if (auto* ci = dynamic_cast<const ConstantInt*>(cv)) return std::to_string(ci->GetValue());
if (auto* cf = dynamic_cast<const ConstantFloat*>(cv)) {
std::string s = std::to_string(cf->GetValue());
size_t dot = s.find('.');
if (dot != std::string::npos) {
size_t e = s.find('e');
if (e == std::string::npos) e = s.size();
while (e > dot + 1 && s[e-1] == '0') e--;
if (e == dot + 1) s = s.substr(0, dot + 1) + "0";
else s = s.substr(0, e);
}
return s;
}
if (auto* ca = dynamic_cast<const ConstantArray*>(cv)) {
// format: [ <elem_ty> <elem>, <elem_ty> <elem>, ... ]
const auto& elems = ca->GetElements();
std::string out = "[";
for (size_t i = 0; i < elems.size(); ++i) {
if (i) out += ", ";
// each element should be printed with its type and value
auto* e = elems[i];
std::string etype = TypeToString(*e->GetType());
out += etype + " " + ConstantValueToString(e);
}
out += "]";
return out;
}
return "<const-unk>";
}
void IRPrinter::Print(const Module& module, std::ostream& os) {
for (const auto& func : module.GetFunctions()) {
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName()
<< "() {\n";
os << "define " << TypeToString(*func->GetType()) << " @" << func->GetName() << "(";
const auto& params = func->GetParams();
for (size_t i = 0; i < params.size(); ++i) {
if (i) os << ", ";
os << TypeToString(*params[i]->GetType()) << " " << ValueToString(params[i]);
}
os << ") {\n";
for (const auto& bb : func->GetBlocks()) {
if (!bb) {
continue;
@ -91,18 +166,31 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
case Opcode::LShr:
case Opcode::AShr: {
auto* bin = static_cast<const BinaryInst*>(inst);
os << " " << bin->GetName() << " = "
<< OpcodeToString(bin->GetOpcode()) << " "
<< TypeToString(*bin->GetLhs()->GetType()) << " "
<< ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n";
// choose opcode name: integer ops use e.g. 'add','sub', float ops use 'fadd','fsub', etc.
std::string op_name = OpcodeToString(bin->GetOpcode());
bool is_float = bin->GetLhs()->GetType()->IsFloat32();
if (is_float) {
switch (bin->GetOpcode()) {
case Opcode::Add: op_name = "fadd"; break;
case Opcode::Sub: op_name = "fsub"; break;
case Opcode::Mul: op_name = "fmul"; break;
case Opcode::Div: op_name = "fdiv"; break;
case Opcode::Mod: op_name = "frem"; break;
default: break;
}
}
os << " %" << bin->GetName() << " = "
<< op_name << " "
<< TypeToString(*bin->GetLhs()->GetType()) << " "
<< ValueToString(bin->GetLhs()) << ", "
<< ValueToString(bin->GetRhs()) << "\n";
break;
}
case Opcode::ICmp:
case Opcode::FCmp: {
auto* cmp = static_cast<const CmpInst*>(inst);
os << " " << cmp->GetName() << " = "
<< OpcodeToString(cmp->GetOpcode()) << " eq "
os << " %" << cmp->GetName() << " = "
<< OpcodeToString(cmp->GetOpcode()) << " " << PredicateToString(cmp->GetPredicate(), cmp->GetOpcode() == Opcode::FCmp) << " "
<< TypeToString(*cmp->GetLhs()->GetType()) << " "
<< ValueToString(cmp->GetLhs()) << ", "
<< ValueToString(cmp->GetRhs()) << "\n";
@ -110,16 +198,16 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
}
case Opcode::Alloca: {
auto* alloca = static_cast<const AllocaInst*>(inst);
os << " " << alloca->GetName() << " = alloca "
<< TypeToString(*static_cast<const PointerType*>(alloca->GetType().get())->GetPointeeType()) << "\n";
os << " %" << alloca->GetName() << " = alloca "
<< TypeToString(*static_cast<const PointerType*>(alloca->GetType().get())->GetPointeeType()) << "\n";
break;
}
case Opcode::Load: {
auto* load = static_cast<const LoadInst*>(inst);
os << " " << load->GetName() << " = load "
<< TypeToString(*load->GetType()) << ", "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n";
os << " %" << load->GetName() << " = load "
<< TypeToString(*load->GetType()) << ", "
<< TypeToString(*load->GetPtr()->GetType()) << " "
<< ValueToString(load->GetPtr()) << "\n";
break;
}
case Opcode::Store: {
@ -152,7 +240,7 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
auto* call = static_cast<const CallInst*>(inst);
os << " ";
if (!call->GetName().empty()) {
os << call->GetName() << " = ";
os << "%" << call->GetName() << " = ";
}
os << "call " << TypeToString(*call->GetCallee()->GetType()) << " @"
<< call->GetCallee()->GetName() << "(";
@ -166,9 +254,12 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
}
case Opcode::GEP: {
auto* gep = static_cast<const GetElementPtrInst*>(inst);
os << " " << gep->GetName() << " = getelementptr "
<< TypeToString(*gep->GetPtr()->GetType()) << " "
<< ValueToString(gep->GetPtr());
os << " %" << gep->GetName() << " = getelementptr ";
// Print element type first, then the pointer type and pointer value
const auto ptrType = gep->GetPtr()->GetType();
const PointerType* pty = static_cast<const PointerType*>(ptrType.get());
os << TypeToString(*pty->GetPointeeType()) << ", "
<< TypeToString(*ptrType) << " " << ValueToString(gep->GetPtr());
for (auto* idx : gep->GetIndices()) {
os << ", " << TypeToString(*idx->GetType()) << " " << ValueToString(idx);
}
@ -177,8 +268,8 @@ void IRPrinter::Print(const Module& module, std::ostream& os) {
}
case Opcode::Phi: {
auto* phi = static_cast<const PhiInst*>(inst);
os << " " << phi->GetName() << " = phi "
<< TypeToString(*phi->GetType());
os << " %" << phi->GetName() << " = phi "
<< TypeToString(*phi->GetType());
for (const auto& incoming : phi->GetIncomings()) {
os << " [ " << ValueToString(incoming.first) << ", %"
<< incoming.second->GetName() << " ]";

@ -63,8 +63,8 @@ void Instruction::SetParent(BasicBlock* parent) { parent_ = parent; }
BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
Value* rhs, std::string name)
: Instruction(op, std::move(ty), std::move(name)) {
if (op != Opcode::Add) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 Add"));
if (op != Opcode::Add && op != Opcode::Sub && op != Opcode::Mul && op != Opcode::Div && op != Opcode::Mod) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持算术操作"));
}
if (!lhs || !rhs) {
throw std::runtime_error(FormatError("ir", "BinaryInst 缺少操作数"));
@ -76,8 +76,8 @@ BinaryInst::BinaryInst(Opcode op, std::shared_ptr<Type> ty, Value* lhs,
type_->GetKind() != lhs->GetType()->GetKind()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 类型不匹配"));
}
if (!type_->IsInt32()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32"));
if (!type_->IsInt32() && !type_->IsFloat32()) {
throw std::runtime_error(FormatError("ir", "BinaryInst 当前只支持 i32 和 float"));
}
AddOperand(lhs);
AddOperand(rhs);

@ -5,6 +5,7 @@
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
#include <functional>
// helper functions removed; VarDef uses Ident() directly per current grammar.
@ -22,6 +23,52 @@ std::any IRGenImpl::visitBlock(SysYParser::BlockContext* ctx) {
return {};
}
std::any IRGenImpl::visitConstDef(SysYParser::ConstDefContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "缺少常量定义"));
if (!ctx->Ident()) throw std::runtime_error(FormatError("irgen", "常量声明缺少名称"));
if (!ctx->constInitVal()) throw std::runtime_error(FormatError("irgen", "常量必须初始化"));
if (storage_map_.find(ctx) != storage_map_.end()) throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
auto* slot = (current_btype_ == "float") ?
static_cast<ir::AllocaInst*>(builder_.CreateAllocaFloat(module_.GetContext().NextTemp())) :
static_cast<ir::AllocaInst*>(builder_.CreateAllocaI32(module_.GetContext().NextTemp()));
storage_map_[ctx] = slot;
name_map_[ctx->Ident()->getText()] = slot;
// Try to evaluate a scalar const initializer
ir::ConstantValue* cinit = nullptr;
try {
auto* initval = ctx->constInitVal();
if (initval && initval->constExp() && initval->constExp()->addExp()) {
if (current_btype_ == "float") {
auto* add = initval->constExp()->addExp();
float fv = std::stof(add->getText());
cinit = module_.GetContext().GetConstFloat(fv);
} else {
auto* add = initval->constExp()->addExp();
int iv = std::stoi(add->getText());
cinit = module_.GetContext().GetConstInt(iv);
}
}
} catch(...) {
// fallback: try evaluate via visitor
try {
auto* add = ctx->constInitVal()->constExp()->addExp();
ir::Value* v = std::any_cast<ir::Value*>(add->accept(this));
if (auto* cv = dynamic_cast<ir::ConstantValue*>(v)) cinit = cv;
} catch(...) {}
}
if (cinit) builder_.CreateStore(cinit, slot);
else builder_.CreateStore((current_btype_=="float"? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) : (ir::Value*)module_.GetContext().GetConstInt(0)), slot);
// record simple integer consts for dimension evaluation
try {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(cinit)) {
const_values_[ctx->Ident()->getText()] = ci->GetValue();
}
} catch(...) {}
return {};
}
IRGenImpl::BlockFlow IRGenImpl::VisitBlockItemResult(
SysYParser::BlockItemContext& item) {
return std::any_cast<BlockFlow>(item.accept(this));
@ -63,6 +110,15 @@ std::any IRGenImpl::visitDecl(SysYParser::DeclContext* ctx) {
}
return {};
}
if (ctx->constDecl()) {
auto* cdecl = ctx->constDecl();
if (!cdecl->bType()) throw std::runtime_error(FormatError("irgen", "缺少常量基类型"));
current_btype_ = cdecl->bType()->getText();
if (current_btype_ != "int" && current_btype_ != "float") throw std::runtime_error(FormatError("irgen", "当前仅支持局部 int/float 常量声明"));
for (auto* const_def : cdecl->constDef()) if (const_def) const_def->accept(this);
current_btype_.clear();
return {};
}
throw std::runtime_error(FormatError("irgen", "暂不支持的声明类型"));
}
@ -81,9 +137,156 @@ std::any IRGenImpl::visitVarDef(SysYParser::VarDefContext* ctx) {
if (storage_map_.find(ctx) != storage_map_.end()) {
throw std::runtime_error(FormatError("irgen", "声明重复生成存储槽位"));
}
// check if this is an array declaration (has constExp dimensions)
if (!ctx->constExp().empty()) {
// parse dims
std::vector<int> dims;
for (auto* ce : ctx->constExp()) {
try {
int v = 0;
auto anyv = sema_.GetConstVal(ce);
if (anyv.has_value()) {
if (anyv.type() == typeid(int)) v = std::any_cast<int>(anyv);
else if (anyv.type() == typeid(long)) v = (int)std::any_cast<long>(anyv);
else throw std::runtime_error("not-const-int");
} else {
// try simple patterns like NUM or IDENT+NUM or NUM+IDENT
std::string s = ce->addExp()->getText();
s.erase(std::remove_if(s.begin(), s.end(), ::isspace), s.end());
auto pos = s.find('+');
if (pos == std::string::npos) {
// plain number or identifier
try { v = std::stoi(s); }
catch(...) {
// try lookup identifier in recorded consts or symbol table
auto it = const_values_.find(s);
if (it != const_values_.end()) v = (int)it->second;
else {
VarInfo vi; void* declctx = nullptr;
if (sema_.GetSymbolTable().LookupVar(s, vi, declctx) && vi.const_val.has_value()) {
if (vi.const_val.type() == typeid(int)) v = std::any_cast<int>(vi.const_val);
else if (vi.const_val.type() == typeid(long)) v = (int)std::any_cast<long>(vi.const_val);
else throw std::runtime_error("not-const-int");
} else throw std::runtime_error("not-const-int");
}
}
} else {
// form A+B where A or B may be ident or number
std::string L = s.substr(0, pos);
std::string R = s.substr(pos + 1);
int lv = 0, rv = 0; bool ok = false;
// try left
try { lv = std::stoi(L); ok = true; } catch(...) {
auto it = const_values_.find(L);
if (it != const_values_.end()) { lv = (int)it->second; ok = true; }
else {
VarInfo vi; void* declctx = nullptr;
if (sema_.GetSymbolTable().LookupVar(L, vi, declctx) && vi.const_val.has_value()) {
if (vi.const_val.type() == typeid(int)) lv = std::any_cast<int>(vi.const_val);
else if (vi.const_val.type() == typeid(long)) lv = (int)std::any_cast<long>(vi.const_val);
ok = true;
}
}
}
// try right
try { rv = std::stoi(R); ok = ok && true; } catch(...) {
auto it2 = const_values_.find(R);
if (it2 != const_values_.end()) { rv = (int)it2->second; ok = ok && true; }
else {
VarInfo vi2; void* declctx2 = nullptr;
if (sema_.GetSymbolTable().LookupVar(R, vi2, declctx2) && vi2.const_val.has_value()) {
if (vi2.const_val.type() == typeid(int)) rv = std::any_cast<int>(vi2.const_val);
else if (vi2.const_val.type() == typeid(long)) rv = (int)std::any_cast<long>(vi2.const_val);
ok = ok && true;
} else ok = false;
}
}
if (!ok) throw std::runtime_error("not-const-int");
v = lv + rv;
}
}
dims.push_back(v);
} catch (...) {
throw std::runtime_error(FormatError("irgen", "数组维度必须为常量整数"));
}
}
std::shared_ptr<ir::Type> elemTy = (current_btype_ == "float") ? ir::Type::GetFloat32Type() : ir::Type::GetInt32Type();
// build nested array type
std::function<std::shared_ptr<ir::Type>(size_t)> makeArrayType = [&](size_t level) -> std::shared_ptr<ir::Type> {
if (level + 1 >= dims.size()) return ir::Type::GetArrayType(elemTy, dims[level]);
auto sub = makeArrayType(level + 1);
return ir::Type::GetArrayType(sub, dims[level]);
};
auto fullArrayTy = makeArrayType(0);
auto arr_ptr_ty = ir::Type::GetPointerType(fullArrayTy);
auto* array_slot = builder_.CreateAlloca(arr_ptr_ty, module_.GetContext().NextTemp());
storage_map_[ctx] = array_slot;
name_map_[ctx->Ident()->getText()] = array_slot;
// compute spans and total scalar slots
int nlevels = (int)dims.size();
std::vector<int> span(nlevels);
int total = 1;
for (int i = nlevels - 1; i >= 0; --i) {
if (i == nlevels - 1) span[i] = 1;
else span[i] = span[i + 1] * dims[i + 1];
total *= dims[i];
}
ir::Value* zero = elemTy->IsFloat32() ? (ir::Value*)module_.GetContext().GetConstFloat(0.0f) : (ir::Value*)module_.GetContext().GetConstInt(0);
std::vector<ir::Value*> slots(total, zero);
// process initializer (if any) into linear slots
if (auto* init_value = ctx->initVal()) {
std::function<void(SysYParser::InitValContext*, int, int&)> process_group;
process_group = [&](SysYParser::InitValContext* init, int level, int& pos) {
if (level >= nlevels) return;
int sub_span = span[level];
int elems = dims[level];
if (!init) { pos += elems * sub_span; return; }
for (auto* child : init->initVal()) {
if (pos >= total) break;
if (!child) { pos += 1; continue; }
if (!child->initVal().empty()) {
int subpos = pos;
int inner = subpos;
process_group(child, level + 1, inner);
pos = subpos + sub_span;
} else if (child->exp()) {
try { ir::Value* v = EvalExpr(*child->exp()); if (pos < total) slots[pos] = v; } catch(...) {}
pos += 1;
} else { pos += 1; }
}
};
int pos0 = 0;
process_group(init_value, 0, pos0);
}
// emit stores for each scalar slot in row-major order
for (int idx = 0; idx < total; ++idx) {
std::vector<int> indices;
int rem = idx;
for (int L = 0; L < nlevels; ++L) {
int ind = rem / span[L];
indices.push_back(ind % dims[L]);
rem = rem % span[L];
}
std::vector<ir::Value*> gep_inds;
gep_inds.push_back(module_.GetContext().GetConstInt(0));
for (int v : indices) gep_inds.push_back(module_.GetContext().GetConstInt(v));
while (gep_inds.size() < (size_t)(1 + nlevels)) gep_inds.push_back(module_.GetContext().GetConstInt(0));
auto* gep = builder_.CreateGEP(array_slot, gep_inds, module_.GetContext().NextTemp());
builder_.CreateStore(slots[idx], gep);
}
return {};
}
// scalar variable
auto* slot = builder_.CreateAllocaI32(module_.GetContext().NextTemp());
storage_map_[ctx] = slot;
ir::Value* init = nullptr;
if (auto* init_value = ctx->initVal()) {
if (!init_value->exp()) {

@ -35,11 +35,14 @@ std::any IRGenImpl::visitPrimaryExp(SysYParser::PrimaryExpContext* ctx) {
std::any IRGenImpl::visitNumber(SysYParser::NumberContext* ctx) {
if (!ctx || !ctx->IntConst()) {
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量"));
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法数字字面量"));
if (ctx->IntConst()) {
return static_cast<ir::Value*>(builder_.CreateConstInt(std::stoi(ctx->getText())));
}
return static_cast<ir::Value*>(
builder_.CreateConstInt(std::stoi(ctx->getText())));
if (ctx->FloatConst()) {
return static_cast<ir::Value*>(builder_.CreateConstFloat(std::stof(ctx->getText())));
}
throw std::runtime_error(FormatError("irgen", "当前仅支持整数字面量或浮点字面量"));
}
// 变量使用的处理流程:
@ -55,11 +58,59 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
// find storage by matching declaration node stored in Sema context
// Sema stores types/decl contexts in IRGenContext maps; here we search storage_map_ by name
std::string name = ctx->Ident()->getText();
// 优先使用按名称的快速映射
auto nit = name_map_.find(name);
if (nit != name_map_.end()) {
// 支持下标访问:若有索引表达式列表,则生成 GEP + load
if (ctx->exp().size() > 0) {
std::vector<ir::Value*> indices;
// 首个索引用于穿过数组对象
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->exp()) {
indices.push_back(EvalExpr(*e));
}
auto* gep = builder_.CreateGEP(nit->second, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
}
// 如果映射到的是常量,直接返回常量值;否则按原来行为从槽位 load
if (nit->second->IsConstant()) return nit->second;
return static_cast<ir::Value*>(builder_.CreateLoad(nit->second, module_.GetContext().NextTemp()));
}
for (auto& kv : storage_map_) {
// kv.first is VarDefContext*, try to get Ident text
if (kv.first && kv.first->Ident() && kv.first->Ident()->getText() == name) {
return static_cast<ir::Value*>(
builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
if (!kv.first) continue;
if (auto* vdef = dynamic_cast<SysYParser::VarDefContext*>(kv.first)) {
if (vdef->Ident() && vdef->Ident()->getText() == name) {
if (ctx->exp().size() > 0) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e));
auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
}
} else if (auto* fparam = dynamic_cast<SysYParser::FuncFParamContext*>(kv.first)) {
if (fparam->Ident() && fparam->Ident()->getText() == name) {
if (ctx->exp().size() > 0) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e));
auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
}
} else if (auto* cdef = dynamic_cast<SysYParser::ConstDefContext*>(kv.first)) {
if (cdef->Ident() && cdef->Ident()->getText() == name) {
if (ctx->exp().size() > 0) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->exp()) indices.push_back(EvalExpr(*e));
auto* gep = builder_.CreateGEP(kv.second, indices, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(builder_.CreateLoad(gep, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(builder_.CreateLoad(kv.second, module_.GetContext().NextTemp()));
}
}
}
throw std::runtime_error(FormatError("irgen", "变量声明缺少存储槽位: " + name));
@ -68,11 +119,190 @@ std::any IRGenImpl::visitLVal(SysYParser::LValContext* ctx) {
std::any IRGenImpl::visitAddExp(SysYParser::AddExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法加法表达式"));
// left-associative: evaluate first two mulExp as a simple binary add
try {
// left-associative: fold across all mulExp operands
if (ctx->mulExp().size() == 1) return ctx->mulExp(0)->accept(this);
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->mulExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->mulExp(1)->accept(this));
ir::Value* cur = std::any_cast<ir::Value*>(ctx->mulExp(0)->accept(this));
// extract operator sequence from text (in-order '+' or '-')
std::string text = ctx->getText();
std::vector<char> ops;
for (char c : text) if (c == '+' || c == '-') ops.push_back(c);
for (size_t i = 1; i < ctx->mulExp().size(); ++i) {
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->mulExp(i)->accept(this));
char opch = (i - 1 < ops.size()) ? ops[i - 1] : '+';
ir::Opcode op = (opch == '-') ? ir::Opcode::Sub : ir::Opcode::Add;
cur = builder_.CreateBinary(op, cur, rhs, module_.GetContext().NextTemp());
}
return static_cast<ir::Value*>(cur);
} catch (const std::exception& e) {
LogInfo(std::string("[irgen] exception in visitAddExp text=") + ctx->getText() + ", err=" + e.what(), std::cerr);
throw;
}
}
std::any IRGenImpl::visitMulExp(SysYParser::MulExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法乘法表达式"));
if (ctx->unaryExp().size() == 1) return ctx->unaryExp(0)->accept(this);
ir::Value* cur = std::any_cast<ir::Value*>(ctx->unaryExp(0)->accept(this));
// extract operator sequence for '*', '/', '%'
std::string text = ctx->getText();
std::vector<char> ops;
for (char c : text) if (c == '*' || c == '/' || c == '%') ops.push_back(c);
for (size_t i = 1; i < ctx->unaryExp().size(); ++i) {
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->unaryExp(i)->accept(this));
char opch = (i - 1 < ops.size()) ? ops[i - 1] : '*';
ir::Opcode op = ir::Opcode::Mul;
if (opch == '/') op = ir::Opcode::Div;
else if (opch == '%') op = ir::Opcode::Mod;
cur = builder_.CreateBinary(op, cur, rhs, module_.GetContext().NextTemp());
}
return static_cast<ir::Value*>(cur);
}
std::any IRGenImpl::visitUnaryExp(SysYParser::UnaryExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法一元表达式"));
if (ctx->primaryExp()) return ctx->primaryExp()->accept(this);
// function call: Ident '(' funcRParams? ')'
if (ctx->Ident() && ctx->getText().find("(") != std::string::npos) {
std::string fname = ctx->Ident()->getText();
std::vector<ir::Value*> args;
if (ctx->funcRParams()) {
for (auto* e : ctx->funcRParams()->exp()) {
args.push_back(EvalExpr(*e));
}
}
// find existing function or create an external declaration (assume int return)
ir::Function* callee = nullptr;
for (auto &fup : module_.GetFunctions()) {
if (fup && fup->GetName() == fname) { callee = fup.get(); break; }
}
if (!callee) {
std::vector<std::shared_ptr<ir::Type>> param_types;
for (auto* a : args) {
if (a && a->IsFloat32()) param_types.push_back(ir::Type::GetFloat32Type());
else param_types.push_back(ir::Type::GetInt32Type());
}
callee = module_.CreateFunction(fname, ir::Type::GetInt32Type(), param_types);
}
return static_cast<ir::Value*>(builder_.CreateCall(callee, args, module_.GetContext().NextTemp()));
}
if (ctx->unaryExp()) {
ir::Value* val = std::any_cast<ir::Value*>(ctx->unaryExp()->accept(this));
if (ctx->unaryOp() && ctx->unaryOp()->getText() == "+") return static_cast<ir::Value*>(val);
else if (ctx->unaryOp() && ctx->unaryOp()->getText() == "-") {
// 负号0 - val区分整型/浮点
if (val->IsFloat32()) {
ir::Value* zero = builder_.CreateConstFloat(0.0f);
return static_cast<ir::Value*>(builder_.CreateSub(zero, val, module_.GetContext().NextTemp()));
} else {
ir::Value* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(builder_.CreateSub(zero, val, module_.GetContext().NextTemp()));
}
}
if (ctx->unaryOp() && ctx->unaryOp()->getText() == "!") {
// logical not: produce int 1 if val == 0, else 0
if (val->IsFloat32()) {
ir::Value* zerof = builder_.CreateConstFloat(0.0f);
return static_cast<ir::Value*>(builder_.CreateFCmp(ir::CmpInst::EQ, val, zerof, module_.GetContext().NextTemp()));
} else {
ir::Value* zero = builder_.CreateConstInt(0);
return static_cast<ir::Value*>(builder_.CreateICmp(ir::CmpInst::EQ, val, zero, module_.GetContext().NextTemp()));
}
}
}
throw std::runtime_error(FormatError("irgen", "不支持的一元运算"));
}
std::any IRGenImpl::visitRelExp(SysYParser::RelExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法关系表达式"));
if (ctx->addExp().size() == 1) return ctx->addExp(0)->accept(this);
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->addExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->addExp(1)->accept(this));
// 类型提升
if (lhs->IsFloat32() && rhs->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(rhs)) {
rhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
} else {
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
}
} else if (rhs->IsFloat32() && lhs->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(lhs)) {
lhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
} else {
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
}
}
ir::CmpInst::Predicate pred = ir::CmpInst::EQ;
std::string text = ctx->getText();
if (text.find("<=") != std::string::npos) pred = ir::CmpInst::LE;
else if (text.find(">=") != std::string::npos) pred = ir::CmpInst::GE;
else if (text.find("<") != std::string::npos) pred = ir::CmpInst::LT;
else if (text.find(">") != std::string::npos) pred = ir::CmpInst::GT;
if (lhs->IsFloat32() || rhs->IsFloat32()) {
return static_cast<ir::Value*>(
builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(
builder_.CreateBinary(ir::Opcode::Add, lhs, rhs,
module_.GetContext().NextTemp()));
builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitEqExp(SysYParser::EqExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法相等表达式"));
if (ctx->relExp().size() == 1) return ctx->relExp(0)->accept(this);
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->relExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->relExp(1)->accept(this));
// 类型提升
if (lhs->IsFloat32() && rhs->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(rhs)) {
rhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
} else {
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
}
} else if (rhs->IsFloat32() && lhs->IsInt32()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(lhs)) {
lhs = builder_.CreateConstFloat(static_cast<float>(ci->GetValue()));
} else {
throw std::runtime_error(FormatError("irgen", "不支持 int 到 float 的隐式转换"));
}
}
ir::CmpInst::Predicate pred = ir::CmpInst::EQ;
std::string text = ctx->getText();
if (text.find("==") != std::string::npos) pred = ir::CmpInst::EQ;
else if (text.find("!=") != std::string::npos) pred = ir::CmpInst::NE;
if (lhs->IsFloat32() || rhs->IsFloat32()) {
return static_cast<ir::Value*>(
builder_.CreateFCmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
}
return static_cast<ir::Value*>(
builder_.CreateICmp(pred, lhs, rhs, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitLAndExp(SysYParser::LAndExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑与表达式"));
if (ctx->eqExp().size() == 1) return ctx->eqExp(0)->accept(this);
// For simplicity, treat as int (0 or 1)
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->eqExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->eqExp(1)->accept(this));
// lhs && rhs : (lhs != 0) && (rhs != 0)
ir::Value* zero = builder_.CreateConstInt(0);
ir::Value* lhs_ne = builder_.CreateICmp(ir::CmpInst::NE, lhs, zero, module_.GetContext().NextTemp());
ir::Value* rhs_ne = builder_.CreateICmp(ir::CmpInst::NE, rhs, zero, module_.GetContext().NextTemp());
return static_cast<ir::Value*>(
builder_.CreateMul(lhs_ne, rhs_ne, module_.GetContext().NextTemp()));
}
std::any IRGenImpl::visitLOrExp(SysYParser::LOrExpContext* ctx) {
if (!ctx) throw std::runtime_error(FormatError("irgen", "非法逻辑或表达式"));
if (ctx->lAndExp().size() == 1) return ctx->lAndExp(0)->accept(this);
ir::Value* lhs = std::any_cast<ir::Value*>(ctx->lAndExp(0)->accept(this));
ir::Value* rhs = std::any_cast<ir::Value*>(ctx->lAndExp(1)->accept(this));
// lhs || rhs : (lhs != 0) || (rhs != 0)
ir::Value* zero = builder_.CreateConstInt(0);
ir::Value* lhs_ne = builder_.CreateICmp(ir::CmpInst::NE, lhs, zero, module_.GetContext().NextTemp());
ir::Value* rhs_ne = builder_.CreateICmp(ir::CmpInst::NE, rhs, zero, module_.GetContext().NextTemp());
ir::Value* or_val = builder_.CreateAdd(lhs_ne, rhs_ne, module_.GetContext().NextTemp());
ir::Value* one = builder_.CreateConstInt(1);
return static_cast<ir::Value*>(
builder_.CreateICmp(ir::CmpInst::GE, or_val, one, module_.GetContext().NextTemp()));
}

@ -5,6 +5,7 @@
#include "SysYParser.h"
#include "ir/IR.h"
#include "utils/Log.h"
#include <functional>
namespace {
@ -42,7 +43,119 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
if (ctx->funcDef().empty()) {
throw std::runtime_error(FormatError("irgen", "缺少函数定义"));
}
ctx->funcDef(0)->accept(this);
// 先处理顶层声明(仅支持简单的 const int 初始化作为全局常量)
for (auto* decl : ctx->decl()) {
if (!decl) continue;
if (decl->constDecl()) {
auto* cdecl = decl->constDecl();
for (auto* cdef : cdecl->constDef()) {
if (!cdef || !cdef->Ident() || !cdef->constInitVal() || !cdef->constInitVal()->constExp()) continue;
// 仅支持形如: const int a = 10; 的简单常量初始化(字面量)
auto* add = cdef->constInitVal()->constExp()->addExp();
if (!add) continue;
try {
int v = std::stoi(add->getText());
auto* cval = module_.GetContext().GetConstInt(v);
name_map_[cdef->Ident()->getText()] = cval;
} catch (...) {
// 无法解析则跳过,全局复杂常量暂不支持
}
}
}
// 支持简单的全局变量声明(数组或标量),初始化为零
if (decl->varDecl()) {
auto* vdecl = decl->varDecl();
if (!vdecl->bType()) continue;
std::string btype = vdecl->bType()->getText();
for (auto* vdef : vdecl->varDef()) {
if (!vdef) continue;
LogInfo(std::string("[irgen] global varDef text=") + vdef->getText() + std::string(" ident=") + (vdef->Ident() ? vdef->Ident()->getText() : std::string("<none>")) + std::string(" dims=") + std::to_string((int)vdef->constExp().size()), std::cerr);
if (!vdef || !vdef->Ident()) continue;
std::string name = vdef->Ident()->getText();
// array globals
if (!vdef->constExp().empty()) {
std::vector<int> dims;
bool ok = true;
for (auto* ce : vdef->constExp()) {
try {
int val = 0;
auto anyv = sema_.GetConstVal(ce);
if (anyv.has_value()) {
if (anyv.type() == typeid(int)) val = std::any_cast<int>(anyv);
else if (anyv.type() == typeid(long)) val = (int)std::any_cast<long>(anyv);
else throw std::runtime_error("not-const-int");
} else {
// try literal parse
try {
val = std::stoi(ce->addExp()->getText());
} catch (...) {
// try lookup in name_map_ for previously created const
std::string t = ce->addExp()->getText();
auto it = name_map_.find(t);
if (it != name_map_.end() && it->second && it->second->IsConstant()) {
if (auto* ci = dynamic_cast<ir::ConstantInt*>(it->second)) {
val = ci->GetValue();
} else {
ok = false; break;
}
} else {
ok = false; break;
}
}
}
dims.push_back(val);
} catch (...) { ok = false; break; }
}
if (!ok) continue;
// build zero constant array similar to visitVarDef
std::function<ir::ConstantValue*(const std::vector<int>&, size_t, std::shared_ptr<ir::Type>)> buildZero;
buildZero = [&](const std::vector<int>& ds, size_t idx, std::shared_ptr<ir::Type> elemTy) -> ir::ConstantValue* {
if (idx >= ds.size()) return nullptr;
std::vector<ir::ConstantValue*> elems;
if (idx + 1 == ds.size()) {
for (int i = 0; i < ds[idx]; ++i) {
if (elemTy->IsFloat32()) elems.push_back(module_.GetContext().GetConstFloat(0.0f));
else elems.push_back(module_.GetContext().GetConstInt(0));
}
} else {
for (int i = 0; i < ds[idx]; ++i) {
ir::ConstantValue* sub = buildZero(ds, idx + 1, elemTy);
if (sub) elems.push_back(sub);
else elems.push_back(module_.GetContext().GetConstInt(0));
}
}
std::function<std::shared_ptr<ir::Type>(size_t)> makeArrayType = [&](size_t level) -> std::shared_ptr<ir::Type> {
if (level + 1 >= ds.size()) return ir::Type::GetArrayType(elemTy, ds[level]);
auto sub = makeArrayType(level + 1);
return ir::Type::GetArrayType(sub, ds[level]);
};
auto at_real = makeArrayType(idx);
return new ir::ConstantArray(at_real, elems);
};
std::shared_ptr<ir::Type> elemTy = (btype == "float") ? ir::Type::GetFloat32Type() : ir::Type::GetInt32Type();
ir::ConstantValue* zero = buildZero(dims, 0, elemTy);
auto gvty = ir::Type::GetPointerType(zero ? zero->GetType() : ir::Type::GetPointerType(elemTy));
ir::GlobalValue* gv = module_.CreateGlobalVariable(name, gvty, zero);
name_map_[name] = gv;
LogInfo(std::string("[irgen] created global ") + name, std::cerr);
} else {
// scalar global
std::shared_ptr<ir::Type> elemTy = (btype == "float") ? ir::Type::GetFloat32Type() : ir::Type::GetInt32Type();
ir::ConstantValue* init = nullptr;
if (btype == "float") init = module_.GetContext().GetConstFloat(0.0f);
else init = module_.GetContext().GetConstInt(0);
ir::GlobalValue* gv = module_.CreateGlobalVariable(name, ir::Type::GetPointerType(elemTy), init);
name_map_[name] = gv;
LogInfo(std::string("[irgen] created global ") + name, std::cerr);
}
}
}
}
// 生成编译单元中所有函数定义(之前只生成第一个函数)
for (size_t i = 0; i < ctx->funcDef().size(); ++i) {
if (ctx->funcDef(i)) ctx->funcDef(i)->accept(this);
}
return {};
}
@ -61,6 +174,7 @@ std::any IRGenImpl::visitCompUnit(SysYParser::CompUnitContext* ctx) {
// - 入口块中的参数初始化逻辑。
// ...
// 因此这里目前只支持最小的“无参 int 函数”生成。
// 因此这里目前只支持最小的“无参 int 函数”生成。
std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
if (!ctx) {
@ -72,15 +186,50 @@ std::any IRGenImpl::visitFuncDef(SysYParser::FuncDefContext* ctx) {
if (!ctx->Ident()) {
throw std::runtime_error(FormatError("irgen", "缺少函数名"));
}
if (!ctx->funcType() || ctx->funcType()->getText() != "int") {
throw std::runtime_error(FormatError("irgen", "当前仅支持无参 int 函数"));
if (!ctx->funcType()) {
throw std::runtime_error(FormatError("irgen", "缺少函数返回类型"));
}
std::shared_ptr<ir::Type> ret_type;
if (ctx->funcType()->getText() == "int") ret_type = ir::Type::GetInt32Type();
else if (ctx->funcType()->getText() == "float") ret_type = ir::Type::GetFloat32Type();
else if (ctx->funcType()->getText() == "void") ret_type = ir::Type::GetVoidType();
else throw std::runtime_error(FormatError("irgen", "当前仅支持 int/float/void 函数"));
std::vector<std::shared_ptr<ir::Type>> param_types;
if (ctx->funcFParams()) {
for (auto* param : ctx->funcFParams()->funcFParam()) {
if (param->bType()->getText() == "int") param_types.push_back(ir::Type::GetInt32Type());
else if (param->bType()->getText() == "float") param_types.push_back(ir::Type::GetFloat32Type());
else throw std::runtime_error(FormatError("irgen", "当前仅支持 int/float 参数"));
}
}
func_ = module_.CreateFunction(ctx->Ident()->getText(), ir::Type::GetInt32Type());
func_ = module_.CreateFunction(ctx->Ident()->getText(), ret_type, param_types);
builder_.SetInsertPoint(func_->GetEntry());
storage_map_.clear();
// Allocate storage for parameters
if (ctx->funcFParams()) {
int idx = 0;
for (auto* param : ctx->funcFParams()->funcFParam()) {
std::string param_name = param->Ident()->getText();
ir::AllocaInst* alloca = nullptr;
if (param->bType()->getText() == "float") alloca = builder_.CreateAllocaFloat(param_name);
else alloca = builder_.CreateAllocaI32(param_name);
storage_map_[param] = alloca;
name_map_[param_name] = alloca;
// Store the argument value
auto* arg = func_->GetParams()[idx];
builder_.CreateStore(arg, alloca);
idx++;
}
}
ctx->block()->accept(this);
// 如果函数体末尾没有显式终结(如 void 函数没有 return补一个隐式 return
if (!builder_.GetInsertBlock()->HasTerminator()) {
builder_.CreateRet(nullptr);
}
// 语义正确性主要由 sema 保证,这里只兜底检查 IR 结构是否合法。
VerifyFunctionStructure(*func_);
return {};

@ -19,8 +19,10 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
if (!ctx) {
throw std::runtime_error(FormatError("irgen", "缺少语句"));
}
// stmt can be many forms; handle return specifically
if (!ctx->getText().empty() && ctx->getText().find("return") != std::string::npos) {
std::string text = ctx->getText();
LogInfo("[irgen] visitStmt text='" + text + "' break_size=" + std::to_string(break_targets_.size()) + " cont_size=" + std::to_string(continue_targets_.size()), std::cerr);
// return
if (ctx->getStart()->getText() == "return") {
if (ctx->exp()) {
ir::Value* v = EvalExpr(*ctx->exp());
builder_.CreateRet(v);
@ -29,6 +31,158 @@ std::any IRGenImpl::visitStmt(SysYParser::StmtContext* ctx) {
}
return BlockFlow::Terminated;
}
// assignment: lVal '=' exp
if (ctx->lVal() && text.find("=") != std::string::npos) {
ir::Value* val = EvalExpr(*ctx->exp());
std::string name = ctx->lVal()->Ident()->getText();
// 优先检查按名称的快速映射(支持全局变量)
auto nit = name_map_.find(name);
if (nit != name_map_.end()) {
// 支持带索引的赋值
if (ctx->lVal()->exp().size() > 0) {
std::vector<ir::Value*> indices;
indices.push_back(builder_.CreateConstInt(0));
for (auto* e : ctx->lVal()->exp()) indices.push_back(EvalExpr(*e));
auto* gep = builder_.CreateGEP(nit->second, indices, module_.GetContext().NextTemp());
builder_.CreateStore(val, gep);
return BlockFlow::Continue;
}
builder_.CreateStore(val, nit->second);
return BlockFlow::Continue;
}
for (auto& kv : storage_map_) {
if (!kv.first) continue;
if (auto* vdef = dynamic_cast<SysYParser::VarDefContext*>(kv.first)) {
if (vdef->Ident() && vdef->Ident()->getText() == name) {
builder_.CreateStore(val, kv.second);
return BlockFlow::Continue;
}
} else if (auto* fparam = dynamic_cast<SysYParser::FuncFParamContext*>(kv.first)) {
if (fparam->Ident() && fparam->Ident()->getText() == name) {
builder_.CreateStore(val, kv.second);
return BlockFlow::Continue;
}
} else if (auto* cdef = dynamic_cast<SysYParser::ConstDefContext*>(kv.first)) {
if (cdef->Ident() && cdef->Ident()->getText() == name) {
builder_.CreateStore(val, kv.second);
return BlockFlow::Continue;
}
}
}
throw std::runtime_error(FormatError("irgen", "变量未声明: " + name));
}
// if
if (ctx->getStart()->getText() == "if" && ctx->cond()) {
ir::Value* condv = std::any_cast<ir::Value*>(ctx->cond()->lOrExp()->accept(this));
ir::BasicBlock* then_bb = func_->CreateBlock("if.then");
ir::BasicBlock* else_bb = (ctx->stmt().size() > 1) ? func_->CreateBlock("if.else") : nullptr;
ir::BasicBlock* merge_bb = func_->CreateBlock("if.merge");
if (else_bb) builder_.CreateCondBr(condv, then_bb, else_bb);
else builder_.CreateCondBr(condv, then_bb, merge_bb);
// then
builder_.SetInsertPoint(then_bb);
ctx->stmt(0)->accept(this);
if (!builder_.GetInsertBlock()->HasTerminator()) builder_.CreateBr(merge_bb);
// else
if (else_bb) {
builder_.SetInsertPoint(else_bb);
ctx->stmt(1)->accept(this);
if (!builder_.GetInsertBlock()->HasTerminator()) builder_.CreateBr(merge_bb);
}
builder_.SetInsertPoint(merge_bb);
return BlockFlow::Continue;
}
// while
if (ctx->getStart()->getText() == "while" && ctx->cond()) {
ir::BasicBlock* cond_bb = func_->CreateBlock("while.cond");
ir::BasicBlock* body_bb = func_->CreateBlock("while.body");
ir::BasicBlock* after_bb = func_->CreateBlock("while.after");
builder_.CreateBr(cond_bb);
// cond
builder_.SetInsertPoint(cond_bb);
ir::Value* condv = std::any_cast<ir::Value*>(ctx->cond()->lOrExp()->accept(this));
builder_.CreateCondBr(condv, body_bb, after_bb);
// body
builder_.SetInsertPoint(body_bb);
LogInfo("[irgen] while body about to push targets, before sizes: break=" + std::to_string(break_targets_.size()) + ", cont=" + std::to_string(continue_targets_.size()), std::cerr);
break_targets_.push_back(after_bb);
continue_targets_.push_back(cond_bb);
LogInfo("[irgen] after push: break_targets size=" + std::to_string(break_targets_.size()) + ", continue_targets size=" + std::to_string(continue_targets_.size()), std::cerr);
ctx->stmt(0)->accept(this);
LogInfo("[irgen] before pop: break_targets size=" + std::to_string(break_targets_.size()) + ", continue_targets size=" + std::to_string(continue_targets_.size()), std::cerr);
continue_targets_.pop_back();
break_targets_.pop_back();
LogInfo("[irgen] after pop: break_targets size=" + std::to_string(break_targets_.size()) + ", continue_targets size=" + std::to_string(continue_targets_.size()), std::cerr);
if (!builder_.GetInsertBlock()->HasTerminator()) builder_.CreateBr(cond_bb);
builder_.SetInsertPoint(after_bb);
return BlockFlow::Continue;
}
// break
if (ctx->getStart()->getText() == "break") {
if (break_targets_.empty()) {
// fallback: 尝试通过函数块名找目标(不依赖 sema兼容因栈丢失导致的情况
ir::BasicBlock* fallback = nullptr;
for (auto &bb_up : func_->GetBlocks()) {
auto *bb = bb_up.get();
if (!bb) continue;
if (bb->GetName().find("while.after") != std::string::npos) fallback = bb;
}
if (fallback) {
LogInfo("[irgen] emit break (fallback), target=" + fallback->GetName(), std::cerr);
builder_.CreateBr(fallback);
return BlockFlow::Terminated;
}
throw std::runtime_error(FormatError("irgen", "break 不在循环内"));
}
LogInfo("[irgen] emit break, break_targets size=" + std::to_string(break_targets_.size()), std::cerr);
builder_.CreateBr(break_targets_.back());
return BlockFlow::Terminated;
}
// continue
if (ctx->getStart()->getText() == "continue") {
if (continue_targets_.empty()) {
ir::BasicBlock* fallback = nullptr;
for (auto &bb_up : func_->GetBlocks()) {
auto *bb = bb_up.get();
if (!bb) continue;
if (bb->GetName().find("while.cond") != std::string::npos) fallback = bb;
}
if (fallback) {
LogInfo("[irgen] emit continue (fallback), target=" + fallback->GetName(), std::cerr);
builder_.CreateBr(fallback);
return BlockFlow::Terminated;
}
throw std::runtime_error(FormatError("irgen", "continue 不在循环内"));
}
LogInfo("[irgen] emit continue, continue_targets size=" + std::to_string(continue_targets_.size()), std::cerr);
builder_.CreateBr(continue_targets_.back());
return BlockFlow::Terminated;
}
// block
if (ctx->block()) {
return ctx->block()->accept(this);
}
// expression statement
if (ctx->exp()) {
EvalExpr(*ctx->exp());
return BlockFlow::Continue;
}
throw std::runtime_error(FormatError("irgen", "暂不支持的语句类型"));
}

Loading…
Cancel
Save