You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
414 lines
12 KiB
414 lines
12 KiB
from __future__ import annotations
|
|
|
|
from warnings import warn
|
|
import inspect
|
|
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
|
|
from .utils import expand_tuples
|
|
import itertools as itl
|
|
|
|
|
|
class MDNotImplementedError(NotImplementedError):
|
|
""" A NotImplementedError for multiple dispatch """
|
|
|
|
|
|
### Functions for on_ambiguity
|
|
|
|
def ambiguity_warn(dispatcher, ambiguities):
|
|
""" Raise warning when ambiguity is detected
|
|
|
|
Parameters
|
|
----------
|
|
dispatcher : Dispatcher
|
|
The dispatcher on which the ambiguity was detected
|
|
ambiguities : set
|
|
Set of type signature pairs that are ambiguous within this dispatcher
|
|
|
|
See Also:
|
|
Dispatcher.add
|
|
warning_text
|
|
"""
|
|
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
|
|
|
|
|
|
class RaiseNotImplementedError:
|
|
"""Raise ``NotImplementedError`` when called."""
|
|
|
|
def __init__(self, dispatcher):
|
|
self.dispatcher = dispatcher
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
types = tuple(type(a) for a in args)
|
|
raise NotImplementedError(
|
|
"Ambiguous signature for %s: <%s>" % (
|
|
self.dispatcher.name, str_signature(types)
|
|
))
|
|
|
|
def ambiguity_register_error_ignore_dup(dispatcher, ambiguities):
|
|
"""
|
|
If super signature for ambiguous types is duplicate types, ignore it.
|
|
Else, register instance of ``RaiseNotImplementedError`` for ambiguous types.
|
|
|
|
Parameters
|
|
----------
|
|
dispatcher : Dispatcher
|
|
The dispatcher on which the ambiguity was detected
|
|
ambiguities : set
|
|
Set of type signature pairs that are ambiguous within this dispatcher
|
|
|
|
See Also:
|
|
Dispatcher.add
|
|
ambiguity_warn
|
|
"""
|
|
for amb in ambiguities:
|
|
signature = tuple(super_signature(amb))
|
|
if len(set(signature)) == 1:
|
|
continue
|
|
dispatcher.add(
|
|
signature, RaiseNotImplementedError(dispatcher),
|
|
on_ambiguity=ambiguity_register_error_ignore_dup
|
|
)
|
|
|
|
###
|
|
|
|
|
|
_unresolved_dispatchers: set[Dispatcher] = set()
|
|
_resolve = [True]
|
|
|
|
|
|
def halt_ordering():
|
|
_resolve[0] = False
|
|
|
|
|
|
def restart_ordering(on_ambiguity=ambiguity_warn):
|
|
_resolve[0] = True
|
|
while _unresolved_dispatchers:
|
|
dispatcher = _unresolved_dispatchers.pop()
|
|
dispatcher.reorder(on_ambiguity=on_ambiguity)
|
|
|
|
|
|
class Dispatcher:
|
|
""" Dispatch methods based on type signature
|
|
|
|
Use ``dispatch`` to add implementations
|
|
|
|
Examples
|
|
--------
|
|
|
|
>>> from sympy.multipledispatch import dispatch
|
|
>>> @dispatch(int)
|
|
... def f(x):
|
|
... return x + 1
|
|
|
|
>>> @dispatch(float)
|
|
... def f(x): # noqa: F811
|
|
... return x - 1
|
|
|
|
>>> f(3)
|
|
4
|
|
>>> f(3.0)
|
|
2.0
|
|
"""
|
|
__slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc'
|
|
|
|
def __init__(self, name, doc=None):
|
|
self.name = self.__name__ = name
|
|
self.funcs = {}
|
|
self._cache = {}
|
|
self.ordering = []
|
|
self.doc = doc
|
|
|
|
def register(self, *types, **kwargs):
|
|
""" Register dispatcher with new implementation
|
|
|
|
>>> from sympy.multipledispatch.dispatcher import Dispatcher
|
|
>>> f = Dispatcher('f')
|
|
>>> @f.register(int)
|
|
... def inc(x):
|
|
... return x + 1
|
|
|
|
>>> @f.register(float)
|
|
... def dec(x):
|
|
... return x - 1
|
|
|
|
>>> @f.register(list)
|
|
... @f.register(tuple)
|
|
... def reverse(x):
|
|
... return x[::-1]
|
|
|
|
>>> f(1)
|
|
2
|
|
|
|
>>> f(1.0)
|
|
0.0
|
|
|
|
>>> f([1, 2, 3])
|
|
[3, 2, 1]
|
|
"""
|
|
def _(func):
|
|
self.add(types, func, **kwargs)
|
|
return func
|
|
return _
|
|
|
|
@classmethod
|
|
def get_func_params(cls, func):
|
|
if hasattr(inspect, "signature"):
|
|
sig = inspect.signature(func)
|
|
return sig.parameters.values()
|
|
|
|
@classmethod
|
|
def get_func_annotations(cls, func):
|
|
""" Get annotations of function positional parameters
|
|
"""
|
|
params = cls.get_func_params(func)
|
|
if params:
|
|
Parameter = inspect.Parameter
|
|
|
|
params = (param for param in params
|
|
if param.kind in
|
|
(Parameter.POSITIONAL_ONLY,
|
|
Parameter.POSITIONAL_OR_KEYWORD))
|
|
|
|
annotations = tuple(
|
|
param.annotation
|
|
for param in params)
|
|
|
|
if not any(ann is Parameter.empty for ann in annotations):
|
|
return annotations
|
|
|
|
def add(self, signature, func, on_ambiguity=ambiguity_warn):
|
|
""" Add new types/method pair to dispatcher
|
|
|
|
>>> from sympy.multipledispatch import Dispatcher
|
|
>>> D = Dispatcher('add')
|
|
>>> D.add((int, int), lambda x, y: x + y)
|
|
>>> D.add((float, float), lambda x, y: x + y)
|
|
|
|
>>> D(1, 2)
|
|
3
|
|
>>> D(1, 2.0)
|
|
Traceback (most recent call last):
|
|
...
|
|
NotImplementedError: Could not find signature for add: <int, float>
|
|
|
|
When ``add`` detects a warning it calls the ``on_ambiguity`` callback
|
|
with a dispatcher/itself, and a set of ambiguous type signature pairs
|
|
as inputs. See ``ambiguity_warn`` for an example.
|
|
"""
|
|
# Handle annotations
|
|
if not signature:
|
|
annotations = self.get_func_annotations(func)
|
|
if annotations:
|
|
signature = annotations
|
|
|
|
# Handle union types
|
|
if any(isinstance(typ, tuple) for typ in signature):
|
|
for typs in expand_tuples(signature):
|
|
self.add(typs, func, on_ambiguity)
|
|
return
|
|
|
|
for typ in signature:
|
|
if not isinstance(typ, type):
|
|
str_sig = ', '.join(c.__name__ if isinstance(c, type)
|
|
else str(c) for c in signature)
|
|
raise TypeError("Tried to dispatch on non-type: %s\n"
|
|
"In signature: <%s>\n"
|
|
"In function: %s" %
|
|
(typ, str_sig, self.name))
|
|
|
|
self.funcs[signature] = func
|
|
self.reorder(on_ambiguity=on_ambiguity)
|
|
self._cache.clear()
|
|
|
|
def reorder(self, on_ambiguity=ambiguity_warn):
|
|
if _resolve[0]:
|
|
self.ordering = ordering(self.funcs)
|
|
amb = ambiguities(self.funcs)
|
|
if amb:
|
|
on_ambiguity(self, amb)
|
|
else:
|
|
_unresolved_dispatchers.add(self)
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
types = tuple([type(arg) for arg in args])
|
|
try:
|
|
func = self._cache[types]
|
|
except KeyError:
|
|
func = self.dispatch(*types)
|
|
if not func:
|
|
raise NotImplementedError(
|
|
'Could not find signature for %s: <%s>' %
|
|
(self.name, str_signature(types)))
|
|
self._cache[types] = func
|
|
try:
|
|
return func(*args, **kwargs)
|
|
|
|
except MDNotImplementedError:
|
|
funcs = self.dispatch_iter(*types)
|
|
next(funcs) # burn first
|
|
for func in funcs:
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except MDNotImplementedError:
|
|
pass
|
|
raise NotImplementedError("Matching functions for "
|
|
"%s: <%s> found, but none completed successfully"
|
|
% (self.name, str_signature(types)))
|
|
|
|
def __str__(self):
|
|
return "<dispatched %s>" % self.name
|
|
__repr__ = __str__
|
|
|
|
def dispatch(self, *types):
|
|
""" Deterimine appropriate implementation for this type signature
|
|
|
|
This method is internal. Users should call this object as a function.
|
|
Implementation resolution occurs within the ``__call__`` method.
|
|
|
|
>>> from sympy.multipledispatch import dispatch
|
|
>>> @dispatch(int)
|
|
... def inc(x):
|
|
... return x + 1
|
|
|
|
>>> implementation = inc.dispatch(int)
|
|
>>> implementation(3)
|
|
4
|
|
|
|
>>> print(inc.dispatch(float))
|
|
None
|
|
|
|
See Also:
|
|
``sympy.multipledispatch.conflict`` - module to determine resolution order
|
|
"""
|
|
|
|
if types in self.funcs:
|
|
return self.funcs[types]
|
|
|
|
try:
|
|
return next(self.dispatch_iter(*types))
|
|
except StopIteration:
|
|
return None
|
|
|
|
def dispatch_iter(self, *types):
|
|
n = len(types)
|
|
for signature in self.ordering:
|
|
if len(signature) == n and all(map(issubclass, types, signature)):
|
|
result = self.funcs[signature]
|
|
yield result
|
|
|
|
def resolve(self, types):
|
|
""" Deterimine appropriate implementation for this type signature
|
|
|
|
.. deprecated:: 0.4.4
|
|
Use ``dispatch(*types)`` instead
|
|
"""
|
|
warn("resolve() is deprecated, use dispatch(*types)",
|
|
DeprecationWarning)
|
|
|
|
return self.dispatch(*types)
|
|
|
|
def __getstate__(self):
|
|
return {'name': self.name,
|
|
'funcs': self.funcs}
|
|
|
|
def __setstate__(self, d):
|
|
self.name = d['name']
|
|
self.funcs = d['funcs']
|
|
self.ordering = ordering(self.funcs)
|
|
self._cache = {}
|
|
|
|
@property
|
|
def __doc__(self):
|
|
docs = ["Multiply dispatched method: %s" % self.name]
|
|
|
|
if self.doc:
|
|
docs.append(self.doc)
|
|
|
|
other = []
|
|
for sig in self.ordering[::-1]:
|
|
func = self.funcs[sig]
|
|
if func.__doc__:
|
|
s = 'Inputs: <%s>\n' % str_signature(sig)
|
|
s += '-' * len(s) + '\n'
|
|
s += func.__doc__.strip()
|
|
docs.append(s)
|
|
else:
|
|
other.append(str_signature(sig))
|
|
|
|
if other:
|
|
docs.append('Other signatures:\n ' + '\n '.join(other))
|
|
|
|
return '\n\n'.join(docs)
|
|
|
|
def _help(self, *args):
|
|
return self.dispatch(*map(type, args)).__doc__
|
|
|
|
def help(self, *args, **kwargs):
|
|
""" Print docstring for the function corresponding to inputs """
|
|
print(self._help(*args))
|
|
|
|
def _source(self, *args):
|
|
func = self.dispatch(*map(type, args))
|
|
if not func:
|
|
raise TypeError("No function found")
|
|
return source(func)
|
|
|
|
def source(self, *args, **kwargs):
|
|
""" Print source code for the function corresponding to inputs """
|
|
print(self._source(*args))
|
|
|
|
|
|
def source(func):
|
|
s = 'File: %s\n\n' % inspect.getsourcefile(func)
|
|
s = s + inspect.getsource(func)
|
|
return s
|
|
|
|
|
|
class MethodDispatcher(Dispatcher):
|
|
""" Dispatch methods based on type signature
|
|
|
|
See Also:
|
|
Dispatcher
|
|
"""
|
|
|
|
@classmethod
|
|
def get_func_params(cls, func):
|
|
if hasattr(inspect, "signature"):
|
|
sig = inspect.signature(func)
|
|
return itl.islice(sig.parameters.values(), 1, None)
|
|
|
|
def __get__(self, instance, owner):
|
|
self.obj = instance
|
|
self.cls = owner
|
|
return self
|
|
|
|
def __call__(self, *args, **kwargs):
|
|
types = tuple([type(arg) for arg in args])
|
|
func = self.dispatch(*types)
|
|
if not func:
|
|
raise NotImplementedError('Could not find signature for %s: <%s>' %
|
|
(self.name, str_signature(types)))
|
|
return func(self.obj, *args, **kwargs)
|
|
|
|
|
|
def str_signature(sig):
|
|
""" String representation of type signature
|
|
|
|
>>> from sympy.multipledispatch.dispatcher import str_signature
|
|
>>> str_signature((int, float))
|
|
'int, float'
|
|
"""
|
|
return ', '.join(cls.__name__ for cls in sig)
|
|
|
|
|
|
def warning_text(name, amb):
|
|
""" The text for ambiguity warnings """
|
|
text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
|
|
text += "The following signatures may result in ambiguous behavior:\n"
|
|
for pair in amb:
|
|
text += "\t" + \
|
|
', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
|
|
text += "\n\nConsider making the following additions:\n\n"
|
|
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
|
|
+ ')\ndef %s(...)' % name for s in amb])
|
|
return text
|