# coding=utf-8 """ 模块基类,用于子域收集模块的基础功能实现 """ import json import re import threading import time import requests import client.subdomain.oneforall.config as config from client.subdomain.oneforall.config import logger from . import utils from .domain import Domain from client.subdomain.oneforall.common.database import Database # 定义线程锁,确保数据库操作的线程安全 lock = threading.Lock() class Module(object): """ 基础模块类,所有子域收集模块将继承此类并实现自己的查询逻辑。 """ def __init__(self): """ 初始化模块基础参数。 """ self.module = 'Module' # 模块名称 self.source = 'BaseModule' # 模块源 self.cookie = None # Cookie,用于HTTP请求 self.header = dict() # 请求头 self.proxy = None # 代理设置 self.delay = config.request_delay # 请求的延时 self.timeout = config.request_timeout # 请求的超时时间 self.verify = config.request_verify # SSL证书验证 self.domain = str() # 当前操作的主域名 self.type = 'A' # DNS查询的记录类型(默认是A记录) self.subdomains = set() # 存放发现的子域 self.records = dict() # 存放子域的解析记录 self.results = list() # 存放模块的查询结果 self.start = time.time() # 模块开始执行时间 self.end = None # 模块结束执行时间 self.elapse = None # 模块执行时间(秒) def check(self, *apis): """ 检查是否配置了API信息。 :param apis: API信息的元组 :return: 如果所有API都已配置,返回True,否则返回False """ if not all(apis): logger.log('ALERT', f'{self.source}模块没有配置API,跳过执行') return False return True def begin(self): """ 输出模块开始执行的信息。 """ logger.log('DEBUG', f'开始执行{self.source}模块,收集{self.domain}的子域') def finish(self): """ 输出模块执行完成的信息,并计算执行耗时。 """ self.end = time.time() # 记录结束时间 self.elapse = round(self.end - self.start, 1) # 计算执行耗时 logger.log('DEBUG', f'结束执行{self.source}模块,收集{self.domain}的子域') logger.log('INFOR', f'{self.source}模块耗时{self.elapse}秒,发现子域{len(self.subdomains)}个') logger.log('DEBUG', f'{self.source}模块发现{self.domain}的子域:{self.subdomains}') def head(self, url, params=None, check=True, **kwargs): """ 自定义的HTTP HEAD请求。 :param str url: 请求的URL :param dict params: 请求的参数 :param bool check: 是否检查响应 :param kwargs: 其他参数 :return: 返回requests的响应对象 """ try: resp = requests.head(url, params=params, cookies=self.cookie, headers=self.header, proxies=self.proxy, timeout=self.timeout, verify=self.verify, **kwargs) except Exception as e: logger.log('ERROR', e.args) return None if not check: return resp if utils.check_response('HEAD', resp): return resp return None def get(self, url, params=None, check=True, **kwargs): """ 自定义的HTTP GET请求。 :param str url: 请求的URL :param dict params: 请求的参数 :param bool check: 是否检查响应 :param kwargs: 其他参数 :return: 返回requests的响应对象 """ try: resp = requests.get(url, params=params, cookies=self.cookie, headers=self.header, proxies=self.proxy, timeout=self.timeout, verify=self.verify, **kwargs) except Exception as e: logger.log('ERROR', e.args) return None if not check: return resp if utils.check_response('GET', resp): return resp return None def post(self, url, data=None, check=True, **kwargs): """ 自定义的HTTP POST请求。 :param str url: 请求的URL :param dict data: 请求的数据 :param bool check: 是否检查响应 :param kwargs: 其他参数 :return: 返回requests的响应对象 """ try: resp = requests.post(url, data=data, cookies=self.cookie, headers=self.header, proxies=self.proxy, timeout=self.timeout, verify=self.verify, **kwargs) except Exception as e: logger.log('ERROR', e.args) return None if not check: return resp if utils.check_response('POST', resp): return resp return None def get_header(self): """ 获取请求头部。 :return: 请求头部 """ if config.enable_fake_header: return utils.gen_fake_header() # 如果启用了伪造请求头,生成伪造请求头 else: return self.header # 否则返回原始的请求头 def get_proxy(self, module): """ 获取代理设置。 :param str module: 模块名 :return: 代理字典 """ if not config.enable_proxy: logger.log('TRACE', f'所有模块不使用代理') return self.proxy if config.proxy_all_module: logger.log('TRACE', f'{module}模块使用代理') return utils.get_random_proxy() # 获取随机代理 if module in config.proxy_partial_module: logger.log('TRACE', f'{module}模块使用代理') return utils.get_random_proxy() # 获取随机代理 else: logger.log('TRACE', f'{module}模块不使用代理') return self.proxy @staticmethod def match(domain, html, distinct=True): """ 使用正则匹配HTML响应体中的子域。 :param str domain: 要匹配的主域 :param str html: 要匹配的HTML响应体 :param bool distinct: 是否去除重复项 :return: 返回匹配到的子域集合或列表 :rtype: set or list """ logger.log('TRACE', f'正则匹配响应体中的子域') regexp = r'(?:\>|\"|\'|\=|\,)(?:http\:\/\/|https\:\/\/)?' \ r'(?:[a-z0-9](?:[a-z0-9\-]{0,61}[a-z0-9])?\.){0,}' \ + domain.replace('.', r'\.') # 构造正则表达式匹配子域 result = re.findall(regexp, html, re.I) if not result: return set() # 如果没有匹配到,返回空集合 regexp = r'(?:http://|https://)' # 对匹配结果进行处理,去掉URL的协议部分 deal = map(lambda s: re.sub(regexp, '', s[1:].lower()), result) if distinct: return set(deal) # 如果去重,返回集合 else: return list(deal) # 否则返回列表 @staticmethod def register(domain): """ 获取域名的注册信息。 :param str domain: 要查询的域名 :return: 注册信息 """ return Domain(domain).registered() # 获取域名注册信息 def save_json(self): """ 将模块结果保存为JSON文件。 :return: 是否保存成功 """ if not config.save_module_result: return False # 如果不保存结果,直接返回False logger.log('TRACE', f'将{self.source}模块发现的子域结果保存为json文件') path = config.result_save_dir.joinpath(self.domain, self.module) path.mkdir(parents=True, exist_ok=True) # 创建保存路径 name = self.source + '.json' # 文件名 path = path.joinpath(name) # 将结果保存到JSON文件 with open(path, mode='w', encoding='utf-8', errors='ignore') as file: result = {'domain': self.domain, 'name': self.module, 'source': self.source, 'elapse': self.elapse, 'find': len(self.subdomains), 'subdomains': list(self.subdomains), 'records': self.records} json.dump(result, file, ensure_ascii=False, indent=4) return True def gen_record(self, subdomains, record): """ 生成子域解析记录。 :param subdomains: 子域集合 :param record: DNS记录 """ item = dict() item['content'] = record for subdomain in subdomains: self.records[subdomain] = item # 将记录添加到字典 def gen_result(self, find=0, brute=None, valid=0): """ 生成模块的最终结果。 :param find: 找到的子域数量 :param brute: 是否为暴力破解结果 :param valid: 有效子域数量 """ logger.log('DEBUG', f'正在生成最终结果') if not len(self.subdomains): # 如果没有发现任何子域 result = {'id': None, 'type': self.type, 'alive': None, 'request': None, 'resolve': None, 'new': None, 'url': None, 'subdomain': None, 'level': None, 'cname': None, 'content': None, 'public': None, 'port': None, 'status': None, 'reason': None, 'title': None, 'banner': None, 'header': None, 'response': None, 'times': None, 'ttl': None, 'resolver': None, 'module': self.module, 'source': self.source, 'elapse': self.elapse, 'find': find, 'brute': brute, 'valid': valid} self.results.append(result) # 将结果添加到结果列表 else: for subdomain in self.subdomains: url = 'http://' + subdomain level = subdomain.count('.') - self.domain.count('.') record = self.records.get(subdomain) if record is None: record = dict() # 如果没有解析记录,则使用空字典 resolve = record.get('resolve') request = record.get('request') alive = record.get('alive') if self.type != 'A': # 如果不是A记录查询,默认为有效 resolve = 1 request = 1 alive = 1 reason = record.get('reason') resolver = record.get('resolver') cname = record.get('cname') content = record.get('content') times = record.get('times') ttl = record.get('ttl') public = record.get('public') if isinstance(cname, list): cname = ','.join(cname) content = ','.join(content) times = ','.join([str(num) for num in times]) ttl = ','.join([str(num) for num in ttl]) public = ','.join([str(num) for num in public]) # 生成最终结果 result = {'id': None, 'type': self.type, 'alive': alive, 'request': request, 'resolve': resolve, 'new': None, 'url': url, 'subdomain': subdomain, 'level': level, 'cname': cname, 'content': content, 'public': public, 'port': 80, 'status': None, 'reason': reason, 'title': None, 'banner': None, 'header': None, 'response': None, 'times': times, 'ttl': ttl, 'resolver': resolver, 'module': self.module, 'source': self.source, 'elapse': self.elapse, 'find': find, 'brute': brute, 'valid': valid} self.results.append(result) # 将每个子域的结果添加到结果列表 def save_db(self): """ 将模块结果保存到数据库。 """ logger.log('DEBUG', f'正在将结果存入数据库') lock.acquire() # 加锁,确保线程安全 db = Database() # 创建数据库实例 db.create_table(self.domain) # 为该域名创建表格 db.save_db(self.domain, self.results, self.source) # 将结果保存到数据库 db.close() # 关闭数据库连接 lock.release() # 释放锁