From d08b23276ac5328d3a2bb8f458de94e7d753603c Mon Sep 17 00:00:00 2001 From: jing <3030349106@qq.com> Date: Sat, 28 Feb 2026 21:03:40 +0800 Subject: [PATCH] =?UTF-8?q?fix(ast):=20=E8=A7=84=E8=8C=83AST=E8=BE=93?= =?UTF-8?q?=E5=87=BA?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/ast/AstNodes.h | 3 + src/ast/AstPrinter.cpp | 182 ++++++++++++++++++++++++++++++++++------- src/main.cpp | 11 ++- src/utils/CLI.cpp | 16 +++- src/utils/CLI.h | 3 +- src/utils/Log.cpp | 3 +- 6 files changed, 182 insertions(+), 36 deletions(-) diff --git a/src/ast/AstNodes.h b/src/ast/AstNodes.h index 5433cd1..55efccc 100644 --- a/src/ast/AstNodes.h +++ b/src/ast/AstNodes.h @@ -1,6 +1,7 @@ #pragma once +#include #include #include #include @@ -66,5 +67,7 @@ struct CompUnit { // 调试打印 void PrintAST(const CompUnit& cu); +// 导出 AST 为 Graphviz DOT。 +void PrintASTDot(const CompUnit& cu, std::ostream& os); } // namespace ast diff --git a/src/ast/AstPrinter.cpp b/src/ast/AstPrinter.cpp index d3e71b1..5aec665 100644 --- a/src/ast/AstPrinter.cpp +++ b/src/ast/AstPrinter.cpp @@ -3,70 +3,194 @@ #include "ast/AstNodes.h" #include +#include +#include namespace ast { -static void PrintExpr(const Expr* expr); - static void PrintIndent(int depth) { for (int i = 0; i < depth; ++i) std::cout << " "; } -static void PrintExpr(const Expr* expr) { +static void PrintLine(int depth, const std::string& text) { + PrintIndent(depth); + std::cout << text << "\n"; +} + +static void DumpExpr(const Expr* expr, int depth) { + if (!expr) { + PrintLine(depth, "NullExpr"); + return; + } + if (auto num = dynamic_cast(expr)) { - std::cout << num->value; - } else if (auto var = dynamic_cast(expr)) { - std::cout << var->name; - } else if (auto bin = dynamic_cast(expr)) { - std::cout << "("; - PrintExpr(bin->lhs.get()); + PrintLine(depth, "NumberExpr value=" + std::to_string(num->value)); + return; + } + if (auto var = dynamic_cast(expr)) { + PrintLine(depth, "VarExpr name=" + var->name); + return; + } + if (auto bin = dynamic_cast(expr)) { const char* op = "?"; switch (bin->op) { case BinaryOp::Add: - op = "+"; + op = "Add"; break; case BinaryOp::Sub: - op = "-"; + op = "Sub"; break; case BinaryOp::Mul: - op = "*"; + op = "Mul"; break; case BinaryOp::Div: - op = "/"; + op = "Div"; break; } - std::cout << " " << op << " "; - PrintExpr(bin->rhs.get()); - std::cout << ")"; + PrintLine(depth, std::string("BinaryExpr op=") + op); + DumpExpr(bin->lhs.get(), depth + 1); + DumpExpr(bin->rhs.get(), depth + 1); + return; } + + PrintLine(depth, "UnknownExpr"); } void PrintAST(const CompUnit& cu) { - if (!cu.func) return; - std::cout << "func " << cu.func->name << " () {\n"; + PrintLine(0, "CompUnit"); + if (!cu.func) { + PrintLine(1, ""); + return; + } + + PrintLine(1, "FuncDef name=" + cu.func->name); const auto& body = cu.func->body; if (!body) { - std::cout << "}\n"; + PrintLine(2, "Block "); return; } + + PrintLine(2, "Block"); for (const auto& decl : body->varDecls) { - PrintIndent(1); - std::cout << "var " << decl->name; + PrintLine(3, "VarDecl name=" + decl->name); if (decl->init) { - std::cout << " = "; - PrintExpr(decl->init.get()); + PrintLine(4, "Init"); + DumpExpr(decl->init.get(), 5); } - std::cout << ";\n"; } for (const auto& stmt : body->stmts) { if (auto ret = dynamic_cast(stmt.get())) { - PrintIndent(1); - std::cout << "return "; - PrintExpr(ret->value.get()); - std::cout << ";\n"; + PrintLine(3, "ReturnStmt"); + DumpExpr(ret->value.get(), 4); + } else { + PrintLine(3, "UnknownStmt"); + } + } +} + +namespace { + +std::string EscapeDotLabel(const std::string& s) { + std::string out; + out.reserve(s.size()); + for (char c : s) { + if (c == '\\' || c == '"') out.push_back('\\'); + out.push_back(c); + } + return out; +} + +class DotWriter { + public: + explicit DotWriter(std::ostream& os) : os_(os) {} + + int AddNode(const std::string& label) { + int id = next_id_++; + os_ << " n" << id << " [label=\"" << EscapeDotLabel(label) << "\"];\n"; + return id; + } + + void AddEdge(int from, int to) { os_ << " n" << from << " -> n" << to << ";\n"; } + + private: + std::ostream& os_; + int next_id_ = 0; +}; + +int EmitExpr(const Expr* expr, DotWriter& dot) { + if (!expr) return dot.AddNode("NullExpr"); + + if (auto num = dynamic_cast(expr)) { + return dot.AddNode("Number(" + std::to_string(num->value) + ")"); + } + if (auto var = dynamic_cast(expr)) { + return dot.AddNode("Var(" + var->name + ")"); + } + if (auto bin = dynamic_cast(expr)) { + const char* op = "?"; + switch (bin->op) { + case BinaryOp::Add: + op = "Add"; + break; + case BinaryOp::Sub: + op = "Sub"; + break; + case BinaryOp::Mul: + op = "Mul"; + break; + case BinaryOp::Div: + op = "Div"; + break; + } + int id = dot.AddNode(std::string("Binary(") + op + ")"); + int lhs = EmitExpr(bin->lhs.get(), dot); + int rhs = EmitExpr(bin->rhs.get(), dot); + dot.AddEdge(id, lhs); + dot.AddEdge(id, rhs); + return id; + } + return dot.AddNode("UnknownExpr"); +} + +} // namespace + +void PrintASTDot(const CompUnit& cu, std::ostream& os) { + os << "digraph AST {\n"; + os << " rankdir=TB;\n"; + os << " node [shape=box, fontname=\"Courier\"];\n"; + + DotWriter dot(os); + int cu_id = dot.AddNode("CompUnit"); + if (cu.func) { + int func_id = dot.AddNode("FuncDef(" + cu.func->name + ")"); + dot.AddEdge(cu_id, func_id); + if (cu.func->body) { + int block_id = dot.AddNode("Block"); + dot.AddEdge(func_id, block_id); + + for (const auto& decl : cu.func->body->varDecls) { + int decl_id = dot.AddNode("VarDecl(" + decl->name + ")"); + dot.AddEdge(block_id, decl_id); + if (decl->init) { + int init_id = EmitExpr(decl->init.get(), dot); + dot.AddEdge(decl_id, init_id); + } + } + + for (const auto& stmt : cu.func->body->stmts) { + if (auto ret = dynamic_cast(stmt.get())) { + int ret_id = dot.AddNode("Return"); + dot.AddEdge(block_id, ret_id); + int value_id = EmitExpr(ret->value.get(), dot); + dot.AddEdge(ret_id, value_id); + } else { + int stmt_id = dot.AddNode("Stmt"); + dot.AddEdge(block_id, stmt_id); + } + } } } - std::cout << "}\n"; + os << "}\n"; } } // namespace ast diff --git a/src/main.cpp b/src/main.cpp index 78d0d72..816409d 100644 --- a/src/main.cpp +++ b/src/main.cpp @@ -1,5 +1,7 @@ #include +#include #include +#include #include "frontend/AntlrDriver.h" #include "frontend/AstBuilder.h" @@ -20,8 +22,15 @@ int main(int argc, char** argv) { auto antlr = ParseFileWithAntlr(opts.input); auto ast = BuildAst(antlr.tree); ast::PrintAST(*ast); // 调试 AST + if (!opts.ast_dot_output.empty()) { + std::ofstream dot_out(opts.ast_dot_output); + if (!dot_out) { + throw std::runtime_error("无法写入 AST DOT 文件: " + opts.ast_dot_output); + } + ast::PrintASTDot(*ast, dot_out); + } ast = RunSema(std::move(ast)); - // IR 生成:当前使用 IRGenDriver.cpp 里的最小参考实现, + // IR 生成:当前使用 IRGenDriver.cpp 里的最小实现, // 同学们可以在 irgen/ 目录下按提示自行完善或替换实现。 auto module = GenerateIR(*ast); diff --git a/src/utils/CLI.cpp b/src/utils/CLI.cpp index f8120bc..ef8f325 100644 --- a/src/utils/CLI.cpp +++ b/src/utils/CLI.cpp @@ -5,15 +5,15 @@ #include "utils/CLI.h" -#include -#include #include +#include +#include CLIOptions ParseCLI(int argc, char** argv) { CLIOptions opt; if (argc <= 1) { - throw std::runtime_error("用法: compiler [--help] "); + throw std::runtime_error("用法: compiler [--help] [--ast-dot ] "); } for (int i = 1; i < argc; ++i) { @@ -23,6 +23,14 @@ CLIOptions ParseCLI(int argc, char** argv) { return opt; } + if (std::strcmp(arg, "--ast-dot") == 0) { + if (i + 1 >= argc) { + throw std::runtime_error("参数错误: --ast-dot 需要一个输出路径"); + } + opt.ast_dot_output = argv[++i]; + continue; + } + if (arg[0] == '-') { throw std::runtime_error(std::string("未知参数: ") + arg + "(使用 --help 查看用法)"); @@ -35,7 +43,7 @@ CLIOptions ParseCLI(int argc, char** argv) { opt.input = arg; } - if (opt.input.empty()) { + if (opt.input.empty() && !opt.show_help) { throw std::runtime_error("缺少输入文件:请提供 (使用 --help 查看用法)"); } return opt; diff --git a/src/utils/CLI.h b/src/utils/CLI.h index 618b9e1..de5440c 100644 --- a/src/utils/CLI.h +++ b/src/utils/CLI.h @@ -1,10 +1,11 @@ -// 简易命令行解析:仅支持输入文件路径。 +// 简易命令行解析:支持帮助、输入文件和 AST DOT 导出。 #pragma once #include struct CLIOptions { std::string input; + std::string ast_dot_output; bool show_help = false; }; diff --git a/src/utils/Log.cpp b/src/utils/Log.cpp index e1702a3..58a7c26 100644 --- a/src/utils/Log.cpp +++ b/src/utils/Log.cpp @@ -10,10 +10,11 @@ void PrintHelp(std::ostream& os) { os << "SysY Compiler (课程实验最小可运行示例)\n" << "\n" << "用法:\n" - << " compiler [--help] \n" + << " compiler [--help] [--ast-dot ] \n" << "\n" << "选项:\n" << " -h, --help 打印帮助信息并退出\n" + << " --ast-dot 导出 AST Graphviz DOT 到指定文件\n" << "\n" << "说明:\n" << " - 当前默认将 IR 输出到标准输出,可使用重定向写入文件:\n"