You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
MiaCTFer/client-1/subdomain/oneforall/common/module.py

359 lines
12 KiB

This file contains ambiguous Unicode characters!

This file contains ambiguous Unicode characters that may be confused with others in your current locale. If your use case is intentional and legitimate, you can safely ignore this warning. Use the Escape button to highlight these characters.

# coding=utf-8
"""
模块基类
"""
import json
import re
import threading
import time
import requests
import client.subdomain.oneforall.config as config
from client.subdomain.oneforall.config import logger
from . import utils
from .domain import Domain
from client.subdomain.oneforall.common.database import Database
lock = threading.Lock()
class Module(object):
def __init__(self):
self.module = 'Module'
self.source = 'BaseModule'
self.cookie = None
self.header = dict()
self.proxy = None
self.delay = config.request_delay # 请求睡眠时延
self.timeout = config.request_timeout # 请求超时时间
self.verify = config.request_verify # 请求SSL验证
self.domain = str() # 当前进行子域名收集的主域
self.type = 'A' # 对主域进行子域收集时利用的DNS记录查询类型(默认利用A记录)
self.subdomains = set() # 存放发现的子域
self.records = dict() # 存放子域解析记录
self.results = list() # 存放模块结果
self.start = time.time() # 模块开始执行时间
self.end = None # 模块结束执行时间
self.elapse = None # 模块执行耗时
def check(self, *apis):
"""
简单检查是否配置了api信息
:param apis: api信息元组
:return: 检查结果
"""
if not all(apis):
logger.log('ALERT', f'{self.source}模块没有配置API跳过执行')
return False
return True
def begin(self):
"""
输出模块开始信息
"""
logger.log('DEBUG', f'开始执行{self.source}模块收集{self.domain}的子域')
def finish(self):
"""
输出模块结束信息
"""
self.end = time.time()
self.elapse = round(self.end - self.start, 1)
logger.log('DEBUG', f'结束执行{self.source}模块收集{self.domain}的子域')
logger.log('INFOR', f'{self.source}模块耗时{self.elapse}秒发现子域'
f'{len(self.subdomains)}')
logger.log('DEBUG', f'{self.source}模块发现{self.domain}的子域\n'
f'{self.subdomains}')
def head(self, url, params=None, check=True, **kwargs):
"""
自定义head请求
:param str url: 请求地址
:param dict params: 请求参数
:param bool check: 检查响应
:param kwargs: 其他参数
:return: requests响应对象
"""
try:
resp = requests.head(url,
params=params,
cookies=self.cookie,
headers=self.header,
proxies=self.proxy,
timeout=self.timeout,
verify=self.verify,
**kwargs)
except Exception as e:
logger.log('ERROR', e.args)
return None
if not check:
return resp
if utils.check_response('HEAD', resp):
return resp
return None
def get(self, url, params=None, check=True, **kwargs):
"""
自定义get请求
:param str url: 请求地址
:param dict params: 请求参数
:param bool check: 检查响应
:param kwargs: 其他参数
:return: requests响应对象
"""
try:
resp = requests.get(url,
params=params,
cookies=self.cookie,
headers=self.header,
proxies=self.proxy,
timeout=self.timeout,
verify=self.verify,
**kwargs)
except Exception as e:
logger.log('ERROR', e.args)
return None
if not check:
return resp
if utils.check_response('GET', resp):
return resp
return None
def post(self, url, data=None, check=True, **kwargs):
"""
自定义post请求
:param str url: 请求地址
:param dict data: 请求数据
:param bool check: 检查响应
:param kwargs: 其他参数
:return: requests响应对象
"""
try:
resp = requests.post(url,
data=data,
cookies=self.cookie,
headers=self.header,
proxies=self.proxy,
timeout=self.timeout,
verify=self.verify,
**kwargs)
except Exception as e:
logger.log('ERROR', e.args)
return None
if not check:
return resp
if utils.check_response('POST', resp):
return resp
return None
def get_header(self):
"""
获取请求头
:return: 请求头
"""
# logger.log('DEBUG', f'获取请求头')
if config.enable_fake_header:
return utils.gen_fake_header()
else:
return self.header
def get_proxy(self, module):
"""
获取代理
:param str module: 模块名
:return: 代理字典
"""
if not config.enable_proxy:
logger.log('TRACE', f'所有模块不使用代理')
return self.proxy
if config.proxy_all_module:
logger.log('TRACE', f'{module}模块使用代理')
return utils.get_random_proxy()
if module in config.proxy_partial_module:
logger.log('TRACE', f'{module}模块使用代理')
return utils.get_random_proxy()
else:
logger.log('TRACE', f'{module}模块不使用代理')
return self.proxy
@staticmethod
def match(domain, html, distinct=True):
"""
正则匹配出子域
:param str domain: 域名
:param str html: 要匹配的html响应体
:param bool distinct: 匹配结果去除
:return: 匹配出的子域集合或列表
:rtype: set or list
"""
logger.log('TRACE', f'正则匹配响应体中的子域')
regexp = r'(?:\>|\"|\'|\=|\,)(?:http\:\/\/|https\:\/\/)?' \
r'(?:[a-z0-9](?:[a-z0-9\-]{0,61}[a-z0-9])?\.){0,}' \
+ domain.replace('.', r'\.')
result = re.findall(regexp, html, re.I)
if not result:
return set()
regexp = r'(?:http://|https://)'
deal = map(lambda s: re.sub(regexp, '', s[1:].lower()), result)
if distinct:
return set(deal)
else:
return list(deal)
@staticmethod
def register(domain):
"""
获取注册域名
:param str domain: 域名
:return: 注册域名
"""
return Domain(domain).registered()
def save_json(self):
"""
将各模块结果保存为json文件
:return 是否保存成功
"""
if not config.save_module_result:
return False
logger.log('TRACE', f'{self.source}模块发现的子域结果保存为json文件')
path = config.result_save_dir.joinpath(self.domain, self.module)
path.mkdir(parents=True, exist_ok=True)
name = self.source + '.json'
path = path.joinpath(name)
with open(path, mode='w', encoding='utf-8', errors='ignore') as file:
result = {'domain': self.domain,
'name': self.module,
'source': self.source,
'elapse': self.elapse,
'find': len(self.subdomains),
'subdomains': list(self.subdomains),
'records': self.records}
json.dump(result, file, ensure_ascii=False, indent=4)
return True
def gen_record(self, subdomains, record):
"""
生成记录字典
"""
item = dict()
item['content'] = record
for subdomain in subdomains:
self.records[subdomain] = item
def gen_result(self, find=0, brute=None, valid=0):
"""
生成结果
"""
logger.log('DEBUG', f'正在生成最终结果')
if not len(self.subdomains): # 该模块一个子域都没有发现的情况
result = {'id': None,
'type': self.type,
'alive': None,
'request': None,
'resolve': None,
'new': None,
'url': None,
'subdomain': None,
'level': None,
'cname': None,
'content': None,
'public': None,
'port': None,
'status': None,
'reason': None,
'title': None,
'banner': None,
'header': None,
'response': None,
'times': None,
'ttl': None,
'resolver': None,
'module': self.module,
'source': self.source,
'elapse': self.elapse,
'find': find,
'brute': brute,
'valid': valid}
self.results.append(result)
else:
for subdomain in self.subdomains:
url = 'http://' + subdomain
level = subdomain.count('.') - self.domain.count('.')
record = self.records.get(subdomain)
if record is None:
record = dict()
resolve = record.get('resolve')
request = record.get('request')
alive = record.get('alive')
if self.type != 'A': # 不是利用的DNS记录的A记录查询子域默认都有效
resolve = 1
request = 1
alive = 1
reason = record.get('reason')
resolver = record.get('resolver')
cname = record.get('cname')
content = record.get('content')
times = record.get('times')
ttl = record.get('ttl')
public = record.get('public')
if isinstance(cname, list):
cname = ','.join(cname)
content = ','.join(content)
times = ','.join([str(num) for num in times])
ttl = ','.join([str(num) for num in ttl])
public = ','.join([str(num) for num in public])
result = {'id': None,
'type': self.type,
'alive': alive,
'request': request,
'resolve': resolve,
'new': None,
'url': url,
'subdomain': subdomain,
'level': level,
'cname': cname,
'content': content,
'public': public,
'port': 80,
'status': None,
'reason': reason,
'title': None,
'banner': None,
'header': None,
'response': None,
'times': times,
'ttl': ttl,
'resolver': resolver,
'module': self.module,
'source': self.source,
'elapse': self.elapse,
'find': find,
'brute': brute,
'valid': valid,
}
self.results.append(result)
def save_db(self):
"""
将模块结果存入数据库中
"""
logger.log('DEBUG', f'正在将结果存入到数据库')
lock.acquire()
db = Database()
db.create_table(self.domain)
db.save_db(self.domain, self.results, self.source)
db.close()
lock.release()