"""Interact with functions using widgets.""" #----------------------------------------------------------------------------- # Copyright (c) 2013, the IPython Development Team. # # Distributed under the terms of the Modified BSD License. # # The full license is in the file COPYING.txt, distributed with this software. #----------------------------------------------------------------------------- #----------------------------------------------------------------------------- # Imports #----------------------------------------------------------------------------- from __future__ import print_function try: # Python >= 3.3 from inspect import signature, Parameter except ImportError: from IPython.utils.signatures import signature, Parameter from inspect import getcallargs from IPython.html.widgets import (Widget, TextWidget, FloatSliderWidget, IntSliderWidget, CheckboxWidget, DropdownWidget, ContainerWidget, DOMWidget) from IPython.display import display, clear_output from IPython.utils.py3compat import string_types, unicode_type from IPython.utils.traitlets import HasTraits, Any, Unicode #----------------------------------------------------------------------------- # Classes and Functions #----------------------------------------------------------------------------- def _matches(o, pattern): """Match a pattern of types in a sequence.""" if not len(o) == len(pattern): return False comps = zip(o,pattern) return all(isinstance(obj,kind) for obj,kind in comps) def _get_min_max_value(min, max, value): """Return min, max, value given input values with possible None.""" if value is None: if not max > min: raise ValueError('max must be greater than min: (min={0}, max={1})'.format(min, max)) value = min + abs(min-max)/2 value = type(min)(value) elif min is None and max is None: if value == 0.0: min, max, value = 0.0, 1.0, 0.5 elif value == 0: min, max, value = 0, 1, 0 elif isinstance(value, float): min, max = (-value, 3.0*value) if value > 0 else (3.0*value, -value) elif isinstance(value, int): min, max = (-value, 3*value) if value > 0 else (3*value, -value) else: raise TypeError('expected a number, got: %r' % value) else: raise ValueError('unable to infer range, value from: ({0}, {1}, {2})'.format(min, max, value)) return min, max, value def _widget_abbrev_single_value(o): """Make widgets from single values, which can be used written as parameter defaults.""" if isinstance(o, string_types): return TextWidget(value=unicode_type(o)) elif isinstance(o, dict): # get a single value in a Python 2+3 way: value = next(iter(o.values())) return DropdownWidget(value=value, values=o) elif isinstance(o, bool): return CheckboxWidget(value=o) elif isinstance(o, float): min, max, value = _get_min_max_value(None, None, o) return FloatSliderWidget(value=o, min=min, max=max) elif isinstance(o, int): min, max, value = _get_min_max_value(None, None, o) return IntSliderWidget(value=o, min=min, max=max) else: return None def _widget_abbrev(o): """Make widgets from abbreviations: single values, lists or tuples.""" if isinstance(o, (list, tuple)): if _matches(o, (int, int)): min, max, value = _get_min_max_value(o[0], o[1], None) return IntSliderWidget(value=value, min=min, max=max) elif _matches(o, (int, int, int)): min, max, value = _get_min_max_value(o[0], o[1], None) return IntSliderWidget(value=value, min=min, max=max, step=o[2]) elif _matches(o, (float, float)): min, max, value = _get_min_max_value(o[0], o[1], None) return FloatSliderWidget(value=value, min=min, max=max) elif _matches(o, (float, float, float)): min, max, value = _get_min_max_value(o[0], o[1], None) return FloatSliderWidget(value=value, min=min, max=max, step=o[2]) elif _matches(o, (float, float, int)): min, max, value = _get_min_max_value(o[0], o[1], None) return FloatSliderWidget(value=value, min=min, max=max, step=float(o[2])) elif all(isinstance(x, string_types) for x in o): return DropdownWidget(value=unicode_type(o[0]), values=[unicode_type(k) for k in o]) else: return _widget_abbrev_single_value(o) def _widget_from_abbrev(abbrev): """Build a Widget intstance given an abbreviation or Widget.""" if isinstance(abbrev, Widget) or isinstance(abbrev, const): return abbrev widget = _widget_abbrev(abbrev) if widget is None: raise ValueError("%r cannot be transformed to a Widget" % abbrev) return widget def _yield_abbreviations_for_parameter(param, args, kwargs): """Get an abbreviation for a function parameter.""" # print(param, args, kwargs) name = param.name kind = param.kind ann = param.annotation default = param.default empty = Parameter.empty if kind == Parameter.POSITIONAL_ONLY: if args: yield name, args.pop(0), False elif ann is not empty: yield name, ann, False else: yield None, None, None elif kind == Parameter.POSITIONAL_OR_KEYWORD: if name in kwargs: yield name, kwargs.pop(name), True elif args: yield name, args.pop(0), False elif ann is not empty: if default is empty: yield name, ann, False else: yield name, ann, True elif default is not empty: yield name, default, True else: yield None, None, None elif kind == Parameter.VAR_POSITIONAL: # In this case name=args or something and we don't actually know the names. for item in args[::]: args.pop(0) yield '', item, False elif kind == Parameter.KEYWORD_ONLY: if name in kwargs: yield name, kwargs.pop(name), True elif ann is not empty: yield name, ann, True elif default is not empty: yield name, default, True else: yield None, None, None elif kind == Parameter.VAR_KEYWORD: # In this case name=kwargs and we yield the items in kwargs with their keys. for k, v in kwargs.copy().items(): kwargs.pop(k) yield k, v, True def _find_abbreviations(f, args, kwargs): """Find the abbreviations for a function and args/kwargs passed to interact.""" new_args = [] new_kwargs = [] for param in signature(f).parameters.values(): for name, value, kw in _yield_abbreviations_for_parameter(param, args, kwargs): if value is None: raise ValueError('cannot find widget or abbreviation for argument: {!r}'.format(name)) if kw: new_kwargs.append((name, value)) else: new_args.append((name, value)) return new_args, new_kwargs def _widgets_from_abbreviations(seq): """Given a sequence of (name, abbrev) tuples, return a sequence of Widgets.""" result = [] for name, abbrev in seq: widget = _widget_from_abbrev(abbrev) widget.description = name result.append(widget) return result def interactive(f, *args, **kwargs): """Build a group of widgets to interact with a function.""" co = kwargs.pop('clear_output', True) args_widgets = [] kwargs_widgets = [] container = ContainerWidget() container.result = None container.args = [] container.kwargs = dict() # We need this to be a list as we iteratively pop elements off it args = list(args) kwargs = kwargs.copy() new_args, new_kwargs = _find_abbreviations(f, args, kwargs) # Before we proceed, let's make sure that the user has passed a set of args+kwargs # that will lead to a valid call of the function. This protects against unspecified # and doubly-specified arguments. getcallargs(f, *[v for n,v in new_args], **{n:v for n,v in new_kwargs}) # Now build the widgets from the abbreviations. args_widgets.extend(_widgets_from_abbreviations(new_args)) kwargs_widgets.extend(_widgets_from_abbreviations(new_kwargs)) kwargs_widgets.extend(_widgets_from_abbreviations(sorted(kwargs.items(), key = lambda x: x[0]))) # This has to be done as an assignment, not using container.children.append, # so that traitlets notices the update. We skip any objects (such as const) that # are not DOMWidgets. c = [w for w in args_widgets+kwargs_widgets if isinstance(w, DOMWidget)] container.children = c # Build the callback def call_f(name, old, new): container.args = [] for widget in args_widgets: value = widget.value container.args.append(value) for widget in kwargs_widgets: value = widget.value container.kwargs[widget.description] = value if co: clear_output(wait=True) container.result = f(*container.args, **container.kwargs) # Wire up the widgets for widget in args_widgets: widget.on_trait_change(call_f, 'value') for widget in kwargs_widgets: widget.on_trait_change(call_f, 'value') container.on_displayed(lambda _: call_f(None, None, None)) return container def interact(*args, **kwargs): """Interact with a function using widgets.""" if args and callable(args[0]): # This branch handles the cases: # 1. interact(f, *args, **kwargs) # 2. @interact # def f(*args, **kwargs): # ... f = args[0] w = interactive(f, *args[1:], **kwargs) f.widget = w display(w) else: # This branch handles the case: # @interact(10, 20, a=30, b=40) # def f(*args, **kwargs): # ... def dec(f): w = interactive(f, *args, **kwargs) f.widget = w display(w) return f return dec class const(HasTraits): """A pseudo-widget whose value is constant and never client synced.""" value = Any(help="Any Python object") description = Unicode('', help="Any Python object") def __init__(self, value, **kwargs): super(const, self).__init__(value=value, **kwargs) def annotate(**kwargs): """Python 3 compatible function annotation for Python 2.""" if not kwargs: raise ValueError('annotations must be provided as keyword arguments') def dec(f): if hasattr(f, '__annotations__'): for k, v in kwargs.items(): f.__annotations__[k] = v else: f.__annotations__ = kwargs return f return dec