lab4实验,批量测试脚本已写

pull/1/head
pu9sp4t32 2 weeks ago
parent 69f2cdf11a
commit 184d5c7cb5

@ -33,7 +33,7 @@ target_include_directories(build_options INTERFACE
option(COMPILER_ENABLE_WARNINGS "Enable common compiler warnings" ON)
if(COMPILER_ENABLE_WARNINGS)
if(MSVC)
target_compile_options(build_options INTERFACE /W4)
target_compile_options(build_options INTERFACE /W4 /utf-8)
else()
target_compile_options(build_options INTERFACE -Wall -Wextra -Wpedantic)
endif()
@ -52,6 +52,7 @@ set(ANTLR4_RUNTIME_SRC_DIR "${PROJECT_SOURCE_DIR}/third_party/antlr4-runtime-4.1
add_library(antlr4_runtime STATIC)
target_compile_features(antlr4_runtime PUBLIC cxx_std_17)
target_compile_definitions(antlr4_runtime PUBLIC ANTLR4CPP_STATIC)
target_include_directories(antlr4_runtime PUBLIC
"${ANTLR4_RUNTIME_SRC_DIR}"

@ -109,3 +109,8 @@ cmake --build build -j "$(nproc)"
目标:脚本自动读取同名 `.in`,并将程序输出与退出码和同名 `.out` 比对,确保优化后程序行为与优化前保持一致。
完成 Lab4 后,应对 `test/test_case` 下全部测试用例逐个回归;如有需要,也可以自行编写批量测试脚本统一执行。
批量测试脚本:
bash test/test_result/lab4_batch/run_all.sh

@ -228,10 +228,13 @@ class User : public Value {
size_t GetNumOperands() const;
Value* GetOperand(size_t index) const;
void SetOperand(size_t index, Value* value);
void RemoveOperand(size_t index);
protected:
// 统一的 operand 入口。
void AddOperand(Value* value);
virtual void OnOperandChanged(size_t index, Value* value);
virtual void OnOperandRemoving(size_t index);
private:
std::vector<Value*> operands_;
@ -360,6 +363,8 @@ class CallInst : public Instruction {
const std::vector<Value*>& GetArgs() const { return args_; }
private:
void OnOperandChanged(size_t index, Value* value) override;
void OnOperandRemoving(size_t index) override;
std::vector<Value*> args_;
};
@ -367,10 +372,13 @@ class PhiInst : public Instruction {
public:
PhiInst(std::shared_ptr<Type> ty, std::string name);
void AddIncoming(Value* value, BasicBlock* block);
void RemoveIncomingFrom(BasicBlock* block);
const std::vector<Value*>& GetIncomingValues() const;
const std::vector<BasicBlock*>& GetIncomingBlocks() const;
private:
void OnOperandChanged(size_t index, Value* value) override;
void OnOperandRemoving(size_t index) override;
std::vector<Value*> incoming_values_;
std::vector<BasicBlock*> incoming_blocks_;
};
@ -383,6 +391,8 @@ class GepInst : public Instruction {
const std::vector<Value*>& GetIndices() const { return indices_; }
private:
void OnOperandChanged(size_t index, Value* value) override;
void OnOperandRemoving(size_t index) override;
std::vector<Value*> indices_;
};
@ -394,10 +404,17 @@ class BasicBlock : public Value {
void SetParent(Function* parent);
bool HasTerminator() const;
const std::vector<std::unique_ptr<Instruction>>& GetInstructions() const;
std::vector<std::unique_ptr<Instruction>>& GetMutableInstructions();
const std::vector<BasicBlock*>& GetPredecessors() const;
const std::vector<BasicBlock*>& GetSuccessors() const;
void AddPredecessor(BasicBlock* pred);
void AddSuccessor(BasicBlock* succ);
void ClearPredecessors();
void ClearSuccessors();
void RemovePredecessor(BasicBlock* pred);
void RemoveSuccessor(BasicBlock* succ);
void EraseInstruction(Instruction* inst);
void ReplaceTerminator(std::unique_ptr<Instruction> inst);
template <typename T, typename... Args>
T* Append(Args&&... args) {
if (HasTerminator()) {
@ -443,6 +460,7 @@ class Function : public Value {
BasicBlock* GetEntry();
const BasicBlock* GetEntry() const;
const std::vector<std::unique_ptr<BasicBlock>>& GetBlocks() const;
std::vector<std::unique_ptr<BasicBlock>>& GetMutableBlocks();
const std::vector<std::unique_ptr<Argument>>& GetArguments() const;
size_t GetNumArgs() const;
Argument* GetArg(size_t index);
@ -548,4 +566,12 @@ class IRPrinter {
void Print(const Module& module, std::ostream& os);
};
bool RunMem2Reg(Module& module);
bool RunConstFold(Module& module);
bool RunConstProp(Module& module);
bool RunCSE(Module& module);
bool RunDCE(Module& module);
bool RunCFGSimplify(Module& module);
bool RunScalarOptimizationPipeline(Module& module);
} // namespace ir

@ -8,6 +8,7 @@ struct CLIOptions {
bool emit_parse_tree = false;
bool emit_ir = true;
bool emit_asm = false;
bool optimize_ir = true;
bool show_help = false;
};

@ -9,6 +9,7 @@
#include "ir/IR.h"
#include <algorithm>
#include <utility>
namespace ir {
@ -32,6 +33,10 @@ const std::vector<std::unique_ptr<Instruction>>& BasicBlock::GetInstructions()
return instructions_;
}
std::vector<std::unique_ptr<Instruction>>& BasicBlock::GetMutableInstructions() {
return instructions_;
}
// 前驱/后继接口先保留给后续 CFG 扩展使用。
// 当前最小 IR 中还没有 branch 指令,因此这些列表通常为空。
const std::vector<BasicBlock*>& BasicBlock::GetPredecessors() const {
@ -58,6 +63,51 @@ void BasicBlock::AddSuccessor(BasicBlock* succ) {
successors_.push_back(succ);
}
void BasicBlock::ClearPredecessors() { predecessors_.clear(); }
void BasicBlock::ClearSuccessors() { successors_.clear(); }
void BasicBlock::RemovePredecessor(BasicBlock* pred) {
predecessors_.erase(std::remove(predecessors_.begin(), predecessors_.end(), pred),
predecessors_.end());
}
void BasicBlock::RemoveSuccessor(BasicBlock* succ) {
successors_.erase(std::remove(successors_.begin(), successors_.end(), succ),
successors_.end());
}
void BasicBlock::EraseInstruction(Instruction* inst) {
if (!inst) return;
auto it = std::find_if(instructions_.begin(), instructions_.end(),
[&](const auto& ptr) { return ptr.get() == inst; });
if (it == instructions_.end()) return;
for (size_t i = 0; i < inst->GetNumOperands(); ++i) {
if (auto* operand = inst->GetOperand(i)) {
operand->RemoveUse(inst, i);
}
}
inst->SetParent(nullptr);
instructions_.erase(it);
}
void BasicBlock::ReplaceTerminator(std::unique_ptr<Instruction> inst) {
if (!inst || !inst->IsTerminator()) return;
if (!instructions_.empty() && instructions_.back()->IsTerminator()) {
auto* old = instructions_.back().get();
for (size_t i = 0; i < old->GetNumOperands(); ++i) {
if (auto* operand = old->GetOperand(i)) {
operand->RemoveUse(old, i);
}
}
old->SetParent(nullptr);
instructions_.pop_back();
}
inst->SetParent(this);
instructions_.push_back(std::move(inst));
LinkSuccessorsIfNeeded(instructions_.back().get());
}
void BasicBlock::LinkSuccessorsIfNeeded(Instruction* inst) {
if (!inst) return;
if (auto* br = dynamic_cast<BranchInst*>(inst)) {

@ -48,6 +48,10 @@ const std::vector<std::unique_ptr<BasicBlock>>& Function::GetBlocks() const {
return blocks_;
}
std::vector<std::unique_ptr<BasicBlock>>& Function::GetMutableBlocks() {
return blocks_;
}
const std::vector<std::unique_ptr<Argument>>& Function::GetArguments() const {
return args_;
}

@ -3,6 +3,7 @@
// - 指令操作数与结果类型管理,支持打印与优化
#include "ir/IR.h"
#include <cstddef>
#include <stdexcept>
#include "utils/Log.h"
@ -36,6 +37,7 @@ void User::SetOperand(size_t index, Value* value) {
}
operands_[index] = value;
value->AddUse(this, index);
OnOperandChanged(index, value);
}
void User::AddOperand(Value* value) {
@ -47,6 +49,27 @@ void User::AddOperand(Value* value) {
value->AddUse(this, operand_index);
}
void User::RemoveOperand(size_t index) {
if (index >= operands_.size()) {
throw std::out_of_range("User operand index out of range");
}
OnOperandRemoving(index);
if (auto* old = operands_[index]) {
old->RemoveUse(this, index);
}
for (size_t i = index + 1; i < operands_.size(); ++i) {
if (auto* value = operands_[i]) {
value->RemoveUse(this, i);
value->AddUse(this, i - 1);
}
}
operands_.erase(operands_.begin() + static_cast<std::ptrdiff_t>(index));
}
void User::OnOperandChanged(size_t, Value*) {}
void User::OnOperandRemoving(size_t) {}
Instruction::Instruction(Opcode op, std::shared_ptr<Type> ty, std::string name)
: User(std::move(ty), std::move(name)), opcode_(op) {}
@ -312,6 +335,22 @@ CallInst::CallInst(std::shared_ptr<Type> ret_ty, Value* callee,
Value* CallInst::GetCallee() const { return GetOperand(0); }
void CallInst::OnOperandChanged(size_t index, Value* value) {
if (index == 0) return;
size_t arg_index = index - 1;
if (arg_index < args_.size()) {
args_[arg_index] = value;
}
}
void CallInst::OnOperandRemoving(size_t index) {
if (index == 0) return;
size_t arg_index = index - 1;
if (arg_index < args_.size()) {
args_.erase(args_.begin() + static_cast<std::ptrdiff_t>(arg_index));
}
}
PhiInst::PhiInst(std::shared_ptr<Type> ty, std::string name)
: Instruction(Opcode::Phi, std::move(ty), std::move(name)) {}
@ -328,6 +367,17 @@ void PhiInst::AddIncoming(Value* value, BasicBlock* block) {
AddOperand(block);
}
void PhiInst::RemoveIncomingFrom(BasicBlock* block) {
for (size_t i = 0; i < incoming_blocks_.size();) {
if (incoming_blocks_[i] != block) {
++i;
continue;
}
RemoveOperand(2 * i + 1);
RemoveOperand(2 * i);
}
}
const std::vector<Value*>& PhiInst::GetIncomingValues() const {
return incoming_values_;
}
@ -336,6 +386,34 @@ const std::vector<BasicBlock*>& PhiInst::GetIncomingBlocks() const {
return incoming_blocks_;
}
void PhiInst::OnOperandChanged(size_t index, Value* value) {
size_t incoming_index = index / 2;
if (index % 2 == 0) {
if (incoming_index < incoming_values_.size()) {
incoming_values_[incoming_index] = value;
}
return;
}
if (incoming_index < incoming_blocks_.size()) {
incoming_blocks_[incoming_index] = static_cast<BasicBlock*>(value);
}
}
void PhiInst::OnOperandRemoving(size_t index) {
size_t incoming_index = index / 2;
if (index % 2 == 0) {
if (incoming_index < incoming_values_.size()) {
incoming_values_.erase(incoming_values_.begin() +
static_cast<std::ptrdiff_t>(incoming_index));
}
return;
}
if (incoming_index < incoming_blocks_.size()) {
incoming_blocks_.erase(incoming_blocks_.begin() +
static_cast<std::ptrdiff_t>(incoming_index));
}
}
GepInst::GepInst(std::shared_ptr<Type> result_ptr_ty, Value* base_ptr,
std::vector<Value*> indices, std::string name)
: Instruction(Opcode::Gep, std::move(result_ptr_ty), std::move(name)),
@ -360,4 +438,20 @@ GepInst::GepInst(std::shared_ptr<Type> result_ptr_ty, Value* base_ptr,
Value* GepInst::GetBasePtr() const { return GetOperand(0); }
void GepInst::OnOperandChanged(size_t index, Value* value) {
if (index == 0) return;
size_t idx = index - 1;
if (idx < indices_.size()) {
indices_[idx] = value;
}
}
void GepInst::OnOperandRemoving(size_t index) {
if (index == 0) return;
size_t idx = index - 1;
if (idx < indices_.size()) {
indices_.erase(indices_.begin() + static_cast<std::ptrdiff_t>(idx));
}
}
} // namespace ir

@ -1,4 +1,135 @@
// CFG 简化:
// - 删除不可达块、合并空块、简化分支等
// - 改善 IR 结构,便于后续优化与后端生成
#include "ir/IR.h"
#include <memory>
#include <queue>
#include <unordered_set>
namespace ir {
namespace {
Instruction* Terminator(BasicBlock& block) {
auto& insts = block.GetMutableInstructions();
if (insts.empty()) return nullptr;
auto* inst = insts.back().get();
return inst && inst->IsTerminator() ? inst : nullptr;
}
void AddCFGEdge(BasicBlock* from, BasicBlock* to) {
if (!from || !to) return;
from->AddSuccessor(to);
to->AddPredecessor(from);
}
void RebuildCFG(Function& func) {
for (const auto& block : func.GetBlocks()) {
if (!block) continue;
block->ClearPredecessors();
block->ClearSuccessors();
}
for (const auto& block : func.GetBlocks()) {
if (!block) continue;
auto* term = Terminator(*block);
if (auto* br = dynamic_cast<BranchInst*>(term)) {
AddCFGEdge(block.get(), br->GetDest());
} else if (auto* cbr = dynamic_cast<CondBrInst*>(term)) {
AddCFGEdge(block.get(), cbr->GetTrueDest());
AddCFGEdge(block.get(), cbr->GetFalseDest());
}
}
}
bool SimplifyBranches(Function& func) {
bool changed = false;
for (const auto& block : func.GetBlocks()) {
if (!block) continue;
auto* cbr = dynamic_cast<CondBrInst*>(Terminator(*block));
if (!cbr) continue;
BasicBlock* dest = nullptr;
if (cbr->GetTrueDest() == cbr->GetFalseDest()) {
dest = cbr->GetTrueDest();
} else if (auto* cond = dynamic_cast<ConstantInt*>(cbr->GetCond())) {
dest = cond->GetValue() != 0 ? cbr->GetTrueDest() : cbr->GetFalseDest();
}
if (dest) {
block->ReplaceTerminator(std::make_unique<BranchInst>(dest));
changed = true;
}
}
return changed;
}
std::unordered_set<BasicBlock*> ReachableBlocks(Function& func) {
std::unordered_set<BasicBlock*> reachable;
std::queue<BasicBlock*> work;
if (auto* entry = func.GetEntry()) {
reachable.insert(entry);
work.push(entry);
}
while (!work.empty()) {
auto* block = work.front();
work.pop();
for (auto* succ : block->GetSuccessors()) {
if (succ && reachable.insert(succ).second) {
work.push(succ);
}
}
}
return reachable;
}
bool RemoveUnreachable(Function& func) {
auto reachable = ReachableBlocks(func);
bool changed = false;
for (const auto& block : func.GetBlocks()) {
if (!block || reachable.count(block.get()) == 0) continue;
for (const auto& inst : block->GetInstructions()) {
auto* phi = dynamic_cast<PhiInst*>(inst.get());
if (!phi) continue;
for (const auto& other : func.GetBlocks()) {
if (other && reachable.count(other.get()) == 0) {
phi->RemoveIncomingFrom(other.get());
}
}
}
}
auto& blocks = func.GetMutableBlocks();
for (auto it = blocks.begin(); it != blocks.end();) {
auto* block = it->get();
if (!block || reachable.count(block) != 0) {
++it;
continue;
}
auto& insts = block->GetMutableInstructions();
while (!insts.empty()) {
block->EraseInstruction(insts.back().get());
}
it = blocks.erase(it);
changed = true;
}
return changed;
}
} // namespace
bool RunCFGSimplify(Module& module) {
bool changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
RebuildCFG(*func);
bool local_changed = SimplifyBranches(*func);
if (local_changed) {
RebuildCFG(*func);
}
local_changed |= RemoveUnreachable(*func);
if (local_changed) {
RebuildCFG(*func);
}
changed |= local_changed;
}
return changed;
}
} // namespace ir

@ -1,4 +1,94 @@
// 公共子表达式消除CSE
// - 识别并复用重复计算的等价表达式
// - 典型放置在 ConstFold 之后、DCE 之前
// - 当前为 Lab4 的框架占位,具体算法由实验实现
#include "ir/IR.h"
#include <algorithm>
#include <cstdint>
#include <sstream>
#include <string>
#include <unordered_map>
namespace ir {
namespace {
bool IsCommutative(Opcode op) {
return op == Opcode::Add || op == Opcode::Mul || op == Opcode::FAdd ||
op == Opcode::FMul;
}
std::string ValueId(Value* value) {
std::ostringstream oss;
oss << reinterpret_cast<std::uintptr_t>(value);
return oss.str();
}
std::string KeyFor(const Instruction& inst) {
std::ostringstream oss;
oss << static_cast<int>(inst.GetOpcode()) << ":";
if (auto* icmp = dynamic_cast<const ICmpInst*>(&inst)) {
oss << static_cast<int>(icmp->GetPredicate()) << ":";
} else if (auto* fcmp = dynamic_cast<const FCmpInst*>(&inst)) {
oss << static_cast<int>(fcmp->GetPredicate()) << ":";
}
std::vector<std::string> operands;
for (size_t i = 0; i < inst.GetNumOperands(); ++i) {
operands.push_back(ValueId(inst.GetOperand(i)));
}
if (IsCommutative(inst.GetOpcode()) && operands.size() == 2) {
std::sort(operands.begin(), operands.end());
}
for (const auto& operand : operands) {
oss << operand << ",";
}
return oss.str();
}
bool IsCSECandidate(const Instruction& inst) {
switch (inst.GetOpcode()) {
case Opcode::Add:
case Opcode::Sub:
case Opcode::Mul:
case Opcode::SDiv:
case Opcode::SRem:
case Opcode::FAdd:
case Opcode::FSub:
case Opcode::FMul:
case Opcode::FDiv:
case Opcode::ICmp:
case Opcode::FCmp:
case Opcode::SIToFP:
case Opcode::FPToSI:
case Opcode::ZExt:
case Opcode::Gep:
return true;
default:
return false;
}
}
} // namespace
bool RunCSE(Module& module) {
bool changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
std::unordered_map<std::string, Instruction*> available;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst || !IsCSECandidate(*inst)) continue;
std::string key = KeyFor(*inst);
auto it = available.find(key);
if (it != available.end()) {
inst->ReplaceAllUsesWith(it->second);
changed = true;
} else {
available.emplace(std::move(key), inst);
}
}
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,239 @@
// IR 常量折叠:
// - 折叠可判定的常量表达式
// - 简化常量控制流分支(按实现范围裁剪)
#include "ir/IR.h"
#include <cmath>
namespace ir {
namespace {
ConstantInt* AsConstInt(Value* value) {
return dynamic_cast<ConstantInt*>(value);
}
ConstantFloat* AsConstFloat(Value* value) {
return dynamic_cast<ConstantFloat*>(value);
}
bool IsZero(Value* value) {
if (auto* i = AsConstInt(value)) return i->GetValue() == 0;
if (auto* f = AsConstFloat(value)) return f->GetValue() == 0.0f;
return false;
}
bool IsOne(Value* value) {
if (auto* i = AsConstInt(value)) return i->GetValue() == 1;
if (auto* f = AsConstFloat(value)) return f->GetValue() == 1.0f;
return false;
}
ConstantValue* FoldBinary(BinaryInst& inst, Context& ctx) {
auto* li = AsConstInt(inst.GetLhs());
auto* ri = AsConstInt(inst.GetRhs());
if (li && ri) {
int lhs = li->GetValue();
int rhs = ri->GetValue();
switch (inst.GetOpcode()) {
case Opcode::Add:
return ctx.GetConstInt(lhs + rhs);
case Opcode::Sub:
return ctx.GetConstInt(lhs - rhs);
case Opcode::Mul:
return ctx.GetConstInt(lhs * rhs);
case Opcode::SDiv:
if (rhs != 0) return ctx.GetConstInt(lhs / rhs);
break;
case Opcode::SRem:
if (rhs != 0) return ctx.GetConstInt(lhs % rhs);
break;
default:
break;
}
}
auto* lf = AsConstFloat(inst.GetLhs());
auto* rf = AsConstFloat(inst.GetRhs());
if (lf && rf) {
float lhs = lf->GetValue();
float rhs = rf->GetValue();
switch (inst.GetOpcode()) {
case Opcode::FAdd:
return ctx.GetConstFloat(lhs + rhs);
case Opcode::FSub:
return ctx.GetConstFloat(lhs - rhs);
case Opcode::FMul:
return ctx.GetConstFloat(lhs * rhs);
case Opcode::FDiv:
return ctx.GetConstFloat(lhs / rhs);
default:
break;
}
}
return nullptr;
}
Value* SimplifyBinary(BinaryInst& inst) {
auto op = inst.GetOpcode();
auto* lhs = inst.GetLhs();
auto* rhs = inst.GetRhs();
switch (op) {
case Opcode::Add:
case Opcode::FAdd:
if (IsZero(rhs)) return lhs;
if (IsZero(lhs)) return rhs;
break;
case Opcode::Sub:
case Opcode::FSub:
if (IsZero(rhs)) return lhs;
break;
case Opcode::Mul:
if (IsOne(rhs)) return lhs;
if (IsOne(lhs)) return rhs;
if (IsZero(rhs)) return rhs;
if (IsZero(lhs)) return lhs;
break;
case Opcode::FMul:
if (IsOne(rhs)) return lhs;
if (IsOne(lhs)) return rhs;
break;
case Opcode::SDiv:
case Opcode::FDiv:
if (IsOne(rhs)) return lhs;
break;
case Opcode::SRem:
if (IsOne(rhs)) return rhs;
break;
default:
break;
}
return nullptr;
}
ConstantInt* FoldICmp(ICmpInst& inst, Context& ctx) {
auto* lhs = AsConstInt(inst.GetLhs());
auto* rhs = AsConstInt(inst.GetRhs());
if (!lhs || !rhs) return nullptr;
int l = lhs->GetValue();
int r = rhs->GetValue();
bool result = false;
switch (inst.GetPredicate()) {
case ICmpPredicate::Eq:
result = l == r;
break;
case ICmpPredicate::Ne:
result = l != r;
break;
case ICmpPredicate::Slt:
result = l < r;
break;
case ICmpPredicate::Sle:
result = l <= r;
break;
case ICmpPredicate::Sgt:
result = l > r;
break;
case ICmpPredicate::Sge:
result = l >= r;
break;
}
return ctx.GetConstBool(result);
}
ConstantInt* FoldFCmp(FCmpInst& inst, Context& ctx) {
auto* lhs = AsConstFloat(inst.GetLhs());
auto* rhs = AsConstFloat(inst.GetRhs());
if (!lhs || !rhs) return nullptr;
float l = lhs->GetValue();
float r = rhs->GetValue();
bool ordered = !std::isnan(l) && !std::isnan(r);
bool result = false;
switch (inst.GetPredicate()) {
case FCmpPredicate::Oeq:
result = ordered && l == r;
break;
case FCmpPredicate::One:
result = ordered && l != r;
break;
case FCmpPredicate::Olt:
result = ordered && l < r;
break;
case FCmpPredicate::Ole:
result = ordered && l <= r;
break;
case FCmpPredicate::Ogt:
result = ordered && l > r;
break;
case FCmpPredicate::Oge:
result = ordered && l >= r;
break;
}
return ctx.GetConstBool(result);
}
ConstantValue* FoldCast(CastInst& inst, Context& ctx) {
switch (inst.GetOpcode()) {
case Opcode::SIToFP:
if (auto* c = AsConstInt(inst.GetValue())) {
return ctx.GetConstFloat(static_cast<float>(c->GetValue()));
}
break;
case Opcode::FPToSI:
if (auto* c = AsConstFloat(inst.GetValue())) {
return ctx.GetConstInt(static_cast<int>(c->GetValue()));
}
break;
case Opcode::ZExt:
if (auto* c = AsConstInt(inst.GetValue())) {
return ctx.GetConstInt(c->GetValue() != 0 ? 1 : 0);
}
break;
default:
break;
}
return nullptr;
}
Value* SimplifyPhi(PhiInst& phi) {
const auto& values = phi.GetIncomingValues();
if (values.empty()) return nullptr;
auto* first = values.front();
for (auto* value : values) {
if (value != first) return nullptr;
}
return first;
}
} // namespace
bool RunConstFold(Module& module) {
bool changed = false;
auto& ctx = module.GetContext();
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
Value* replacement = nullptr;
if (auto* bin = dynamic_cast<BinaryInst*>(inst)) {
replacement = FoldBinary(*bin, ctx);
if (!replacement) replacement = SimplifyBinary(*bin);
} else if (auto* icmp = dynamic_cast<ICmpInst*>(inst)) {
replacement = FoldICmp(*icmp, ctx);
} else if (auto* fcmp = dynamic_cast<FCmpInst*>(inst)) {
replacement = FoldFCmp(*fcmp, ctx);
} else if (auto* cast = dynamic_cast<CastInst*>(inst)) {
replacement = FoldCast(*cast, ctx);
} else if (auto* phi = dynamic_cast<PhiInst*>(inst)) {
replacement = SimplifyPhi(*phi);
}
if (replacement && replacement != inst) {
inst->ReplaceAllUsesWith(replacement);
changed = true;
}
}
}
}
return changed;
}
} // namespace ir

@ -1,5 +1,77 @@
// 常量传播Constant Propagation
// - 沿 use-def 关系传播已知常量
// - 将可替换的 SSA 值改写为常量,暴露更多折叠机会
// - 常与 ConstFold、DCE、CFGSimplify 迭代配合使用
#include "ir/IR.h"
#include <unordered_map>
namespace ir {
namespace {
ConstantValue* AsConstant(Value* value) {
return dynamic_cast<ConstantValue*>(value);
}
ConstantValue* ScalarConstInitializer(Value* ptr) {
auto* global = dynamic_cast<GlobalVariable*>(ptr);
if (!global || !global->IsConst()) return nullptr;
auto* init = global->GetInitializer();
if (!init) return nullptr;
if (dynamic_cast<ConstantArray*>(init)) return nullptr;
return init;
}
bool IsScalarStackSlot(Value* ptr) {
auto* alloca = dynamic_cast<AllocaInst*>(ptr);
if (!alloca) return false;
const auto& ty = alloca->GetAllocatedType();
return ty && (ty->IsInt1() || ty->IsInt32() || ty->IsFloat());
}
} // namespace
bool RunConstProp(Module& module) {
bool changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
std::unordered_map<Value*, ConstantValue*> known_memory;
for (const auto& inst_ptr : block->GetInstructions()) {
auto* inst = inst_ptr.get();
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
auto* ptr = load->GetPtr();
ConstantValue* known = nullptr;
auto it = known_memory.find(ptr);
if (it != known_memory.end()) {
known = it->second;
} else {
known = ScalarConstInitializer(ptr);
}
if (known) {
load->ReplaceAllUsesWith(known);
changed = true;
}
continue;
}
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
auto* ptr = store->GetPtr();
if (IsScalarStackSlot(ptr)) {
if (auto* c = AsConstant(store->GetValue())) {
known_memory[ptr] = c;
} else {
known_memory.erase(ptr);
}
}
continue;
}
if (inst->GetOpcode() == Opcode::Call) {
known_memory.clear();
}
}
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,55 @@
// 死代码删除DCE
// - 删除无用指令与无用基本块
// - 通常与 CFG 简化配合使用
#include "ir/IR.h"
namespace ir {
namespace {
bool HasSideEffect(const Instruction& inst) {
switch (inst.GetOpcode()) {
case Opcode::Store:
case Opcode::Ret:
case Opcode::Br:
case Opcode::CondBr:
case Opcode::Call:
return true;
default:
return false;
}
}
bool IsDead(const Instruction& inst) {
if (inst.IsVoid()) return false;
if (HasSideEffect(inst)) return false;
return inst.GetUses().empty();
}
} // namespace
bool RunDCE(Module& module) {
bool changed = false;
bool local_changed = true;
while (local_changed) {
local_changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
auto& insts = block->GetMutableInstructions();
for (auto it = insts.begin(); it != insts.end();) {
auto* inst = it->get();
if (inst && IsDead(*inst)) {
block->EraseInstruction(inst);
local_changed = true;
changed = true;
it = insts.begin();
} else {
++it;
}
}
}
}
}
return changed;
}
} // namespace ir

@ -1,4 +1,109 @@
// Mem2RegSSA 构造):
// - 将局部变量的 alloca/load/store 提升为 SSA 形式
// - 插入 PHI 并重写使用,依赖支配树等分析
#include "ir/IR.h"
#include <vector>
namespace ir {
namespace {
bool IsPromotableType(const AllocaInst& alloca) {
const auto& ty = alloca.GetAllocatedType();
return ty && (ty->IsInt1() || ty->IsInt32() || ty->IsFloat());
}
bool IsDirectUseOf(Value* ptr, Instruction* inst) {
if (auto* load = dynamic_cast<LoadInst*>(inst)) {
return load->GetPtr() == ptr;
}
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
return store->GetPtr() == ptr;
}
return false;
}
BasicBlock* SingleUseBlock(AllocaInst& alloca) {
BasicBlock* use_block = nullptr;
for (const auto& use : alloca.GetUses()) {
auto* inst = dynamic_cast<Instruction*>(use.GetUser());
if (!inst || !IsDirectUseOf(&alloca, inst)) return nullptr;
auto* parent = inst->GetParent();
if (!parent) return nullptr;
if (!use_block) {
use_block = parent;
} else if (use_block != parent) {
return nullptr;
}
}
return use_block;
}
bool CanPromoteInBlock(AllocaInst& alloca, BasicBlock& block) {
bool has_value = false;
bool saw_use = false;
for (const auto& inst_ptr : block.GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst || !IsDirectUseOf(&alloca, inst)) continue;
saw_use = true;
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
if (store->GetValue() == &alloca) return false;
has_value = true;
} else if (dynamic_cast<LoadInst*>(inst)) {
if (!has_value) return false;
}
}
return saw_use;
}
bool PromoteInBlock(AllocaInst& alloca, BasicBlock& block) {
Value* current = nullptr;
std::vector<Instruction*> erase;
for (const auto& inst_ptr : block.GetInstructions()) {
auto* inst = inst_ptr.get();
if (!inst || !IsDirectUseOf(&alloca, inst)) continue;
if (auto* store = dynamic_cast<StoreInst*>(inst)) {
current = store->GetValue();
erase.push_back(store);
} else if (auto* load = dynamic_cast<LoadInst*>(inst)) {
load->ReplaceAllUsesWith(current);
erase.push_back(load);
}
}
for (auto* inst : erase) {
block.EraseInstruction(inst);
}
if (alloca.GetUses().empty()) {
if (auto* parent = alloca.GetParent()) {
parent->EraseInstruction(&alloca);
}
}
return !erase.empty();
}
} // namespace
bool RunMem2Reg(Module& module) {
bool changed = false;
for (const auto& func : module.GetFunctions()) {
if (!func || func->IsDeclaration()) continue;
std::vector<AllocaInst*> allocas;
for (const auto& block : func->GetBlocks()) {
if (!block) continue;
for (const auto& inst : block->GetInstructions()) {
if (auto* alloca = dynamic_cast<AllocaInst*>(inst.get())) {
if (IsPromotableType(*alloca)) {
allocas.push_back(alloca);
}
}
}
}
for (auto* alloca : allocas) {
if (!alloca || alloca->GetUses().empty()) continue;
auto* block = SingleUseBlock(*alloca);
if (!block || !CanPromoteInBlock(*alloca, *block)) continue;
changed |= PromoteInBlock(*alloca, *block);
}
}
return changed;
}
} // namespace ir

@ -1 +1,23 @@
// IR Pass 管理骨架。
#include "ir/IR.h"
namespace ir {
bool RunScalarOptimizationPipeline(Module& module) {
bool changed = false;
changed |= RunMem2Reg(module);
for (int i = 0; i < 8; ++i) {
bool iter_changed = false;
iter_changed |= RunConstFold(module);
iter_changed |= RunConstProp(module);
iter_changed |= RunCSE(module);
iter_changed |= RunDCE(module);
iter_changed |= RunCFGSimplify(module);
changed |= iter_changed;
if (!iter_changed) break;
}
return changed;
}
} // namespace ir

@ -36,6 +36,9 @@ int main(int argc, char** argv) {
auto sema = RunSema(*comp_unit);
auto module = GenerateIR(*comp_unit, sema);
if (opts.optimize_ir) {
ir::RunScalarOptimizationPipeline(*module);
}
if (opts.emit_ir) {
ir::IRPrinter printer;
if (need_blank_line) {

@ -1,5 +1,6 @@
#include "mir/MIR.h"
#include <chrono>
#include <cstdio>
#include <cstdlib>
#include <filesystem>
@ -18,6 +19,18 @@ namespace {
std::string ShellQuote(const std::filesystem::path& path) {
std::string raw = path.string();
#if defined(_WIN32)
std::string quoted = "\"";
for (char ch : raw) {
if (ch == '"') {
quoted += "\\\"";
} else {
quoted += ch;
}
}
quoted += "\"";
return quoted;
#else
std::string quoted = "'";
for (char ch : raw) {
if (ch == '\'') {
@ -28,6 +41,7 @@ std::string ShellQuote(const std::filesystem::path& path) {
}
quoted += "'";
return quoted;
#endif
}
std::string ReadTextFile(const std::filesystem::path& path) {
@ -41,20 +55,35 @@ std::string ReadTextFile(const std::filesystem::path& path) {
return oss.str();
}
} // namespace
void PrintAArch64AsmFromIR(const ir::Module& module, std::ostream& os) {
auto tmp_dir = std::filesystem::temp_directory_path();
std::string pattern = (tmp_dir / "nudt_lab3_XXXXXX").string();
std::filesystem::path CreateTempDir() {
auto base = std::filesystem::temp_directory_path();
#if defined(_WIN32)
auto seed = std::chrono::high_resolution_clock::now().time_since_epoch().count();
for (int i = 0; i < 100; ++i) {
auto candidate =
base / ("nudt_lab3_" + std::to_string(seed) + "_" + std::to_string(i));
std::error_code ec;
if (std::filesystem::create_directory(candidate, ec)) {
return candidate;
}
}
throw std::runtime_error(FormatError("mir", "创建临时目录失败"));
#else
std::string pattern = (base / "nudt_lab3_XXXXXX").string();
std::vector<char> dir_template(pattern.begin(), pattern.end());
dir_template.push_back('\0');
char* created = mkdtemp(dir_template.data());
if (!created) {
throw std::runtime_error(FormatError("mir", "创建临时目录失败"));
}
return std::filesystem::path(created);
#endif
}
} // namespace
std::filesystem::path work_dir(created);
void PrintAArch64AsmFromIR(const ir::Module& module, std::ostream& os) {
std::filesystem::path work_dir = CreateTempDir();
const auto ir_file = work_dir / "module.ll";
const auto asm_file = work_dir / "module.s";
const auto err_file = work_dir / "clang.err";

@ -15,7 +15,7 @@ CLIOptions ParseCLI(int argc, char** argv) {
if (argc <= 1) {
throw std::runtime_error(FormatError(
"cli",
"用法: compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] <input.sy>"));
"用法: compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] [--no-opt] <input.sy>"));
}
for (int i = 1; i < argc; ++i) {
@ -58,6 +58,11 @@ CLIOptions ParseCLI(int argc, char** argv) {
continue;
}
if (std::strcmp(arg, "--no-opt") == 0) {
opt.optimize_ir = false;
continue;
}
if (arg[0] == '-') {
throw std::runtime_error(
FormatError("cli", std::string("未知参数: ") + arg +

@ -50,13 +50,14 @@ void PrintHelp(std::ostream& os) {
os << "SysY Compiler\n"
<< "\n"
<< "用法:\n"
<< " compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] <input.sy>\n"
<< " compiler [--help] [--emit-parse-tree] [--emit-ir] [--emit-asm] [--no-opt] <input.sy>\n"
<< "\n"
<< "选项:\n"
<< " -h, --help 打印帮助信息并退出\n"
<< " --emit-parse-tree 仅在显式模式下启用语法树输出\n"
<< " --emit-ir 仅在显式模式下启用 IR 输出\n"
<< " --emit-asm 仅在显式模式下启用 AArch64 汇编输出\n"
<< " --no-opt 关闭 Lab4 IR 标量优化管线\n"
<< "\n"
<< "说明:\n"
<< " - 默认输出 IR\n"

@ -6,7 +6,17 @@ int getint() {
return v;
}
int getch() { return getchar(); }
int getch() {
int c = getchar();
if (c == '\r') {
int next = getchar();
if (next != '\n' && next != EOF) {
ungetc(next, stdin);
}
return '\n';
}
return c;
}
int getarray(int a[]) {
int n = 0;

@ -11,6 +11,8 @@
#include "atn/ProfilingATNSimulator.h"
#include <chrono>
using namespace antlr4;
using namespace antlr4::atn;
using namespace antlr4::dfa;

Loading…
Cancel
Save