add comments to utils

pull/3/head
wang 3 months ago
parent d9b7712f9d
commit 4881e3c4e1

@ -128,6 +128,7 @@ class Database(object):
class Task(object): class Task(object):
def __init__(self, taskid, remote_addr): def __init__(self, taskid, remote_addr):
# 初始化任务对象,设置远程地址、进程、输出目录、选项等属性
self.remote_addr = remote_addr self.remote_addr = remote_addr
self.process = None self.process = None
self.output_directory = None self.output_directory = None
@ -136,6 +137,7 @@ class Task(object):
self.initialize_options(taskid) self.initialize_options(taskid)
def initialize_options(self, taskid): def initialize_options(self, taskid):
# 初始化选项,设置默认值
datatype = {"boolean": False, "string": None, "integer": None, "float": None} datatype = {"boolean": False, "string": None, "integer": None, "float": None}
self.options = AttribDict() self.options = AttribDict()
@ -158,18 +160,23 @@ class Task(object):
self._original_options = AttribDict(self.options) self._original_options = AttribDict(self.options)
def set_option(self, option, value): def set_option(self, option, value):
# 设置选项
self.options[option] = value self.options[option] = value
def get_option(self, option): def get_option(self, option):
# 获取选项
return self.options[option] return self.options[option]
def get_options(self): def get_options(self):
# 获取所有选项
return self.options return self.options
def reset_options(self): def reset_options(self):
# 重置选项为初始值
self.options = AttribDict(self._original_options) self.options = AttribDict(self._original_options)
def engine_start(self): def engine_start(self):
# 启动sqlmap引擎
handle, configFile = tempfile.mkstemp(prefix=MKSTEMP_PREFIX.CONFIG, text=True) handle, configFile = tempfile.mkstemp(prefix=MKSTEMP_PREFIX.CONFIG, text=True)
os.close(handle) os.close(handle)
saveConfig(self.options, configFile) saveConfig(self.options, configFile)
@ -184,6 +191,7 @@ class Task(object):
self.process = Popen(["sqlmap", "--api", "-c", configFile], shell=False, close_fds=not IS_WIN) self.process = Popen(["sqlmap", "--api", "-c", configFile], shell=False, close_fds=not IS_WIN)
def engine_stop(self): def engine_stop(self):
# 停止sqlmap引擎
if self.process: if self.process:
self.process.terminate() self.process.terminate()
return self.process.wait() return self.process.wait()
@ -191,9 +199,11 @@ class Task(object):
return None return None
def engine_process(self): def engine_process(self):
# 获取sqlmap引擎进程
return self.process return self.process
def engine_kill(self): def engine_kill(self):
# 杀死sqlmap引擎进程
if self.process: if self.process:
try: try:
self.process.kill() self.process.kill()
@ -203,12 +213,14 @@ class Task(object):
return None return None
def engine_get_id(self): def engine_get_id(self):
# 获取sqlmap引擎进程ID
if self.process: if self.process:
return self.process.pid return self.process.pid
else: else:
return None return None
def engine_get_returncode(self): def engine_get_returncode(self):
# 获取sqlmap引擎进程返回码
if self.process: if self.process:
self.process.poll() self.process.poll()
return self.process.returncode return self.process.returncode
@ -216,6 +228,7 @@ class Task(object):
return None return None
def engine_has_terminated(self): def engine_has_terminated(self):
# 判断sqlmap引擎进程是否已经终止
return isinstance(self.engine_get_returncode(), int) return isinstance(self.engine_get_returncode(), int)
# Wrapper functions for sqlmap engine # Wrapper functions for sqlmap engine
@ -279,11 +292,14 @@ class LogRecorder(logging.StreamHandler):
conf.databaseCursor.execute("INSERT INTO logs VALUES(NULL, ?, ?, ?, ?)", (conf.taskid, time.strftime("%X"), record.levelname, str(record.msg % record.args if record.args else record.msg))) conf.databaseCursor.execute("INSERT INTO logs VALUES(NULL, ?, ?, ?, ?)", (conf.taskid, time.strftime("%X"), record.levelname, str(record.msg % record.args if record.args else record.msg)))
def setRestAPILog(): def setRestAPILog():
# 如果配置文件中api字段为真
if conf.api: if conf.api:
try: try:
# 连接数据库
conf.databaseCursor = Database(conf.database) conf.databaseCursor = Database(conf.database)
conf.databaseCursor.connect("client") conf.databaseCursor.connect("client")
except sqlite3.OperationalError as ex: except sqlite3.OperationalError as ex:
# 如果连接数据库失败,抛出异常
raise SqlmapConnectionException("%s ('%s')" % (ex, conf.database)) raise SqlmapConnectionException("%s ('%s')" % (ex, conf.database))
# Set a logging handler that writes log messages to a IPC database # Set a logging handler that writes log messages to a IPC database
@ -297,24 +313,33 @@ def is_admin(token):
@hook('before_request') @hook('before_request')
def check_authentication(): def check_authentication():
# 检查是否已经认证
if not any((DataStore.username, DataStore.password)): if not any((DataStore.username, DataStore.password)):
return return
# 获取请求头中的Authorization字段
authorization = request.headers.get("Authorization", "") authorization = request.headers.get("Authorization", "")
# 使用正则表达式匹配Authorization字段中的Basic认证信息
match = re.search(r"(?i)\ABasic\s+([^\s]+)", authorization) match = re.search(r"(?i)\ABasic\s+([^\s]+)", authorization)
# 如果没有匹配到Basic认证信息则将请求路径设置为错误页面
if not match: if not match:
request.environ["PATH_INFO"] = "/error/401" request.environ["PATH_INFO"] = "/error/401"
try: try:
# 解码Basic认证信息
creds = decodeBase64(match.group(1), binary=False) creds = decodeBase64(match.group(1), binary=False)
except: except:
# 如果解码失败,则将请求路径设置为错误页面
request.environ["PATH_INFO"] = "/error/401" request.environ["PATH_INFO"] = "/error/401"
else: else:
# 如果解码后的认证信息中冒号的数量不等于1则将请求路径设置为错误页面
if creds.count(':') != 1: if creds.count(':') != 1:
request.environ["PATH_INFO"] = "/error/401" request.environ["PATH_INFO"] = "/error/401"
else: else:
# 将认证信息分割为用户名和密码
username, password = creds.split(':') username, password = creds.split(':')
# 如果用户名或密码不匹配,则将请求路径设置为错误页面
if username.strip() != (DataStore.username or "") or password.strip() != (DataStore.password or ""): if username.strip() != (DataStore.username or "") or password.strip() != (DataStore.password or ""):
request.environ["PATH_INFO"] = "/error/401" request.environ["PATH_INFO"] = "/error/401"
@ -480,18 +505,28 @@ def option_set(taskid):
Set value of option(s) for a certain task ID Set value of option(s) for a certain task ID
""" """
# Check if the task ID exists in the DataStore
if taskid not in DataStore.tasks: if taskid not in DataStore.tasks:
# Log a warning if the task ID does not exist
logger.warning("[%s] Invalid task ID provided to option_set()" % taskid) logger.warning("[%s] Invalid task ID provided to option_set()" % taskid)
# Return a JSON response indicating failure
return jsonize({"success": False, "message": "Invalid task ID"}) return jsonize({"success": False, "message": "Invalid task ID"})
# Check if the request JSON is None
if request.json is None: if request.json is None:
# Log a warning if the request JSON is None
logger.warning("[%s] Invalid JSON options provided to option_set()" % taskid) logger.warning("[%s] Invalid JSON options provided to option_set()" % taskid)
# Return a JSON response indicating failure
return jsonize({"success": False, "message": "Invalid JSON options"}) return jsonize({"success": False, "message": "Invalid JSON options"})
# Iterate through the request JSON
for option, value in request.json.items(): for option, value in request.json.items():
# Set the option for the task ID in the DataStore
DataStore.tasks[taskid].set_option(option, value) DataStore.tasks[taskid].set_option(option, value)
# Log a debug message indicating the options have been set
logger.debug("(%s) Requested to set options" % taskid) logger.debug("(%s) Requested to set options" % taskid)
# Return a JSON response indicating success
return jsonize({"success": True}) return jsonize({"success": True})
# Handle scans # Handle scans
@ -530,13 +565,18 @@ def scan_stop(taskid):
Stop a scan Stop a scan
""" """
# 检查任务ID是否有效
if (taskid not in DataStore.tasks or DataStore.tasks[taskid].engine_process() is None or DataStore.tasks[taskid].engine_has_terminated()): if (taskid not in DataStore.tasks or DataStore.tasks[taskid].engine_process() is None or DataStore.tasks[taskid].engine_has_terminated()):
# 如果任务ID无效记录警告日志并返回错误信息
logger.warning("[%s] Invalid task ID provided to scan_stop()" % taskid) logger.warning("[%s] Invalid task ID provided to scan_stop()" % taskid)
return jsonize({"success": False, "message": "Invalid task ID"}) return jsonize({"success": False, "message": "Invalid task ID"})
# 停止任务
DataStore.tasks[taskid].engine_stop() DataStore.tasks[taskid].engine_stop()
# 记录调试日志
logger.debug("(%s) Stopped scan" % taskid) logger.debug("(%s) Stopped scan" % taskid)
# 返回成功信息
return jsonize({"success": True}) return jsonize({"success": True})
@get("/scan/<taskid>/kill") @get("/scan/<taskid>/kill")
@ -609,14 +649,17 @@ def scan_log_limited(taskid, start, end):
json_log_messages = list() json_log_messages = list()
# Check if the taskid exists in the DataStore
if taskid not in DataStore.tasks: if taskid not in DataStore.tasks:
logger.warning("[%s] Invalid task ID provided to scan_log_limited()" % taskid) logger.warning("[%s] Invalid task ID provided to scan_log_limited()" % taskid)
return jsonize({"success": False, "message": "Invalid task ID"}) return jsonize({"success": False, "message": "Invalid task ID"})
# Check if the start and end values are digits and if the end value is greater than the start value
if not start.isdigit() or not end.isdigit() or int(end) < int(start): if not start.isdigit() or not end.isdigit() or int(end) < int(start):
logger.warning("[%s] Invalid start or end value provided to scan_log_limited()" % taskid) logger.warning("[%s] Invalid start or end value provided to scan_log_limited()" % taskid)
return jsonize({"success": False, "message": "Invalid start or end value, must be digits"}) return jsonize({"success": False, "message": "Invalid start or end value, must be digits"})
# Set the start and end values to a minimum of 1
start = max(1, int(start)) start = max(1, int(start))
end = max(1, int(end)) end = max(1, int(end))
@ -635,6 +678,7 @@ def scan_log(taskid):
json_log_messages = list() json_log_messages = list()
# Check if the taskid exists in the DataStore
if taskid not in DataStore.tasks: if taskid not in DataStore.tasks:
logger.warning("[%s] Invalid task ID provided to scan_log()" % taskid) logger.warning("[%s] Invalid task ID provided to scan_log()" % taskid)
return jsonize({"success": False, "message": "Invalid task ID"}) return jsonize({"success": False, "message": "Invalid task ID"})
@ -685,21 +729,27 @@ def server(host=RESTAPI_DEFAULT_ADDRESS, port=RESTAPI_DEFAULT_PORT, adapter=REST
REST-JSON API server REST-JSON API server
""" """
# 生成一个随机的16字节的admin_token
DataStore.admin_token = encodeHex(os.urandom(16), binary=False) DataStore.admin_token = encodeHex(os.urandom(16), binary=False)
# 设置用户名和密码
DataStore.username = username DataStore.username = username
DataStore.password = password DataStore.password = password
# 如果没有指定数据库,则创建一个临时数据库
if not database: if not database:
_, Database.filepath = tempfile.mkstemp(prefix=MKSTEMP_PREFIX.IPC, text=False) _, Database.filepath = tempfile.mkstemp(prefix=MKSTEMP_PREFIX.IPC, text=False)
os.close(_) os.close(_)
else: else:
# 否则使用指定的数据库
Database.filepath = database Database.filepath = database
# 如果端口为0则随机生成一个端口
if port == 0: # random if port == 0: # random
with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s: with contextlib.closing(socket.socket(socket.AF_INET, socket.SOCK_STREAM)) as s:
s.bind((host, 0)) s.bind((host, 0))
port = s.getsockname()[1] port = s.getsockname()[1]
# 打印运行信息
logger.info("Running REST-JSON API server at '%s:%d'.." % (host, port)) logger.info("Running REST-JSON API server at '%s:%d'.." % (host, port))
logger.info("Admin (secret) token: %s" % DataStore.admin_token) logger.info("Admin (secret) token: %s" % DataStore.admin_token)
logger.debug("IPC database: '%s'" % Database.filepath) logger.debug("IPC database: '%s'" % Database.filepath)
@ -737,25 +787,35 @@ def server(host=RESTAPI_DEFAULT_ADDRESS, port=RESTAPI_DEFAULT_PORT, adapter=REST
logger.critical(errMsg) logger.critical(errMsg)
def _client(url, options=None): def _client(url, options=None):
# 打印正在调用的url
logger.debug("Calling '%s'" % url) logger.debug("Calling '%s'" % url)
try: try:
# 设置请求头
headers = {"Content-Type": "application/json"} headers = {"Content-Type": "application/json"}
# 如果options不为空则将options转换为json格式
if options is not None: if options is not None:
data = getBytes(jsonize(options)) data = getBytes(jsonize(options))
else: else:
data = None data = None
# 如果DataStore中有用户名或密码则将用户名和密码进行base64编码并添加到请求头中
if DataStore.username or DataStore.password: if DataStore.username or DataStore.password:
headers["Authorization"] = "Basic %s" % encodeBase64("%s:%s" % (DataStore.username or "", DataStore.password or ""), binary=False) headers["Authorization"] = "Basic %s" % encodeBase64("%s:%s" % (DataStore.username or "", DataStore.password or ""), binary=False)
# 创建请求对象
req = _urllib.request.Request(url, data, headers) req = _urllib.request.Request(url, data, headers)
# 发送请求并获取响应
response = _urllib.request.urlopen(req) response = _urllib.request.urlopen(req)
# 将响应内容转换为文本
text = getText(response.read()) text = getText(response.read())
except: except:
# 如果options不为空则打印错误信息
if options: if options:
logger.error("Failed to load and parse %s" % url) logger.error("Failed to load and parse %s" % url)
# 抛出异常
raise raise
# 返回文本
return text return text
def client(host=RESTAPI_DEFAULT_ADDRESS, port=RESTAPI_DEFAULT_PORT, username=None, password=None): def client(host=RESTAPI_DEFAULT_ADDRESS, port=RESTAPI_DEFAULT_PORT, username=None, password=None):

@ -62,60 +62,77 @@ def _addPageTextWords():
@stackedmethod @stackedmethod
def tableExists(tableFile, regex=None): def tableExists(tableFile, regex=None):
# 检查是否需要使用表存在性检查
if kb.choices.tableExists is None and not any(_ for _ in kb.injection.data if _ not in (PAYLOAD.TECHNIQUE.TIME, PAYLOAD.TECHNIQUE.STACKED)) and not conf.direct: if kb.choices.tableExists is None and not any(_ for _ in kb.injection.data if _ not in (PAYLOAD.TECHNIQUE.TIME, PAYLOAD.TECHNIQUE.STACKED)) and not conf.direct:
# 如果使用PAYLOAD.TECHNIQUE.TIME和PAYLOAD.TECHNIQUE.STACKED进行表存在性检查则发出警告
warnMsg = "it's not recommended to use '%s' and/or '%s' " % (PAYLOAD.SQLINJECTION[PAYLOAD.TECHNIQUE.TIME], PAYLOAD.SQLINJECTION[PAYLOAD.TECHNIQUE.STACKED]) warnMsg = "it's not recommended to use '%s' and/or '%s' " % (PAYLOAD.SQLINJECTION[PAYLOAD.TECHNIQUE.TIME], PAYLOAD.SQLINJECTION[PAYLOAD.TECHNIQUE.STACKED])
warnMsg += "for common table existence check" warnMsg += "for common table existence check"
logger.warning(warnMsg) logger.warning(warnMsg)
# 提示用户是否继续
message = "are you sure you want to continue? [y/N] " message = "are you sure you want to continue? [y/N] "
kb.choices.tableExists = readInput(message, default='N', boolean=True) kb.choices.tableExists = readInput(message, default='N', boolean=True)
# 如果用户选择不继续则返回None
if not kb.choices.tableExists: if not kb.choices.tableExists:
return None return None
# 检查表存在性
result = inject.checkBooleanExpression("%s" % safeStringFormat(BRUTE_TABLE_EXISTS_TEMPLATE, (randomInt(1), randomStr()))) result = inject.checkBooleanExpression("%s" % safeStringFormat(BRUTE_TABLE_EXISTS_TEMPLATE, (randomInt(1), randomStr())))
# 如果检查结果无效,则抛出异常
if result: if result:
errMsg = "can't use table existence check because of detected invalid results " errMsg = "can't use table existence check because of detected invalid results "
errMsg += "(most likely caused by inability of the used injection " errMsg += "(most likely caused by inability of the used injection "
errMsg += "to distinguish erroneous results)" errMsg += "to distinguish erroneous results)"
raise SqlmapDataException(errMsg) raise SqlmapDataException(errMsg)
# 将数据库信息推送到kb.injection.data中
pushValue(conf.db) pushValue(conf.db)
# 如果数据库信息存在,并且数据库类型为大写,则将数据库信息转换为大写
if conf.db and Backend.getIdentifiedDbms() in UPPER_CASE_DBMSES: if conf.db and Backend.getIdentifiedDbms() in UPPER_CASE_DBMSES:
conf.db = conf.db.upper() conf.db = conf.db.upper()
# 提示用户选择表存在性检查的文件
message = "which common tables (wordlist) file do you want to use?\n" message = "which common tables (wordlist) file do you want to use?\n"
message += "[1] default '%s' (press Enter)\n" % tableFile message += "[1] default '%s' (press Enter)\n" % tableFile
message += "[2] custom" message += "[2] custom"
choice = readInput(message, default='1') choice = readInput(message, default='1')
# 如果用户选择自定义文件,则提示用户输入文件路径
if choice == '2': if choice == '2':
message = "what's the custom common tables file location?\n" message = "what's the custom common tables file location?\n"
tableFile = readInput(message) or tableFile tableFile = readInput(message) or tableFile
# 打印信息,表示正在使用文件进行表存在性检查
infoMsg = "performing table existence using items from '%s'" % tableFile infoMsg = "performing table existence using items from '%s'" % tableFile
logger.info(infoMsg) logger.info(infoMsg)
# 获取文件中的表名
tables = getFileItems(tableFile, lowercase=Backend.getIdentifiedDbms() in (DBMS.ACCESS,), unique=True) tables = getFileItems(tableFile, lowercase=Backend.getIdentifiedDbms() in (DBMS.ACCESS,), unique=True)
tables.extend(_addPageTextWords()) tables.extend(_addPageTextWords())
tables = filterListValue(tables, regex) tables = filterListValue(tables, regex)
# 遍历数据库信息
for conf.db in (conf.db.split(',') if conf.db else [conf.db]): for conf.db in (conf.db.split(',') if conf.db else [conf.db]):
# 如果数据库信息存在,并且不是元数据库,则打印信息
if conf.db and METADB_SUFFIX not in conf.db: if conf.db and METADB_SUFFIX not in conf.db:
infoMsg = "checking database '%s'" % conf.db infoMsg = "checking database '%s'" % conf.db
logger.info(infoMsg) logger.info(infoMsg)
# 获取当前线程数据
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
threadData.shared.count = 0 threadData.shared.count = 0
threadData.shared.limit = len(tables) threadData.shared.limit = len(tables)
threadData.shared.files = [] threadData.shared.files = []
threadData.shared.unique = set() threadData.shared.unique = set()
# 定义线程函数
def tableExistsThread(): def tableExistsThread():
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
# 循环检查表存在性
while kb.threadContinue: while kb.threadContinue:
kb.locks.count.acquire() kb.locks.count.acquire()
if threadData.shared.count < threadData.shared.limit: if threadData.shared.count < threadData.shared.limit:
@ -126,28 +143,33 @@ def tableExists(tableFile, regex=None):
kb.locks.count.release() kb.locks.count.release()
break break
# 如果数据库信息存在并且不是元数据库并且数据库类型不是SQLite、Access、Firebird则构建完整的表名
if conf.db and METADB_SUFFIX not in conf.db and Backend.getIdentifiedDbms() not in (DBMS.SQLITE, DBMS.ACCESS, DBMS.FIREBIRD): if conf.db and METADB_SUFFIX not in conf.db and Backend.getIdentifiedDbms() not in (DBMS.SQLITE, DBMS.ACCESS, DBMS.FIREBIRD):
fullTableName = "%s.%s" % (conf.db, table) fullTableName = "%s.%s" % (conf.db, table)
else: else:
fullTableName = table fullTableName = table
# 根据数据库类型构建表存在性检查的SQL语句
if Backend.isDbms(DBMS.MCKOI): if Backend.isDbms(DBMS.MCKOI):
_ = randomInt(1) _ = randomInt(1)
result = inject.checkBooleanExpression("%s" % safeStringFormat("%d=(SELECT %d FROM %s)", (_, _, fullTableName))) result = inject.checkBooleanExpression("%s" % safeStringFormat("%d=(SELECT %d FROM %s)", (_, _, fullTableName)))
else: else:
result = inject.checkBooleanExpression("%s" % safeStringFormat(BRUTE_TABLE_EXISTS_TEMPLATE, (randomInt(1), fullTableName))) result = inject.checkBooleanExpression("%s" % safeStringFormat(BRUTE_TABLE_EXISTS_TEMPLATE, (randomInt(1), fullTableName)))
# 将结果添加到线程数据中
kb.locks.io.acquire() kb.locks.io.acquire()
if result and table.lower() not in threadData.shared.unique: if result and table.lower() not in threadData.shared.unique:
threadData.shared.files.append(table) threadData.shared.files.append(table)
threadData.shared.unique.add(table.lower()) threadData.shared.unique.add(table.lower())
# 如果verbose级别为1或2并且不是API调用则打印信息
if conf.verbose in (1, 2) and not conf.api: if conf.verbose in (1, 2) and not conf.api:
clearConsoleLine(True) clearConsoleLine(True)
infoMsg = "[%s] [INFO] retrieved: %s\n" % (time.strftime("%X"), unsafeSQLIdentificatorNaming(table)) infoMsg = "[%s] [INFO] retrieved: %s\n" % (time.strftime("%X"), unsafeSQLIdentificatorNaming(table))
dataToStdout(infoMsg, True) dataToStdout(infoMsg, True)
# 如果verbose级别为1或2则打印状态信息
if conf.verbose in (1, 2): if conf.verbose in (1, 2):
status = '%d/%d items (%d%%)' % (threadData.shared.count, threadData.shared.limit, round(100.0 * threadData.shared.count / threadData.shared.limit)) status = '%d/%d items (%d%%)' % (threadData.shared.count, threadData.shared.limit, round(100.0 * threadData.shared.count / threadData.shared.limit))
dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True) dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True)
@ -155,225 +177,342 @@ def tableExists(tableFile, regex=None):
kb.locks.io.release() kb.locks.io.release()
try: try:
# 尝试运行线程
runThreads(conf.threads, tableExistsThread, threadChoice=True) runThreads(conf.threads, tableExistsThread, threadChoice=True)
except KeyboardInterrupt: except KeyboardInterrupt:
# 捕获用户中断
warnMsg = "user aborted during table existence " warnMsg = "user aborted during table existence "
warnMsg += "check. sqlmap will display partial output" warnMsg += "check. sqlmap will display partial output"
logger.warning(warnMsg) logger.warning(warnMsg)
# 清除控制台行
clearConsoleLine(True) clearConsoleLine(True)
# 输出换行符
dataToStdout("\n") dataToStdout("\n")
# 如果没有找到表
if not threadData.shared.files: if not threadData.shared.files:
warnMsg = "no table(s) found" warnMsg = "no table(s) found"
# 如果指定了数据库
if conf.db: if conf.db:
warnMsg += " for database '%s'" % conf.db warnMsg += " for database '%s'" % conf.db
logger.warning(warnMsg) logger.warning(warnMsg)
else: else:
# 遍历找到的表
for item in threadData.shared.files: for item in threadData.shared.files:
# 如果数据库不在缓存表中
if conf.db not in kb.data.cachedTables: if conf.db not in kb.data.cachedTables:
# 将表添加到缓存表中
kb.data.cachedTables[conf.db] = [item] kb.data.cachedTables[conf.db] = [item]
else: else:
# 否则将表添加到缓存表的列表中
kb.data.cachedTables[conf.db].append(item) kb.data.cachedTables[conf.db].append(item)
# 遍历找到的表
for _ in ((conf.db, item) for item in threadData.shared.files): for _ in ((conf.db, item) for item in threadData.shared.files):
# 如果表不在暴力破解表中
if _ not in kb.brute.tables: if _ not in kb.brute.tables:
# 将表添加到暴力破解表中
kb.brute.tables.append(_) kb.brute.tables.append(_)
# 从配置中弹出数据库
conf.db = popValue() conf.db = popValue()
# 将暴力破解表写入哈希数据库
hashDBWrite(HASHDB_KEYS.KB_BRUTE_TABLES, kb.brute.tables, True) hashDBWrite(HASHDB_KEYS.KB_BRUTE_TABLES, kb.brute.tables, True)
# 返回缓存表
return kb.data.cachedTables return kb.data.cachedTables
def columnExists(columnFile, regex=None): def columnExists(columnFile, regex=None):
# 如果没有指定列存在性检查
if kb.choices.columnExists is None and not any(_ for _ in kb.injection.data if _ not in (PAYLOAD.TECHNIQUE.TIME, PAYLOAD.TECHNIQUE.STACKED)) and not conf.direct: if kb.choices.columnExists is None and not any(_ for _ in kb.injection.data if _ not in (PAYLOAD.TECHNIQUE.TIME, PAYLOAD.TECHNIQUE.STACKED)) and not conf.direct:
# 警告信息
warnMsg = "it's not recommended to use '%s' and/or '%s' " % (PAYLOAD.SQLINJECTION[PAYLOAD.TECHNIQUE.TIME], PAYLOAD.SQLINJECTION[PAYLOAD.TECHNIQUE.STACKED]) warnMsg = "it's not recommended to use '%s' and/or '%s' " % (PAYLOAD.SQLINJECTION[PAYLOAD.TECHNIQUE.TIME], PAYLOAD.SQLINJECTION[PAYLOAD.TECHNIQUE.STACKED])
warnMsg += "for common column existence check" warnMsg += "for common column existence check"
logger.warning(warnMsg) logger.warning(warnMsg)
# 提示用户是否继续
message = "are you sure you want to continue? [y/N] " message = "are you sure you want to continue? [y/N] "
kb.choices.columnExists = readInput(message, default='N', boolean=True) kb.choices.columnExists = readInput(message, default='N', boolean=True)
# 如果用户选择不继续
if not kb.choices.columnExists: if not kb.choices.columnExists:
return None return None
# 如果没有指定表
if not conf.tbl: if not conf.tbl:
# 抛出缺少表参数异常
errMsg = "missing table parameter" errMsg = "missing table parameter"
raise SqlmapMissingMandatoryOptionException(errMsg) raise SqlmapMissingMandatoryOptionException(errMsg)
# 如果指定了数据库并且数据库管理系统是大写的
if conf.db and Backend.getIdentifiedDbms() in UPPER_CASE_DBMSES: if conf.db and Backend.getIdentifiedDbms() in UPPER_CASE_DBMSES:
# 将数据库转换为大写
conf.db = conf.db.upper() conf.db = conf.db.upper()
# 注入检查布尔表达式
result = inject.checkBooleanExpression(safeStringFormat(BRUTE_COLUMN_EXISTS_TEMPLATE, (randomStr(), randomStr()))) result = inject.checkBooleanExpression(safeStringFormat(BRUTE_COLUMN_EXISTS_TEMPLATE, (randomStr(), randomStr())))
# 如果结果无效
if result: if result:
# 抛出数据异常
errMsg = "can't use column existence check because of detected invalid results " errMsg = "can't use column existence check because of detected invalid results "
errMsg += "(most likely caused by inability of the used injection " errMsg += "(most likely caused by inability of the used injection "
errMsg += "to distinguish erroneous results)" errMsg += "to distinguish erroneous results)"
raise SqlmapDataException(errMsg) raise SqlmapDataException(errMsg)
# 提示用户选择列存在性检查文件
message = "which common columns (wordlist) file do you want to use?\n" message = "which common columns (wordlist) file do you want to use?\n"
message += "[1] default '%s' (press Enter)\n" % columnFile message += "[1] default '%s' (press Enter)\n" % columnFile
message += "[2] custom" message += "[2] custom"
choice = readInput(message, default='1') choice = readInput(message, default='1')
# 如果用户选择自定义文件
if choice == '2': if choice == '2':
# 提示用户输入自定义文件位置
message = "what's the custom common columns file location?\n" message = "what's the custom common columns file location?\n"
columnFile = readInput(message) or columnFile columnFile = readInput(message) or columnFile
# 输出信息
infoMsg = "checking column existence using items from '%s'" % columnFile infoMsg = "checking column existence using items from '%s'" % columnFile
logger.info(infoMsg) logger.info(infoMsg)
# 获取文件项
columns = getFileItems(columnFile, unique=True) columns = getFileItems(columnFile, unique=True)
# 添加页面文本单词
columns.extend(_addPageTextWords()) columns.extend(_addPageTextWords())
# 过滤列表值
columns = filterListValue(columns, regex) columns = filterListValue(columns, regex)
# 获取表名
table = safeSQLIdentificatorNaming(conf.tbl, True) table = safeSQLIdentificatorNaming(conf.tbl, True)
# 如果指定了数据库并且数据库后缀不在配置中并且数据库管理系统不是SQLite、Access或Firebird
if conf.db and METADB_SUFFIX not in conf.db and Backend.getIdentifiedDbms() not in (DBMS.SQLITE, DBMS.ACCESS, DBMS.FIREBIRD): if conf.db and METADB_SUFFIX not in conf.db and Backend.getIdentifiedDbms() not in (DBMS.SQLITE, DBMS.ACCESS, DBMS.FIREBIRD):
# 将表名转换为数据库表名
table = "%s.%s" % (safeSQLIdentificatorNaming(conf.db), table) table = "%s.%s" % (safeSQLIdentificatorNaming(conf.db), table)
# 设置线程继续
kb.threadContinue = True kb.threadContinue = True
# 设置暴力破解模式
kb.bruteMode = True kb.bruteMode = True
# 获取当前线程数据
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
threadData.shared.count = 0 threadData.shared.count = 0
threadData.shared.limit = len(columns) threadData.shared.limit = len(columns)
threadData.shared.files = [] threadData.shared.files = []
def columnExistsThread(): def columnExistsThread():
# 获取当前线程的数据
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
# 当kb.threadContinue为True时循环执行
while kb.threadContinue: while kb.threadContinue:
# 获取count锁
kb.locks.count.acquire() kb.locks.count.acquire()
# 如果threadData.shared.count小于threadData.shared.limit
if threadData.shared.count < threadData.shared.limit: if threadData.shared.count < threadData.shared.limit:
# 获取列名
column = safeSQLIdentificatorNaming(columns[threadData.shared.count]) column = safeSQLIdentificatorNaming(columns[threadData.shared.count])
# 增加计数
threadData.shared.count += 1 threadData.shared.count += 1
# 释放count锁
kb.locks.count.release() kb.locks.count.release()
else: else:
# 释放count锁
kb.locks.count.release() kb.locks.count.release()
# 跳出循环
break break
# 如果数据库类型是MCKOI
if Backend.isDbms(DBMS.MCKOI): if Backend.isDbms(DBMS.MCKOI):
# 检查列是否存在
result = inject.checkBooleanExpression(safeStringFormat("0<(SELECT COUNT(%s) FROM %s)", (column, table))) result = inject.checkBooleanExpression(safeStringFormat("0<(SELECT COUNT(%s) FROM %s)", (column, table)))
else: else:
# 检查列是否存在
result = inject.checkBooleanExpression(safeStringFormat(BRUTE_COLUMN_EXISTS_TEMPLATE, (column, table))) result = inject.checkBooleanExpression(safeStringFormat(BRUTE_COLUMN_EXISTS_TEMPLATE, (column, table)))
# 获取io锁
kb.locks.io.acquire() kb.locks.io.acquire()
# 如果列存在
if result: if result:
# 将列名添加到threadData.shared.files中
threadData.shared.files.append(column) threadData.shared.files.append(column)
# 如果verbose为1或2且不使用api
if conf.verbose in (1, 2) and not conf.api: if conf.verbose in (1, 2) and not conf.api:
# 清除控制台行
clearConsoleLine(True) clearConsoleLine(True)
# 输出信息
infoMsg = "[%s] [INFO] retrieved: %s\n" % (time.strftime("%X"), unsafeSQLIdentificatorNaming(column)) infoMsg = "[%s] [INFO] retrieved: %s\n" % (time.strftime("%X"), unsafeSQLIdentificatorNaming(column))
dataToStdout(infoMsg, True) dataToStdout(infoMsg, True)
# 如果verbose为1或2
if conf.verbose in (1, 2): if conf.verbose in (1, 2):
# 计算状态
status = "%d/%d items (%d%%)" % (threadData.shared.count, threadData.shared.limit, round(100.0 * threadData.shared.count / threadData.shared.limit)) status = "%d/%d items (%d%%)" % (threadData.shared.count, threadData.shared.limit, round(100.0 * threadData.shared.count / threadData.shared.limit))
# 输出状态
dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True) dataToStdout("\r[%s] [INFO] tried %s" % (time.strftime("%X"), status), True)
# 释放io锁
kb.locks.io.release() kb.locks.io.release()
try: try:
# 运行线程
runThreads(conf.threads, columnExistsThread, threadChoice=True) runThreads(conf.threads, columnExistsThread, threadChoice=True)
except KeyboardInterrupt: except KeyboardInterrupt:
# 如果用户中断,输出警告信息
warnMsg = "user aborted during column existence " warnMsg = "user aborted during column existence "
warnMsg += "check. sqlmap will display partial output" warnMsg += "check. sqlmap will display partial output"
logger.warning(warnMsg) logger.warning(warnMsg)
finally: finally:
# 将bruteMode设置为False
kb.bruteMode = False kb.bruteMode = False
# 清除控制台行
clearConsoleLine(True) clearConsoleLine(True)
# 输出换行
dataToStdout("\n") dataToStdout("\n")
# 如果没有找到列
if not threadData.shared.files: if not threadData.shared.files:
# 输出警告信息
warnMsg = "no column(s) found" warnMsg = "no column(s) found"
logger.warning(warnMsg) logger.warning(warnMsg)
else: else:
# 初始化columns字典
columns = {} columns = {}
# 遍历threadData.shared.files中的列名
for column in threadData.shared.files: for column in threadData.shared.files:
# 如果数据库类型是MySQL
if Backend.getIdentifiedDbms() in (DBMS.MYSQL,): if Backend.getIdentifiedDbms() in (DBMS.MYSQL,):
# 检查列是否为数字
result = not inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE %s REGEXP '[^0-9]')", (column, table, column))) result = not inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE %s REGEXP '[^0-9]')", (column, table, column)))
# 如果数据库类型是SQLite
elif Backend.getIdentifiedDbms() in (DBMS.SQLITE,): elif Backend.getIdentifiedDbms() in (DBMS.SQLITE,):
# 检查列是否为数字
result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE %s NOT GLOB '*[^0-9]*')", (column, table, column))) result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE %s NOT GLOB '*[^0-9]*')", (column, table, column)))
# 如果数据库类型是MCKOI
elif Backend.getIdentifiedDbms() in (DBMS.MCKOI,): elif Backend.getIdentifiedDbms() in (DBMS.MCKOI,):
# 检查列是否为数字
result = inject.checkBooleanExpression("%s" % safeStringFormat("0=(SELECT MAX(%s)-MAX(%s) FROM %s)", (column, column, table))) result = inject.checkBooleanExpression("%s" % safeStringFormat("0=(SELECT MAX(%s)-MAX(%s) FROM %s)", (column, column, table)))
else: else:
# 检查列是否为数字
result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE ROUND(%s)=ROUND(%s))", (column, table, column, column))) result = inject.checkBooleanExpression("%s" % safeStringFormat("EXISTS(SELECT %s FROM %s WHERE ROUND(%s)=ROUND(%s))", (column, table, column, column)))
# 如果列是数字
if result: if result:
# 将列名和类型添加到columns字典中
columns[column] = "numeric" columns[column] = "numeric"
else: else:
# 将列名和类型添加到columns字典中
columns[column] = "non-numeric" columns[column] = "non-numeric"
# 将columns字典添加到kb.data.cachedColumns中
kb.data.cachedColumns[conf.db] = {conf.tbl: columns} kb.data.cachedColumns[conf.db] = {conf.tbl: columns}
# 遍历columns字典中的列名和类型
for _ in ((conf.db, conf.tbl, item[0], item[1]) for item in columns.items()): for _ in ((conf.db, conf.tbl, item[0], item[1]) for item in columns.items()):
# 如果列名和类型不在kb.brute.columns中
if _ not in kb.brute.columns: if _ not in kb.brute.columns:
# 将列名和类型添加到kb.brute.columns中
kb.brute.columns.append(_) kb.brute.columns.append(_)
# 将kb.brute.columns写入hashDB
hashDBWrite(HASHDB_KEYS.KB_BRUTE_COLUMNS, kb.brute.columns, True) hashDBWrite(HASHDB_KEYS.KB_BRUTE_COLUMNS, kb.brute.columns, True)
# 返回kb.data.cachedColumns
return kb.data.cachedColumns return kb.data.cachedColumns
@stackedmethod @stackedmethod
def fileExists(pathFile): def fileExists(pathFile):
# 定义一个空列表,用于存储文件路径
retVal = [] retVal = []
# 提示用户选择要使用的公共文件
message = "which common files file do you want to use?\n" message = "which common files file do you want to use?\n"
message += "[1] default '%s' (press Enter)\n" % pathFile message += "[1] default '%s' (press Enter)\n" % pathFile
message += "[2] custom" message += "[2] custom"
# 读取用户输入默认为1
choice = readInput(message, default='1') choice = readInput(message, default='1')
# 如果用户选择自定义文件
if choice == '2': if choice == '2':
# 提示用户输入自定义文件路径
message = "what's the custom common files file location?\n" message = "what's the custom common files file location?\n"
pathFile = readInput(message) or pathFile pathFile = readInput(message) or pathFile
# 打印检查文件存在的信息
infoMsg = "checking files existence using items from '%s'" % pathFile infoMsg = "checking files existence using items from '%s'" % pathFile
logger.info(infoMsg) logger.info(infoMsg)
# 获取文件路径列表
paths = getFileItems(pathFile, unique=True) paths = getFileItems(pathFile, unique=True)
# 设置暴力模式为True
kb.bruteMode = True kb.bruteMode = True
try: try:
# 读取随机字符串
conf.dbmsHandler.readFile(randomStr()) conf.dbmsHandler.readFile(randomStr())
except SqlmapNoneDataException: except SqlmapNoneDataException:
pass pass
except: except:
# 如果发生异常将暴力模式设置为False
kb.bruteMode = False kb.bruteMode = False
raise raise
# 获取当前线程数据
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
# 设置计数器为0
threadData.shared.count = 0 threadData.shared.count = 0
# 设置限制为路径列表的长度
threadData.shared.limit = len(paths) threadData.shared.limit = len(paths)
# 创建一个空列表,用于存储文件
threadData.shared.files = [] threadData.shared.files = []
# 定义一个线程函数,用于检查文件是否存在
def fileExistsThread(): def fileExistsThread():
# 获取当前线程数据
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
# 当线程继续时
while kb.threadContinue: while kb.threadContinue:
# 获取计数器的锁
kb.locks.count.acquire() kb.locks.count.acquire()
# 如果计数器小于限制
if threadData.shared.count < threadData.shared.limit: if threadData.shared.count < threadData.shared.limit:
# 获取路径
path = ntToPosixSlashes(paths[threadData.shared.count]) path = ntToPosixSlashes(paths[threadData.shared.count])
# 计数器加1
threadData.shared.count += 1 threadData.shared.count += 1
# 释放计数器的锁
kb.locks.count.release() kb.locks.count.release()
else: else:
# 释放计数器的锁
kb.locks.count.release() kb.locks.count.release()
# 跳出循环
break break
try: try:
# 读取路径
result = unArrayizeValue(conf.dbmsHandler.readFile(path)) result = unArrayizeValue(conf.dbmsHandler.readFile(path))
except SqlmapNoneDataException: except SqlmapNoneDataException:
# 如果没有数据将结果设置为None
result = None result = None
# 获取IO的锁
kb.locks.io.acquire() kb.locks.io.acquire()
# 如果结果不是None
if not isNoneValue(result): if not isNoneValue(result):
# 将结果添加到文件列表中
threadData.shared.files.append(result) threadData.shared.files.append(result)
# 如果不是API模式
if not conf.api: if not conf.api:
clearConsoleLine(True) clearConsoleLine(True)
infoMsg = "[%s] [INFO] retrieved: '%s'\n" % (time.strftime("%X"), path) infoMsg = "[%s] [INFO] retrieved: '%s'\n" % (time.strftime("%X"), path)

