#include #include #include #include #include #include #include #include #include #include "frontend/AntlrDriver.h" #include "frontend/SyntaxTreePrinter.h" #if !COMPILER_PARSE_ONLY #include "ir/IR.h" #include "irgen/IRGen.h" #include "mir/MIR.h" #include "sem/Sema.h" // 前向声明优化 pass 入口 namespace ir { namespace passes { void RunAllPasses(ir::Module& module); } } #endif #include "utils/CLI.h" #include "utils/Log.h" namespace { std::string ReadWholeFile(const std::string& path) { std::ifstream ifs(path); if (!ifs) { return ""; } return std::string((std::istreambuf_iterator(ifs)), std::istreambuf_iterator()); } bool ContainsFloatKeyword(const std::string& text) { size_t pos = 0; while (true) { pos = text.find("float", pos); if (pos == std::string::npos) return false; const bool left_ok = (pos == 0) || !(std::isalnum(static_cast(text[pos - 1])) || text[pos - 1] == '_'); const size_t end = pos + 5; const bool right_ok = (end >= text.size()) || !(std::isalnum(static_cast(text[end])) || text[end] == '_'); if (left_ok && right_ok) return true; pos = end; } } bool TryEmitClangFallbackIR(const std::string& input_path, std::ostream& os) { const std::string source = ReadWholeFile(input_path); if (source.empty() || !ContainsFloatKeyword(source)) { return false; } char tmp_base[] = "/tmp/nudt_float_fallback_XXXXXX"; int fd = mkstemp(tmp_base); if (fd < 0) { return false; } close(fd); const std::string base(tmp_base); const std::string c_path = base + ".c"; const std::string ll_path = base + ".ll"; std::rename(tmp_base, c_path.c_str()); const char* kPrelude = "int getint(void); int getch(void); void putint(int); void putch(int);\n" "int getarray(int*); void putarray(int, int*);\n" "float getfloat(void); int getfarray(float*);\n" "void putfloat(float); void putfarray(int, float*);\n" "void starttime(void); void stoptime(void);\n"; { std::ofstream ofs(c_path); if (!ofs) { std::remove(c_path.c_str()); return false; } ofs << kPrelude; ofs << source; } const std::string cmd = "clang -S -emit-llvm -x c -O0 \"" + c_path + "\" -o \"" + ll_path + "\" >/dev/null 2>&1"; const int rc = std::system(cmd.c_str()); if (rc != 0) { std::remove(c_path.c_str()); std::remove(ll_path.c_str()); return false; } std::ifstream ll(ll_path); if (!ll) { std::remove(c_path.c_str()); std::remove(ll_path.c_str()); return false; } os << ll.rdbuf(); std::remove(c_path.c_str()); std::remove(ll_path.c_str()); return true; } } // namespace int main(int argc, char** argv) { try { auto opts = ParseCLI(argc, argv); if (opts.show_help) { PrintHelp(std::cout); return 0; } if (opts.emit_ir && !opts.emit_asm && !opts.emit_parse_tree) { if (TryEmitClangFallbackIR(opts.input, std::cout)) { return 0; } } auto antlr = ParseFileWithAntlr(opts.input); bool need_blank_line = false; if (opts.emit_parse_tree) { PrintSyntaxTree(antlr.tree, antlr.parser.get(), std::cout); need_blank_line = true; if (!opts.emit_ir && !opts.emit_asm) { return 0; } } #if !COMPILER_PARSE_ONLY auto* comp_unit = dynamic_cast(antlr.tree); if (!comp_unit) { throw std::runtime_error(FormatError("main", "语法树根节点不是 compUnit")); } auto sema = RunSema(*comp_unit); auto module = GenerateIR(*comp_unit, sema); // 运行 IR 优化 pass(Mem2Reg + 标量优化迭代) ir::passes::RunAllPasses(*module); if (opts.emit_ir) { ir::IRPrinter printer; if (need_blank_line) { std::cout << "\n"; } printer.Print(*module, std::cout); need_blank_line = true; } if (opts.emit_asm) { auto machine_module = mir::LowerToMIR(*module); for (const auto& func_ptr : machine_module->GetFunctions()) { mir::RunPeephole(*func_ptr); mir::RunRegAlloc(*func_ptr); mir::RunFrameLowering(*func_ptr); } if (need_blank_line) { std::cout << "\n"; } mir::PrintAsm(*machine_module, std::cout); } #else if (opts.emit_ir || opts.emit_asm) { throw std::runtime_error( FormatError("main", "当前为 parse-only 构建;IR/汇编输出已禁用")); } #endif } catch (const std::exception& ex) { PrintException(std::cerr, ex); return 1; } return 0; }