#!/usr/bin/env python

"""
Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission
"""

# 导入需要的模块
from __future__ import print_function

import mimetypes  # 用于猜测文件的MIME类型
import gzip  # 用于gzip压缩
import os  # 操作系统相关功能
import re  # 正则表达式
import sys  # 系统相关功能
import threading  # 多线程支持
import time  # 时间相关功能
import traceback  # 异常追踪

# 将上级目录添加到Python路径中
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))

# 导入自定义模块和第三方库
from lib.core.enums import HTTP_HEADER  # HTTP头部常量
from lib.core.settings import UNICODE_ENCODING  # 编码设置
from lib.core.settings import VERSION_STRING  # 版本信息
from thirdparty import six  # Python 2/3 兼容库
from thirdparty.six.moves import BaseHTTPServer as _BaseHTTPServer  # HTTP服务器基类
from thirdparty.six.moves import http_client as _http_client  # HTTP客户端
from thirdparty.six.moves import socketserver as _socketserver  # Socket服务器
from thirdparty.six.moves import urllib as _urllib  # URL处理

# 服务器配置
HTTP_ADDRESS = "0.0.0.0"  # 监听所有网络接口
HTTP_PORT = 8951  # 服务器端口
DEBUG = True  # 调试模式开关
HTML_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "data", "html"))  # HTML文件目录
DISABLED_CONTENT_EXTENSIONS = (".py", ".pyc", ".md", ".txt", ".bak", ".conf", ".zip", "~")  # 禁止访问的文件扩展名

class ThreadingServer(_socketserver.ThreadingMixIn, _BaseHTTPServer.HTTPServer):
    """多线程HTTP服务器类"""
    def finish_request(self, *args, **kwargs):
        """处理请求完成时的回调"""
        try:
            _BaseHTTPServer.HTTPServer.finish_request(self, *args, **kwargs)
        except Exception:
            if DEBUG:
                traceback.print_exc()

class ReqHandler(_BaseHTTPServer.BaseHTTPRequestHandler):
    """HTTP请求处理器类"""
    def do_GET(self):
        """处理GET请求"""
        # 解析URL和查询参数
        path, query = self.path.split('?', 1) if '?' in self.path else (self.path, "")
        params = {}
        content = None

        # 解析查询参数
        if query:
            params.update(_urllib.parse.parse_qs(query))

        # 只保留每个参数的最后一个值
        for key in params:
            if params[key]:
                params[key] = params[key][-1]

        self.url, self.params = path, params

        # 处理根路径请求
        if path == '/':
            path = "index.html"

        # 处理文件路径
        path = path.strip('/')
        path = path.replace('/', os.path.sep)
        path = os.path.abspath(os.path.join(HTML_DIR, path)).strip()

        # 如果文件不存在但存在同名的.html文件,则使用.html文件
        if not os.path.isfile(path) and os.path.isfile("%s.html" % path):
            path = "%s.html" % path

        # 检查文件是否可访问并返回相应内容
        if ".." not in os.path.relpath(path, HTML_DIR) and os.path.isfile(path) and not path.endswith(DISABLED_CONTENT_EXTENSIONS):
            content = open(path, "rb").read()
            self.send_response(_http_client.OK)
            self.send_header(HTTP_HEADER.CONNECTION, "close")
            self.send_header(HTTP_HEADER.CONTENT_TYPE, mimetypes.guess_type(path)[0] or "application/octet-stream")
        else:
            # 返回404错误页面
            content = ("<!DOCTYPE html><html lang=\"en\"><head><title>404 Not Found</title></head><body><h1>Not Found</h1><p>The requested URL %s was not found on this server.</p></body></html>" % self.path.split('?')[0]).encode(UNICODE_ENCODING)
            self.send_response(_http_client.NOT_FOUND)
            self.send_header(HTTP_HEADER.CONNECTION, "close")

        if content is not None:
            # 处理模板标记
            for match in re.finditer(b"<!(\\w+)!>", content):
                name = match.group(1)
                _ = getattr(self, "_%s" % name.lower(), None)
                if _:
                    content = self._format(content, **{name: _()})

            # 如果客户端支持gzip压缩,则压缩内容
            if "gzip" in self.headers.get(HTTP_HEADER.ACCEPT_ENCODING):
                self.send_header(HTTP_HEADER.CONTENT_ENCODING, "gzip")
                _ = six.BytesIO()
                compress = gzip.GzipFile("", "w+b", 9, _)
                compress._stream = _
                compress.write(content)
                compress.flush()
                compress.close()
                content = compress._stream.getvalue()

            self.send_header(HTTP_HEADER.CONTENT_LENGTH, str(len(content)))

        self.end_headers()

        # 发送响应内容
        if content:
            self.wfile.write(content)

        self.wfile.flush()

    def _format(self, content, **params):
        """格式化响应内容,替换模板标记"""
        if content:
            for key, value in params.items():
                content = content.replace("<!%s!>" % key, value)
        return content

    def version_string(self):
        """返回服务器版本信息"""
        return VERSION_STRING

    def log_message(self, format, *args):
        """禁用日志记录"""
        return

    def finish(self):
        """完成请求处理"""
        try:
            _BaseHTTPServer.BaseHTTPRequestHandler.finish(self)
        except Exception:
            if DEBUG:
                traceback.print_exc()

def start_httpd():
    """启动HTTP服务器"""
    server = ThreadingServer((HTTP_ADDRESS, HTTP_PORT), ReqHandler)
    thread = threading.Thread(target=server.serve_forever)
    thread.daemon = True  # 设置为守护线程
    thread.start()

    print("[i] running HTTP server at '%s:%d'" % (HTTP_ADDRESS, HTTP_PORT))

if __name__ == "__main__":
    """主程序入口"""
    try:
        start_httpd()
        # 保持程序运行
        while True:
            time.sleep(1)
    except KeyboardInterrupt:
        pass