You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

414 lines
12 KiB

5 months ago
from __future__ import annotations
from warnings import warn
import inspect
from .conflict import ordering, ambiguities, super_signature, AmbiguityWarning
from .utils import expand_tuples
import itertools as itl
class MDNotImplementedError(NotImplementedError):
""" A NotImplementedError for multiple dispatch """
### Functions for on_ambiguity
def ambiguity_warn(dispatcher, ambiguities):
""" Raise warning when ambiguity is detected
Parameters
----------
dispatcher : Dispatcher
The dispatcher on which the ambiguity was detected
ambiguities : set
Set of type signature pairs that are ambiguous within this dispatcher
See Also:
Dispatcher.add
warning_text
"""
warn(warning_text(dispatcher.name, ambiguities), AmbiguityWarning)
class RaiseNotImplementedError:
"""Raise ``NotImplementedError`` when called."""
def __init__(self, dispatcher):
self.dispatcher = dispatcher
def __call__(self, *args, **kwargs):
types = tuple(type(a) for a in args)
raise NotImplementedError(
"Ambiguous signature for %s: <%s>" % (
self.dispatcher.name, str_signature(types)
))
def ambiguity_register_error_ignore_dup(dispatcher, ambiguities):
"""
If super signature for ambiguous types is duplicate types, ignore it.
Else, register instance of ``RaiseNotImplementedError`` for ambiguous types.
Parameters
----------
dispatcher : Dispatcher
The dispatcher on which the ambiguity was detected
ambiguities : set
Set of type signature pairs that are ambiguous within this dispatcher
See Also:
Dispatcher.add
ambiguity_warn
"""
for amb in ambiguities:
signature = tuple(super_signature(amb))
if len(set(signature)) == 1:
continue
dispatcher.add(
signature, RaiseNotImplementedError(dispatcher),
on_ambiguity=ambiguity_register_error_ignore_dup
)
###
_unresolved_dispatchers: set[Dispatcher] = set()
_resolve = [True]
def halt_ordering():
_resolve[0] = False
def restart_ordering(on_ambiguity=ambiguity_warn):
_resolve[0] = True
while _unresolved_dispatchers:
dispatcher = _unresolved_dispatchers.pop()
dispatcher.reorder(on_ambiguity=on_ambiguity)
class Dispatcher:
""" Dispatch methods based on type signature
Use ``dispatch`` to add implementations
Examples
--------
>>> from sympy.multipledispatch import dispatch
>>> @dispatch(int)
... def f(x):
... return x + 1
>>> @dispatch(float)
... def f(x): # noqa: F811
... return x - 1
>>> f(3)
4
>>> f(3.0)
2.0
"""
__slots__ = '__name__', 'name', 'funcs', 'ordering', '_cache', 'doc'
def __init__(self, name, doc=None):
self.name = self.__name__ = name
self.funcs = {}
self._cache = {}
self.ordering = []
self.doc = doc
def register(self, *types, **kwargs):
""" Register dispatcher with new implementation
>>> from sympy.multipledispatch.dispatcher import Dispatcher
>>> f = Dispatcher('f')
>>> @f.register(int)
... def inc(x):
... return x + 1
>>> @f.register(float)
... def dec(x):
... return x - 1
>>> @f.register(list)
... @f.register(tuple)
... def reverse(x):
... return x[::-1]
>>> f(1)
2
>>> f(1.0)
0.0
>>> f([1, 2, 3])
[3, 2, 1]
"""
def _(func):
self.add(types, func, **kwargs)
return func
return _
@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return sig.parameters.values()
@classmethod
def get_func_annotations(cls, func):
""" Get annotations of function positional parameters
"""
params = cls.get_func_params(func)
if params:
Parameter = inspect.Parameter
params = (param for param in params
if param.kind in
(Parameter.POSITIONAL_ONLY,
Parameter.POSITIONAL_OR_KEYWORD))
annotations = tuple(
param.annotation
for param in params)
if not any(ann is Parameter.empty for ann in annotations):
return annotations
def add(self, signature, func, on_ambiguity=ambiguity_warn):
""" Add new types/method pair to dispatcher
>>> from sympy.multipledispatch import Dispatcher
>>> D = Dispatcher('add')
>>> D.add((int, int), lambda x, y: x + y)
>>> D.add((float, float), lambda x, y: x + y)
>>> D(1, 2)
3
>>> D(1, 2.0)
Traceback (most recent call last):
...
NotImplementedError: Could not find signature for add: <int, float>
When ``add`` detects a warning it calls the ``on_ambiguity`` callback
with a dispatcher/itself, and a set of ambiguous type signature pairs
as inputs. See ``ambiguity_warn`` for an example.
"""
# Handle annotations
if not signature:
annotations = self.get_func_annotations(func)
if annotations:
signature = annotations
# Handle union types
if any(isinstance(typ, tuple) for typ in signature):
for typs in expand_tuples(signature):
self.add(typs, func, on_ambiguity)
return
for typ in signature:
if not isinstance(typ, type):
str_sig = ', '.join(c.__name__ if isinstance(c, type)
else str(c) for c in signature)
raise TypeError("Tried to dispatch on non-type: %s\n"
"In signature: <%s>\n"
"In function: %s" %
(typ, str_sig, self.name))
self.funcs[signature] = func
self.reorder(on_ambiguity=on_ambiguity)
self._cache.clear()
def reorder(self, on_ambiguity=ambiguity_warn):
if _resolve[0]:
self.ordering = ordering(self.funcs)
amb = ambiguities(self.funcs)
if amb:
on_ambiguity(self, amb)
else:
_unresolved_dispatchers.add(self)
def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
try:
func = self._cache[types]
except KeyError:
func = self.dispatch(*types)
if not func:
raise NotImplementedError(
'Could not find signature for %s: <%s>' %
(self.name, str_signature(types)))
self._cache[types] = func
try:
return func(*args, **kwargs)
except MDNotImplementedError:
funcs = self.dispatch_iter(*types)
next(funcs) # burn first
for func in funcs:
try:
return func(*args, **kwargs)
except MDNotImplementedError:
pass
raise NotImplementedError("Matching functions for "
"%s: <%s> found, but none completed successfully"
% (self.name, str_signature(types)))
def __str__(self):
return "<dispatched %s>" % self.name
__repr__ = __str__
def dispatch(self, *types):
""" Deterimine appropriate implementation for this type signature
This method is internal. Users should call this object as a function.
Implementation resolution occurs within the ``__call__`` method.
>>> from sympy.multipledispatch import dispatch
>>> @dispatch(int)
... def inc(x):
... return x + 1
>>> implementation = inc.dispatch(int)
>>> implementation(3)
4
>>> print(inc.dispatch(float))
None
See Also:
``sympy.multipledispatch.conflict`` - module to determine resolution order
"""
if types in self.funcs:
return self.funcs[types]
try:
return next(self.dispatch_iter(*types))
except StopIteration:
return None
def dispatch_iter(self, *types):
n = len(types)
for signature in self.ordering:
if len(signature) == n and all(map(issubclass, types, signature)):
result = self.funcs[signature]
yield result
def resolve(self, types):
""" Deterimine appropriate implementation for this type signature
.. deprecated:: 0.4.4
Use ``dispatch(*types)`` instead
"""
warn("resolve() is deprecated, use dispatch(*types)",
DeprecationWarning)
return self.dispatch(*types)
def __getstate__(self):
return {'name': self.name,
'funcs': self.funcs}
def __setstate__(self, d):
self.name = d['name']
self.funcs = d['funcs']
self.ordering = ordering(self.funcs)
self._cache = {}
@property
def __doc__(self):
docs = ["Multiply dispatched method: %s" % self.name]
if self.doc:
docs.append(self.doc)
other = []
for sig in self.ordering[::-1]:
func = self.funcs[sig]
if func.__doc__:
s = 'Inputs: <%s>\n' % str_signature(sig)
s += '-' * len(s) + '\n'
s += func.__doc__.strip()
docs.append(s)
else:
other.append(str_signature(sig))
if other:
docs.append('Other signatures:\n ' + '\n '.join(other))
return '\n\n'.join(docs)
def _help(self, *args):
return self.dispatch(*map(type, args)).__doc__
def help(self, *args, **kwargs):
""" Print docstring for the function corresponding to inputs """
print(self._help(*args))
def _source(self, *args):
func = self.dispatch(*map(type, args))
if not func:
raise TypeError("No function found")
return source(func)
def source(self, *args, **kwargs):
""" Print source code for the function corresponding to inputs """
print(self._source(*args))
def source(func):
s = 'File: %s\n\n' % inspect.getsourcefile(func)
s = s + inspect.getsource(func)
return s
class MethodDispatcher(Dispatcher):
""" Dispatch methods based on type signature
See Also:
Dispatcher
"""
@classmethod
def get_func_params(cls, func):
if hasattr(inspect, "signature"):
sig = inspect.signature(func)
return itl.islice(sig.parameters.values(), 1, None)
def __get__(self, instance, owner):
self.obj = instance
self.cls = owner
return self
def __call__(self, *args, **kwargs):
types = tuple([type(arg) for arg in args])
func = self.dispatch(*types)
if not func:
raise NotImplementedError('Could not find signature for %s: <%s>' %
(self.name, str_signature(types)))
return func(self.obj, *args, **kwargs)
def str_signature(sig):
""" String representation of type signature
>>> from sympy.multipledispatch.dispatcher import str_signature
>>> str_signature((int, float))
'int, float'
"""
return ', '.join(cls.__name__ for cls in sig)
def warning_text(name, amb):
""" The text for ambiguity warnings """
text = "\nAmbiguities exist in dispatched function %s\n\n" % (name)
text += "The following signatures may result in ambiguous behavior:\n"
for pair in amb:
text += "\t" + \
', '.join('[' + str_signature(s) + ']' for s in pair) + "\n"
text += "\n\nConsider making the following additions:\n\n"
text += '\n\n'.join(['@dispatch(' + str_signature(super_signature(s))
+ ')\ndef %s(...)' % name for s in amb])
return text