@ -42,10 +42,12 @@ from thirdparty.six.moves import http_client as _http_client
from thirdparty.six.moves import urllib as _urllib from thirdparty.six.moves import urllib as _urllib
def crawl(target, post=None, cookie=None): def crawl(target, post=None, cookie=None):
# 如果目标为空,直接返回
if not target: if not target:
return return
try: try:
# 创建一个已访问集合
visited = set() visited = set()
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
threadData.shared.value = OrderedSet() threadData.shared.value = OrderedSet()
@ -54,12 +56,16 @@ def crawl(target, post=None, cookie=None):
def crawlThread(): def crawlThread():
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
# 当线程继续时
while kb.threadContinue: while kb.threadContinue:
with kb.locks.limit: with kb.locks.limit:
# 如果还有未处理的链接
if threadData.shared.unprocessed: if threadData.shared.unprocessed:
current = threadData.shared.unprocessed.pop() current = threadData.shared.unprocessed.pop()
# 如果已经访问过,跳过
if current in visited: if current in visited:
continue continue
# 如果有排除规则且当前链接符合排除规则,跳过
elif conf.crawlExclude and re.search(conf.crawlExclude, current): elif conf.crawlExclude and re.search(conf.crawlExclude, current):
dbgMsg = "skipping '%s'" % current dbgMsg = "skipping '%s'" % current
logger.debug(dbgMsg) logger.debug(dbgMsg)
@ -71,8 +77,10 @@ def crawl(target, post=None, cookie=None):
content = None content = None
try: try:
# 发送请求获取页面内容
if current: if current:
content = Request.getPage(url=current, post=post, cookie=None, crawling=True, raise404=False)[0] content = Request.getPage(url=current, post=post, cookie=None, crawling=True, raise404=False)[0]
# 处理不同的异常
except SqlmapConnectionException as ex: except SqlmapConnectionException as ex:
errMsg = "connection exception detected ('%s'). skipping " % getSafeExString(ex) errMsg = "connection exception detected ('%s'). skipping " % getSafeExString(ex)
errMsg += "URL '%s'" % current errMsg += "URL '%s'" % current
@ -88,8 +96,10 @@ def crawl(target, post=None, cookie=None):
if not kb.threadContinue: if not kb.threadContinue:
break break
# 如果内容是文本类型
if isinstance(content, six.text_type): if isinstance(content, six.text_type):
try: try:
# 提取 HTML 内容
match = re.search(r"(?si)<html[^>]*>(.+)</html>", content) match = re.search(r"(?si)<html[^>]*>(.+)</html>", content)
if match: if match:
content = "<html>%s</html>" % match.group(1) content = "<html>%s</html>" % match.group(1)
@ -97,6 +107,7 @@ def crawl(target, post=None, cookie=None):
soup = BeautifulSoup(content) soup = BeautifulSoup(content)
tags = soup('a') tags = soup('a')
# 查找其他可能的链接
tags += re.finditer(r'(?i)\s(href|src)=["\'](?P<href>[^>"\']+)', content) tags += re.finditer(r'(?i)\s(href|src)=["\'](?P<href>[^>"\']+)', content)
tags += re.finditer(r'(?i)window\.open\(["\'](?P<href>[^)"\']+)["\']', content) tags += re.finditer(r'(?i)window\.open\(["\'](?P<href>[^)"\']+)["\']', content)
@ -108,41 +119,47 @@ def crawl(target, post=None, cookie=None):
current = threadData.lastRedirectURL[1] current = threadData.lastRedirectURL[1]
url = _urllib.parse.urljoin(current, htmlUnescape(href)) url = _urllib.parse.urljoin(current, htmlUnescape(href))
# flag to know if we are dealing with the same target host # 检查是否是同一主机
_ = checkSameHost(url, target) _ = checkSameHost(url, target)
# 检查是否在范围中
if conf.scope: if conf.scope:
if not re.search(conf.scope, url, re.I): if not re.search(conf.scope, url, re.I):
continue continue
elif not _: elif not _:
continue continue
# 检查扩展是否在排除列表中
if (extractRegexResult(r"\A[^?]+\.(?P<result>\w+)(\?|\Z)", url) or "").lower() not in CRAWL_EXCLUDE_EXTENSIONS: if (extractRegexResult(r"\A[^?]+\.(?P<result>\w+)(\?|\Z)", url) or "").lower() not in CRAWL_EXCLUDE_EXTENSIONS:
with kb.locks.value: with kb.locks.value:
threadData.shared.deeper.add(url) threadData.shared.deeper.add(url)
# 筛选链接添加到不同集合
if re.search(r"(.*?)\?(.+)", url) and not re.search(r"\?(v=)?\d+\Z", url) and not re.search(r"(?i)\.(js|css)(\?|\Z)", url): if re.search(r"(.*?)\?(.+)", url) and not re.search(r"\?(v=)?\d+\Z", url) and not re.search(r"(?i)\.(js|css)(\?|\Z)", url):
threadData.shared.value.add(url) threadData.shared.value.add(url)
except UnicodeEncodeError: # for non-HTML files except UnicodeEncodeError: # 处理非 HTML 文件异常
pass pass
except ValueError: # for non-valid links except ValueError: # 处理无效链接异常
pass pass
except AssertionError: # for invalid HTML except AssertionError: # 处理无效 HTML 异常
pass pass
finally: finally:
# 检查是否找到表单
if conf.forms: if conf.forms:
threadData.shared.formsFound |= len(findPageForms(content, current, False, True)) > 0 threadData.shared.formsFound |= len(findPageForms(content, current, False, True)) > 0
if conf.verbose in (1, 2): if conf.verbose in (1, 2):
threadData.shared.count += 1 threadData.shared.count += 1
# 输出状态信息
status = '%d/%d links visited (%d%%)' % (threadData.shared.count, threadData.shared.length, round(100.0 * threadData.shared.count / threadData.shared.length)) status = '%d/%d links visited (%d%%)' % (threadData.shared.count, threadData.shared.length, round(100.0 * threadData.shared.count / threadData.shared.length))
dataToStdout("\r[%s] [INFO] %s" % (time.strftime("%X"), status), True) dataToStdout("\r[%s] [INFO] %s" % (time.strftime("%X"), status), True)
threadData.shared.deeper = set() threadData.shared.deeper = set()
threadData.shared.unprocessed = set([target]) threadData.shared.unprocessed = set([target])
# 处理目标 URL
_ = re.sub(r"(?<!/)/(?!/).*", "", target) _ = re.sub(r"(?<!/)/(?!/).*", "", target)
if _: if _:
if target.strip('/') != _.strip('/'): if target.strip('/')!= _.strip('/'):
threadData.shared.unprocessed.add(_) threadData.shared.unprocessed.add(_)
if re.search(r"\?.*\b\w+=", target): if re.search(r"\?.*\b\w+=", target):
@ -158,6 +175,7 @@ def crawl(target, post=None, cookie=None):
items = None items = None
url = _urllib.parse.urljoin(target, "/sitemap.xml") url = _urllib.parse.urljoin(target, "/sitemap.xml")
try: try:
# 解析站点地图
items = parseSitemap(url) items = parseSitemap(url)
except SqlmapConnectionException as ex: except SqlmapConnectionException as ex:
if "page not found" in getSafeExString(ex): if "page not found" in getSafeExString(ex):
@ -179,6 +197,7 @@ def crawl(target, post=None, cookie=None):
infoMsg = "starting crawler for target URL '%s'" % target infoMsg = "starting crawler for target URL '%s'" % target
logger.info(infoMsg) logger.info(infoMsg)
# 启动多个线程进行爬取
for i in xrange(conf.crawlDepth): for i in xrange(conf.crawlDepth):
threadData.shared.count = 0 threadData.shared.count = 0
threadData.shared.length = len(threadData.shared.unprocessed) threadData.shared.length = len(threadData.shared.unprocessed)
@ -201,25 +220,32 @@ def crawl(target, post=None, cookie=None):
logger.warning(warnMsg) logger.warning(warnMsg)
finally: finally:
# 清除控制台行
clearConsoleLine(True) clearConsoleLine(True)
# 如果没有找到可用链接
if not threadData.shared.value: if not threadData.shared.value:
# 如果没有找到表单
if not (conf.forms and threadData.shared.formsFound): if not (conf.forms and threadData.shared.formsFound):
# 输出警告信息
warnMsg = "no usable links found (with GET parameters)" warnMsg = "no usable links found (with GET parameters)"
if conf.forms: if conf.forms:
warnMsg += " or forms" warnMsg += " or forms"
logger.warning(warnMsg) logger.warning(warnMsg)
else: else:
# 遍历找到的链接添加到 kb.targets 中
for url in threadData.shared.value: for url in threadData.shared.value:
kb.targets.add((urldecode(url, kb.pageEncoding), None, None, None, None)) kb.targets.add((urldecode(url, kb.pageEncoding), None, None, None, None))
# 如果 kb.targets 中有链接
if kb.targets: if kb.targets:
# 如果未选择规范化选项
if kb.normalizeCrawlingChoice is None: if kb.normalizeCrawlingChoice is None:
message = "do you want to normalize " message = "do you want to normalize "
message += "crawling results [Y/n] " message += "crawling results [Y/n] "
kb.normalizeCrawlingChoice = readInput(message, default='Y', boolean=True) kb.normalizeCrawlingChoice = readInput(message, default='Y', boolean=True)
# 如果用户选择规范化
if kb.normalizeCrawlingChoice: if kb.normalizeCrawlingChoice:
seen = set() seen = set()
results = OrderedSet() results = OrderedSet()
@ -235,29 +261,40 @@ def crawl(target, post=None, cookie=None):
kb.targets = results kb.targets = results
# 存储结果到文件
storeResultsToFile(kb.targets) storeResultsToFile(kb.targets)
def storeResultsToFile(results): def storeResultsToFile(results):
# 如果结果为空,则返回
if not results: if not results:
return return
# 如果kb.storeCrawlingChoice为空则提示用户是否将爬取结果存储到临时文件中
if kb.storeCrawlingChoice is None: if kb.storeCrawlingChoice is None:
message = "do you want to store crawling results to a temporary file " message = "do you want to store crawling results to a temporary file "
message += "for eventual further processing with other tools [y/N] " message += "for eventual further processing with other tools [y/N] "
# 读取用户输入默认为N返回布尔值
kb.storeCrawlingChoice = readInput(message, default='N', boolean=True) kb.storeCrawlingChoice = readInput(message, default='N', boolean=True)
# 如果用户选择存储,则创建临时文件
if kb.storeCrawlingChoice: if kb.storeCrawlingChoice:
# 创建临时文件,返回文件句柄和文件名
handle, filename = tempfile.mkstemp(prefix=MKSTEMP_PREFIX.CRAWLER, suffix=".csv" if conf.forms else ".txt") handle, filename = tempfile.mkstemp(prefix=MKSTEMP_PREFIX.CRAWLER, suffix=".csv" if conf.forms else ".txt")
# 关闭文件句柄
os.close(handle) os.close(handle)
# 记录日志,表示将爬取结果写入临时文件
infoMsg = "writing crawling results to a temporary file '%s' " % filename infoMsg = "writing crawling results to a temporary file '%s' " % filename
logger.info(infoMsg) logger.info(infoMsg)
# 打开文件,以二进制写模式
with openFile(filename, "w+b") as f: with openFile(filename, "w+b") as f:
# 如果配置了表单,则写入表单标题
if conf.forms: if conf.forms:
f.write("URL,POST\n") f.write("URL,POST\n")
# 遍历结果将URL和POST写入文件
for url, _, data, _, _ in results: for url, _, data, _, _ in results:
if conf.forms: if conf.forms:
f.write("%s,%s\n" % (safeCSValue(url), safeCSValue(data or ""))) f.write("%s,%s\n" % (safeCSValue(url), safeCSValue(data or "")))

