diff --git a/notebook/config_manager.py b/notebook/config_manager.py index ff6f486ce..2cd84427a 100644 --- a/notebook/config_manager.py +++ b/notebook/config_manager.py @@ -9,6 +9,7 @@ import glob import io import json import os +import copy from six import PY3 from traitlets.config import LoggingConfigurable @@ -35,6 +36,19 @@ def recursive_update(target, new): else: target[k] = v +def remove_defaults(data, defaults): + """Recursively remove items from dict that are already in defaults""" + for key, value in list(data.items()): # copy the iterator, since data will be modified + new_value = None + if key in defaults: + if isinstance(value, dict): + remove_defaults(data[key], defaults[key]) + if not data[key]: # prune empty subdicts + del data[key] + else: + if value == defaults[key]: + del data[key] + class BaseJSONConfigManager(LoggingConfigurable): """General JSON config manager @@ -62,13 +76,16 @@ class BaseJSONConfigManager(LoggingConfigurable): """Returns the directory name for the section name: {config_dir}/{section_name}.d""" return os.path.join(self.config_dir, section_name+'.d') - def get(self, section_name): + def get(self, section_name, include_root=True): """Retrieve the config data for the specified section. Returns the data as a dictionary, or an empty dictionary if the file doesn't exist. + + When include_root is False, it will not read the root .json file, effectively + returning the default values. """ - paths = [self.file_name(section_name)] + paths = [self.file_name(section_name)] if include_root else [] if self.read_directory: pattern = os.path.join(self.directory(section_name), '*.json') # These json files should be processed first so that the @@ -91,6 +108,12 @@ class BaseJSONConfigManager(LoggingConfigurable): filename = self.file_name(section_name) self.ensure_config_dir_exists() + # we will modify data in place, so make a copy + data = copy.deepcopy(data) + defaults = self.get(section_name, include_root=False) + print(data, defaults) + remove_defaults(data, defaults) + # Generate the JSON up front, since it could raise an exception, # in order to avoid writing half-finished corrupted data to disk. json_content = json.dumps(data, indent=2) diff --git a/notebook/tests/test_config_manager.py b/notebook/tests/test_config_manager.py index 04ea9c443..ded2d33b9 100644 --- a/notebook/tests/test_config_manager.py +++ b/notebook/tests/test_config_manager.py @@ -9,20 +9,27 @@ from notebook.config_manager import BaseJSONConfigManager def test_json(): tmpdir = tempfile.mkdtemp() try: + root_data = dict(a=1, x=2, nest={'a':1, 'x':2}) with open(os.path.join(tmpdir, 'foo.json'), 'w') as f: - json.dump(dict(a=1), f) + json.dump(root_data, f) # also make a foo.d/ directory with multiple json files os.makedirs(os.path.join(tmpdir, 'foo.d')) with open(os.path.join(tmpdir, 'foo.d', 'a.json'), 'w') as f: - json.dump(dict(a=2, b=1), f) + json.dump(dict(a=2, b=1, nest={'a':2, 'b':1}), f) with open(os.path.join(tmpdir, 'foo.d', 'b.json'), 'w') as f: - json.dump(dict(a=3, b=2, c=3), f) + json.dump(dict(a=3, b=2, c=3, nest={'a':3, 'b':2, 'c':3}, only_in_b={'x':1}), f) manager = BaseJSONConfigManager(config_dir=tmpdir, read_directory=False) data = manager.get('foo') assert 'a' in data + assert 'x' in data assert 'b' not in data assert 'c' not in data assert data['a'] == 1 + assert 'x' in data['nest'] + # if we write it out, it also shouldn't pick up the subdirectoy + manager.set('foo', data) + data = manager.get('foo') + assert data == root_data manager = BaseJSONConfigManager(config_dir=tmpdir, read_directory=True) data = manager.get('foo') @@ -33,6 +40,17 @@ def test_json(): assert data['a'] == 1 assert data['b'] == 2 assert data['c'] == 3 + assert data['nest']['a'] == 1 + assert data['nest']['b'] == 2 + assert data['nest']['c'] == 3 + assert data['nest']['x'] == 2 + + # when writing out, we don't want foo.d/*.json data to be included in the root foo.json + manager.set('foo', data) + manager = BaseJSONConfigManager(config_dir=tmpdir, read_directory=False) + data = manager.get('foo') + assert data == root_data + finally: shutil.rmtree(tmpdir)