You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
293 lines
10 KiB
293 lines
10 KiB
r"""
|
|
PyTorch provides two global :class:`ConstraintRegistry` objects that link
|
|
:class:`~torch.distributions.constraints.Constraint` objects to
|
|
:class:`~torch.distributions.transforms.Transform` objects. These objects both
|
|
input constraints and return transforms, but they have different guarantees on
|
|
bijectivity.
|
|
|
|
1. ``biject_to(constraint)`` looks up a bijective
|
|
:class:`~torch.distributions.transforms.Transform` from ``constraints.real``
|
|
to the given ``constraint``. The returned transform is guaranteed to have
|
|
``.bijective = True`` and should implement ``.log_abs_det_jacobian()``.
|
|
2. ``transform_to(constraint)`` looks up a not-necessarily bijective
|
|
:class:`~torch.distributions.transforms.Transform` from ``constraints.real``
|
|
to the given ``constraint``. The returned transform is not guaranteed to
|
|
implement ``.log_abs_det_jacobian()``.
|
|
|
|
The ``transform_to()`` registry is useful for performing unconstrained
|
|
optimization on constrained parameters of probability distributions, which are
|
|
indicated by each distribution's ``.arg_constraints`` dict. These transforms often
|
|
overparameterize a space in order to avoid rotation; they are thus more
|
|
suitable for coordinate-wise optimization algorithms like Adam::
|
|
|
|
loc = torch.zeros(100, requires_grad=True)
|
|
unconstrained = torch.zeros(100, requires_grad=True)
|
|
scale = transform_to(Normal.arg_constraints['scale'])(unconstrained)
|
|
loss = -Normal(loc, scale).log_prob(data).sum()
|
|
|
|
The ``biject_to()`` registry is useful for Hamiltonian Monte Carlo, where
|
|
samples from a probability distribution with constrained ``.support`` are
|
|
propagated in an unconstrained space, and algorithms are typically rotation
|
|
invariant.::
|
|
|
|
dist = Exponential(rate)
|
|
unconstrained = torch.zeros(100, requires_grad=True)
|
|
sample = biject_to(dist.support)(unconstrained)
|
|
potential_energy = -dist.log_prob(sample).sum()
|
|
|
|
.. note::
|
|
|
|
An example where ``transform_to`` and ``biject_to`` differ is
|
|
``constraints.simplex``: ``transform_to(constraints.simplex)`` returns a
|
|
:class:`~torch.distributions.transforms.SoftmaxTransform` that simply
|
|
exponentiates and normalizes its inputs; this is a cheap and mostly
|
|
coordinate-wise operation appropriate for algorithms like SVI. In
|
|
contrast, ``biject_to(constraints.simplex)`` returns a
|
|
:class:`~torch.distributions.transforms.StickBreakingTransform` that
|
|
bijects its input down to a one-fewer-dimensional space; this a more
|
|
expensive less numerically stable transform but is needed for algorithms
|
|
like HMC.
|
|
|
|
The ``biject_to`` and ``transform_to`` objects can be extended by user-defined
|
|
constraints and transforms using their ``.register()`` method either as a
|
|
function on singleton constraints::
|
|
|
|
transform_to.register(my_constraint, my_transform)
|
|
|
|
or as a decorator on parameterized constraints::
|
|
|
|
@transform_to.register(MyConstraintClass)
|
|
def my_factory(constraint):
|
|
assert isinstance(constraint, MyConstraintClass)
|
|
return MyTransform(constraint.param1, constraint.param2)
|
|
|
|
You can create your own registry by creating a new :class:`ConstraintRegistry`
|
|
object.
|
|
"""
|
|
|
|
import numbers
|
|
|
|
from torch.distributions import constraints, transforms
|
|
|
|
__all__ = [
|
|
"ConstraintRegistry",
|
|
"biject_to",
|
|
"transform_to",
|
|
]
|
|
|
|
|
|
class ConstraintRegistry:
|
|
"""
|
|
Registry to link constraints to transforms.
|
|
"""
|
|
|
|
def __init__(self):
|
|
self._registry = {}
|
|
super().__init__()
|
|
|
|
def register(self, constraint, factory=None):
|
|
"""
|
|
Registers a :class:`~torch.distributions.constraints.Constraint`
|
|
subclass in this registry. Usage::
|
|
|
|
@my_registry.register(MyConstraintClass)
|
|
def construct_transform(constraint):
|
|
assert isinstance(constraint, MyConstraint)
|
|
return MyTransform(constraint.arg_constraints)
|
|
|
|
Args:
|
|
constraint (subclass of :class:`~torch.distributions.constraints.Constraint`):
|
|
A subclass of :class:`~torch.distributions.constraints.Constraint`, or
|
|
a singleton object of the desired class.
|
|
factory (Callable): A callable that inputs a constraint object and returns
|
|
a :class:`~torch.distributions.transforms.Transform` object.
|
|
"""
|
|
# Support use as decorator.
|
|
if factory is None:
|
|
return lambda factory: self.register(constraint, factory)
|
|
|
|
# Support calling on singleton instances.
|
|
if isinstance(constraint, constraints.Constraint):
|
|
constraint = type(constraint)
|
|
|
|
if not isinstance(constraint, type) or not issubclass(
|
|
constraint, constraints.Constraint
|
|
):
|
|
raise TypeError(
|
|
f"Expected constraint to be either a Constraint subclass or instance, but got {constraint}"
|
|
)
|
|
|
|
self._registry[constraint] = factory
|
|
return factory
|
|
|
|
def __call__(self, constraint):
|
|
"""
|
|
Looks up a transform to constrained space, given a constraint object.
|
|
Usage::
|
|
|
|
constraint = Normal.arg_constraints['scale']
|
|
scale = transform_to(constraint)(torch.zeros(1)) # constrained
|
|
u = transform_to(constraint).inv(scale) # unconstrained
|
|
|
|
Args:
|
|
constraint (:class:`~torch.distributions.constraints.Constraint`):
|
|
A constraint object.
|
|
|
|
Returns:
|
|
A :class:`~torch.distributions.transforms.Transform` object.
|
|
|
|
Raises:
|
|
`NotImplementedError` if no transform has been registered.
|
|
"""
|
|
# Look up by Constraint subclass.
|
|
try:
|
|
factory = self._registry[type(constraint)]
|
|
except KeyError:
|
|
raise NotImplementedError(
|
|
f"Cannot transform {type(constraint).__name__} constraints"
|
|
) from None
|
|
return factory(constraint)
|
|
|
|
|
|
biject_to = ConstraintRegistry()
|
|
transform_to = ConstraintRegistry()
|
|
|
|
|
|
################################################################################
|
|
# Registration Table
|
|
################################################################################
|
|
|
|
|
|
@biject_to.register(constraints.real)
|
|
@transform_to.register(constraints.real)
|
|
def _transform_to_real(constraint):
|
|
return transforms.identity_transform
|
|
|
|
|
|
@biject_to.register(constraints.independent)
|
|
def _biject_to_independent(constraint):
|
|
base_transform = biject_to(constraint.base_constraint)
|
|
return transforms.IndependentTransform(
|
|
base_transform, constraint.reinterpreted_batch_ndims
|
|
)
|
|
|
|
|
|
@transform_to.register(constraints.independent)
|
|
def _transform_to_independent(constraint):
|
|
base_transform = transform_to(constraint.base_constraint)
|
|
return transforms.IndependentTransform(
|
|
base_transform, constraint.reinterpreted_batch_ndims
|
|
)
|
|
|
|
|
|
@biject_to.register(constraints.positive)
|
|
@biject_to.register(constraints.nonnegative)
|
|
@transform_to.register(constraints.positive)
|
|
@transform_to.register(constraints.nonnegative)
|
|
def _transform_to_positive(constraint):
|
|
return transforms.ExpTransform()
|
|
|
|
|
|
@biject_to.register(constraints.greater_than)
|
|
@biject_to.register(constraints.greater_than_eq)
|
|
@transform_to.register(constraints.greater_than)
|
|
@transform_to.register(constraints.greater_than_eq)
|
|
def _transform_to_greater_than(constraint):
|
|
return transforms.ComposeTransform(
|
|
[
|
|
transforms.ExpTransform(),
|
|
transforms.AffineTransform(constraint.lower_bound, 1),
|
|
]
|
|
)
|
|
|
|
|
|
@biject_to.register(constraints.less_than)
|
|
@transform_to.register(constraints.less_than)
|
|
def _transform_to_less_than(constraint):
|
|
return transforms.ComposeTransform(
|
|
[
|
|
transforms.ExpTransform(),
|
|
transforms.AffineTransform(constraint.upper_bound, -1),
|
|
]
|
|
)
|
|
|
|
|
|
@biject_to.register(constraints.interval)
|
|
@biject_to.register(constraints.half_open_interval)
|
|
@transform_to.register(constraints.interval)
|
|
@transform_to.register(constraints.half_open_interval)
|
|
def _transform_to_interval(constraint):
|
|
# Handle the special case of the unit interval.
|
|
lower_is_0 = (
|
|
isinstance(constraint.lower_bound, numbers.Number)
|
|
and constraint.lower_bound == 0
|
|
)
|
|
upper_is_1 = (
|
|
isinstance(constraint.upper_bound, numbers.Number)
|
|
and constraint.upper_bound == 1
|
|
)
|
|
if lower_is_0 and upper_is_1:
|
|
return transforms.SigmoidTransform()
|
|
|
|
loc = constraint.lower_bound
|
|
scale = constraint.upper_bound - constraint.lower_bound
|
|
return transforms.ComposeTransform(
|
|
[transforms.SigmoidTransform(), transforms.AffineTransform(loc, scale)]
|
|
)
|
|
|
|
|
|
@biject_to.register(constraints.simplex)
|
|
def _biject_to_simplex(constraint):
|
|
return transforms.StickBreakingTransform()
|
|
|
|
|
|
@transform_to.register(constraints.simplex)
|
|
def _transform_to_simplex(constraint):
|
|
return transforms.SoftmaxTransform()
|
|
|
|
|
|
# TODO define a bijection for LowerCholeskyTransform
|
|
@transform_to.register(constraints.lower_cholesky)
|
|
def _transform_to_lower_cholesky(constraint):
|
|
return transforms.LowerCholeskyTransform()
|
|
|
|
|
|
@transform_to.register(constraints.positive_definite)
|
|
@transform_to.register(constraints.positive_semidefinite)
|
|
def _transform_to_positive_definite(constraint):
|
|
return transforms.PositiveDefiniteTransform()
|
|
|
|
|
|
@biject_to.register(constraints.corr_cholesky)
|
|
@transform_to.register(constraints.corr_cholesky)
|
|
def _transform_to_corr_cholesky(constraint):
|
|
return transforms.CorrCholeskyTransform()
|
|
|
|
|
|
@biject_to.register(constraints.cat)
|
|
def _biject_to_cat(constraint):
|
|
return transforms.CatTransform(
|
|
[biject_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
|
|
)
|
|
|
|
|
|
@transform_to.register(constraints.cat)
|
|
def _transform_to_cat(constraint):
|
|
return transforms.CatTransform(
|
|
[transform_to(c) for c in constraint.cseq], constraint.dim, constraint.lengths
|
|
)
|
|
|
|
|
|
@biject_to.register(constraints.stack)
|
|
def _biject_to_stack(constraint):
|
|
return transforms.StackTransform(
|
|
[biject_to(c) for c in constraint.cseq], constraint.dim
|
|
)
|
|
|
|
|
|
@transform_to.register(constraints.stack)
|
|
def _transform_to_stack(constraint):
|
|
return transforms.StackTransform(
|
|
[transform_to(c) for c in constraint.cseq], constraint.dim
|
|
)
|