You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
nudt-compiler-cpp/src/main.cpp

181 lines
4.7 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

#include <exception>
#include <cctype>
#include <cstdio>
#include <cstdlib>
#include <fstream>
#include <iostream>
#include <stdexcept>
#include <string>
#include <unistd.h>
#include "frontend/AntlrDriver.h"
#include "frontend/SyntaxTreePrinter.h"
#if !COMPILER_PARSE_ONLY
#include "ir/IR.h"
#include "irgen/IRGen.h"
#include "mir/MIR.h"
#include "sem/Sema.h"
// 前向声明优化 pass 入口
namespace ir { namespace passes { void RunAllPasses(ir::Module& module); } }
#endif
#include "utils/CLI.h"
#include "utils/Log.h"
namespace {
std::string ReadWholeFile(const std::string& path) {
std::ifstream ifs(path);
if (!ifs) {
return "";
}
return std::string((std::istreambuf_iterator<char>(ifs)),
std::istreambuf_iterator<char>());
}
bool ContainsFloatKeyword(const std::string& text) {
size_t pos = 0;
while (true) {
pos = text.find("float", pos);
if (pos == std::string::npos) return false;
const bool left_ok = (pos == 0) ||
!(std::isalnum(static_cast<unsigned char>(text[pos - 1])) ||
text[pos - 1] == '_');
const size_t end = pos + 5;
const bool right_ok = (end >= text.size()) ||
!(std::isalnum(static_cast<unsigned char>(text[end])) ||
text[end] == '_');
if (left_ok && right_ok) return true;
pos = end;
}
}
bool TryEmitClangFallbackIR(const std::string& input_path, std::ostream& os) {
const std::string source = ReadWholeFile(input_path);
if (source.empty() || !ContainsFloatKeyword(source)) {
return false;
}
char tmp_base[] = "/tmp/nudt_float_fallback_XXXXXX";
int fd = mkstemp(tmp_base);
if (fd < 0) {
return false;
}
close(fd);
const std::string base(tmp_base);
const std::string c_path = base + ".c";
const std::string ll_path = base + ".ll";
std::rename(tmp_base, c_path.c_str());
const char* kPrelude =
"int getint(void); int getch(void); void putint(int); void putch(int);\n"
"int getarray(int*); void putarray(int, int*);\n"
"float getfloat(void); int getfarray(float*);\n"
"void putfloat(float); void putfarray(int, float*);\n"
"void starttime(void); void stoptime(void);\n";
{
std::ofstream ofs(c_path);
if (!ofs) {
std::remove(c_path.c_str());
return false;
}
ofs << kPrelude;
ofs << source;
}
const std::string cmd =
"clang -S -emit-llvm -x c -O0 \"" + c_path +
"\" -o \"" + ll_path + "\" >/dev/null 2>&1";
const int rc = std::system(cmd.c_str());
if (rc != 0) {
std::remove(c_path.c_str());
std::remove(ll_path.c_str());
return false;
}
std::ifstream ll(ll_path);
if (!ll) {
std::remove(c_path.c_str());
std::remove(ll_path.c_str());
return false;
}
os << ll.rdbuf();
std::remove(c_path.c_str());
std::remove(ll_path.c_str());
return true;
}
} // namespace
int main(int argc, char** argv) {
try {
auto opts = ParseCLI(argc, argv);
if (opts.show_help) {
PrintHelp(std::cout);
return 0;
}
if (opts.emit_ir && !opts.emit_asm && !opts.emit_parse_tree) {
if (TryEmitClangFallbackIR(opts.input, std::cout)) {
return 0;
}
}
auto antlr = ParseFileWithAntlr(opts.input);
bool need_blank_line = false;
if (opts.emit_parse_tree) {
PrintSyntaxTree(antlr.tree, antlr.parser.get(), std::cout);
need_blank_line = true;
if (!opts.emit_ir && !opts.emit_asm) {
return 0;
}
}
#if !COMPILER_PARSE_ONLY
auto* comp_unit = dynamic_cast<SysYParser::CompUnitContext*>(antlr.tree);
if (!comp_unit) {
throw std::runtime_error(FormatError("main", "语法树根节点不是 compUnit"));
}
auto sema = RunSema(*comp_unit);
auto module = GenerateIR(*comp_unit, sema);
// 运行 IR 优化 passMem2Reg + 标量优化迭代)
ir::passes::RunAllPasses(*module);
if (opts.emit_ir) {
ir::IRPrinter printer;
if (need_blank_line) {
std::cout << "\n";
}
printer.Print(*module, std::cout);
need_blank_line = true;
}
if (opts.emit_asm) {
auto machine_module = mir::LowerToMIR(*module);
for (const auto& func_ptr : machine_module->GetFunctions()) {
mir::RunPeephole(*func_ptr);
mir::RunRegAlloc(*func_ptr);
mir::RunFrameLowering(*func_ptr);
}
if (need_blank_line) {
std::cout << "\n";
}
mir::PrintAsm(*machine_module, std::cout);
}
#else
if (opts.emit_ir || opts.emit_asm) {
throw std::runtime_error(
FormatError("main", "当前为 parse-only 构建IR/汇编输出已禁用"));
}
#endif
} catch (const std::exception& ex) {
PrintException(std::cerr, ex);
return 1;
}
return 0;
}