You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
352 lines
12 KiB
352 lines
12 KiB
import json
import os.path as osp
import shutil
import sys
import tempfile
from argparse import Action, ArgumentParser
from collections import abc
from importlib import import_module
from addict import Dict
BASE_KEY = '_base_'
DELETE_KEY = '_delete_'
class ConfigDict(Dict):
def __missing__(self, name):
raise KeyError(name)
def __getattr__(self, name):
value = super(ConfigDict, self).__getattr__(name)
except KeyError:
ex = AttributeError(f"'{self.__class__.__name__}' object has no "
f"attribute '{name}'")
except Exception as e:
ex = e
return value
raise ex
def add_args(parser, cfg, prefix=''):
for k, v in cfg.items():
if isinstance(v, str):
parser.add_argument('--' + prefix + k)
elif isinstance(v, int):
parser.add_argument('--' + prefix + k, type=int)
elif isinstance(v, float):
parser.add_argument('--' + prefix + k, type=float)
elif isinstance(v, bool):
parser.add_argument('--' + prefix + k, action='store_true')
elif isinstance(v, dict):
add_args(parser, v, prefix + k + '.')
elif isinstance(v, abc.Iterable):
parser.add_argument('--' + prefix + k, type=type(v[0]), nargs='+')
print(f'cannot parse key {prefix + k} of type {type(v)}')
return parser
class Config(object):
"""A facility for config and config files.
It supports common file formats as configs: python/json/yaml. The interface
is the same as a dict object and also allows access config values as
>>> cfg = Config(dict(a=1, b=dict(b1=[0, 1])))
>>> cfg.a
>>> cfg.b
{'b1': [0, 1]}
>>> cfg.b.b1
[0, 1]
>>> cfg = Config.fromfile('tests/data/config/')
>>> cfg.filename
>>> cfg.item4
>>> cfg
"Config [path: /home/kchen/projects/mmcv/tests/data/config/]: "
"{'item1': [1, 2], 'item2': {'a': 0}, 'item3': True, 'item4': 'test'}"
def _file2dict(filename):
filename = osp.abspath(osp.expanduser(filename))
if filename.endswith('.py'):
with tempfile.TemporaryDirectory() as temp_config_dir:
temp_config_file = tempfile.NamedTemporaryFile(
dir=temp_config_dir, suffix='.py')
temp_config_name = osp.basename(
# close temp file
osp.join(temp_config_dir, temp_config_name))
temp_module_name = osp.splitext(temp_config_name)[0]
sys.path.insert(0, temp_config_dir)
mod = import_module(temp_module_name)
cfg_dict = {
name: value
for name, value in mod.__dict__.items()
if not name.startswith('__')
# delete imported module
del sys.modules[temp_module_name]
elif filename.endswith(('.yml', '.yaml', '.json')):
import mmcv
cfg_dict = mmcv.load(filename)
raise IOError('Only py/yml/yaml/json type are supported now!')
cfg_text = filename + '\n'
with open(filename, 'r') as f:
cfg_text +=
if BASE_KEY in cfg_dict:
cfg_dir = osp.dirname(filename)
base_filename = cfg_dict.pop(BASE_KEY)
base_filename = base_filename if isinstance(
base_filename, list) else [base_filename]
cfg_dict_list = list()
cfg_text_list = list()
for f in base_filename:
_cfg_dict, _cfg_text = Config._file2dict(osp.join(cfg_dir, f))
base_cfg_dict = dict()
for c in cfg_dict_list:
if len(base_cfg_dict.keys() & c.keys()) > 0:
raise KeyError('Duplicate key is not allowed among bases')
base_cfg_dict = Config._merge_a_into_b(cfg_dict, base_cfg_dict)
cfg_dict = base_cfg_dict
# merge cfg_text
cfg_text = '\n'.join(cfg_text_list)
return cfg_dict, cfg_text
def _merge_a_into_b(a, b):
# merge dict `a` into dict `b` (non-inplace). values in `a` will
# overwrite `b`.
# copy first to avoid inplace modification
b = b.copy()
for k, v in a.items():
if isinstance(v, dict) and k in b and not v.pop(DELETE_KEY, False):
if not isinstance(b[k], dict):
raise TypeError(
f'{k}={v} in child config cannot inherit from base '
f'because {k} is a dict in the child config but is of '
f'type {type(b[k])} in base config. You may set '
f'`{DELETE_KEY}=True` to ignore the base config')
b[k] = Config._merge_a_into_b(v, b[k])
b[k] = v
return b
def fromfile(filename):
cfg_dict, cfg_text = Config._file2dict(filename)
return Config(cfg_dict, cfg_text=cfg_text, filename=filename)
def auto_argparser(description=None):
"""Generate argparser from config file automatically (experimental)
partial_parser = ArgumentParser(description=description)
partial_parser.add_argument('config', help='config file path')
cfg_file = partial_parser.parse_known_args()[0].config
cfg = Config.fromfile(cfg_file)
parser = ArgumentParser(description=description)
parser.add_argument('config', help='config file path')
add_args(parser, cfg)
return parser, cfg
def __init__(self, cfg_dict=None, cfg_text=None, filename=None):
if cfg_dict is None:
cfg_dict = dict()
elif not isinstance(cfg_dict, dict):
raise TypeError('cfg_dict must be a dict, but '
f'got {type(cfg_dict)}')
super(Config, self).__setattr__('_cfg_dict', ConfigDict(cfg_dict))
super(Config, self).__setattr__('_filename', filename)
if cfg_text:
text = cfg_text
elif filename:
with open(filename, 'r') as f:
text =
text = ''
super(Config, self).__setattr__('_text', text)
def filename(self):
return self._filename
def text(self):
return self._text
def pretty_text(self):
indent = 4
def _indent(s_, num_spaces):
s = s_.split('\n')
if len(s) == 1:
return s_
first = s.pop(0)
s = [(num_spaces * ' ') + line for line in s]
s = '\n'.join(s)
s = first + '\n' + s
return s
def _format_basic_types(k, v):
if isinstance(v, str):
v_str = f"'{v}'"
v_str = str(v)
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent)
return attr_str
def _format_list(k, v):
# check if all items in the list are dict
if all(isinstance(_, dict) for _ in v):
v_str = '[\n'
v_str += '\n'.join(
f'dict({_indent(_format_dict(v_), indent)}),'
for v_ in v).rstrip(',')
attr_str = f'{str(k)}={v_str}'
attr_str = _indent(attr_str, indent) + ']'
attr_str = _format_basic_types(k, v)
return attr_str
def _format_dict(d, outest_level=False):
r = ''
s = []
for idx, (k, v) in enumerate(d.items()):
is_last = idx >= len(d) - 1
end = '' if outest_level or is_last else ','
if isinstance(v, dict):
v_str = '\n' + _format_dict(v)
attr_str = f'{str(k)}=dict({v_str}'
attr_str = _indent(attr_str, indent) + ')' + end
elif isinstance(v, list):
attr_str = _format_list(k, v) + end
attr_str = _format_basic_types(k, v) + end
r += '\n'.join(s)
return r
cfg_dict = self._cfg_dict.to_dict()
text = _format_dict(cfg_dict, outest_level=True)
return text
def __repr__(self):
return f'Config (path: {self.filename}): {self._cfg_dict.__repr__()}'
def __len__(self):
return len(self._cfg_dict)
def __getattr__(self, name):
return getattr(self._cfg_dict, name)
def __getitem__(self, name):
return self._cfg_dict.__getitem__(name)
def __setattr__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setattr__(name, value)
def __setitem__(self, name, value):
if isinstance(value, dict):
value = ConfigDict(value)
self._cfg_dict.__setitem__(name, value)
def __iter__(self):
return iter(self._cfg_dict)
def dump(self):
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
format_text = json.dumps(cfg_dict, indent=2)
return format_text
def merge_from_dict(self, options):
"""Merge list into cfg_dict
Merge the dict parsed by MultipleKVAction into this cfg.
>>> options = {'model.backbone.depth': 50,
... 'model.backbone.with_cp':True}
>>> cfg = Config(dict(model=dict(backbone=dict(type='ResNet'))))
>>> cfg.merge_from_dict(options)
>>> cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
>>> assert cfg_dict == dict(
... model=dict(backbone=dict(depth=50, with_cp=True)))
options (dict): dict of configs to merge from.
option_cfg_dict = {}
for full_key, v in options.items():
d = option_cfg_dict
key_list = full_key.split('.')
for subkey in key_list[:-1]:
d.setdefault(subkey, ConfigDict())
d = d[subkey]
subkey = key_list[-1]
d[subkey] = v
cfg_dict = super(Config, self).__getattribute__('_cfg_dict')
super(Config, self).__setattr__(
'_cfg_dict', Config._merge_a_into_b(option_cfg_dict, cfg_dict))
class DictAction(Action):
argparse action to split an argument into KEY=VALUE form
on the first = and append to a dictionary. List options should
be passed as comma separated values, i.e KEY=V1,V2,V3
def _parse_int_float_bool(val):
return int(val)
except ValueError:
return float(val)
except ValueError:
if val.lower() in ['true', 'false']:
return True if val.lower() == 'true' else False
return val
def __call__(self, parser, namespace, values, option_string=None):
options = {}
for kv in values:
key, val = kv.split('=', maxsplit=1)
val = [self._parse_int_float_bool(v) for v in val.split(',')]
if len(val) == 1:
val = val[0]
options[key] = val
setattr(namespace, self.dest, options) |