@ -11,17 +11,22 @@ from lib.core.enums import DBMS
from lib.core.settings import IS_WIN from lib.core.settings import IS_WIN
def checkDependencies(): def checkDependencies():
# 定义一个集合,用于存储缺失的库
missing_libraries = set() missing_libraries = set()
# 遍历DBMS_DICT字典获取数据库类型和对应的库信息
for dbmsName, data in DBMS_DICT.items(): for dbmsName, data in DBMS_DICT.items():
# 如果库信息为空,则跳过
if data[1] is None: if data[1] is None:
continue continue
try: try:
# 根据数据库类型导入对应的库
if dbmsName in (DBMS.MSSQL, DBMS.SYBASE): if dbmsName in (DBMS.MSSQL, DBMS.SYBASE):
__import__("_mssql") __import__("_mssql")
pymssql = __import__("pymssql") pymssql = __import__("pymssql")
# 如果库版本低于1.0.2,则发出警告
if not hasattr(pymssql, "__version__") or pymssql.__version__ < "1.0.2": if not hasattr(pymssql, "__version__") or pymssql.__version__ < "1.0.2":
warnMsg = "'%s' third-party library must be " % data[1] warnMsg = "'%s' third-party library must be " % data[1]
warnMsg += "version >= 1.0.2 to work properly. " warnMsg += "version >= 1.0.2 to work properly. "
@ -61,6 +66,7 @@ def checkDependencies():
elif dbmsName == DBMS.CLICKHOUSE: elif dbmsName == DBMS.CLICKHOUSE:
__import__("clickhouse_connect") __import__("clickhouse_connect")
except: except:
# 如果导入库失败,则发出警告,并将库添加到缺失库集合中
warnMsg = "sqlmap requires '%s' third-party library " % data[1] warnMsg = "sqlmap requires '%s' third-party library " % data[1]
warnMsg += "in order to directly connect to the DBMS " warnMsg += "in order to directly connect to the DBMS "
warnMsg += "'%s'. Download from '%s'" % (dbmsName, data[2]) warnMsg += "'%s'. Download from '%s'" % (dbmsName, data[2])
@ -69,14 +75,17 @@ def checkDependencies():
continue continue
# 如果导入库成功,则发出调试信息
debugMsg = "'%s' third-party library is found" % data[1] debugMsg = "'%s' third-party library is found" % data[1]
logger.debug(debugMsg) logger.debug(debugMsg)
try: try:
# 导入impacket库
__import__("impacket") __import__("impacket")
debugMsg = "'python-impacket' third-party library is found" debugMsg = "'python-impacket' third-party library is found"
logger.debug(debugMsg) logger.debug(debugMsg)
except ImportError: except ImportError:
# 如果导入失败,则发出警告,并将库添加到缺失库集合中
warnMsg = "sqlmap requires 'python-impacket' third-party library for " warnMsg = "sqlmap requires 'python-impacket' third-party library for "
warnMsg += "out-of-band takeover feature. Download from " warnMsg += "out-of-band takeover feature. Download from "
warnMsg += "'https://github.com/coresecurity/impacket'" warnMsg += "'https://github.com/coresecurity/impacket'"
@ -84,10 +93,12 @@ def checkDependencies():
missing_libraries.add('python-impacket') missing_libraries.add('python-impacket')
try: try:
# 导入ntlm库
__import__("ntlm") __import__("ntlm")
debugMsg = "'python-ntlm' third-party library is found" debugMsg = "'python-ntlm' third-party library is found"
logger.debug(debugMsg) logger.debug(debugMsg)
except ImportError: except ImportError:
# 如果导入失败,则发出警告,并将库添加到缺失库集合中
warnMsg = "sqlmap requires 'python-ntlm' third-party library " warnMsg = "sqlmap requires 'python-ntlm' third-party library "
warnMsg += "if you plan to attack a web application behind NTLM " warnMsg += "if you plan to attack a web application behind NTLM "
warnMsg += "authentication. Download from 'https://github.com/mullender/python-ntlm'" warnMsg += "authentication. Download from 'https://github.com/mullender/python-ntlm'"
@ -95,10 +106,12 @@ def checkDependencies():
missing_libraries.add('python-ntlm') missing_libraries.add('python-ntlm')
try: try:
# 导入websocket._abnf库
__import__("websocket._abnf") __import__("websocket._abnf")
debugMsg = "'websocket-client' library is found" debugMsg = "'websocket-client' library is found"
logger.debug(debugMsg) logger.debug(debugMsg)
except ImportError: except ImportError:
# 如果导入失败,则发出警告,并将库添加到缺失库集合中
warnMsg = "sqlmap requires 'websocket-client' third-party library " warnMsg = "sqlmap requires 'websocket-client' third-party library "
warnMsg += "if you plan to attack a web application using WebSocket. " warnMsg += "if you plan to attack a web application using WebSocket. "
warnMsg += "Download from 'https://pypi.python.org/pypi/websocket-client/'" warnMsg += "Download from 'https://pypi.python.org/pypi/websocket-client/'"
@ -106,31 +119,37 @@ def checkDependencies():
missing_libraries.add('websocket-client') missing_libraries.add('websocket-client')
try: try:
# 导入tkinter库
__import__("tkinter") __import__("tkinter")
debugMsg = "'tkinter' library is found" debugMsg = "'tkinter' library is found"
logger.debug(debugMsg) logger.debug(debugMsg)
except ImportError: except ImportError:
# 如果导入失败,则发出警告,并将库添加到缺失库集合中
warnMsg = "sqlmap requires 'tkinter' library " warnMsg = "sqlmap requires 'tkinter' library "
warnMsg += "if you plan to run a GUI" warnMsg += "if you plan to run a GUI"
logger.warning(warnMsg) logger.warning(warnMsg)
missing_libraries.add('tkinter') missing_libraries.add('tkinter')
try: try:
# 导入tkinter.ttk库
__import__("tkinter.ttk") __import__("tkinter.ttk")
debugMsg = "'tkinter.ttk' library is found" debugMsg = "'tkinter.ttk' library is found"
logger.debug(debugMsg) logger.debug(debugMsg)
except ImportError: except ImportError:
# 如果导入失败,则发出警告,并将库添加到缺失库集合中
warnMsg = "sqlmap requires 'tkinter.ttk' library " warnMsg = "sqlmap requires 'tkinter.ttk' library "
warnMsg += "if you plan to run a GUI" warnMsg += "if you plan to run a GUI"
logger.warning(warnMsg) logger.warning(warnMsg)
missing_libraries.add('tkinter.ttk') missing_libraries.add('tkinter.ttk')
# 如果是Windows系统则导入pyreadline库
if IS_WIN: if IS_WIN:
try: try:
__import__("pyreadline") __import__("pyreadline")
debugMsg = "'python-pyreadline' third-party library is found" debugMsg = "'python-pyreadline' third-party library is found"
logger.debug(debugMsg) logger.debug(debugMsg)
except ImportError: except ImportError:
# 如果导入失败,则发出警告,并将库添加到缺失库集合中
warnMsg = "sqlmap requires 'pyreadline' third-party library to " warnMsg = "sqlmap requires 'pyreadline' third-party library to "
warnMsg += "be able to take advantage of the sqlmap TAB " warnMsg += "be able to take advantage of the sqlmap TAB "
warnMsg += "completion and history support features in the SQL " warnMsg += "completion and history support features in the SQL "
@ -139,6 +158,7 @@ def checkDependencies():
logger.warning(warnMsg) logger.warning(warnMsg)
missing_libraries.add('python-pyreadline') missing_libraries.add('python-pyreadline')
# 如果缺失库集合为空,则发出信息,表示所有依赖都已安装
if len(missing_libraries) == 0: if len(missing_libraries) == 0:
infoMsg = "all dependencies are installed" infoMsg = "all dependencies are installed"
logger.info(infoMsg) logger.info(infoMsg)

@ -11,41 +11,62 @@ class _Getch(object):
the screen (reference: http://code.activestate.com/recipes/134892/) the screen (reference: http://code.activestate.com/recipes/134892/)
""" """
def __init__(self): def __init__(self):
# 尝试使用Windows系统的方法获取字符
try: try:
self.impl = _GetchWindows() self.impl = _GetchWindows()
# 如果Windows系统的方法不可用则尝试使用Mac系统的方法获取字符
except ImportError: except ImportError:
try: try:
self.impl = _GetchMacCarbon() self.impl = _GetchMacCarbon()
# 如果Mac系统的方法不可用则使用Unix系统的方法获取字符
except(AttributeError, ImportError): except(AttributeError, ImportError):
self.impl = _GetchUnix() self.impl = _GetchUnix()
def __call__(self): def __call__(self):
# 调用获取字符的方法
return self.impl() return self.impl()
class _GetchUnix(object): class _GetchUnix(object):
"""
Unix implementation of _Getch
"""
def __init__(self): def __init__(self):
# 导入tty模块
__import__("tty") __import__("tty")
def __call__(self): def __call__(self):
# 导入sys、termios、tty模块
import sys import sys
import termios import termios
import tty import tty
# 获取标准输入的文件描述符
fd = sys.stdin.fileno() fd = sys.stdin.fileno()
# 获取当前终端的属性
old_settings = termios.tcgetattr(fd) old_settings = termios.tcgetattr(fd)
try: try:
# 设置终端为原始模式
tty.setraw(sys.stdin.fileno()) tty.setraw(sys.stdin.fileno())
# 读取一个字符
ch = sys.stdin.read(1) ch = sys.stdin.read(1)
finally: finally:
# 恢复终端的属性
termios.tcsetattr(fd, termios.TCSADRAIN, old_settings) termios.tcsetattr(fd, termios.TCSADRAIN, old_settings)
# 返回读取的字符
return ch return ch
class _GetchWindows(object): class _GetchWindows(object):
"""
Windows implementation of _Getch
"""
def __init__(self): def __init__(self):
# 导入msvcrt模块
__import__("msvcrt") __import__("msvcrt")
def __call__(self): def __call__(self):
# 导入msvcrt模块
import msvcrt import msvcrt
# 调用msvcrt模块的getch函数获取键盘输入
return msvcrt.getch() return msvcrt.getch()
class _GetchMacCarbon(object): class _GetchMacCarbon(object):
@ -56,13 +77,17 @@ class _GetchMacCarbon(object):
very helpful in figuring out how to do this. very helpful in figuring out how to do this.
""" """
def __init__(self): def __init__(self):
# 导入Carbon模块
import Carbon import Carbon
# 检查Carbon模块中是否有Evt属性
getattr(Carbon, "Evt") # see if it has this (in Unix, it doesn't) getattr(Carbon, "Evt") # see if it has this (in Unix, it doesn't)
def __call__(self): def __call__(self):
# 导入Carbon模块
import Carbon import Carbon
# 检查是否有按键按下
if Carbon.Evt.EventAvail(0x0008)[0] == 0: # 0x0008 is the keyDownMask if Carbon.Evt.EventAvail(0x0008)[0] == 0: # 0x0008 is the keyDownMask
return '' return ''
else: else:
@ -71,7 +96,7 @@ class _GetchMacCarbon(object):
# (what,msg,when,where,mod)=Carbon.Evt.GetNextEvent(0x0008)[1] # (what,msg,when,where,mod)=Carbon.Evt.GetNextEvent(0x0008)[1]
# #
# The message (msg) contains the ASCII char which is # The message (msg) contains the ASCII char which is
# extracted with the 0x000000FF charCodeMask; this # extracted with the 0x000000FF charCodeMask; this number is
# number is converted to an ASCII character with chr() and # number is converted to an ASCII character with chr() and
# returned # returned
# #

@ -5,43 +5,52 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import base64 # 导入所需的Python标准库
import datetime import base64 # 用于Base64编码解码
import io import datetime # 处理日期和时间
import re import io # 处理流式IO操作
import time import re # 正则表达式支持
import time # 时间相关功能
from lib.core.bigarray import BigArray
from lib.core.convert import getBytes # 导入自定义和第三方库
from lib.core.convert import getText from lib.core.bigarray import BigArray # 用于处理大型数组
from lib.core.settings import VERSION from lib.core.convert import getBytes # 字符串转字节函数
from thirdparty.six.moves import BaseHTTPServer as _BaseHTTPServer from lib.core.convert import getText # 字节转字符串函数
from thirdparty.six.moves import http_client as _http_client from lib.core.settings import VERSION # 获取版本信息
from thirdparty.six.moves import BaseHTTPServer as _BaseHTTPServer # HTTP服务器基类
from thirdparty.six.moves import http_client as _http_client # HTTP客户端
# HAR(HTTP Archive)格式参考文档
# Reference: https://dvcs.w3.org/hg/webperf/raw-file/tip/specs/HAR/Overview.html # Reference: https://dvcs.w3.org/hg/webperf/raw-file/tip/specs/HAR/Overview.html
# http://www.softwareishard.com/har/viewer/ # http://www.softwareishard.com/har/viewer/
class HTTPCollectorFactory(object): class HTTPCollectorFactory(object):
"""HTTP收集器工厂类,用于创建HTTP收集器实例"""
def __init__(self, harFile=False): def __init__(self, harFile=False):
self.harFile = harFile self.harFile = harFile
def create(self): def create(self):
"""创建并返回一个新的HTTP收集器实例"""
return HTTPCollector() return HTTPCollector()
class HTTPCollector(object): class HTTPCollector(object):
"""HTTP收集器类,用于收集HTTP请求和响应信息"""
def __init__(self): def __init__(self):
self.messages = BigArray() self.messages = BigArray() # 存储请求-响应对
self.extendedArguments = {} self.extendedArguments = {} # 存储扩展参数
def setExtendedArguments(self, arguments): def setExtendedArguments(self, arguments):
"""设置扩展参数"""
self.extendedArguments = arguments self.extendedArguments = arguments
def collectRequest(self, requestMessage, responseMessage, startTime=None, endTime=None): def collectRequest(self, requestMessage, responseMessage, startTime=None, endTime=None):
"""收集一对请求-响应消息"""
self.messages.append(RawPair(requestMessage, responseMessage, self.messages.append(RawPair(requestMessage, responseMessage,
startTime=startTime, endTime=endTime, startTime=startTime, endTime=endTime,
extendedArguments=self.extendedArguments)) extendedArguments=self.extendedArguments))
def obtain(self): def obtain(self):
"""获取HAR格式的日志数据"""
return {"log": { return {"log": {
"version": "1.2", "version": "1.2",
"creator": {"name": "sqlmap", "version": VERSION}, "creator": {"name": "sqlmap", "version": VERSION},
@ -49,54 +58,60 @@ class HTTPCollector(object):
}} }}
class RawPair(object): class RawPair(object):
"""原始请求-响应对类"""
def __init__(self, request, response, startTime=None, endTime=None, extendedArguments=None): def __init__(self, request, response, startTime=None, endTime=None, extendedArguments=None):
self.request = getBytes(request) self.request = getBytes(request) # 请求数据
self.response = getBytes(response) self.response = getBytes(response) # 响应数据
self.startTime = startTime self.startTime = startTime # 开始时间
self.endTime = endTime self.endTime = endTime # 结束时间
self.extendedArguments = extendedArguments or {} self.extendedArguments = extendedArguments or {} # 扩展参数
def toEntry(self): def toEntry(self):
"""转换为Entry对象"""
return Entry(request=Request.parse(self.request), response=Response.parse(self.response), return Entry(request=Request.parse(self.request), response=Response.parse(self.response),
startTime=self.startTime, endTime=self.endTime, startTime=self.startTime, endTime=self.endTime,
extendedArguments=self.extendedArguments) extendedArguments=self.extendedArguments)
class Entry(object): class Entry(object):
"""HAR条目类,表示一个完整的请求-响应交互"""
def __init__(self, request, response, startTime, endTime, extendedArguments): def __init__(self, request, response, startTime, endTime, extendedArguments):
self.request = request self.request = request # 请求对象
self.response = response self.response = response # 响应对象
self.startTime = startTime or 0 self.startTime = startTime or 0 # 开始时间
self.endTime = endTime or 0 self.endTime = endTime or 0 # 结束时间
self.extendedArguments = extendedArguments self.extendedArguments = extendedArguments # 扩展参数
def toDict(self): def toDict(self):
"""转换为字典格式"""
out = { out = {
"request": self.request.toDict(), "request": self.request.toDict(),
"response": self.response.toDict(), "response": self.response.toDict(),
"cache": {}, "cache": {}, # 缓存信息
"timings": { "timings": { # 时间统计
"send": -1, "send": -1,
"wait": -1, "wait": -1,
"receive": -1, "receive": -1,
}, },
"time": int(1000 * (self.endTime - self.startTime)), "time": int(1000 * (self.endTime - self.startTime)), # 总耗时(毫秒)
"startedDateTime": "%s%s" % (datetime.datetime.fromtimestamp(self.startTime).isoformat(), time.strftime("%z")) if self.startTime else None "startedDateTime": "%s%s" % (datetime.datetime.fromtimestamp(self.startTime).isoformat(), time.strftime("%z")) if self.startTime else None
} }
out.update(self.extendedArguments) out.update(self.extendedArguments)
return out return out
class Request(object): class Request(object):
"""HTTP请求类"""
def __init__(self, method, path, httpVersion, headers, postBody=None, raw=None, comment=None): def __init__(self, method, path, httpVersion, headers, postBody=None, raw=None, comment=None):
self.method = method self.method = method # 请求方法(GET/POST等)
self.path = path self.path = path # 请求路径
self.httpVersion = httpVersion self.httpVersion = httpVersion # HTTP版本
self.headers = headers or {} self.headers = headers or {} # 请求头
self.postBody = postBody self.postBody = postBody # POST请求体
self.comment = comment.strip() if comment else comment self.comment = comment.strip() if comment else comment # 注释
self.raw = raw self.raw = raw # 原始请求数据
@classmethod @classmethod
def parse(cls, raw): def parse(cls, raw):
"""解析原始请求数据"""
request = HTTPRequest(raw) request = HTTPRequest(raw)
return cls(method=request.command, return cls(method=request.command,
path=request.path, path=request.path,
@ -108,10 +123,12 @@ class Request(object):
@property @property
def url(self): def url(self):
"""构建完整URL"""
host = self.headers.get("Host", "unknown") host = self.headers.get("Host", "unknown")
return "http://%s%s" % (host, self.path) return "http://%s%s" % (host, self.path)
def toDict(self): def toDict(self):
"""转换为字典格式"""
out = { out = {
"httpVersion": self.httpVersion, "httpVersion": self.httpVersion,
"method": self.method, "method": self.method,
@ -134,22 +151,25 @@ class Request(object):
return out return out
class Response(object): class Response(object):
extract_status = re.compile(b'\\((\\d{3}) (.*)\\)') """HTTP响应类"""
extract_status = re.compile(b'\\((\\d{3}) (.*)\\)') # 用于提取状态码的正则表达式
def __init__(self, httpVersion, status, statusText, headers, content, raw=None, comment=None): def __init__(self, httpVersion, status, statusText, headers, content, raw=None, comment=None):
self.raw = raw self.raw = raw # 原始响应数据
self.httpVersion = httpVersion self.httpVersion = httpVersion # HTTP版本
self.status = status self.status = status # 状态码
self.statusText = statusText self.statusText = statusText # 状态描述
self.headers = headers self.headers = headers # 响应头
self.content = content self.content = content # 响应内容
self.comment = comment.strip() if comment else comment self.comment = comment.strip() if comment else comment # 注释
@classmethod @classmethod
def parse(cls, raw): def parse(cls, raw):
"""解析原始响应数据"""
altered = raw altered = raw
comment = b"" comment = b""
# 处理特殊格式的响应
if altered.startswith(b"HTTP response [") or altered.startswith(b"HTTP redirect ["): if altered.startswith(b"HTTP response [") or altered.startswith(b"HTTP redirect ["):
stream = io.BytesIO(raw) stream = io.BytesIO(raw)
first_line = stream.readline() first_line = stream.readline()
@ -176,12 +196,14 @@ class Response(object):
raw=raw) raw=raw)
def toDict(self): def toDict(self):
"""转换为字典格式"""
content = { content = {
"mimeType": self.headers.get("Content-Type"), "mimeType": self.headers.get("Content-Type"),
"text": self.content, "text": self.content,
"size": len(self.content or "") "size": len(self.content or "")
} }
# 检测是否为二进制内容
binary = set([b'\0', b'\1']) binary = set([b'\0', b'\1'])
if any(c in binary for c in self.content): if any(c in binary for c in self.content):
content["encoding"] = "base64" content["encoding"] = "base64"
@ -203,9 +225,9 @@ class Response(object):
} }
class FakeSocket(object): class FakeSocket(object):
# Original source: """模拟Socket类,用于HTTP响应解析
# https://stackoverflow.com/questions/24728088/python-parse-http-response-string 原始来源: https://stackoverflow.com/questions/24728088/python-parse-http-response-string
"""
def __init__(self, response_text): def __init__(self, response_text):
self._file = io.BytesIO(response_text) self._file = io.BytesIO(response_text)
@ -213,14 +235,15 @@ class FakeSocket(object):
return self._file return self._file
class HTTPRequest(_BaseHTTPServer.BaseHTTPRequestHandler): class HTTPRequest(_BaseHTTPServer.BaseHTTPRequestHandler):
# Original source: """HTTP请求处理类
# https://stackoverflow.com/questions/4685217/parse-raw-http-headers 原始来源: https://stackoverflow.com/questions/4685217/parse-raw-http-headers
"""
def __init__(self, request_text): def __init__(self, request_text):
self.comment = None self.comment = None
self.rfile = io.BytesIO(request_text) self.rfile = io.BytesIO(request_text)
self.raw_requestline = self.rfile.readline() self.raw_requestline = self.rfile.readline()
# 处理特殊格式的请求
if self.raw_requestline.startswith(b"HTTP request ["): if self.raw_requestline.startswith(b"HTTP request ["):
self.comment = self.raw_requestline self.comment = self.raw_requestline
self.raw_requestline = self.rfile.readline() self.raw_requestline = self.rfile.readline()
@ -229,5 +252,6 @@ class HTTPRequest(_BaseHTTPServer.BaseHTTPRequestHandler):
self.parse_request() self.parse_request()
def send_error(self, code, message): def send_error(self, code, message):
"""记录错误信息"""
self.error_code = code self.error_code = code
self.error_message = message self.error_message = message

@ -226,22 +226,28 @@ def oracle_old_passwd(password, username, uppercase=True): # prior to version '
IV, pad = b"\0" * 8, b"\0" IV, pad = b"\0" * 8, b"\0"
# 将用户名和密码转换为大写并编码为UNICODE_ENCODING
unistr = b"".join((b"\0" + _.encode(UNICODE_ENCODING)) if ord(_) < 256 else _.encode(UNICODE_ENCODING) for _ in (username + password).upper()) unistr = b"".join((b"\0" + _.encode(UNICODE_ENCODING)) if ord(_) < 256 else _.encode(UNICODE_ENCODING) for _ in (username + password).upper())
# 如果使用的是Crypto.Cipher.DES模块则进行加密
if des.__module__ == "Crypto.Cipher.DES": if des.__module__ == "Crypto.Cipher.DES":
# 如果unistr的长度不是8的倍数则进行填充
unistr += b"\0" * ((8 - len(unistr) % 8) & 7) unistr += b"\0" * ((8 - len(unistr) % 8) & 7)
cipher = des(decodeHex("0123456789ABCDEF"), CBC, iv=IV) cipher = des(decodeHex("0123456789ABCDEF"), CBC, iv=IV)
encrypted = cipher.encrypt(unistr) encrypted = cipher.encrypt(unistr)
cipher = des(encrypted[-8:], CBC, iv=IV) cipher = des(encrypted[-8:], CBC, iv=IV)
encrypted = cipher.encrypt(unistr) encrypted = cipher.encrypt(unistr)
else: else:
# 否则使用其他模块进行加密
cipher = des(decodeHex("0123456789ABCDEF"), CBC, IV, pad) cipher = des(decodeHex("0123456789ABCDEF"), CBC, IV, pad)
encrypted = cipher.encrypt(unistr) encrypted = cipher.encrypt(unistr)
cipher = des(encrypted[-8:], CBC, IV, pad) cipher = des(encrypted[-8:], CBC, IV, pad)
encrypted = cipher.encrypt(unistr) encrypted = cipher.encrypt(unistr)
# 将加密后的结果转换为十六进制字符串
retVal = encodeHex(encrypted[-8:], binary=False) retVal = encodeHex(encrypted[-8:], binary=False)
# 如果uppercase为True则将结果转换为大写否则转换为小写
return retVal.upper() if uppercase else retVal.lower() return retVal.upper() if uppercase else retVal.lower()
def md5_generic_passwd(password, uppercase=False): def md5_generic_passwd(password, uppercase=False):
@ -373,8 +379,10 @@ def unix_md5_passwd(password, salt, magic="$1$", **kwargs):
>>> unix_md5_passwd(password='testpass', salt='aD9ZLmkp') >>> unix_md5_passwd(password='testpass', salt='aD9ZLmkp')
'$1$aD9ZLmkp$DRM5a7rRZGyuuOPOjTEk61' '$1$aD9ZLmkp$DRM5a7rRZGyuuOPOjTEk61'
""" """
# 将value转换为64进制字符串
def _encode64(value, count): def _encode64(value, count):
# 将value转换为64进制字符串
output = "" output = ""
while (count - 1 >= 0): while (count - 1 >= 0):
@ -388,16 +396,21 @@ def unix_md5_passwd(password, salt, magic="$1$", **kwargs):
magic = getBytes(magic) magic = getBytes(magic)
salt = getBytes(salt) salt = getBytes(salt)
# 取salt的前8个字节
salt = salt[:8] salt = salt[:8]
# 将password、magic、salt拼接成ctx
ctx = password + magic + salt ctx = password + magic + salt
# 计算password + salt + password的md5值
final = md5(password + salt + password).digest() final = md5(password + salt + password).digest()
# 将final的前16个字节与ctx拼接
for pl in xrange(len(password), 0, -16): for pl in xrange(len(password), 0, -16):
if pl > 16: if pl > 16:
ctx = ctx + final[:16] ctx = ctx + final[:16]
else: else:
ctx = ctx + final[:pl] ctx = ctx + final[:pl]
# 将password转换为二进制
i = len(password) i = len(password)
while i: while i:
if i & 1: if i & 1:
@ -580,11 +593,13 @@ __functions__ = {
} }
def _finalize(retVal, results, processes, attack_info=None): def _finalize(retVal, results, processes, attack_info=None):
# 如果使用多进程,则启用垃圾回收
if _multiprocessing: if _multiprocessing:
gc.enable() gc.enable()
# NOTE: https://github.com/sqlmapproject/sqlmap/issues/4367 # NOTE: https://github.com/sqlmapproject/sqlmap/issues/4367
# NOTE: https://dzone.com/articles/python-101-creating-multiple-processes # NOTE: https://dzone.com/articles/python-101-creating-multiple-processes
# 遍历所有进程,尝试终止并加入
for process in processes: for process in processes:
try: try:
process.terminate() process.terminate()
@ -592,25 +607,31 @@ def _finalize(retVal, results, processes, attack_info=None):
except (OSError, AttributeError): except (OSError, AttributeError):
pass pass
# 如果retVal不为空则执行以下操作
if retVal: if retVal:
removals = set() removals = set()
# 如果使用哈希数据库,则开始事务
if conf.hashDB: if conf.hashDB:
conf.hashDB.beginTransaction() conf.hashDB.beginTransaction()
# 从retVal中获取数据并添加到results中
while not retVal.empty(): while not retVal.empty():
user, hash_, word = item = retVal.get(block=False) user, hash_, word = item = retVal.get(block=False)
results.append(item) results.append(item)
removals.add((user, hash_)) removals.add((user, hash_))
hashDBWrite(hash_, word) hashDBWrite(hash_, word)
# 如果attack_info不为空则从attack_info中移除已经添加到results中的数据
for item in attack_info or []: for item in attack_info or []:
if (item[0][0], item[0][1]) in removals: if (item[0][0], item[0][1]) in removals:
attack_info.remove(item) attack_info.remove(item)
# 如果使用哈希数据库,则结束事务
if conf.hashDB: if conf.hashDB:
conf.hashDB.endTransaction() conf.hashDB.endTransaction()
# 如果retVal有close方法则调用close方法
if hasattr(retVal, "close"): if hasattr(retVal, "close"):
retVal.close() retVal.close()
@ -654,55 +675,82 @@ def storeHashesToFile(attack_dict):
pass pass
def attackCachedUsersPasswords(): def attackCachedUsersPasswords():
# 如果缓存的用户密码不为空
if kb.data.cachedUsersPasswords: if kb.data.cachedUsersPasswords:
# 使用字典攻击函数对缓存的用户密码进行攻击
results = dictionaryAttack(kb.data.cachedUsersPasswords) results = dictionaryAttack(kb.data.cachedUsersPasswords)
# 创建一个空字典
lut = {} lut = {}
# 遍历攻击结果
for (_, hash_, password) in results: for (_, hash_, password) in results:
# 将哈希值转换为小写,并将其作为键,密码作为值存入字典
lut[hash_.lower()] = password lut[hash_.lower()] = password
# 遍历缓存的用户密码
for user in kb.data.cachedUsersPasswords: for user in kb.data.cachedUsersPasswords:
# 遍历每个用户的密码
for i in xrange(len(kb.data.cachedUsersPasswords[user])): for i in xrange(len(kb.data.cachedUsersPasswords[user])):
# 如果密码不为空
if (kb.data.cachedUsersPasswords[user][i] or "").strip(): if (kb.data.cachedUsersPasswords[user][i] or "").strip():
# 将密码转换为小写,并取第一个单词作为值
value = kb.data.cachedUsersPasswords[user][i].lower().split()[0] value = kb.data.cachedUsersPasswords[user][i].lower().split()[0]
# 如果值在字典中
if value in lut: if value in lut:
# 将密码和对应的明文密码添加到缓存的用户密码中
kb.data.cachedUsersPasswords[user][i] += "%s clear-text password: %s" % ('\n' if kb.data.cachedUsersPasswords[user][i][-1] != '\n' else '', lut[value]) kb.data.cachedUsersPasswords[user][i] += "%s clear-text password: %s" % ('\n' if kb.data.cachedUsersPasswords[user][i][-1] != '\n' else '', lut[value])
def attackDumpedTable(): def attackDumpedTable():
# 如果kb.data.dumpedTable存在
if kb.data.dumpedTable: if kb.data.dumpedTable:
# 获取dumpedTable
table = kb.data.dumpedTable table = kb.data.dumpedTable
# 获取table的键值
columns = list(table.keys()) columns = list(table.keys())
# 获取table中__infos__键的值
count = table["__infos__"]["count"] count = table["__infos__"]["count"]
# 如果count不存在则返回
if not count: if not count:
return return
# 打印debug信息
debugMsg = "analyzing table dump for possible password hashes" debugMsg = "analyzing table dump for possible password hashes"
logger.debug(debugMsg) logger.debug(debugMsg)
# 初始化found为False
found = False found = False
# 初始化col_user为空字符串
col_user = '' col_user = ''
# 初始化col_passwords为空集合
col_passwords = set() col_passwords = set()
# 初始化attack_dict为空字典
attack_dict = {} attack_dict = {}
# 初始化binary_fields为空集合
binary_fields = OrderedSet() binary_fields = OrderedSet()
# 初始化replacements为空字典
replacements = {} replacements = {}
# 遍历columns找到col_user
for column in sorted(columns, key=len, reverse=True): for column in sorted(columns, key=len, reverse=True):
if column and column.lower() in COMMON_USER_COLUMNS: if column and column.lower() in COMMON_USER_COLUMNS:
col_user = column col_user = column
break break
# 遍历columns找到binary_fields
for column in columns: for column in columns:
if column != "__infos__" and table[column]["values"]: if column != "__infos__" and table[column]["values"]:
if all(INVALID_UNICODE_CHAR_FORMAT.split('%')[0] in (value or "") for value in table[column]["values"]): if all(INVALID_UNICODE_CHAR_FORMAT.split('%')[0] in (value or "") for value in table[column]["values"]):
binary_fields.add(column) binary_fields.add(column)
# 如果binary_fields存在则打印警告信息
if binary_fields: if binary_fields:
_ = ','.join(binary_fields) _ = ','.join(binary_fields)
warnMsg = "potential binary fields detected ('%s'). In case of any problems you are " % _ warnMsg = "potential binary fields detected ('%s'). In case of any problems you are " % _
warnMsg += "advised to rerun table dump with '--fresh-queries --binary-fields=\"%s\"'" % _ warnMsg += "advised to rerun table dump with '--fresh-queries --binary-fields=\"%s\"'" % _
logger.warning(warnMsg) logger.warning(warnMsg)
# 遍历count找到found
for i in xrange(count): for i in xrange(count):
if not found and i > HASH_RECOGNITION_QUIT_THRESHOLD: if not found and i > HASH_RECOGNITION_QUIT_THRESHOLD:
break break
@ -719,22 +767,27 @@ def attackDumpedTable():
value = table[column]["values"][i] value = table[column]["values"][i]
# 如果column在binary_fields中并且column符合HASH_BINARY_COLUMNS_REGEX则进行编码
if column in binary_fields and re.search(HASH_BINARY_COLUMNS_REGEX, column) is not None: if column in binary_fields and re.search(HASH_BINARY_COLUMNS_REGEX, column) is not None:
previous = value previous = value
value = encodeHex(getBytes(value), binary=False) value = encodeHex(getBytes(value), binary=False)
replacements[value] = previous replacements[value] = previous
# 如果value符合hashRecognition则进行攻击
if hashRecognition(value): if hashRecognition(value):
found = True found = True
# 如果col_user存在并且i小于len(table[col_user]["values"])则将value添加到attack_dict中
if col_user and i < len(table[col_user]["values"]): if col_user and i < len(table[col_user]["values"]):
if table[col_user]["values"][i] not in attack_dict: if table[col_user]["values"][i] not in attack_dict:
attack_dict[table[col_user]["values"][i]] = [] attack_dict[table[col_user]["values"][i]] = []
attack_dict[table[col_user]["values"][i]].append(value) attack_dict[table[col_user]["values"][i]].append(value)
# 否则将value添加到attack_dict中
else: else:
attack_dict["%s%d" % (DUMMY_USER_PREFIX, i)] = [value] attack_dict["%s%d" % (DUMMY_USER_PREFIX, i)] = [value]
# 将column添加到col_passwords中
col_passwords.add(column) col_passwords.add(column)
if attack_dict: if attack_dict:
@ -812,67 +865,91 @@ def hashRecognition(value):
return retVal return retVal
def _bruteProcessVariantA(attack_info, hash_regex, suffix, retVal, proc_id, proc_count, wordlists, custom_wordlist, api): def _bruteProcessVariantA(attack_info, hash_regex, suffix, retVal, proc_id, proc_count, wordlists, custom_wordlist, api):
# 初始化颜色
if IS_WIN: if IS_WIN:
coloramainit() coloramainit()
count = 0 count = 0
rotator = 0 rotator = 0
# 获取所有哈希值
hashes = set(item[0][1] for item in attack_info) hashes = set(item[0][1] for item in attack_info)
# 创建Wordlist对象
wordlist = Wordlist(wordlists, proc_id, getattr(proc_count, "value", 0), custom_wordlist) wordlist = Wordlist(wordlists, proc_id, getattr(proc_count, "value", 0), custom_wordlist)
try: try:
# 遍历Wordlist中的每个单词
for word in wordlist: for word in wordlist:
# 如果attack_info为空则跳出循环
if not attack_info: if not attack_info:
break break
count += 1 count += 1
# 如果单词是二进制类型则转换为Unicode
if isinstance(word, six.binary_type): if isinstance(word, six.binary_type):
word = getUnicode(word) word = getUnicode(word)
# 如果单词不是字符串类型,则跳过
elif not isinstance(word, six.string_types): elif not isinstance(word, six.string_types):
continue continue
# 如果suffix不为空则将suffix添加到单词后面
if suffix: if suffix:
word = word + suffix word = word + suffix
try: try:
# 使用__functions__中的hash_regex函数对单词进行哈希
current = __functions__[hash_regex](password=word, uppercase=False) current = __functions__[hash_regex](password=word, uppercase=False)
# 如果哈希值在hashes中则说明找到了匹配的密码
if current in hashes: if current in hashes:
# 遍历attack_info中的每个元素
for item in attack_info[:]: for item in attack_info[:]:
((user, hash_), _) = item ((user, hash_), _) = item
# 如果哈希值匹配则将用户名、哈希值和密码放入retVal中
if hash_ == current: if hash_ == current:
retVal.put((user, hash_, word)) retVal.put((user, hash_, word))
# 清除控制台行
clearConsoleLine() clearConsoleLine()
# 输出破解的密码信息
infoMsg = "\r[%s] [INFO] cracked password '%s'" % (time.strftime("%X"), word) infoMsg = "\r[%s] [INFO] cracked password '%s'" % (time.strftime("%X"), word)
# 如果用户名存在且不是DUMMY_USER_PREFIX则输出用户名
if user and not user.startswith(DUMMY_USER_PREFIX): if user and not user.startswith(DUMMY_USER_PREFIX):
infoMsg += " for user '%s'\n" % user infoMsg += " for user '%s'\n" % user
# 否则输出哈希值
else: else:
infoMsg += " for hash '%s'\n" % hash_ infoMsg += " for hash '%s'\n" % hash_
# 输出信息
dataToStdout(infoMsg, True) dataToStdout(infoMsg, True)
# 从attack_info中移除该元素
attack_info.remove(item) attack_info.remove(item)
# 如果proc_id为0或者proc_count为1并且count能被HASH_MOD_ITEM_DISPLAY整除或者hash_regex为HASH.ORACLE_OLD或者HASH.CRYPT_GENERIC且IS_WIN为True则输出当前状态
elif (proc_id == 0 or getattr(proc_count, "value", 0) == 1) and count % HASH_MOD_ITEM_DISPLAY == 0 or hash_regex == HASH.ORACLE_OLD or hash_regex == HASH.CRYPT_GENERIC and IS_WIN: elif (proc_id == 0 or getattr(proc_count, "value", 0) == 1) and count % HASH_MOD_ITEM_DISPLAY == 0 or hash_regex == HASH.ORACLE_OLD or hash_regex == HASH.CRYPT_GENERIC and IS_WIN:
rotator += 1 rotator += 1
# 如果rotator大于等于ROTATING_CHARS的长度则重置为0
if rotator >= len(ROTATING_CHARS): if rotator >= len(ROTATING_CHARS):
rotator = 0 rotator = 0
# 输出当前状态
status = "current status: %s... %s" % (word.ljust(5)[:5], ROTATING_CHARS[rotator]) status = "current status: %s... %s" % (word.ljust(5)[:5], ROTATING_CHARS[rotator])
# 如果api为False则输出状态
if not api: if not api:
dataToStdout("\r[%s] [INFO] %s" % (time.strftime("%X"), status)) dataToStdout("\r[%s] [INFO] %s" % (time.strftime("%X"), status))
# 捕获KeyboardInterrupt异常
except KeyboardInterrupt: except KeyboardInterrupt:
raise raise
# 捕获UnicodeEncodeError和UnicodeDecodeError异常
except (UnicodeEncodeError, UnicodeDecodeError): except (UnicodeEncodeError, UnicodeDecodeError):
pass # ignore possible encoding problems caused by some words in custom dictionaries pass # ignore possible encoding problems caused by some words in custom dictionaries
@ -969,8 +1046,10 @@ def _bruteProcessVariantB(user, hash_, kwargs, hash_regex, suffix, retVal, found
proc_count.value -= 1 proc_count.value -= 1
def dictionaryAttack(attack_dict): def dictionaryAttack(attack_dict):
# 定义一个全局变量_multiprocessing
global _multiprocessing global _multiprocessing
# 定义一些变量
suffix_list = [""] suffix_list = [""]
custom_wordlist = [""] custom_wordlist = [""]
hash_regexes = [] hash_regexes = []
@ -980,11 +1059,13 @@ def dictionaryAttack(attack_dict):
processException = False processException = False
foundHash = False foundHash = False
# 如果禁用了多进程则将_multiprocessing设置为None
if conf.disableMulti: if conf.disableMulti:
_multiprocessing = None _multiprocessing = None
else: else:
# Note: https://github.com/sqlmapproject/sqlmap/issues/4367 # Note: https://github.com/sqlmapproject/sqlmap/issues/4367
try: try:
# 尝试导入multiprocessing模块
import multiprocessing import multiprocessing
# problems on FreeBSD (Reference: https://web.archive.org/web/20110710041353/http://www.eggheadcafe.com/microsoft/Python/35880259/multiprocessing-on-freebsd.aspx) # problems on FreeBSD (Reference: https://web.archive.org/web/20110710041353/http://www.eggheadcafe.com/microsoft/Python/35880259/multiprocessing-on-freebsd.aspx)
@ -996,46 +1077,58 @@ def dictionaryAttack(attack_dict):
pass pass
else: else:
try: try:
# 如果CPU数量大于1则将_multiprocessing设置为multiprocessing
if multiprocessing.cpu_count() > 1: if multiprocessing.cpu_count() > 1:
_multiprocessing = multiprocessing _multiprocessing = multiprocessing
except NotImplementedError: except NotImplementedError:
pass pass
# 遍历attack_dict中的每个hash
for (_, hashes) in attack_dict.items(): for (_, hashes) in attack_dict.items():
for hash_ in hashes: for hash_ in hashes:
if not hash_: if not hash_:
continue continue
# 将hash_转换为字符串
hash_ = hash_.split()[0] if hash_ and hash_.strip() else hash_ hash_ = hash_.split()[0] if hash_ and hash_.strip() else hash_
# 调用hashRecognition函数获取hash_的正则表达式
regex = hashRecognition(hash_) regex = hashRecognition(hash_)
# 如果正则表达式存在且不在hash_regexes中则将其添加到hash_regexes中
if regex and regex not in hash_regexes: if regex and regex not in hash_regexes:
hash_regexes.append(regex) hash_regexes.append(regex)
infoMsg = "using hash method '%s'" % __functions__[regex].__name__ infoMsg = "using hash method '%s'" % __functions__[regex].__name__
logger.info(infoMsg) logger.info(infoMsg)
# 遍历hash_regexes中的每个正则表达式
for hash_regex in hash_regexes: for hash_regex in hash_regexes:
keys = set() keys = set()
attack_info = [] attack_info = []
# 遍历attack_dict中的每个hash
for (user, hashes) in attack_dict.items(): for (user, hashes) in attack_dict.items():
for hash_ in hashes: for hash_ in hashes:
if not hash_: if not hash_:
continue continue
# 将hash_转换为字符串
foundHash = True foundHash = True
hash_ = hash_.split()[0] if hash_ and hash_.strip() else hash_ hash_ = hash_.split()[0] if hash_ and hash_.strip() else hash_
# 如果hash_匹配hash_regex则执行以下操作
if re.match(hash_regex, hash_): if re.match(hash_regex, hash_):
try: try:
item = None item = None
# 如果hash_regex不在以下列表中则将hash_转换为小写
if hash_regex not in (HASH.CRYPT_GENERIC, HASH.JOOMLA, HASH.PHPASS, HASH.UNIX_MD5_CRYPT, HASH.APACHE_MD5_CRYPT, HASH.APACHE_SHA1, HASH.VBULLETIN, HASH.VBULLETIN_OLD, HASH.SSHA, HASH.SSHA256, HASH.SSHA512, HASH.DJANGO_MD5, HASH.DJANGO_SHA1, HASH.MD5_BASE64, HASH.SHA1_BASE64, HASH.SHA256_BASE64, HASH.SHA512_BASE64): if hash_regex not in (HASH.CRYPT_GENERIC, HASH.JOOMLA, HASH.PHPASS, HASH.UNIX_MD5_CRYPT, HASH.APACHE_MD5_CRYPT, HASH.APACHE_SHA1, HASH.VBULLETIN, HASH.VBULLETIN_OLD, HASH.SSHA, HASH.SSHA256, HASH.SSHA512, HASH.DJANGO_MD5, HASH.DJANGO_SHA1, HASH.MD5_BASE64, HASH.SHA1_BASE64, HASH.SHA256_BASE64, HASH.SHA512_BASE64):
hash_ = hash_.lower() hash_ = hash_.lower()
# 如果hash_regex在以下列表中则执行以下操作
if hash_regex in (HASH.MD5_BASE64, HASH.SHA1_BASE64, HASH.SHA256_BASE64, HASH.SHA512_BASE64): if hash_regex in (HASH.MD5_BASE64, HASH.SHA1_BASE64, HASH.SHA256_BASE64, HASH.SHA512_BASE64):
item = [(user, encodeHex(decodeBase64(hash_, binary=True))), {}] item = [(user, encodeHex(decodeBase64(hash_, binary=True))), {}]
elif hash_regex in (HASH.MYSQL, HASH.MYSQL_OLD, HASH.MD5_GENERIC, HASH.SHA1_GENERIC, HASH.SHA224_GENERIC, HASH.SHA256_GENERIC, HASH.SHA384_GENERIC, HASH.SHA512_GENERIC, HASH.APACHE_SHA1): elif hash_regex in (HASH.MYSQL, HASH.MYSQL_OLD, HASH.MD5_GENERIC, HASH.SHA1_GENERIC, HASH.SHA224_GENERIC, HASH.SHA256_GENERIC, HASH.SHA384_GENERIC, HASH.SHA512_GENERIC, HASH.APACHE_SHA1):
# 如果hash_以"0x"开头,则将其去掉
if hash_.startswith("0x"): # Reference: https://docs.microsoft.com/en-us/sql/t-sql/functions/hashbytes-transact-sql?view=sql-server-2017 if hash_.startswith("0x"): # Reference: https://docs.microsoft.com/en-us/sql/t-sql/functions/hashbytes-transact-sql?view=sql-server-2017
hash_ = hash_[2:] hash_ = hash_[2:]
item = [(user, hash_), {}] item = [(user, hash_), {}]
@ -1060,23 +1153,29 @@ def dictionaryAttack(attack_dict):
elif hash_regex in (HASH.DJANGO_MD5, HASH.DJANGO_SHA1): elif hash_regex in (HASH.DJANGO_MD5, HASH.DJANGO_SHA1):
item = [(user, hash_), {"salt": hash_.split('$')[1]}] item = [(user, hash_), {"salt": hash_.split('$')[1]}]
elif hash_regex in (HASH.PHPASS,): elif hash_regex in (HASH.PHPASS,):
# 如果hash_的第四个字符在ITOA64中索引小于32则执行以下操作
if ITOA64.index(hash_[3]) < 32: if ITOA64.index(hash_[3]) < 32:
item = [(user, hash_), {"salt": hash_[4:12], "count": 1 << ITOA64.index(hash_[3]), "prefix": hash_[:3]}] item = [(user, hash_), {"salt": hash_[4:12], "count": 1 << ITOA64.index(hash_[3]), "prefix": hash_[:3]}]
else: else:
warnMsg = "invalid hash '%s'" % hash_ warnMsg = "invalid hash '%s'" % hash_
logger.warning(warnMsg) logger.warning(warnMsg)
# 如果item存在且hash_不在keys中则执行以下操作
if item and hash_ not in keys: if item and hash_ not in keys:
# 调用hashDBRetrieve函数获取hash_的密码
resumed = hashDBRetrieve(hash_) resumed = hashDBRetrieve(hash_)
# 如果resumed不存在则将item添加到attack_info中并将user_hash添加到user_hash中
if not resumed: if not resumed:
attack_info.append(item) attack_info.append(item)
user_hash.append(item[0]) user_hash.append(item[0])
else: else:
# 如果user存在且不以DUMMY_USER_PREFIX开头则将resumed添加到resumes中
infoMsg = "resuming password '%s' for hash '%s'" % (resumed, hash_) infoMsg = "resuming password '%s' for hash '%s'" % (resumed, hash_)
if user and not user.startswith(DUMMY_USER_PREFIX): if user and not user.startswith(DUMMY_USER_PREFIX):
infoMsg += " for user '%s'" % user infoMsg += " for user '%s'" % user
logger.info(infoMsg) logger.info(infoMsg)
resumes.append((user, hash_, resumed)) resumes.append((user, hash_, resumed))
# 将hash_添加到keys中
keys.add(hash_) keys.add(hash_)
except (binascii.Error, TypeError, IndexError): except (binascii.Error, TypeError, IndexError):
@ -1298,15 +1397,23 @@ def dictionaryAttack(attack_dict):
return results return results
# 定义一个函数,用于破解哈希文件
def crackHashFile(hashFile): def crackHashFile(hashFile):
# 初始化计数器
i = 0 i = 0
# 初始化一个空字典,用于存储用户名和哈希值
attack_dict = {} attack_dict = {}
# 遍历哈希文件中的每一行
for line in getFileItems(conf.hashFile): for line in getFileItems(conf.hashFile):
# 如果行中包含冒号
if ':' in line: if ':' in line:
# 将行按照冒号分割,得到用户名和哈希值
user, hash_ = line.split(':', 1) user, hash_ = line.split(':', 1)
# 将用户名和哈希值存入字典中
attack_dict[user] = [hash_] attack_dict[user] = [hash_]
else: else:
# 如果行中不包含冒号则将行存入字典中用户名为DUMMY_USER_PREFIX加上计数器的值
attack_dict["%s%d" % (DUMMY_USER_PREFIX, i)] = [line] attack_dict["%s%d" % (DUMMY_USER_PREFIX, i)] = [line]
i += 1 i += 1

@ -5,44 +5,60 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import hashlib # 导入所需的标准库和自定义模块
import os import hashlib # 用于计算哈希值
import sqlite3 import os # 用于文件和路径操作
import threading import sqlite3 # SQLite数据库操作
import time import threading # 多线程支持
import time # 时间相关操作
from lib.core.common import getSafeExString
from lib.core.common import serializeObject # 导入自定义工具函数
from lib.core.common import singleTimeWarnMessage from lib.core.common import getSafeExString # 安全地获取异常字符串
from lib.core.common import unserializeObject from lib.core.common import serializeObject # 对象序列化
from lib.core.compat import xrange from lib.core.common import singleTimeWarnMessage # 单次警告消息
from lib.core.convert import getBytes from lib.core.common import unserializeObject # 对象反序列化
from lib.core.convert import getUnicode from lib.core.compat import xrange # 兼容Python2/3的range函数
from lib.core.data import logger from lib.core.convert import getBytes # 转换为字节
from lib.core.exception import SqlmapConnectionException from lib.core.convert import getUnicode # 转换为Unicode
from lib.core.settings import HASHDB_END_TRANSACTION_RETRIES from lib.core.data import logger # 日志记录器
from lib.core.settings import HASHDB_FLUSH_RETRIES from lib.core.exception import SqlmapConnectionException # 自定义连接异常
from lib.core.settings import HASHDB_FLUSH_THRESHOLD from lib.core.settings import HASHDB_END_TRANSACTION_RETRIES # 事务结束重试次数
from lib.core.settings import HASHDB_RETRIEVE_RETRIES from lib.core.settings import HASHDB_FLUSH_RETRIES # 刷新重试次数
from lib.core.threads import getCurrentThreadData from lib.core.settings import HASHDB_FLUSH_THRESHOLD # 刷新阈值
from lib.core.threads import getCurrentThreadName from lib.core.settings import HASHDB_RETRIEVE_RETRIES # 检索重试次数
from thirdparty import six from lib.core.threads import getCurrentThreadData # 获取当前线程数据
from lib.core.threads import getCurrentThreadName # 获取当前线程名称
from thirdparty import six # Python 2/3 兼容库
class HashDB(object): class HashDB(object):
"""
哈希数据库类,用于管理SQLite数据库中的键值存储
"""
def __init__(self, filepath): def __init__(self, filepath):
self.filepath = filepath """
self._write_cache = {} 初始化哈希数据库
self._cache_lock = threading.Lock() @param filepath: 数据库文件路径
self._connections = [] """
self.filepath = filepath # 数据库文件路径
self._write_cache = {} # 写入缓存字典
self._cache_lock = threading.Lock() # 缓存锁,用于线程同步
self._connections = [] # 数据库连接列表
def _get_cursor(self): def _get_cursor(self):
threadData = getCurrentThreadData() """
获取数据库游标
@return: SQLite游标对象
"""
threadData = getCurrentThreadData() # 获取当前线程数据
# 如果当前线程没有游标,则创建新的连接和游标
if threadData.hashDBCursor is None: if threadData.hashDBCursor is None:
try: try:
# 创建SQLite连接,禁用事务自动提交
connection = sqlite3.connect(self.filepath, timeout=3, isolation_level=None) connection = sqlite3.connect(self.filepath, timeout=3, isolation_level=None)
self._connections.append(connection) self._connections.append(connection)
threadData.hashDBCursor = connection.cursor() threadData.hashDBCursor = connection.cursor()
# 创建存储表(如果不存在)
threadData.hashDBCursor.execute("CREATE TABLE IF NOT EXISTS storage (id INTEGER PRIMARY KEY, value TEXT)") threadData.hashDBCursor.execute("CREATE TABLE IF NOT EXISTS storage (id INTEGER PRIMARY KEY, value TEXT)")
connection.commit() connection.commit()
except Exception as ex: except Exception as ex:
@ -53,12 +69,20 @@ class HashDB(object):
return threadData.hashDBCursor return threadData.hashDBCursor
def _set_cursor(self, cursor): def _set_cursor(self, cursor):
"""
设置数据库游标
@param cursor: SQLite游标对象
"""
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
threadData.hashDBCursor = cursor threadData.hashDBCursor = cursor
# 游标属性
cursor = property(_get_cursor, _set_cursor) cursor = property(_get_cursor, _set_cursor)
def close(self): def close(self):
"""
关闭当前线程的数据库连接
"""
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
try: try:
if threadData.hashDBCursor: if threadData.hashDBCursor:
@ -70,6 +94,9 @@ class HashDB(object):
pass pass
def closeAll(self): def closeAll(self):
"""
关闭所有数据库连接
"""
for connection in self._connections: for connection in self._connections:
try: try:
connection.commit() connection.commit()
@ -79,17 +106,30 @@ class HashDB(object):
@staticmethod @staticmethod
def hashKey(key): def hashKey(key):
"""
计算键的哈希值
@param key: 要哈希的键
@return: 64位整数哈希值
"""
key = getBytes(key if isinstance(key, six.text_type) else repr(key), errors="xmlcharrefreplace") key = getBytes(key if isinstance(key, six.text_type) else repr(key), errors="xmlcharrefreplace")
retVal = int(hashlib.md5(key).hexdigest(), 16) & 0x7fffffffffffffff # Reference: http://stackoverflow.com/a/4448400 retVal = int(hashlib.md5(key).hexdigest(), 16) & 0x7fffffffffffffff # 确保返回64位正整数
return retVal return retVal
def retrieve(self, key, unserialize=False): def retrieve(self, key, unserialize=False):
"""
从数据库检索值
@param key: 要检索的键
@param unserialize: 是否需要反序列化
@return: 检索到的值
"""
retVal = None retVal = None
if key and (self._write_cache or os.path.isfile(self.filepath)): if key and (self._write_cache or os.path.isfile(self.filepath)):
hash_ = HashDB.hashKey(key) hash_ = HashDB.hashKey(key)
# 首先检查缓存
retVal = self._write_cache.get(hash_) retVal = self._write_cache.get(hash_)
if not retVal: if not retVal:
# 多次尝试从数据库检索
for _ in xrange(HASHDB_RETRIEVE_RETRIES): for _ in xrange(HASHDB_RETRIEVE_RETRIES):
try: try:
for row in self.cursor.execute("SELECT value FROM storage WHERE id=?", (hash_,)): for row in self.cursor.execute("SELECT value FROM storage WHERE id=?", (hash_,)):
@ -109,6 +149,7 @@ class HashDB(object):
time.sleep(1) time.sleep(1)
# 如果需要反序列化
if retVal and unserialize: if retVal and unserialize:
try: try:
retVal = unserializeObject(retVal) retVal = unserializeObject(retVal)
@ -121,22 +162,35 @@ class HashDB(object):
return retVal return retVal
def write(self, key, value, serialize=False): def write(self, key, value, serialize=False):
"""
写入值到数据库
@param key:
@param value:
@param serialize: 是否需要序列化
"""
if key: if key:
hash_ = HashDB.hashKey(key) hash_ = HashDB.hashKey(key)
self._cache_lock.acquire() self._cache_lock.acquire()
self._write_cache[hash_] = getUnicode(value) if not serialize else serializeObject(value) self._write_cache[hash_] = getUnicode(value) if not serialize else serializeObject(value)
self._cache_lock.release() self._cache_lock.release()
# 主线程自动刷新缓存
if getCurrentThreadName() in ('0', "MainThread"): if getCurrentThreadName() in ('0', "MainThread"):
self.flush() self.flush()
def flush(self, forced=False): def flush(self, forced=False):
"""
将缓存刷新到数据库
@param forced: 是否强制刷新
"""
if not self._write_cache: if not self._write_cache:
return return
# 如果未强制刷新且缓存未达到阈值,则不刷新
if not forced and len(self._write_cache) < HASHDB_FLUSH_THRESHOLD: if not forced and len(self._write_cache) < HASHDB_FLUSH_THRESHOLD:
return return
# 获取并清空缓存
self._cache_lock.acquire() self._cache_lock.acquire()
_ = self._write_cache _ = self._write_cache
self._write_cache = {} self._write_cache = {}
@ -144,6 +198,7 @@ class HashDB(object):
try: try:
self.beginTransaction() self.beginTransaction()
# 遍历缓存项写入数据库
for hash_, value in _.items(): for hash_, value in _.items():
retries = 0 retries = 0
while True: while True:
@ -152,7 +207,7 @@ class HashDB(object):
self.cursor.execute("INSERT INTO storage VALUES (?, ?)", (hash_, value,)) self.cursor.execute("INSERT INTO storage VALUES (?, ?)", (hash_, value,))
except sqlite3.IntegrityError: except sqlite3.IntegrityError:
self.cursor.execute("UPDATE storage SET value=? WHERE id=?", (value, hash_,)) self.cursor.execute("UPDATE storage SET value=? WHERE id=?", (value, hash_,))
except (UnicodeError, OverflowError): # e.g. surrogates not allowed (Issue #3851) except (UnicodeError, OverflowError): # 处理编码错误
break break
except sqlite3.DatabaseError as ex: except sqlite3.DatabaseError as ex:
if not os.path.exists(self.filepath): if not os.path.exists(self.filepath):
@ -176,13 +231,15 @@ class HashDB(object):
self.endTransaction() self.endTransaction()
def beginTransaction(self): def beginTransaction(self):
"""
开始数据库事务
"""
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
if not threadData.inTransaction: if not threadData.inTransaction:
try: try:
self.cursor.execute("BEGIN TRANSACTION") self.cursor.execute("BEGIN TRANSACTION")
except: except:
try: try:
# Reference: http://stackoverflow.com/a/25245731
self.cursor.close() self.cursor.close()
except sqlite3.ProgrammingError: except sqlite3.ProgrammingError:
pass pass
@ -192,6 +249,9 @@ class HashDB(object):
threadData.inTransaction = True threadData.inTransaction = True
def endTransaction(self): def endTransaction(self):
"""
结束数据库事务
"""
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
if threadData.inTransaction: if threadData.inTransaction:
retries = 0 retries = 0

@ -5,36 +5,42 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
# 导入需要的模块
from __future__ import print_function from __future__ import print_function
import mimetypes import mimetypes # 用于猜测文件的MIME类型
import gzip import gzip # 用于gzip压缩
import os import os # 操作系统相关功能
import re import re # 正则表达式
import sys import sys # 系统相关功能
import threading import threading # 多线程支持
import time import time # 时间相关功能
import traceback import traceback # 异常追踪
# 将上级目录添加到Python路径中
sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", ".."))) sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..")))
from lib.core.enums import HTTP_HEADER # 导入自定义模块和第三方库
from lib.core.settings import UNICODE_ENCODING from lib.core.enums import HTTP_HEADER # HTTP头部常量
from lib.core.settings import VERSION_STRING from lib.core.settings import UNICODE_ENCODING # 编码设置
from thirdparty import six from lib.core.settings import VERSION_STRING # 版本信息
from thirdparty.six.moves import BaseHTTPServer as _BaseHTTPServer from thirdparty import six # Python 2/3 兼容库
from thirdparty.six.moves import http_client as _http_client from thirdparty.six.moves import BaseHTTPServer as _BaseHTTPServer # HTTP服务器基类
from thirdparty.six.moves import socketserver as _socketserver from thirdparty.six.moves import http_client as _http_client # HTTP客户端
from thirdparty.six.moves import urllib as _urllib from thirdparty.six.moves import socketserver as _socketserver # Socket服务器
from thirdparty.six.moves import urllib as _urllib # URL处理
HTTP_ADDRESS = "0.0.0.0"
HTTP_PORT = 8951 # 服务器配置
DEBUG = True HTTP_ADDRESS = "0.0.0.0" # 监听所有网络接口
HTML_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "data", "html")) HTTP_PORT = 8951 # 服务器端口
DISABLED_CONTENT_EXTENSIONS = (".py", ".pyc", ".md", ".txt", ".bak", ".conf", ".zip", "~") DEBUG = True # 调试模式开关
HTML_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..", "..", "data", "html")) # HTML文件目录
DISABLED_CONTENT_EXTENSIONS = (".py", ".pyc", ".md", ".txt", ".bak", ".conf", ".zip", "~") # 禁止访问的文件扩展名
class ThreadingServer(_socketserver.ThreadingMixIn, _BaseHTTPServer.HTTPServer): class ThreadingServer(_socketserver.ThreadingMixIn, _BaseHTTPServer.HTTPServer):
"""多线程HTTP服务器类"""
def finish_request(self, *args, **kwargs): def finish_request(self, *args, **kwargs):
"""处理请求完成时的回调"""
try: try:
_BaseHTTPServer.HTTPServer.finish_request(self, *args, **kwargs) _BaseHTTPServer.HTTPServer.finish_request(self, *args, **kwargs)
except Exception: except Exception:
@ -42,48 +48,59 @@ class ThreadingServer(_socketserver.ThreadingMixIn, _BaseHTTPServer.HTTPServer):
traceback.print_exc() traceback.print_exc()
class ReqHandler(_BaseHTTPServer.BaseHTTPRequestHandler): class ReqHandler(_BaseHTTPServer.BaseHTTPRequestHandler):
"""HTTP请求处理器类"""
def do_GET(self): def do_GET(self):
"""处理GET请求"""
# 解析URL和查询参数
path, query = self.path.split('?', 1) if '?' in self.path else (self.path, "") path, query = self.path.split('?', 1) if '?' in self.path else (self.path, "")
params = {} params = {}
content = None content = None
# 解析查询参数
if query: if query:
params.update(_urllib.parse.parse_qs(query)) params.update(_urllib.parse.parse_qs(query))
# 只保留每个参数的最后一个值
for key in params: for key in params:
if params[key]: if params[key]:
params[key] = params[key][-1] params[key] = params[key][-1]
self.url, self.params = path, params self.url, self.params = path, params
# 处理根路径请求
if path == '/': if path == '/':
path = "index.html" path = "index.html"
# 处理文件路径
path = path.strip('/') path = path.strip('/')
path = path.replace('/', os.path.sep) path = path.replace('/', os.path.sep)
path = os.path.abspath(os.path.join(HTML_DIR, path)).strip() path = os.path.abspath(os.path.join(HTML_DIR, path)).strip()
# 如果文件不存在但存在同名的.html文件则使用.html文件
if not os.path.isfile(path) and os.path.isfile("%s.html" % path): if not os.path.isfile(path) and os.path.isfile("%s.html" % path):
path = "%s.html" % path path = "%s.html" % path
# 检查文件是否可访问并返回相应内容
if ".." not in os.path.relpath(path, HTML_DIR) and os.path.isfile(path) and not path.endswith(DISABLED_CONTENT_EXTENSIONS): if ".." not in os.path.relpath(path, HTML_DIR) and os.path.isfile(path) and not path.endswith(DISABLED_CONTENT_EXTENSIONS):
content = open(path, "rb").read() content = open(path, "rb").read()
self.send_response(_http_client.OK) self.send_response(_http_client.OK)
self.send_header(HTTP_HEADER.CONNECTION, "close") self.send_header(HTTP_HEADER.CONNECTION, "close")
self.send_header(HTTP_HEADER.CONTENT_TYPE, mimetypes.guess_type(path)[0] or "application/octet-stream") self.send_header(HTTP_HEADER.CONTENT_TYPE, mimetypes.guess_type(path)[0] or "application/octet-stream")
else: else:
# 返回404错误页面
content = ("<!DOCTYPE html><html lang=\"en\"><head><title>404 Not Found</title></head><body><h1>Not Found</h1><p>The requested URL %s was not found on this server.</p></body></html>" % self.path.split('?')[0]).encode(UNICODE_ENCODING) content = ("<!DOCTYPE html><html lang=\"en\"><head><title>404 Not Found</title></head><body><h1>Not Found</h1><p>The requested URL %s was not found on this server.</p></body></html>" % self.path.split('?')[0]).encode(UNICODE_ENCODING)
self.send_response(_http_client.NOT_FOUND) self.send_response(_http_client.NOT_FOUND)
self.send_header(HTTP_HEADER.CONNECTION, "close") self.send_header(HTTP_HEADER.CONNECTION, "close")
if content is not None: if content is not None:
# 处理模板标记
for match in re.finditer(b"<!(\\w+)!>", content): for match in re.finditer(b"<!(\\w+)!>", content):
name = match.group(1) name = match.group(1)
_ = getattr(self, "_%s" % name.lower(), None) _ = getattr(self, "_%s" % name.lower(), None)
if _: if _:
content = self._format(content, **{name: _()}) content = self._format(content, **{name: _()})
# 如果客户端支持gzip压缩则压缩内容
if "gzip" in self.headers.get(HTTP_HEADER.ACCEPT_ENCODING): if "gzip" in self.headers.get(HTTP_HEADER.ACCEPT_ENCODING):
self.send_header(HTTP_HEADER.CONTENT_ENCODING, "gzip") self.send_header(HTTP_HEADER.CONTENT_ENCODING, "gzip")
_ = six.BytesIO() _ = six.BytesIO()
@ -98,25 +115,29 @@ class ReqHandler(_BaseHTTPServer.BaseHTTPRequestHandler):
self.end_headers() self.end_headers()
# 发送响应内容
if content: if content:
self.wfile.write(content) self.wfile.write(content)
self.wfile.flush() self.wfile.flush()
def _format(self, content, **params): def _format(self, content, **params):
"""格式化响应内容,替换模板标记"""
if content: if content:
for key, value in params.items(): for key, value in params.items():
content = content.replace("<!%s!>" % key, value) content = content.replace("<!%s!>" % key, value)
return content return content
def version_string(self): def version_string(self):
"""返回服务器版本信息"""
return VERSION_STRING return VERSION_STRING
def log_message(self, format, *args): def log_message(self, format, *args):
"""禁用日志记录"""
return return
def finish(self): def finish(self):
"""完成请求处理"""
try: try:
_BaseHTTPServer.BaseHTTPRequestHandler.finish(self) _BaseHTTPServer.BaseHTTPRequestHandler.finish(self)
except Exception: except Exception:
@ -124,17 +145,19 @@ class ReqHandler(_BaseHTTPServer.BaseHTTPRequestHandler):
traceback.print_exc() traceback.print_exc()
def start_httpd(): def start_httpd():
"""启动HTTP服务器"""
server = ThreadingServer((HTTP_ADDRESS, HTTP_PORT), ReqHandler) server = ThreadingServer((HTTP_ADDRESS, HTTP_PORT), ReqHandler)
thread = threading.Thread(target=server.serve_forever) thread = threading.Thread(target=server.serve_forever)
thread.daemon = True thread.daemon = True # 设置为守护线程
thread.start() thread.start()
print("[i] running HTTP server at '%s:%d'" % (HTTP_ADDRESS, HTTP_PORT)) print("[i] running HTTP server at '%s:%d'" % (HTTP_ADDRESS, HTTP_PORT))
if __name__ == "__main__": if __name__ == "__main__":
"""主程序入口"""
try: try:
start_httpd() start_httpd()
# 保持程序运行
while True: while True:
time.sleep(1) time.sleep(1)
except KeyboardInterrupt: except KeyboardInterrupt:

@ -5,73 +5,98 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import re # 导入所需的模块
import re # 正则表达式模块
from lib.core.agent import agent
from lib.core.bigarray import BigArray # 导入sqlmap自定义模块
from lib.core.common import Backend from lib.core.agent import agent # SQL语句处理代理
from lib.core.common import filterNone from lib.core.bigarray import BigArray # 大数组数据结构
from lib.core.common import getSafeExString from lib.core.common import Backend # 数据库后端
from lib.core.common import isNoneValue from lib.core.common import filterNone # 过滤None值
from lib.core.common import isNumPosStrValue from lib.core.common import getSafeExString # 安全获取异常字符串
from lib.core.common import singleTimeWarnMessage from lib.core.common import isNoneValue # 判断是否为None值
from lib.core.common import unArrayizeValue from lib.core.common import isNumPosStrValue # 判断是否为正数字符串
from lib.core.common import unsafeSQLIdentificatorNaming from lib.core.common import singleTimeWarnMessage # 单次警告消息
from lib.core.compat import xrange from lib.core.common import unArrayizeValue # 数组值转换
from lib.core.convert import getUnicode from lib.core.common import unsafeSQLIdentificatorNaming # SQL标识符命名
from lib.core.data import conf from lib.core.compat import xrange # 兼容Python2/3的range
from lib.core.data import kb from lib.core.convert import getUnicode # 转换为Unicode
from lib.core.data import logger from lib.core.data import conf # 配置数据
from lib.core.data import queries from lib.core.data import kb # 知识库数据
from lib.core.dicts import DUMP_REPLACEMENTS from lib.core.data import logger # 日志记录器
from lib.core.enums import CHARSET_TYPE from lib.core.data import queries # SQL查询语句
from lib.core.enums import EXPECTED from lib.core.dicts import DUMP_REPLACEMENTS # 转储替换字典
from lib.core.exception import SqlmapConnectionException from lib.core.enums import CHARSET_TYPE # 字符集类型
from lib.core.exception import SqlmapNoneDataException from lib.core.enums import EXPECTED # 期望值类型
from lib.core.settings import MAX_INT from lib.core.exception import SqlmapConnectionException # 连接异常
from lib.core.settings import NULL from lib.core.exception import SqlmapNoneDataException # 空数据异常
from lib.core.settings import SINGLE_QUOTE_MARKER from lib.core.settings import MAX_INT # 最大整数值
from lib.core.unescaper import unescaper from lib.core.settings import NULL # NULL值
from lib.request import inject from lib.core.settings import SINGLE_QUOTE_MARKER # 单引号标记
from lib.utils.safe2bin import safechardecode from lib.core.unescaper import unescaper # SQL转义处理
from thirdparty.six import unichr as _unichr from lib.request import inject # SQL注入
from lib.utils.safe2bin import safechardecode # 安全字符解码
from thirdparty.six import unichr as _unichr # Unicode字符处理
def pivotDumpTable(table, colList, count=None, blind=True, alias=None): def pivotDumpTable(table, colList, count=None, blind=True, alias=None):
lengths = {} """
entries = {} 数据透视表转储函数
参数:
table - 要转储的表名
colList - 列名列表
count - 记录数(可选)
blind - 是否使用盲注(默认True)
alias - 表别名(可选)
"""
# 初始化存储结构
lengths = {} # 存储每列的最大长度
entries = {} # 存储每列的数据
# 获取当前数据库对应的转储查询语句
dumpNode = queries[Backend.getIdentifiedDbms()].dump_table.blind dumpNode = queries[Backend.getIdentifiedDbms()].dump_table.blind
validColumnList = False # 初始化验证标志
validPivotValue = False validColumnList = False # 列表是否有效
validPivotValue = False # 透视值是否有效
# 如果没有提供记录数,则查询获取
if count is None: if count is None:
query = dumpNode.count % table query = dumpNode.count % table
query = agent.whereQuery(query) query = agent.whereQuery(query) # 添加WHERE子句
# 根据注入方式获取记录数
count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) if blind else inject.getValue(query, blind=False, time=False, expected=EXPECTED.INT) count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) if blind else inject.getValue(query, blind=False, time=False, expected=EXPECTED.INT)
# 验证记录数的有效性
if hasattr(count, "isdigit") and count.isdigit(): if hasattr(count, "isdigit") and count.isdigit():
count = int(count) count = int(count)
# 处理空表情况
if count == 0: if count == 0:
infoMsg = "table '%s' appears to be empty" % unsafeSQLIdentificatorNaming(table) infoMsg = "table '%s' appears to be empty" % unsafeSQLIdentificatorNaming(table)
logger.info(infoMsg) logger.info(infoMsg)
# 初始化空表的返回结构
for column in colList: for column in colList:
lengths[column] = len(column) lengths[column] = len(column)
entries[column] = [] entries[column] = []
return entries, lengths return entries, lengths
# 检查记录数是否有效
elif not isNumPosStrValue(count): elif not isNumPosStrValue(count):
return None return None
# 初始化每列的数据结构
for column in colList: for column in colList:
lengths[column] = 0 lengths[column] = 0
entries[column] = BigArray() entries[column] = BigArray()
# 按列名长度排序列表
colList = filterNone(sorted(colList, key=lambda x: len(x) if x else MAX_INT)) colList = filterNone(sorted(colList, key=lambda x: len(x) if x else MAX_INT))
# 处理指定的透视列
if conf.pivotColumn: if conf.pivotColumn:
for _ in colList: for _ in colList:
if re.search(r"(.+\.)?%s" % re.escape(conf.pivotColumn), _, re.I): if re.search(r"(.+\.)?%s" % re.escape(conf.pivotColumn), _, re.I):
@ -79,23 +104,27 @@ def pivotDumpTable(table, colList, count=None, blind=True, alias=None):
infoMsg += "for retrieving row data" infoMsg += "for retrieving row data"
logger.info(infoMsg) logger.info(infoMsg)
# 将透视列移到列表首位
colList.remove(_) colList.remove(_)
colList.insert(0, _) colList.insert(0, _)
validPivotValue = True validPivotValue = True
break break
# 透视列未找到时发出警告
if not validPivotValue: if not validPivotValue:
warnMsg = "column '%s' not " % conf.pivotColumn warnMsg = "column '%s' not " % conf.pivotColumn
warnMsg += "found in table '%s'" % table warnMsg += "found in table '%s'" % table
logger.warning(warnMsg) logger.warning(warnMsg)
# 如果没有有效的透视值,则自动选择合适的透视列
if not validPivotValue: if not validPivotValue:
for column in colList: for column in colList:
infoMsg = "fetching number of distinct " infoMsg = "fetching number of distinct "
infoMsg += "values for column '%s'" % column.replace(("%s." % alias) if alias else "", "") infoMsg += "values for column '%s'" % column.replace(("%s." % alias) if alias else "", "")
logger.info(infoMsg) logger.info(infoMsg)
# 查询列的唯一值数量
query = dumpNode.count2 % (column, table) query = dumpNode.count2 % (column, table)
query = agent.whereQuery(query) query = agent.whereQuery(query)
value = inject.getValue(query, blind=blind, union=not blind, error=not blind, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) value = inject.getValue(query, blind=blind, union=not blind, error=not blind, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS)
@ -103,6 +132,7 @@ def pivotDumpTable(table, colList, count=None, blind=True, alias=None):
if isNumPosStrValue(value): if isNumPosStrValue(value):
validColumnList = True validColumnList = True
# 如果唯一值数量等于记录数,则选为透视列
if value == count: if value == count:
infoMsg = "using column '%s' as a pivot " % column.replace(("%s." % alias) if alias else "", "") infoMsg = "using column '%s' as a pivot " % column.replace(("%s." % alias) if alias else "", "")
infoMsg += "for retrieving row data" infoMsg += "for retrieving row data"
@ -113,37 +143,53 @@ def pivotDumpTable(table, colList, count=None, blind=True, alias=None):
colList.insert(0, column) colList.insert(0, column)
break break
# 处理无效列名情况
if not validColumnList: if not validColumnList:
errMsg = "all provided column name(s) are non-existent" errMsg = "all provided column name(s) are non-existent"
raise SqlmapNoneDataException(errMsg) raise SqlmapNoneDataException(errMsg)
# 没有合适的透视列时发出警告
if not validPivotValue: if not validPivotValue:
warnMsg = "no proper pivot column provided (with unique values)." warnMsg = "no proper pivot column provided (with unique values)."
warnMsg += " It won't be possible to retrieve all rows" warnMsg += " It won't be possible to retrieve all rows"
logger.warning(warnMsg) logger.warning(warnMsg)
pivotValue = " " pivotValue = " " # 初始透视值
breakRetrieval = False breakRetrieval = False # 中断检索标志
# 内部查询函数
def _(column, pivotValue): def _(column, pivotValue):
"""
执行单个值的查询
参数:
column - 列名
pivotValue - 透视值
"""
if column == colList[0]: if column == colList[0]:
# 透视列的查询
query = dumpNode.query.replace("'%s'" if unescaper.escape(pivotValue, False) != pivotValue else "%s", "%s") % (agent.preprocessField(table, column), table, agent.preprocessField(table, column), unescaper.escape(pivotValue, False)) query = dumpNode.query.replace("'%s'" if unescaper.escape(pivotValue, False) != pivotValue else "%s", "%s") % (agent.preprocessField(table, column), table, agent.preprocessField(table, column), unescaper.escape(pivotValue, False))
else: else:
# 非透视列的查询
query = dumpNode.query2.replace("'%s'" if unescaper.escape(pivotValue, False) != pivotValue else "%s", "%s") % (agent.preprocessField(table, column), table, agent.preprocessField(table, colList[0]), unescaper.escape(pivotValue, False) if SINGLE_QUOTE_MARKER not in dumpNode.query2 else pivotValue) query = dumpNode.query2.replace("'%s'" if unescaper.escape(pivotValue, False) != pivotValue else "%s", "%s") % (agent.preprocessField(table, column), table, agent.preprocessField(table, colList[0]), unescaper.escape(pivotValue, False) if SINGLE_QUOTE_MARKER not in dumpNode.query2 else pivotValue)
query = agent.whereQuery(query) query = agent.whereQuery(query)
return unArrayizeValue(inject.getValue(query, blind=blind, time=blind, union=not blind, error=not blind)) return unArrayizeValue(inject.getValue(query, blind=blind, time=blind, union=not blind, error=not blind))
try: try:
# 主循环:遍历所有记录
for i in xrange(count): for i in xrange(count):
if breakRetrieval: if breakRetrieval:
break break
# 遍历每一列
for column in colList: for column in colList:
value = _(column, pivotValue) value = _(column, pivotValue)
if column == colList[0]: if column == colList[0]:
# 处理透视列的特殊情况
if isNoneValue(value): if isNoneValue(value):
try: try:
# 尝试不同的透视值
for pivotValue in filterNone((" " if pivotValue == " " else None, "%s%s" % (pivotValue[0], _unichr(ord(pivotValue[1]) + 1)) if len(pivotValue) > 1 else None, _unichr(ord(pivotValue[0]) + 1))): for pivotValue in filterNone((" " if pivotValue == " " else None, "%s%s" % (pivotValue[0], _unichr(ord(pivotValue[1]) + 1)) if len(pivotValue) > 1 else None, _unichr(ord(pivotValue[0]) + 1))):
value = _(column, pivotValue) value = _(column, pivotValue)
if not isNoneValue(value): if not isNoneValue(value):
@ -151,12 +197,14 @@ def pivotDumpTable(table, colList, count=None, blind=True, alias=None):
except ValueError: except ValueError:
pass pass
# 检查是否需要中断检索
if isNoneValue(value) or value == NULL: if isNoneValue(value) or value == NULL:
breakRetrieval = True breakRetrieval = True
break break
pivotValue = safechardecode(value) pivotValue = safechardecode(value)
# 处理LIMIT语句
if conf.limitStart or conf.limitStop: if conf.limitStart or conf.limitStop:
if conf.limitStart and (i + 1) < conf.limitStart: if conf.limitStart and (i + 1) < conf.limitStart:
warnMsg = "skipping first %d pivot " % conf.limitStart warnMsg = "skipping first %d pivot " % conf.limitStart
@ -167,12 +215,16 @@ def pivotDumpTable(table, colList, count=None, blind=True, alias=None):
breakRetrieval = True breakRetrieval = True
break break
# 处理获取到的值
value = "" if isNoneValue(value) else unArrayizeValue(value) value = "" if isNoneValue(value) else unArrayizeValue(value)
# 更新列的最大长度
lengths[column] = max(lengths[column], len(DUMP_REPLACEMENTS.get(getUnicode(value), getUnicode(value)))) lengths[column] = max(lengths[column], len(DUMP_REPLACEMENTS.get(getUnicode(value), getUnicode(value))))
# 添加值到结果集
entries[column].append(value) entries[column].append(value)
except KeyboardInterrupt: except KeyboardInterrupt:
# 处理用户中断
kb.dumpKeyboardInterrupt = True kb.dumpKeyboardInterrupt = True
warnMsg = "user aborted during enumeration. sqlmap " warnMsg = "user aborted during enumeration. sqlmap "
@ -180,9 +232,11 @@ def pivotDumpTable(table, colList, count=None, blind=True, alias=None):
logger.warning(warnMsg) logger.warning(warnMsg)
except SqlmapConnectionException as ex: except SqlmapConnectionException as ex:
# 处理连接异常
errMsg = "connection exception detected ('%s'). sqlmap " % getSafeExString(ex) errMsg = "connection exception detected ('%s'). sqlmap " % getSafeExString(ex)
errMsg += "will display partial output" errMsg += "will display partial output"
logger.critical(errMsg) logger.critical(errMsg)
# 返回结果
return entries, lengths return entries, lengths

@ -5,42 +5,57 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
from __future__ import division # 导入必要的模块
from __future__ import division # 确保除法运算返回浮点数
import time import time # 用于计时功能
from lib.core.common import dataToStdout # 导入自定义模块
from lib.core.convert import getUnicode from lib.core.common import dataToStdout # 用于向标准输出写入数据
from lib.core.data import conf from lib.core.convert import getUnicode # 用于字符串Unicode转换
from lib.core.data import kb from lib.core.data import conf # 配置相关
from lib.core.data import kb # 知识库相关
class ProgressBar(object): class ProgressBar(object):
""" """
This class defines methods to update and draw a progress bar 这个类定义了更新和绘制进度条的方法
""" """
def __init__(self, minValue=0, maxValue=10, totalWidth=None): def __init__(self, minValue=0, maxValue=10, totalWidth=None):
self._progBar = "[]" """
self._min = int(minValue) 初始化进度条
self._max = int(maxValue) 参数:
self._span = max(self._max - self._min, 0.001) minValue: 最小值(默认0)
self._width = totalWidth if totalWidth else conf.progressWidth maxValue: 最大值(默认10)
self._amount = 0 totalWidth: 进度条总宽度(如果为None则使用配置中的宽度)
self._start = None """
self.update() self._progBar = "[]" # 进度条的基本形状
self._min = int(minValue) # 最小值
self._max = int(maxValue) # 最大值
self._span = max(self._max - self._min, 0.001) # 计算范围(避免除以0)
self._width = totalWidth if totalWidth else conf.progressWidth # 设置进度条宽度
self._amount = 0 # 当前进度值
self._start = None # 开始时间
self.update() # 初始化时更新一次
def _convertSeconds(self, value): def _convertSeconds(self, value):
"""
将秒数转换为分:秒格式
例如: 90 -> "01:30"
"""
seconds = value seconds = value
minutes = seconds // 60 minutes = seconds // 60 # 计算分钟数
seconds = seconds - (minutes * 60) seconds = seconds - (minutes * 60) # 计算剩余秒数
return "%.2d:%.2d" % (minutes, seconds) return "%.2d:%.2d" % (minutes, seconds) # 返回格式化的时间字符串
def update(self, newAmount=0): def update(self, newAmount=0):
""" """
This method updates the progress bar 更新进度条的状态
参数:
newAmount: 新的进度值
""" """
# 确保进度值在合理范围内
if newAmount < self._min: if newAmount < self._min:
newAmount = self._min newAmount = self._min
elif newAmount > self._max: elif newAmount > self._max:
@ -48,57 +63,67 @@ class ProgressBar(object):
self._amount = newAmount self._amount = newAmount
# Figure out the new percent done, round to an integer # 计算完成百分比
diffFromMin = float(self._amount - self._min) diffFromMin = float(self._amount - self._min)
percentDone = (diffFromMin / float(self._span)) * 100.0 percentDone = (diffFromMin / float(self._span)) * 100.0
percentDone = round(percentDone) percentDone = round(percentDone)
percentDone = min(100, int(percentDone)) percentDone = min(100, int(percentDone))
# Figure out how many hash bars the percentage should be # 计算需要显示多少个等号作为进度标记
allFull = self._width - len("100%% [] %s/%s (ETA 00:00)" % (self._max, self._max)) allFull = self._width - len("100%% [] %s/%s (ETA 00:00)" % (self._max, self._max))
numHashes = (percentDone / 100.0) * allFull numHashes = (percentDone / 100.0) * allFull
numHashes = int(round(numHashes)) numHashes = int(round(numHashes))
# Build a progress bar with an arrow of equal signs # 构建进度条字符串
if numHashes == 0: if numHashes == 0: # 0%进度
self._progBar = "[>%s]" % (" " * (allFull - 1)) self._progBar = "[>%s]" % (" " * (allFull - 1))
elif numHashes == allFull: elif numHashes == allFull: # 100%进度
self._progBar = "[%s]" % ("=" * allFull) self._progBar = "[%s]" % ("=" * allFull)
else: else: # 部分进度
self._progBar = "[%s>%s]" % ("=" * (numHashes - 1), " " * (allFull - numHashes)) self._progBar = "[%s>%s]" % ("=" * (numHashes - 1), " " * (allFull - numHashes))
# Add the percentage at the beginning of the progress bar # 在进度条开头添加百分比
percentString = getUnicode(percentDone) + "%" percentString = getUnicode(percentDone) + "%"
self._progBar = "%s %s" % (percentString, self._progBar) self._progBar = "%s %s" % (percentString, self._progBar)
def progress(self, newAmount): def progress(self, newAmount):
""" """
This method saves item delta time and shows updated progress bar with calculated eta 保存时间增量并显示更新后的进度条及预计完成时间
参数:
newAmount: 新的进度值
""" """
# 计算预计完成时间(ETA)
if self._start is None or newAmount > self._max: if self._start is None or newAmount > self._max:
self._start = time.time() self._start = time.time()
eta = None eta = None
else: else:
delta = time.time() - self._start delta = time.time() - self._start # 计算已经过时间
eta = (self._max - self._min) * (1.0 * delta / newAmount) - delta eta = (self._max - self._min) * (1.0 * delta / newAmount) - delta # 估算剩余时间
self.update(newAmount) self.update(newAmount) # 更新进度
self.draw(eta) self.draw(eta) # 绘制进度条
def draw(self, eta=None): def draw(self, eta=None):
""" """
This method draws the progress bar if it has changed 如果进度发生变化则绘制进度条
参数:
eta: 预计完成时间()
""" """
# 在同一行更新进度条显示
dataToStdout("\r%s %d/%d%s" % (self._progBar, self._amount, self._max, (" (ETA %s)" % (self._convertSeconds(int(eta)) if eta is not None else "??:??")))) dataToStdout("\r%s %d/%d%s" % (
self._progBar, # 进度条
self._amount, # 当前进度
self._max, # 最大值
(" (ETA %s)" % (self._convertSeconds(int(eta)) if eta is not None else "??:??")) # 显示预计完成时间
))
# 当进度达到最大值时清除进度条
if self._amount >= self._max: if self._amount >= self._max:
dataToStdout("\r%s\r" % (" " * self._width)) dataToStdout("\r%s\r" % (" " * self._width))
kb.prependFlag = False kb.prependFlag = False
def __str__(self): def __str__(self):
""" """
This method returns the progress bar string 返回进度条的字符串表示
""" """
return getUnicode(self._progBar) return getUnicode(self._progBar)

@ -5,40 +5,49 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import functools # 导入所需的标准库和自定义模块
import os import functools # 用于高阶函数和操作
import random import os # 提供与操作系统交互的功能
import shutil import random # 用于生成随机数
import stat import shutil # 提供高级文件操作
import string import stat # 提供文件属性相关常量
import string # 提供字符串常量
from lib.core.common import getSafeExString # 导入自定义工具函数
from lib.core.common import openFile from lib.core.common import getSafeExString # 安全地获取异常字符串
from lib.core.compat import xrange from lib.core.common import openFile # 安全地打开文件
from lib.core.convert import getUnicode from lib.core.compat import xrange # 兼容Python2/3的range函数
from lib.core.data import logger from lib.core.convert import getUnicode # 转换为Unicode字符串
from thirdparty.six import unichr as _unichr from lib.core.data import logger # 日志记录器
from thirdparty.six import unichr as _unichr # 兼容Python2/3的unichr函数
def purge(directory): def purge(directory):
""" """
Safely removes content from a given directory 安全地删除指定目录的所有内容
参数:
directory: 要清理的目录路径
""" """
# 检查目录是否存在
if not os.path.isdir(directory): if not os.path.isdir(directory):
warnMsg = "skipping purging of directory '%s' as it does not exist" % directory warnMsg = "skipping purging of directory '%s' as it does not exist" % directory
logger.warning(warnMsg) logger.warning(warnMsg)
return return
# 输出开始清理的信息
infoMsg = "purging content of directory '%s'..." % directory infoMsg = "purging content of directory '%s'..." % directory
logger.info(infoMsg) logger.info(infoMsg)
# 初始化存储文件路径和目录路径的列表
filepaths = [] filepaths = []
dirpaths = [] dirpaths = []
# 遍历目录树,收集所有文件和目录的绝对路径
for rootpath, directories, filenames in os.walk(directory): for rootpath, directories, filenames in os.walk(directory):
dirpaths.extend(os.path.abspath(os.path.join(rootpath, _)) for _ in directories) dirpaths.extend(os.path.abspath(os.path.join(rootpath, _)) for _ in directories)
filepaths.extend(os.path.abspath(os.path.join(rootpath, _)) for _ in filenames) filepaths.extend(os.path.abspath(os.path.join(rootpath, _)) for _ in filenames)
# 第一步:修改所有文件的权限为可读可写
logger.debug("changing file attributes") logger.debug("changing file attributes")
for filepath in filepaths: for filepath in filepaths:
try: try:
@ -46,6 +55,7 @@ def purge(directory):
except: except:
pass pass
# 第二步:用随机数据覆盖文件内容
logger.debug("writing random data to files") logger.debug("writing random data to files")
for filepath in filepaths: for filepath in filepaths:
try: try:
@ -55,6 +65,7 @@ def purge(directory):
except: except:
pass pass
# 第三步:清空所有文件
logger.debug("truncating files") logger.debug("truncating files")
for filepath in filepaths: for filepath in filepaths:
try: try:
@ -63,6 +74,7 @@ def purge(directory):
except: except:
pass pass
# 第四步:将文件名替换为随机字母组合
logger.debug("renaming filenames to random values") logger.debug("renaming filenames to random values")
for filepath in filepaths: for filepath in filepaths:
try: try:
@ -70,8 +82,10 @@ def purge(directory):
except: except:
pass pass
# 按目录深度排序,确保先处理最深的目录
dirpaths.sort(key=functools.cmp_to_key(lambda x, y: y.count(os.path.sep) - x.count(os.path.sep))) dirpaths.sort(key=functools.cmp_to_key(lambda x, y: y.count(os.path.sep) - x.count(os.path.sep)))
# 第五步:将目录名替换为随机字母组合
logger.debug("renaming directory names to random values") logger.debug("renaming directory names to random values")
for dirpath in dirpaths: for dirpath in dirpaths:
try: try:
@ -79,6 +93,7 @@ def purge(directory):
except: except:
pass pass
# 最后一步:删除整个目录树
logger.debug("deleting the whole directory tree") logger.debug("deleting the whole directory tree")
try: try:
shutil.rmtree(directory) shutil.rmtree(directory)

@ -5,44 +5,55 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import binascii # 导入所需的标准库
import re import binascii # 用于二进制和ASCII转换
import string import re # 用于正则表达式操作
import sys import string # 提供字符串常量和工具
import sys # 提供系统相关的功能
# 判断Python版本是否为Python 3及以上
PY3 = sys.version_info >= (3, 0) PY3 = sys.version_info >= (3, 0)
# 根据Python版本设置不同的类型别名
if PY3: if PY3:
xrange = range xrange = range # Python 3中range替代了xrange
text_type = str text_type = str # Python 3中str即为文本类型
string_types = (str,) string_types = (str,) # Python 3中字符串类型只有str
unichr = chr unichr = chr # Python 3中chr替代了unichr
else: else:
text_type = unicode text_type = unicode # Python 2中使用unicode作为文本类型
string_types = (basestring,) string_types = (basestring,) # Python 2中的字符串基类
# Regex used for recognition of hex encoded characters # 用于识别十六进制编码字符的正则表达式
HEX_ENCODED_CHAR_REGEX = r"(?P<result>\\x[0-9A-Fa-f]{2})" HEX_ENCODED_CHAR_REGEX = r"(?P<result>\\x[0-9A-Fa-f]{2})"
# Raw chars that will be safe encoded to their slash (\) representations (e.g. newline to \n) # 需要被安全编码为反斜杠表示的原始字符(如换行符编码为\n)
SAFE_ENCODE_SLASH_REPLACEMENTS = "\t\n\r\x0b\x0c" SAFE_ENCODE_SLASH_REPLACEMENTS = "\t\n\r\x0b\x0c"
# Characters that don't need to be safe encoded # 不需要安全编码的字符集合
# 包含所有可打印字符,但不包含反斜杠和上面定义的需要特殊处理的字符
SAFE_CHARS = "".join([_ for _ in string.printable.replace('\\', '') if _ not in SAFE_ENCODE_SLASH_REPLACEMENTS]) SAFE_CHARS = "".join([_ for _ in string.printable.replace('\\', '') if _ not in SAFE_ENCODE_SLASH_REPLACEMENTS])
# Prefix used for hex encoded values # 十六进制编码值的前缀
HEX_ENCODED_PREFIX = r"\x" HEX_ENCODED_PREFIX = r"\x"
# Strings used for temporary marking of hex encoded prefixes (to prevent double encoding) # 用于临时标记十六进制编码前缀的字符串(防止重复编码)
HEX_ENCODED_PREFIX_MARKER = "__HEX_ENCODED_PREFIX__" HEX_ENCODED_PREFIX_MARKER = "__HEX_ENCODED_PREFIX__"
# String used for temporary marking of slash characters # 用于临时标记反斜杠字符的字符串
SLASH_MARKER = "__SLASH__" SLASH_MARKER = "__SLASH__"
def safecharencode(value): def safecharencode(value):
""" """
Returns safe representation of a given basestring value 将给定的字符串值转换为安全的表示形式
参数:
value: 需要编码的值可以是字符串或列表
返回:
编码后的安全字符串或处理后的列表
示例:
>>> safecharencode(u'test123') == u'test123' >>> safecharencode(u'test123') == u'test123'
True True
>>> safecharencode(u'test\x01\x02\xaf') == u'test\\\\x01\\\\x02\\xaf' >>> safecharencode(u'test\x01\x02\xaf') == u'test\\\\x01\\\\x02\\xaf'
@ -51,35 +62,48 @@ def safecharencode(value):
retVal = value retVal = value
if isinstance(value, string_types): if isinstance(value, string_types): # 如果输入是字符串类型
if any(_ not in SAFE_CHARS for _ in value): if any(_ not in SAFE_CHARS for _ in value): # 检查是否包含不安全字符
# 临时替换已有的十六进制编码前缀和反斜杠
retVal = retVal.replace(HEX_ENCODED_PREFIX, HEX_ENCODED_PREFIX_MARKER) retVal = retVal.replace(HEX_ENCODED_PREFIX, HEX_ENCODED_PREFIX_MARKER)
retVal = retVal.replace('\\', SLASH_MARKER) retVal = retVal.replace('\\', SLASH_MARKER)
# 处理特殊字符(如\n, \t等)
for char in SAFE_ENCODE_SLASH_REPLACEMENTS: for char in SAFE_ENCODE_SLASH_REPLACEMENTS:
retVal = retVal.replace(char, repr(char).strip('\'')) retVal = retVal.replace(char, repr(char).strip('\''))
# 将不可打印字符转换为十六进制表示
for char in set(retVal): for char in set(retVal):
if not (char in string.printable or isinstance(value, text_type) and ord(char) >= 160): if not (char in string.printable or isinstance(value, text_type) and ord(char) >= 160):
retVal = retVal.replace(char, '\\x%02x' % ord(char)) retVal = retVal.replace(char, '\\x%02x' % ord(char))
# 恢复临时替换的标记
retVal = retVal.replace(SLASH_MARKER, "\\\\") retVal = retVal.replace(SLASH_MARKER, "\\\\")
retVal = retVal.replace(HEX_ENCODED_PREFIX_MARKER, HEX_ENCODED_PREFIX) retVal = retVal.replace(HEX_ENCODED_PREFIX_MARKER, HEX_ENCODED_PREFIX)
elif isinstance(value, list): elif isinstance(value, list): # 如果输入是列表
for i in xrange(len(value)): for i in xrange(len(value)):
retVal[i] = safecharencode(value[i]) retVal[i] = safecharencode(value[i]) # 递归处理列表中的每个元素
return retVal return retVal
def safechardecode(value, binary=False): def safechardecode(value, binary=False):
""" """
Reverse function to safecharencode safecharencode的反向函数将安全编码的字符串解码回原始形式
参数:
value: 需要解码的值可以是字符串或列表
binary: 是否返回二进制格式默认为False
返回:
解码后的原始字符串或处理后的列表
""" """
retVal = value retVal = value
if isinstance(value, string_types): if isinstance(value, string_types): # 如果输入是字符串类型
# 临时替换反斜杠
retVal = retVal.replace('\\\\', SLASH_MARKER) retVal = retVal.replace('\\\\', SLASH_MARKER)
# 解码所有十六进制编码的字符
while True: while True:
match = re.search(HEX_ENCODED_CHAR_REGEX, retVal) match = re.search(HEX_ENCODED_CHAR_REGEX, retVal)
if match: if match:
@ -87,17 +111,20 @@ def safechardecode(value, binary=False):
else: else:
break break
# 还原特殊字符
for char in SAFE_ENCODE_SLASH_REPLACEMENTS[::-1]: for char in SAFE_ENCODE_SLASH_REPLACEMENTS[::-1]:
retVal = retVal.replace(repr(char).strip('\''), char) retVal = retVal.replace(repr(char).strip('\''), char)
# 恢复反斜杠
retVal = retVal.replace(SLASH_MARKER, '\\') retVal = retVal.replace(SLASH_MARKER, '\\')
# 如果需要二进制格式
if binary: if binary:
if isinstance(retVal, text_type): if isinstance(retVal, text_type):
retVal = retVal.encode("utf8", errors="surrogatepass" if PY3 else "strict") retVal = retVal.encode("utf8", errors="surrogatepass" if PY3 else "strict")
elif isinstance(value, (list, tuple)): elif isinstance(value, (list, tuple)): # 如果输入是列表或元组
for i in xrange(len(value)): for i in xrange(len(value)):
retVal[i] = safechardecode(value[i]) retVal[i] = safechardecode(value[i]) # 递归处理每个元素
return retVal return retVal

@ -5,9 +5,11 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
# 导入所需的标准库和第三方库
import re import re
import socket import socket
# 导入自定义工具函数
from lib.core.common import getSafeExString from lib.core.common import getSafeExString
from lib.core.common import popValue from lib.core.common import popValue
from lib.core.common import pushValue from lib.core.common import pushValue
@ -39,50 +41,63 @@ from thirdparty.socks import socks
def _search(dork): def _search(dork):
""" """
This method performs the effective search on Google providing 该方法使用Google搜索引擎执行实际的搜索操作
the google dork and the Google session cookie 参数:
dork: 搜索关键词
返回:
搜索结果URL列表,如果搜索失败则返回None
""" """
# 如果搜索关键词为空则返回None
if not dork: if not dork:
return None return None
# 初始化变量
page = None page = None
data = None data = None
requestHeaders = {} requestHeaders = {}
responseHeaders = {} responseHeaders = {}
# 设置HTTP请求头
requestHeaders[HTTP_HEADER.USER_AGENT] = dict(conf.httpHeaders).get(HTTP_HEADER.USER_AGENT, DUMMY_SEARCH_USER_AGENT) requestHeaders[HTTP_HEADER.USER_AGENT] = dict(conf.httpHeaders).get(HTTP_HEADER.USER_AGENT, DUMMY_SEARCH_USER_AGENT)
requestHeaders[HTTP_HEADER.ACCEPT_ENCODING] = HTTP_ACCEPT_ENCODING_HEADER_VALUE requestHeaders[HTTP_HEADER.ACCEPT_ENCODING] = HTTP_ACCEPT_ENCODING_HEADER_VALUE
requestHeaders[HTTP_HEADER.COOKIE] = GOOGLE_CONSENT_COOKIE requestHeaders[HTTP_HEADER.COOKIE] = GOOGLE_CONSENT_COOKIE
try: try:
# 首先访问Google的NCR(No Country Redirect)页面
req = _urllib.request.Request("https://www.google.com/ncr", headers=requestHeaders) req = _urllib.request.Request("https://www.google.com/ncr", headers=requestHeaders)
conn = _urllib.request.urlopen(req) conn = _urllib.request.urlopen(req)
except Exception as ex: except Exception as ex:
errMsg = "unable to connect to Google ('%s')" % getSafeExString(ex) errMsg = "unable to connect to Google ('%s')" % getSafeExString(ex)
raise SqlmapConnectionException(errMsg) raise SqlmapConnectionException(errMsg)
# 获取搜索结果页码
gpage = conf.googlePage if conf.googlePage > 1 else 1 gpage = conf.googlePage if conf.googlePage > 1 else 1
logger.info("using search result page #%d" % gpage) logger.info("using search result page #%d" % gpage)
url = "https://www.google.com/search?" # NOTE: if consent fails, try to use the "http://" # 构造Google搜索URL
url = "https://www.google.com/search?" # 如果consent失败,尝试使用"http://"
url += "q=%s&" % urlencode(dork, convall=True) url += "q=%s&" % urlencode(dork, convall=True)
url += "num=100&hl=en&complete=0&safe=off&filter=0&btnG=Search" url += "num=100&hl=en&complete=0&safe=off&filter=0&btnG=Search"
url += "&start=%d" % ((gpage - 1) * 100) url += "&start=%d" % ((gpage - 1) * 100)
try: try:
# 发送搜索请求
req = _urllib.request.Request(url, headers=requestHeaders) req = _urllib.request.Request(url, headers=requestHeaders)
conn = _urllib.request.urlopen(req) conn = _urllib.request.urlopen(req)
# 记录请求日志
requestMsg = "HTTP request:\nGET %s" % url requestMsg = "HTTP request:\nGET %s" % url
requestMsg += " %s" % _http_client.HTTPConnection._http_vsn_str requestMsg += " %s" % _http_client.HTTPConnection._http_vsn_str
logger.log(CUSTOM_LOGGING.TRAFFIC_OUT, requestMsg) logger.log(CUSTOM_LOGGING.TRAFFIC_OUT, requestMsg)
# 获取响应内容
page = conn.read() page = conn.read()
code = conn.code code = conn.code
status = conn.msg status = conn.msg
responseHeaders = conn.info() responseHeaders = conn.info()
# 记录响应日志
responseMsg = "HTTP response (%s - %d):\n" % (status, code) responseMsg = "HTTP response (%s - %d):\n" % (status, code)
if conf.verbose <= 4: if conf.verbose <= 4:
@ -93,6 +108,7 @@ def _search(dork):
logger.log(CUSTOM_LOGGING.TRAFFIC_IN, responseMsg) logger.log(CUSTOM_LOGGING.TRAFFIC_IN, responseMsg)
except _urllib.error.HTTPError as ex: except _urllib.error.HTTPError as ex:
try: try:
# 处理HTTP错误
page = ex.read() page = ex.read()
responseHeaders = ex.info() responseHeaders = ex.info()
except Exception as _: except Exception as _:
@ -104,12 +120,15 @@ def _search(dork):
errMsg = "unable to connect to Google" errMsg = "unable to connect to Google"
raise SqlmapConnectionException(errMsg) raise SqlmapConnectionException(errMsg)
# 解码页面内容
page = decodePage(page, responseHeaders.get(HTTP_HEADER.CONTENT_ENCODING), responseHeaders.get(HTTP_HEADER.CONTENT_TYPE)) page = decodePage(page, responseHeaders.get(HTTP_HEADER.CONTENT_ENCODING), responseHeaders.get(HTTP_HEADER.CONTENT_TYPE))
page = getUnicode(page) # Note: if decodePage call fails (Issue #4202) page = getUnicode(page) # 如果decodePage调用失败的处理(Issue #4202)
# 使用正则表达式提取搜索结果URL
retVal = [_urllib.parse.unquote(match.group(1) or match.group(2)) for match in re.finditer(GOOGLE_REGEX, page, re.I)] retVal = [_urllib.parse.unquote(match.group(1) or match.group(2)) for match in re.finditer(GOOGLE_REGEX, page, re.I)]
# 检查是否被Google检测为异常流量
if not retVal and "detected unusual traffic" in page: if not retVal and "detected unusual traffic" in page:
warnMsg = "Google has detected 'unusual' traffic from " warnMsg = "Google has detected 'unusual' traffic from "
warnMsg += "used IP address disabling further searches" warnMsg += "used IP address disabling further searches"
@ -119,6 +138,7 @@ def _search(dork):
else: else:
logger.critical(warnMsg) logger.critical(warnMsg)
# 如果Google搜索失败,提供备选搜索引擎选项
if not retVal: if not retVal:
message = "no usable links found. What do you want to do?" message = "no usable links found. What do you want to do?"
message += "\n[1] (re)try with DuckDuckGo (default)" message += "\n[1] (re)try with DuckDuckGo (default)"
@ -129,27 +149,33 @@ def _search(dork):
if choice == '3': if choice == '3':
raise SqlmapUserQuitException raise SqlmapUserQuitException
elif choice == '2': elif choice == '2':
# 使用Bing搜索
url = "https://www.bing.com/search?q=%s&first=%d" % (urlencode(dork, convall=True), (gpage - 1) * 10 + 1) url = "https://www.bing.com/search?q=%s&first=%d" % (urlencode(dork, convall=True), (gpage - 1) * 10 + 1)
regex = BING_REGEX regex = BING_REGEX
else: else:
# 使用DuckDuckGo搜索
url = "https://html.duckduckgo.com/html/" url = "https://html.duckduckgo.com/html/"
data = "q=%s&s=%d" % (urlencode(dork, convall=True), (gpage - 1) * 30) data = "q=%s&s=%d" % (urlencode(dork, convall=True), (gpage - 1) * 30)
regex = DUCKDUCKGO_REGEX regex = DUCKDUCKGO_REGEX
try: try:
# 发送备选搜索引擎请求
req = _urllib.request.Request(url, data=getBytes(data), headers=requestHeaders) req = _urllib.request.Request(url, data=getBytes(data), headers=requestHeaders)
conn = _urllib.request.urlopen(req) conn = _urllib.request.urlopen(req)
# 记录请求日志
requestMsg = "HTTP request:\nGET %s" % url requestMsg = "HTTP request:\nGET %s" % url
requestMsg += " %s" % _http_client.HTTPConnection._http_vsn_str requestMsg += " %s" % _http_client.HTTPConnection._http_vsn_str
logger.log(CUSTOM_LOGGING.TRAFFIC_OUT, requestMsg) logger.log(CUSTOM_LOGGING.TRAFFIC_OUT, requestMsg)
# 获取响应内容
page = conn.read() page = conn.read()
code = conn.code code = conn.code
status = conn.msg status = conn.msg
responseHeaders = conn.info() responseHeaders = conn.info()
page = decodePage(page, responseHeaders.get("Content-Encoding"), responseHeaders.get("Content-Type")) page = decodePage(page, responseHeaders.get("Content-Encoding"), responseHeaders.get("Content-Type"))
# 记录响应日志
responseMsg = "HTTP response (%s - %d):\n" % (status, code) responseMsg = "HTTP response (%s - %d):\n" % (status, code)
if conf.verbose <= 4: if conf.verbose <= 4:
@ -160,6 +186,7 @@ def _search(dork):
logger.log(CUSTOM_LOGGING.TRAFFIC_IN, responseMsg) logger.log(CUSTOM_LOGGING.TRAFFIC_IN, responseMsg)
except _urllib.error.HTTPError as ex: except _urllib.error.HTTPError as ex:
try: try:
# 处理HTTP错误
page = ex.read() page = ex.read()
page = decodePage(page, ex.headers.get("Content-Encoding"), ex.headers.get("Content-Type")) page = decodePage(page, ex.headers.get("Content-Encoding"), ex.headers.get("Content-Type"))
except socket.timeout: except socket.timeout:
@ -171,10 +198,12 @@ def _search(dork):
errMsg = "unable to connect" errMsg = "unable to connect"
raise SqlmapConnectionException(errMsg) raise SqlmapConnectionException(errMsg)
page = getUnicode(page) # Note: if decodePage call fails (Issue #4202) page = getUnicode(page) # 如果decodePage调用失败的处理(Issue #4202)
# 使用相应的正则表达式提取搜索结果URL
retVal = [_urllib.parse.unquote(match.group(1).replace("&amp;", "&")) for match in re.finditer(regex, page, re.I | re.S)] retVal = [_urllib.parse.unquote(match.group(1).replace("&amp;", "&")) for match in re.finditer(regex, page, re.I | re.S)]
# 检查是否被DuckDuckGo检测为异常流量
if not retVal and "issue with the Tor Exit Node you are currently using" in page: if not retVal and "issue with the Tor Exit Node you are currently using" in page:
warnMsg = "DuckDuckGo has detected 'unusual' traffic from " warnMsg = "DuckDuckGo has detected 'unusual' traffic from "
warnMsg += "used (Tor) IP address" warnMsg += "used (Tor) IP address"
@ -188,6 +217,13 @@ def _search(dork):
@stackedmethod @stackedmethod
def search(dork): def search(dork):
"""
搜索函数的包装器,处理重定向和代理相关的逻辑
参数:
dork: 搜索关键词
返回:
搜索结果URL列表
"""
pushValue(kb.choices.redirect) pushValue(kb.choices.redirect)
kb.choices.redirect = REDIRECTION.YES kb.choices.redirect = REDIRECTION.YES
@ -209,5 +245,5 @@ def search(dork):
finally: finally:
kb.choices.redirect = popValue() kb.choices.redirect = popValue()
def setHTTPHandlers(): # Cross-referenced function def setHTTPHandlers(): # 交叉引用的函数
raise NotImplementedError raise NotImplementedError

@ -1,141 +1,165 @@
"""A parser for SGML, using the derived class as a static DTD.""" """一个用于解析SGML的解析器,使用派生类作为静态DTD(文档类型定义)。"""
# Note: missing in Python3 # 注意:Python3中已移除此模块
# XXX This only supports those SGML features used by HTML. # XXX 这个解析器只支持HTML中使用的SGML特性
# XXX There should be a way to distinguish between PCDATA (parsed # XXX 应该有一种方法来区分以下三种数据类型:
# character data -- the normal case), RCDATA (replaceable character # PCDATA(解析字符数据 - 正常情况)
# data -- only char and entity references and end tags are special) # RCDATA(可替换字符数据 - 只有字符、实体引用和结束标签是特殊的)
# and CDATA (character data -- only end tags are special). RCDATA is # CDATA(字符数据 - 只有结束标签是特殊的)
# not supported at all. # 目前不支持RCDATA
from __future__ import print_function from __future__ import print_function
try: try:
import _markupbase as markupbase import _markupbase as markupbase # 尝试导入_markupbase模块
except: except:
import markupbase import markupbase # 如果失败则导入markupbase模块
import re import re # 导入正则表达式模块
__all__ = ["SGMLParser", "SGMLParseError"] __all__ = ["SGMLParser", "SGMLParseError"] # 指定可被导入的公共接口
# Regular expressions used for parsing
# 用于解析的正则表达式定义
# 匹配有趣的字符(&和<)
interesting = re.compile('[&<]') interesting = re.compile('[&<]')
# 匹配不完整的标签或实体引用
incomplete = re.compile('&([a-zA-Z][a-zA-Z0-9]*|#[0-9]*)?|' incomplete = re.compile('&([a-zA-Z][a-zA-Z0-9]*|#[0-9]*)?|'
'<([a-zA-Z][^<>]*|' '<([a-zA-Z][^<>]*|'
'/([a-zA-Z][^<>]*)?|' '/([a-zA-Z][^<>]*)?|'
'![^<>]*)?') '![^<>]*)?')
# 匹配实体引用,如&amp;
entityref = re.compile('&([a-zA-Z][-.a-zA-Z0-9]*)[^a-zA-Z0-9]') entityref = re.compile('&([a-zA-Z][-.a-zA-Z0-9]*)[^a-zA-Z0-9]')
# 匹配字符引用,如&#160;
charref = re.compile('&#([0-9]+)[^0-9]') charref = re.compile('&#([0-9]+)[^0-9]')
# 匹配开始标签的开头
starttagopen = re.compile('<[>a-zA-Z]') starttagopen = re.compile('<[>a-zA-Z]')
# 匹配简写标签的开头,如<tag/
shorttagopen = re.compile('<[a-zA-Z][-.a-zA-Z0-9]*/') shorttagopen = re.compile('<[a-zA-Z][-.a-zA-Z0-9]*/')
# 匹配完整的简写标签,如<tag/data/
shorttag = re.compile('<([a-zA-Z][-.a-zA-Z0-9]*)/([^/]*)/') shorttag = re.compile('<([a-zA-Z][-.a-zA-Z0-9]*)/([^/]*)/')
# 匹配处理指令的结束符>
piclose = re.compile('>') piclose = re.compile('>')
# 匹配尖括号
endbracket = re.compile('[<>]') endbracket = re.compile('[<>]')
# 匹配标签名
tagfind = re.compile('[a-zA-Z][-_.a-zA-Z0-9]*') tagfind = re.compile('[a-zA-Z][-_.a-zA-Z0-9]*')
# 匹配属性
attrfind = re.compile( attrfind = re.compile(
r'\s*([a-zA-Z_][-:.a-zA-Z_0-9]*)(\s*=\s*' r'\s*([a-zA-Z_][-:.a-zA-Z_0-9]*)(\s*=\s*'
r'(\'[^\']*\'|"[^"]*"|[][\-a-zA-Z0-9./,:;+*%?!&$\(\)_#=~\'"@]*))?') r'(\'[^\']*\'|"[^"]*"|[][\-a-zA-Z0-9./,:;+*%?!&$\(\)_#=~\'"@]*))?')
class SGMLParseError(RuntimeError): class SGMLParseError(RuntimeError):
"""Exception raised for all parse errors.""" """解析错误时抛出的异常类"""
pass pass
# SGML parser base class -- find tags and call handler functions.
# Usage: p = SGMLParser(); p.feed(data); ...; p.close().
# The dtd is defined by deriving a class which defines methods
# with special names to handle tags: start_foo and end_foo to handle
# <foo> and </foo>, respectively, or do_foo to handle <foo> by itself.
# (Tags are converted to lower case for this purpose.) The data
# between tags is passed to the parser by calling self.handle_data()
# with some data as argument (the data may be split up in arbitrary
# chunks). Entity references are passed by calling
# self.handle_entityref() with the entity reference as argument.
class SGMLParser(markupbase.ParserBase): class SGMLParser(markupbase.ParserBase):
# Definition of entities -- derived classes may override """SGML解析器基类 - 查找标签并调用处理函数
用法: p = SGMLParser(); p.feed(data); ...; p.close()
DTD通过派生类定义,派生类需要定义特殊名称的方法来处理标签:
- start_foo和end_foo分别处理<foo></foo>
- 或者do_foo单独处理<foo>
(标签名会被转换为小写)
标签之间的数据通过调用self.handle_data(data)传递给解析器
实体引用通过调用self.handle_entityref(name)传递
"""
# 实体或字符引用的正则表达式
entity_or_charref = re.compile('&(?:' entity_or_charref = re.compile('&(?:'
'([a-zA-Z][-.a-zA-Z0-9]*)|#([0-9]+)' '([a-zA-Z][-.a-zA-Z0-9]*)|#([0-9]+)'
')(;?)') ')(;?)')
def __init__(self, verbose=0): def __init__(self, verbose=0):
"""Initialize and reset this instance.""" """初始化并重置实例"""
self.verbose = verbose self.verbose = verbose # 是否输出详细信息
self.reset() self.reset()
def reset(self): def reset(self):
"""Reset this instance. Loses all unprocessed data.""" """重置实例状态,丢弃所有未处理的数据"""
self.__starttag_text = None self.__starttag_text = None # 开始标签的原始文本
self.rawdata = '' self.rawdata = '' # 原始数据
self.stack = [] self.stack = [] # 标签栈
self.lasttag = '???' self.lasttag = '???' # 最后处理的标签
self.nomoretags = 0 self.nomoretags = 0 # 是否停止处理标签
self.literal = 0 self.literal = 0 # 是否处于文字模式
markupbase.ParserBase.reset(self) markupbase.ParserBase.reset(self)
def setnomoretags(self): def setnomoretags(self):
"""Enter literal mode (CDATA) till EOF. """进入文字模式(CDATA)直到文件结束
Intended for derived classes only. 仅供派生类使用
""" """
self.nomoretags = self.literal = 1 self.nomoretags = self.literal = 1
def setliteral(self, *args): def setliteral(self, *args):
"""Enter literal mode (CDATA). """进入文字模式(CDATA)
Intended for derived classes only. 仅供派生类使用
""" """
self.literal = 1 self.literal = 1
def feed(self, data): def feed(self, data):
"""Feed some data to the parser. """向解析器提供数据
Call this as often as you want, with as little or as much text 可以多次调用,每次提供任意长度的文本(可以包含换行符)
as you want (may include '\n'). (This just saves the text, 这个方法只是保存文本,实际处理由goahead()完成
all the processing is done by goahead().)
""" """
self.rawdata = self.rawdata + data self.rawdata = self.rawdata + data
self.goahead(0) self.goahead(0)
def close(self): def close(self):
"""Handle the remaining data.""" """处理剩余数据"""
self.goahead(1) self.goahead(1)
def error(self, message): def error(self, message):
"""抛出解析错误异常"""
raise SGMLParseError(message) raise SGMLParseError(message)
# Internal -- handle data as far as reasonable. May leave state
# and data to be processed by a subsequent call. If 'end' is
# true, force handling all data as if followed by EOF marker.
def goahead(self, end): def goahead(self, end):
"""内部方法 - 尽可能处理数据
可能会留下状态和数据等待后续调用处理
如果end为True,则强制处理所有数据
"""
rawdata = self.rawdata rawdata = self.rawdata
i = 0 i = 0 # 当前处理位置
n = len(rawdata) n = len(rawdata)
while i < n: while i < n:
if self.nomoretags: if self.nomoretags: # 如果在文字模式下
self.handle_data(rawdata[i:n]) self.handle_data(rawdata[i:n])
i = n i = n
break break
# 查找下一个特殊字符(&或<)
match = interesting.search(rawdata, i) match = interesting.search(rawdata, i)
if match: if match:
j = match.start() j = match.start()
else: else:
j = n j = n
# 处理普通文本
if i < j: if i < j:
self.handle_data(rawdata[i:j]) self.handle_data(rawdata[i:j])
i = j i = j
if i == n: if i == n:
break break
# 处理标签和实体引用
if rawdata[i] == '<': if rawdata[i] == '<':
if starttagopen.match(rawdata, i): if starttagopen.match(rawdata, i): # 开始标签
if self.literal: if self.literal:
self.handle_data(rawdata[i]) self.handle_data(rawdata[i])
i = i + 1 i = i + 1
@ -145,7 +169,7 @@ class SGMLParser(markupbase.ParserBase):
break break
i = k i = k
continue continue
if rawdata.startswith("</", i): if rawdata.startswith("</", i): # 结束标签
k = self.parse_endtag(i) k = self.parse_endtag(i)
if k < 0: if k < 0:
break break
@ -157,40 +181,32 @@ class SGMLParser(markupbase.ParserBase):
self.handle_data("<") self.handle_data("<")
i = i + 1 i = i + 1
else: else:
# incomplete
break break
continue continue
if rawdata.startswith("<!--", i): if rawdata.startswith("<!--", i): # 注释
# Strictly speaking, a comment is --.*--
# within a declaration tag <!...>.
# This should be removed,
# and comments handled only in parse_declaration.
k = self.parse_comment(i) k = self.parse_comment(i)
if k < 0: if k < 0:
break break
i = k i = k
continue continue
if rawdata.startswith("<?", i): if rawdata.startswith("<?", i): # 处理指令
k = self.parse_pi(i) k = self.parse_pi(i)
if k < 0: if k < 0:
break break
i = i + k i = i + k
continue continue
if rawdata.startswith("<!", i): if rawdata.startswith("<!", i): # 声明(如DOCTYPE)
# This is some sort of declaration; in "HTML as
# deployed," this should only be the document type
# declaration ("<!DOCTYPE html...>").
k = self.parse_declaration(i) k = self.parse_declaration(i)
if k < 0: if k < 0:
break break
i = k i = k
continue continue
elif rawdata[i] == '&': elif rawdata[i] == '&': # 处理实体引用
if self.literal: if self.literal:
self.handle_data(rawdata[i]) self.handle_data(rawdata[i])
i = i + 1 i = i + 1
continue continue
match = charref.match(rawdata, i) match = charref.match(rawdata, i) # 字符引用
if match: if match:
name = match.group(1) name = match.group(1)
self.handle_charref(name) self.handle_charref(name)
@ -198,7 +214,7 @@ class SGMLParser(markupbase.ParserBase):
if rawdata[i - 1] != ';': if rawdata[i - 1] != ';':
i = i - 1 i = i - 1
continue continue
match = entityref.match(rawdata, i) match = entityref.match(rawdata, i) # 实体引用
if match: if match:
name = match.group(1) name = match.group(1)
self.handle_entityref(name) self.handle_entityref(name)
@ -208,8 +224,7 @@ class SGMLParser(markupbase.ParserBase):
continue continue
else: else:
self.error('neither < nor & ??') self.error('neither < nor & ??')
# We get here only if incomplete matches but # 处理不完整的匹配
# nothing else
match = incomplete.match(rawdata, i) match = incomplete.match(rawdata, i)
if not match: if not match:
self.handle_data(rawdata[i]) self.handle_data(rawdata[i])
@ -217,21 +232,20 @@ class SGMLParser(markupbase.ParserBase):
continue continue
j = match.end(0) j = match.end(0)
if j == n: if j == n:
break # Really incomplete break
self.handle_data(rawdata[i:j]) self.handle_data(rawdata[i:j])
i = j i = j
# end while # 处理剩余数据
if end and i < n: if end and i < n:
self.handle_data(rawdata[i:n]) self.handle_data(rawdata[i:n])
i = n i = n
self.rawdata = rawdata[i:] self.rawdata = rawdata[i:]
# XXX if end: check for empty stack
# Extensions for the DOCTYPE scanner:
_decl_otherchars = '='
# Internal -- parse processing instr, return length or -1 if not terminated
def parse_pi(self, i): def parse_pi(self, i):
"""内部方法 - 解析处理指令
返回处理的字符数,如果未结束则返回-1
"""
rawdata = self.rawdata rawdata = self.rawdata
if rawdata[i:i + 2] != '<?': if rawdata[i:i + 2] != '<?':
self.error('unexpected call to parse_pi()') self.error('unexpected call to parse_pi()')
@ -244,18 +258,18 @@ class SGMLParser(markupbase.ParserBase):
return j - i return j - i
def get_starttag_text(self): def get_starttag_text(self):
"""获取最近处理的开始标签文本"""
return self.__starttag_text return self.__starttag_text
# Internal -- handle starttag, return length or -1 if not terminated
def parse_starttag(self, i): def parse_starttag(self, i):
"""内部方法 - 处理开始标签
返回处理的字符数,如果未结束则返回-1
"""
self.__starttag_text = None self.__starttag_text = None
start_pos = i start_pos = i
rawdata = self.rawdata rawdata = self.rawdata
if shorttagopen.match(rawdata, i): if shorttagopen.match(rawdata, i): # 简写标签
# SGML shorthand: <tag/data/ == <tag>data</tag>
# XXX Can data contain &... (entity or char refs)?
# XXX Can data contain < or > (tag characters)?
# XXX Can there be whitespace before the first /?
match = shorttag.match(rawdata, i) match = shorttag.match(rawdata, i)
if not match: if not match:
return -1 return -1
@ -266,18 +280,16 @@ class SGMLParser(markupbase.ParserBase):
self.finish_shorttag(tag, data) self.finish_shorttag(tag, data)
self.__starttag_text = rawdata[start_pos:match.end(1) + 1] self.__starttag_text = rawdata[start_pos:match.end(1) + 1]
return k return k
# XXX The following should skip matching quotes (' or ")
# As a shortcut way to exit, this isn't so bad, but shouldn't # 查找标签结束位置
# be used to locate the actual end of the start tag since the
# < or > characters may be embedded in an attribute value.
match = endbracket.search(rawdata, i + 1) match = endbracket.search(rawdata, i + 1)
if not match: if not match:
return -1 return -1
j = match.start(0) j = match.start(0)
# Now parse the data between i + 1 and j into a tag and attrs
# 解析标签名和属性
attrs = [] attrs = []
if rawdata[i:i + 2] == '<>': if rawdata[i:i + 2] == '<>': # <>表示重复上一个开始标签
# SGML shorthand: <> == <last open tag seen>
k = j k = j
tag = self.lasttag tag = self.lasttag
else: else:
@ -287,6 +299,8 @@ class SGMLParser(markupbase.ParserBase):
k = match.end(0) k = match.end(0)
tag = rawdata[i + 1:k].lower() tag = rawdata[i + 1:k].lower()
self.lasttag = tag self.lasttag = tag
# 解析属性
while k < j: while k < j:
match = attrfind.match(rawdata, k) match = attrfind.match(rawdata, k)
if not match: if not match:
@ -297,31 +311,31 @@ class SGMLParser(markupbase.ParserBase):
else: else:
if (attrvalue[:1] == "'" == attrvalue[-1:] or if (attrvalue[:1] == "'" == attrvalue[-1:] or
attrvalue[:1] == '"' == attrvalue[-1:]): attrvalue[:1] == '"' == attrvalue[-1:]):
# strip quotes attrvalue = attrvalue[1:-1] # 去掉引号
attrvalue = attrvalue[1:-1]
attrvalue = self.entity_or_charref.sub( attrvalue = self.entity_or_charref.sub(
self._convert_ref, attrvalue) self._convert_ref, attrvalue)
attrs.append((attrname.lower(), attrvalue)) attrs.append((attrname.lower(), attrvalue))
k = match.end(0) k = match.end(0)
if rawdata[j] == '>': if rawdata[j] == '>':
j = j + 1 j = j + 1
self.__starttag_text = rawdata[start_pos:j] self.__starttag_text = rawdata[start_pos:j]
self.finish_starttag(tag, attrs) self.finish_starttag(tag, attrs)
return j return j
# Internal -- convert entity or character reference
def _convert_ref(self, match): def _convert_ref(self, match):
if match.group(2): """内部方法 - 转换实体引用或字符引用"""
if match.group(2): # 字符引用
return self.convert_charref(match.group(2)) or \ return self.convert_charref(match.group(2)) or \
'&#%s%s' % match.groups()[1:] '&#%s%s' % match.groups()[1:]
elif match.group(3): elif match.group(3): # 实体引用
return self.convert_entityref(match.group(1)) or \ return self.convert_entityref(match.group(1)) or \
'&%s;' % match.group(1) '&%s;' % match.group(1)
else: else:
return '&%s' % match.group(1) return '&%s' % match.group(1)
# Internal -- parse endtag
def parse_endtag(self, i): def parse_endtag(self, i):
"""内部方法 - 解析结束标签"""
rawdata = self.rawdata rawdata = self.rawdata
match = endbracket.search(rawdata, i + 1) match = endbracket.search(rawdata, i + 1)
if not match: if not match:
@ -333,15 +347,23 @@ class SGMLParser(markupbase.ParserBase):
self.finish_endtag(tag) self.finish_endtag(tag)
return j return j
# Internal -- finish parsing of <tag/data/ (same as <tag>data</tag>)
def finish_shorttag(self, tag, data): def finish_shorttag(self, tag, data):
"""内部方法 - 完成简写标签的处理
<tag/data/> 等同于 <tag>data</tag>
"""
self.finish_starttag(tag, []) self.finish_starttag(tag, [])
self.handle_data(data) self.handle_data(data)
self.finish_endtag(tag) self.finish_endtag(tag)
# Internal -- finish processing of start tag
# Return -1 for unknown tag, 0 for open-only tag, 1 for balanced tag
def finish_starttag(self, tag, attrs): def finish_starttag(self, tag, attrs):
"""内部方法 - 完成开始标签的处理
返回:
-1: 未知标签
0: 仅开始标签
1: 平衡标签
"""
try: try:
method = getattr(self, 'start_' + tag) method = getattr(self, 'start_' + tag)
except AttributeError: except AttributeError:
@ -358,15 +380,15 @@ class SGMLParser(markupbase.ParserBase):
self.handle_starttag(tag, method, attrs) self.handle_starttag(tag, method, attrs)
return 1 return 1
# Internal -- finish processing of end tag
def finish_endtag(self, tag): def finish_endtag(self, tag):
if not tag: """内部方法 - 完成结束标签的处理"""
if not tag: # 空标签
found = len(self.stack) - 1 found = len(self.stack) - 1
if found < 0: if found < 0:
self.unknown_endtag(tag) self.unknown_endtag(tag)
return return
else: else:
if tag not in self.stack: if tag not in self.stack: # 未匹配的结束标签
try: try:
method = getattr(self, 'end_' + tag) method = getattr(self, 'end_' + tag)
except AttributeError: except AttributeError:
@ -378,6 +400,8 @@ class SGMLParser(markupbase.ParserBase):
for i in range(found): for i in range(found):
if self.stack[i] == tag: if self.stack[i] == tag:
found = i found = i
# 处理所有未闭合的标签
while len(self.stack) > found: while len(self.stack) > found:
tag = self.stack[-1] tag = self.stack[-1]
try: try:
@ -390,22 +414,24 @@ class SGMLParser(markupbase.ParserBase):
self.unknown_endtag(tag) self.unknown_endtag(tag)
del self.stack[-1] del self.stack[-1]
# Overridable -- handle start tag # 以下方法可被派生类重写
def handle_starttag(self, tag, method, attrs): def handle_starttag(self, tag, method, attrs):
"""处理开始标签"""
method(attrs) method(attrs)
# Overridable -- handle end tag
def handle_endtag(self, tag, method): def handle_endtag(self, tag, method):
"""处理结束标签"""
method() method()
# Example -- report an unbalanced </...> tag.
def report_unbalanced(self, tag): def report_unbalanced(self, tag):
"""报告未匹配的结束标签"""
if self.verbose: if self.verbose:
print('*** Unbalanced </' + tag + '>') print('*** Unbalanced </' + tag + '>')
print('*** Stack:', self.stack) print('*** Stack:', self.stack)
def convert_charref(self, name): def convert_charref(self, name):
"""Convert character reference, may be overridden.""" """转换字符引用"""
try: try:
n = int(name) n = int(name)
except ValueError: except ValueError:
@ -415,25 +441,25 @@ class SGMLParser(markupbase.ParserBase):
return self.convert_codepoint(n) return self.convert_codepoint(n)
def convert_codepoint(self, codepoint): def convert_codepoint(self, codepoint):
"""转换代码点为字符"""
return chr(codepoint) return chr(codepoint)
def handle_charref(self, name): def handle_charref(self, name):
"""Handle character reference, no need to override.""" """处理字符引用"""
replacement = self.convert_charref(name) replacement = self.convert_charref(name)
if replacement is None: if replacement is None:
self.unknown_charref(name) self.unknown_charref(name)
else: else:
self.handle_data(replacement) self.handle_data(replacement)
# Definition of entities -- derived classes may override # 实体定义 - 派生类可以重写
entitydefs = \ entitydefs = \
{'lt': '<', 'gt': '>', 'amp': '&', 'quot': '"', 'apos': '\''} {'lt': '<', 'gt': '>', 'amp': '&', 'quot': '"', 'apos': '\''}
def convert_entityref(self, name): def convert_entityref(self, name):
"""Convert entity references. """转换实体引用
As an alternative to overriding this method; one can tailor the 可以通过设置self.entitydefs来自定义转换规则
results by setting up the self.entitydefs mapping appropriately.
""" """
table = self.entitydefs table = self.entitydefs
if name in table: if name in table:
@ -442,61 +468,72 @@ class SGMLParser(markupbase.ParserBase):
return return
def handle_entityref(self, name): def handle_entityref(self, name):
"""Handle entity references, no need to override.""" """处理实体引用"""
replacement = self.convert_entityref(name) replacement = self.convert_entityref(name)
if replacement is None: if replacement is None:
self.unknown_entityref(name) self.unknown_entityref(name)
else: else:
self.handle_data(replacement) self.handle_data(replacement)
# Example -- handle data, should be overridden # 以下是示例处理方法 - 应该被重写
def handle_data(self, data): def handle_data(self, data):
"""处理文本数据"""
pass pass
# Example -- handle comment, could be overridden
def handle_comment(self, data): def handle_comment(self, data):
"""处理注释"""
pass pass
# Example -- handle declaration, could be overridden
def handle_decl(self, decl): def handle_decl(self, decl):
"""处理声明"""
pass pass
# Example -- handle processing instruction, could be overridden
def handle_pi(self, data): def handle_pi(self, data):
"""处理处理指令"""
pass pass
# To be overridden -- handlers for unknown objects # 处理未知对象的方法 - 需要重写
def unknown_starttag(self, tag, attrs): def unknown_starttag(self, tag, attrs):
"""处理未知开始标签"""
pass pass
def unknown_endtag(self, tag): def unknown_endtag(self, tag):
"""处理未知结束标签"""
pass pass
def unknown_charref(self, ref): def unknown_charref(self, ref):
"""处理未知字符引用"""
pass pass
def unknown_entityref(self, ref): def unknown_entityref(self, ref):
"""处理未知实体引用"""
pass pass
class TestSGMLParser(SGMLParser): class TestSGMLParser(SGMLParser):
"""用于测试的SGML解析器"""
def __init__(self, verbose=0): def __init__(self, verbose=0):
self.testdata = "" self.testdata = ""
SGMLParser.__init__(self, verbose) SGMLParser.__init__(self, verbose)
def handle_data(self, data): def handle_data(self, data):
"""收集并打印文本数据"""
self.testdata = self.testdata + data self.testdata = self.testdata + data
if len(repr(self.testdata)) >= 70: if len(repr(self.testdata)) >= 70:
self.flush() self.flush()
def flush(self): def flush(self):
"""打印收集的数据"""
data = self.testdata data = self.testdata
if data: if data:
self.testdata = "" self.testdata = ""
print('data:', repr(data)) print('data:', repr(data))
def handle_comment(self, data): def handle_comment(self, data):
"""打印注释"""
self.flush() self.flush()
r = repr(data) r = repr(data)
if len(r) > 68: if len(r) > 68:
@ -504,6 +541,7 @@ class TestSGMLParser(SGMLParser):
print('comment:', r) print('comment:', r)
def unknown_starttag(self, tag, attrs): def unknown_starttag(self, tag, attrs):
"""打印未知开始标签"""
self.flush() self.flush()
if not attrs: if not attrs:
print('start tag: <' + tag + '>') print('start tag: <' + tag + '>')
@ -514,27 +552,33 @@ class TestSGMLParser(SGMLParser):
print('>') print('>')
def unknown_endtag(self, tag): def unknown_endtag(self, tag):
"""打印未知结束标签"""
self.flush() self.flush()
print('end tag: </' + tag + '>') print('end tag: </' + tag + '>')
def unknown_entityref(self, ref): def unknown_entityref(self, ref):
"""打印未知实体引用"""
self.flush() self.flush()
print('*** unknown entity ref: &' + ref + ';') print('*** unknown entity ref: &' + ref + ';')
def unknown_charref(self, ref): def unknown_charref(self, ref):
"""打印未知字符引用"""
self.flush() self.flush()
print('*** unknown char ref: &#' + ref + ';') print('*** unknown char ref: &#' + ref + ';')
def unknown_decl(self, data): def unknown_decl(self, data):
"""打印未知声明"""
self.flush() self.flush()
print('*** unknown decl: [' + data + ']') print('*** unknown decl: [' + data + ']')
def close(self): def close(self):
"""关闭解析器并打印剩余数据"""
SGMLParser.close(self) SGMLParser.close(self)
self.flush() self.flush()
def test(args=None): def test(args=None):
"""测试函数"""
import sys import sys
if args is None: if args is None:

@ -5,33 +5,37 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import importlib # 导入所需的Python标准库
import logging import importlib # 用于动态导入模块
import os import logging # 用于日志记录
import re import os # 用于操作系统相关功能
import sys import re # 用于正则表达式操作
import traceback import sys # 用于系统相关功能
import warnings import traceback # 用于异常追踪
import warnings # 用于警告控制
_path = list(sys.path)
# 尝试导入SQLAlchemy
_path = list(sys.path) # 保存原始系统路径
_sqlalchemy = None _sqlalchemy = None
try: try:
sys.path = sys.path[1:] sys.path = sys.path[1:] # 修改系统路径以避免命名冲突
module = importlib.import_module("sqlalchemy") module = importlib.import_module("sqlalchemy") # 动态导入SQLAlchemy
if hasattr(module, "dialects"): if hasattr(module, "dialects"):
_sqlalchemy = module _sqlalchemy = module
warnings.simplefilter(action="ignore", category=_sqlalchemy.exc.SAWarning) warnings.simplefilter(action="ignore", category=_sqlalchemy.exc.SAWarning) # 忽略SQLAlchemy的警告
except: except:
pass pass
finally: finally:
sys.path = _path sys.path = _path # 恢复原始系统路径
# 尝试导入MySQL-python驱动
try: try:
import MySQLdb # used by SQLAlchemy in case of MySQL import MySQLdb # SQLAlchemy在使用MySQL时需要
warnings.filterwarnings("error", category=MySQLdb.Warning) warnings.filterwarnings("error", category=MySQLdb.Warning) # 将MySQL警告转换为错误
except (ImportError, AttributeError): except (ImportError, AttributeError):
pass pass
# 导入项目内部模块
from lib.core.data import conf from lib.core.data import conf
from lib.core.data import logger from lib.core.data import logger
from lib.core.exception import SqlmapConnectionException from lib.core.exception import SqlmapConnectionException
@ -41,32 +45,44 @@ from plugins.generic.connector import Connector as GenericConnector
from thirdparty import six from thirdparty import six
from thirdparty.six.moves import urllib as _urllib from thirdparty.six.moves import urllib as _urllib
def getSafeExString(ex, encoding=None): # Cross-referenced function def getSafeExString(ex, encoding=None): # 交叉引用的函数
raise NotImplementedError raise NotImplementedError
class SQLAlchemy(GenericConnector): class SQLAlchemy(GenericConnector):
"""SQLAlchemy连接器类用于处理数据库连接和查询"""
def __init__(self, dialect=None): def __init__(self, dialect=None):
"""初始化SQLAlchemy连接器
Args:
dialect: 数据库方言(如mysql, postgresql等)
"""
GenericConnector.__init__(self) GenericConnector.__init__(self)
self.dialect = dialect self.dialect = dialect
self.address = conf.direct self.address = conf.direct # 从配置获取数据库连接地址
# 处理数据库用户名中的特殊字符
if conf.dbmsUser: if conf.dbmsUser:
self.address = self.address.replace("'%s':" % conf.dbmsUser, "%s:" % _urllib.parse.quote(conf.dbmsUser)) self.address = self.address.replace("'%s':" % conf.dbmsUser, "%s:" % _urllib.parse.quote(conf.dbmsUser))
self.address = self.address.replace("%s:" % conf.dbmsUser, "%s:" % _urllib.parse.quote(conf.dbmsUser)) self.address = self.address.replace("%s:" % conf.dbmsUser, "%s:" % _urllib.parse.quote(conf.dbmsUser))
# 处理数据库密码中的特殊字符
if conf.dbmsPass: if conf.dbmsPass:
self.address = self.address.replace(":'%s'@" % conf.dbmsPass, ":%s@" % _urllib.parse.quote(conf.dbmsPass)) self.address = self.address.replace(":'%s'@" % conf.dbmsPass, ":%s@" % _urllib.parse.quote(conf.dbmsPass))
self.address = self.address.replace(":%s@" % conf.dbmsPass, ":%s@" % _urllib.parse.quote(conf.dbmsPass)) self.address = self.address.replace(":%s@" % conf.dbmsPass, ":%s@" % _urllib.parse.quote(conf.dbmsPass))
# 设置数据库方言
if self.dialect: if self.dialect:
self.address = re.sub(r"\A.+://", "%s://" % self.dialect, self.address) self.address = re.sub(r"\A.+://", "%s://" % self.dialect, self.address)
def connect(self): def connect(self):
"""建立数据库连接"""
if _sqlalchemy: if _sqlalchemy:
self.initConnection() self.initConnection()
try: try:
# 处理SQLite数据库文件路径
if not self.port and self.db: if not self.port and self.db:
if not os.path.exists(self.db): if not os.path.exists(self.db):
raise SqlmapFilePathException("the provided database file '%s' does not exist" % self.db) raise SqlmapFilePathException("the provided database file '%s' does not exist" % self.db)
@ -74,6 +90,7 @@ class SQLAlchemy(GenericConnector):
_ = self.address.split("//", 1) _ = self.address.split("//", 1)
self.address = "%s////%s" % (_[0], os.path.abspath(self.db)) self.address = "%s////%s" % (_[0], os.path.abspath(self.db))
# 根据不同数据库类型创建引擎
if self.dialect == "sqlite": if self.dialect == "sqlite":
engine = _sqlalchemy.create_engine(self.address, connect_args={"check_same_thread": False}) engine = _sqlalchemy.create_engine(self.address, connect_args={"check_same_thread": False})
elif self.dialect == "oracle": elif self.dialect == "oracle":
@ -81,8 +98,9 @@ class SQLAlchemy(GenericConnector):
else: else:
engine = _sqlalchemy.create_engine(self.address, connect_args={}) engine = _sqlalchemy.create_engine(self.address, connect_args={})
self.connector = engine.connect() self.connector = engine.connect() # 建立连接
except (TypeError, ValueError): except (TypeError, ValueError):
# 处理特定的连接错误
if "_get_server_version_info" in traceback.format_exc(): if "_get_server_version_info" in traceback.format_exc():
try: try:
import pymssql import pymssql
@ -104,6 +122,11 @@ class SQLAlchemy(GenericConnector):
raise SqlmapMissingDependence("SQLAlchemy not available (e.g. 'pip%s install SQLAlchemy')" % ('3' if six.PY3 else "")) raise SqlmapMissingDependence("SQLAlchemy not available (e.g. 'pip%s install SQLAlchemy')" % ('3' if six.PY3 else ""))
def fetchall(self): def fetchall(self):
"""获取所有查询结果
Returns:
list: 查询结果列表每个元素为一个元组
"""
try: try:
retVal = [] retVal = []
for row in self.cursor.fetchall(): for row in self.cursor.fetchall():
@ -114,9 +137,17 @@ class SQLAlchemy(GenericConnector):
return None return None
def execute(self, query): def execute(self, query):
"""执行SQL查询
Args:
query: SQL查询语句
Returns:
bool: 查询是否执行成功
"""
retVal = False retVal = False
# Reference: https://stackoverflow.com/a/69491015 # 将查询转换为SQLAlchemy的text对象
if hasattr(_sqlalchemy, "text"): if hasattr(_sqlalchemy, "text"):
query = _sqlalchemy.text(query) query = _sqlalchemy.text(query)
@ -131,6 +162,14 @@ class SQLAlchemy(GenericConnector):
return retVal return retVal
def select(self, query): def select(self, query):
"""执行SELECT查询
Args:
query: SELECT查询语句
Returns:
list: 查询结果列表如果查询失败返回None
"""
retVal = None retVal = None
if self.execute(query): if self.execute(query):

@ -1,10 +1,74 @@
#!/usr/bin/env python #!/usr/bin/env python
"""
Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission
"""#!/usr/bin/env python
""" """
Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/) Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import threading
from lib.core.data import logger
from lib.core.enums import CUSTOM_LOGGING
from lib.core.enums import TIMEOUT_STATE
def timeout(func, args=None, kwargs=None, duration=1, default=None):
"""
带超时控制的函数执行装饰器
参数说明:
func: 要执行的目标函数
args: 传递给目标函数的位置参数,默认为None
kwargs: 传递给目标函数的关键字参数,默认为None
duration: 超时时间,单位为秒,默认为1秒
default: 超时后的默认返回值,默认为None
"""
class InterruptableThread(threading.Thread):
"""
可中断的线程类,继承自threading.Thread
用于在独立线程中执行目标函数
"""
def __init__(self):
threading.Thread.__init__(self)
self.result = None # 存储函数执行结果
self.timeout_state = None # 存储执行状态
def run(self):
"""
线程执行的主要逻辑:
1. 尝试执行目标函数并保存结果
2. 如果执行成功,设置状态为NORMAL
3. 如果发生异常,记录日志并设置状态为EXCEPTION
"""
try:
# 执行目标函数,处理参数
self.result = func(*(args or ()), **(kwargs or {}))
self.timeout_state = TIMEOUT_STATE.NORMAL
except Exception as ex:
# 异常处理:记录日志,返回默认值
logger.log(CUSTOM_LOGGING.TRAFFIC_IN, ex)
self.result = default
self.timeout_state = TIMEOUT_STATE.EXCEPTION
# 创建并启动线程
thread = InterruptableThread()
thread.start()
# 等待线程执行,最多等待duration秒
thread.join(duration)
# 判断线程是否还在运行
if thread.is_alive():
# 如果超时,返回默认值和超时状态
return default, TIMEOUT_STATE.TIMEOUT
else:
# 如果执行完成,返回执行结果和执行状态
return thread.result, thread.timeout_state
import threading import threading
from lib.core.data import logger from lib.core.data import logger

@ -5,24 +5,40 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
# 导入所需的系统模块和时间模块
import sys import sys
import time import time
# 获取当前Python版本号(如"3.8.0")
PYVERSION = sys.version.split()[0] PYVERSION = sys.version.split()[0]
# 检查Python版本是否低于2.6
# 如果是,则输出错误信息并退出程序
# 错误信息包含当前时间和检测到的Python版本
if PYVERSION < "2.6": if PYVERSION < "2.6":
sys.exit("[%s] [CRITICAL] incompatible Python version detected ('%s'). To successfully run sqlmap you'll have to use version 2.6, 2.7 or 3.x (visit 'https://www.python.org/downloads/')" % (time.strftime("%X"), PYVERSION)) sys.exit("[%s] [CRITICAL] incompatible Python version detected ('%s'). To successfully run sqlmap you'll have to use version 2.6, 2.7 or 3.x (visit 'https://www.python.org/downloads/')" % (time.strftime("%X"), PYVERSION))
# 初始化错误列表
errors = [] errors = []
# 定义需要检查的核心扩展模块元组
extensions = ("bz2", "gzip", "pyexpat", "ssl", "sqlite3", "zlib") extensions = ("bz2", "gzip", "pyexpat", "ssl", "sqlite3", "zlib")
# 遍历所有需要的扩展模块
for _ in extensions: for _ in extensions:
try: try:
# 尝试导入每个扩展模块
__import__(_) __import__(_)
except ImportError: except ImportError:
# 如果导入失败,将该模块名添加到错误列表中
errors.append(_) errors.append(_)
# 如果存在任何缺失的扩展模块
if errors: if errors:
# 构建错误信息,包含:
# 1. 当前时间
# 2. 所有缺失的模块名称(以逗号分隔)
errMsg = "[%s] [CRITICAL] missing one or more core extensions (%s) " % (time.strftime("%X"), ", ".join("'%s'" % _ for _ in errors)) errMsg = "[%s] [CRITICAL] missing one or more core extensions (%s) " % (time.strftime("%X"), ", ".join("'%s'" % _ for _ in errors))
# 补充说明可能的原因:Python安装时缺少相应的开发包
errMsg += "most likely because current version of Python has been " errMsg += "most likely because current version of Python has been "
errMsg += "built without appropriate dev packages" errMsg += "built without appropriate dev packages"
# 输出错误信息并退出程序
sys.exit(errMsg) sys.exit(errMsg)

@ -9,8 +9,8 @@ import numbers
class xrange(object): class xrange(object):
""" """
Advanced (re)implementation of xrange (supports slice/copy/etc.) xrange的高级(重新)实现(支持切片/复制等操作)
Reference: http://code.activestate.com/recipes/521885-a-pythonic-implementation-of-xrange/ 参考: http://code.activestate.com/recipes/521885-a-pythonic-implementation-of-xrange/
>>> list(xrange(1, 9)) == list(range(1, 9)) >>> list(xrange(1, 9)) == list(range(1, 9))
True True
@ -35,58 +35,77 @@ class xrange(object):
1 1
""" """
# 使用__slots__来限制类的属性,只允许_slice属性,可以节省内存
__slots__ = ['_slice'] __slots__ = ['_slice']
def __init__(self, *args): def __init__(self, *args):
# 如果第一个参数是xrange对象,则复制其属性
if args and isinstance(args[0], type(self)): if args and isinstance(args[0], type(self)):
self._slice = slice(args[0].start, args[0].stop, args[0].step) self._slice = slice(args[0].start, args[0].stop, args[0].step)
else: else:
# 否则创建新的slice对象
self._slice = slice(*args) self._slice = slice(*args)
# 确保stop参数不为None
if self._slice.stop is None: if self._slice.stop is None:
raise TypeError("xrange stop must not be None") raise TypeError("xrange stop must not be None")
@property @property
def start(self): def start(self):
# 返回起始值,如果未指定则默认为0
if self._slice.start is not None: if self._slice.start is not None:
return self._slice.start return self._slice.start
return 0 return 0
@property @property
def stop(self): def stop(self):
# 返回结束值
return self._slice.stop return self._slice.stop
@property @property
def step(self): def step(self):
# 返回步长,如果未指定则默认为1
if self._slice.step is not None: if self._slice.step is not None:
return self._slice.step return self._slice.step
return 1 return 1
def __hash__(self): def __hash__(self):
# 返回slice对象的哈希值
return hash(self._slice) return hash(self._slice)
def __repr__(self): def __repr__(self):
# 返回对象的字符串表示
return '%s(%r, %r, %r)' % (type(self).__name__, self.start, self.stop, self.step) return '%s(%r, %r, %r)' % (type(self).__name__, self.start, self.stop, self.step)
def __len__(self): def __len__(self):
# 返回序列的长度
return self._len() return self._len()
def _len(self): def _len(self):
# 计算序列的长度: (stop-1-start)//step + 1,确保结果不小于0
return max(0, 1 + int((self.stop - 1 - self.start) // self.step)) return max(0, 1 + int((self.stop - 1 - self.start) // self.step))
def __contains__(self, value): def __contains__(self, value):
# 判断value是否在序列中
# 条件1: value在start和stop范围内
# 条件2: value与start的差值能被step整除
return (self.start <= value < self.stop) and (value - self.start) % self.step == 0 return (self.start <= value < self.stop) and (value - self.start) % self.step == 0
def __getitem__(self, index): def __getitem__(self, index):
# 支持通过索引或切片获取元素
if isinstance(index, slice): if isinstance(index, slice):
# 如果是切片,返回新的xrange对象
start, stop, step = index.indices(self._len()) start, stop, step = index.indices(self._len())
return xrange(self._index(start), return xrange(self._index(start),
self._index(stop), step * self.step) self._index(stop), step * self.step)
elif isinstance(index, numbers.Integral): elif isinstance(index, numbers.Integral):
# 如果是整数索引
if index < 0: if index < 0:
# 处理负数索引
fixed_index = index + self._len() fixed_index = index + self._len()
else: else:
fixed_index = index fixed_index = index
# 检查索引是否越界
if not 0 <= fixed_index < self._len(): if not 0 <= fixed_index < self._len():
raise IndexError("Index %d out of %r" % (index, self)) raise IndexError("Index %d out of %r" % (index, self))
@ -95,9 +114,11 @@ class xrange(object):
raise TypeError("xrange indices must be slices or integers") raise TypeError("xrange indices must be slices or integers")
def _index(self, i): def _index(self, i):
# 计算第i个元素的实际值
return self.start + self.step * i return self.start + self.step * i
def index(self, i): def index(self, i):
# 返回值为i的元素在序列中的索引位置
if self.start <= i < self.stop: if self.start <= i < self.stop:
return i - self.start return i - self.start
else: else:

Loading…
Cancel
Save