You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
658 lines
18 KiB
658 lines
18 KiB
r"""
|
|
The following constraints are implemented:
|
|
|
|
- ``constraints.boolean``
|
|
- ``constraints.cat``
|
|
- ``constraints.corr_cholesky``
|
|
- ``constraints.dependent``
|
|
- ``constraints.greater_than(lower_bound)``
|
|
- ``constraints.greater_than_eq(lower_bound)``
|
|
- ``constraints.independent(constraint, reinterpreted_batch_ndims)``
|
|
- ``constraints.integer_interval(lower_bound, upper_bound)``
|
|
- ``constraints.interval(lower_bound, upper_bound)``
|
|
- ``constraints.less_than(upper_bound)``
|
|
- ``constraints.lower_cholesky``
|
|
- ``constraints.lower_triangular``
|
|
- ``constraints.multinomial``
|
|
- ``constraints.nonnegative``
|
|
- ``constraints.nonnegative_integer``
|
|
- ``constraints.one_hot``
|
|
- ``constraints.positive_integer``
|
|
- ``constraints.positive``
|
|
- ``constraints.positive_semidefinite``
|
|
- ``constraints.positive_definite``
|
|
- ``constraints.real_vector``
|
|
- ``constraints.real``
|
|
- ``constraints.simplex``
|
|
- ``constraints.symmetric``
|
|
- ``constraints.stack``
|
|
- ``constraints.square``
|
|
- ``constraints.symmetric``
|
|
- ``constraints.unit_interval``
|
|
"""
|
|
|
|
import torch
|
|
|
|
__all__ = [
|
|
"Constraint",
|
|
"boolean",
|
|
"cat",
|
|
"corr_cholesky",
|
|
"dependent",
|
|
"dependent_property",
|
|
"greater_than",
|
|
"greater_than_eq",
|
|
"independent",
|
|
"integer_interval",
|
|
"interval",
|
|
"half_open_interval",
|
|
"is_dependent",
|
|
"less_than",
|
|
"lower_cholesky",
|
|
"lower_triangular",
|
|
"multinomial",
|
|
"nonnegative",
|
|
"nonnegative_integer",
|
|
"one_hot",
|
|
"positive",
|
|
"positive_semidefinite",
|
|
"positive_definite",
|
|
"positive_integer",
|
|
"real",
|
|
"real_vector",
|
|
"simplex",
|
|
"square",
|
|
"stack",
|
|
"symmetric",
|
|
"unit_interval",
|
|
]
|
|
|
|
|
|
class Constraint:
|
|
"""
|
|
Abstract base class for constraints.
|
|
|
|
A constraint object represents a region over which a variable is valid,
|
|
e.g. within which a variable can be optimized.
|
|
|
|
Attributes:
|
|
is_discrete (bool): Whether constrained space is discrete.
|
|
Defaults to False.
|
|
event_dim (int): Number of rightmost dimensions that together define
|
|
an event. The :meth:`check` method will remove this many dimensions
|
|
when computing validity.
|
|
"""
|
|
|
|
is_discrete = False # Default to continuous.
|
|
event_dim = 0 # Default to univariate.
|
|
|
|
def check(self, value):
|
|
"""
|
|
Returns a byte tensor of ``sample_shape + batch_shape`` indicating
|
|
whether each event in value satisfies this constraint.
|
|
"""
|
|
raise NotImplementedError
|
|
|
|
def __repr__(self):
|
|
return self.__class__.__name__[1:] + "()"
|
|
|
|
|
|
class _Dependent(Constraint):
|
|
"""
|
|
Placeholder for variables whose support depends on other variables.
|
|
These variables obey no simple coordinate-wise constraints.
|
|
|
|
Args:
|
|
is_discrete (bool): Optional value of ``.is_discrete`` in case this
|
|
can be computed statically. If not provided, access to the
|
|
``.is_discrete`` attribute will raise a NotImplementedError.
|
|
event_dim (int): Optional value of ``.event_dim`` in case this
|
|
can be computed statically. If not provided, access to the
|
|
``.event_dim`` attribute will raise a NotImplementedError.
|
|
"""
|
|
|
|
def __init__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
|
|
self._is_discrete = is_discrete
|
|
self._event_dim = event_dim
|
|
super().__init__()
|
|
|
|
@property
|
|
def is_discrete(self):
|
|
if self._is_discrete is NotImplemented:
|
|
raise NotImplementedError(".is_discrete cannot be determined statically")
|
|
return self._is_discrete
|
|
|
|
@property
|
|
def event_dim(self):
|
|
if self._event_dim is NotImplemented:
|
|
raise NotImplementedError(".event_dim cannot be determined statically")
|
|
return self._event_dim
|
|
|
|
def __call__(self, *, is_discrete=NotImplemented, event_dim=NotImplemented):
|
|
"""
|
|
Support for syntax to customize static attributes::
|
|
|
|
constraints.dependent(is_discrete=True, event_dim=1)
|
|
"""
|
|
if is_discrete is NotImplemented:
|
|
is_discrete = self._is_discrete
|
|
if event_dim is NotImplemented:
|
|
event_dim = self._event_dim
|
|
return _Dependent(is_discrete=is_discrete, event_dim=event_dim)
|
|
|
|
def check(self, x):
|
|
raise ValueError("Cannot determine validity of dependent constraint")
|
|
|
|
|
|
def is_dependent(constraint):
|
|
return isinstance(constraint, _Dependent)
|
|
|
|
|
|
class _DependentProperty(property, _Dependent):
|
|
"""
|
|
Decorator that extends @property to act like a `Dependent` constraint when
|
|
called on a class and act like a property when called on an object.
|
|
|
|
Example::
|
|
|
|
class Uniform(Distribution):
|
|
def __init__(self, low, high):
|
|
self.low = low
|
|
self.high = high
|
|
@constraints.dependent_property(is_discrete=False, event_dim=0)
|
|
def support(self):
|
|
return constraints.interval(self.low, self.high)
|
|
|
|
Args:
|
|
fn (Callable): The function to be decorated.
|
|
is_discrete (bool): Optional value of ``.is_discrete`` in case this
|
|
can be computed statically. If not provided, access to the
|
|
``.is_discrete`` attribute will raise a NotImplementedError.
|
|
event_dim (int): Optional value of ``.event_dim`` in case this
|
|
can be computed statically. If not provided, access to the
|
|
``.event_dim`` attribute will raise a NotImplementedError.
|
|
"""
|
|
|
|
def __init__(
|
|
self, fn=None, *, is_discrete=NotImplemented, event_dim=NotImplemented
|
|
):
|
|
super().__init__(fn)
|
|
self._is_discrete = is_discrete
|
|
self._event_dim = event_dim
|
|
|
|
def __call__(self, fn):
|
|
"""
|
|
Support for syntax to customize static attributes::
|
|
|
|
@constraints.dependent_property(is_discrete=True, event_dim=1)
|
|
def support(self):
|
|
...
|
|
"""
|
|
return _DependentProperty(
|
|
fn, is_discrete=self._is_discrete, event_dim=self._event_dim
|
|
)
|
|
|
|
|
|
class _IndependentConstraint(Constraint):
|
|
"""
|
|
Wraps a constraint by aggregating over ``reinterpreted_batch_ndims``-many
|
|
dims in :meth:`check`, so that an event is valid only if all its
|
|
independent entries are valid.
|
|
"""
|
|
|
|
def __init__(self, base_constraint, reinterpreted_batch_ndims):
|
|
assert isinstance(base_constraint, Constraint)
|
|
assert isinstance(reinterpreted_batch_ndims, int)
|
|
assert reinterpreted_batch_ndims >= 0
|
|
self.base_constraint = base_constraint
|
|
self.reinterpreted_batch_ndims = reinterpreted_batch_ndims
|
|
super().__init__()
|
|
|
|
@property
|
|
def is_discrete(self):
|
|
return self.base_constraint.is_discrete
|
|
|
|
@property
|
|
def event_dim(self):
|
|
return self.base_constraint.event_dim + self.reinterpreted_batch_ndims
|
|
|
|
def check(self, value):
|
|
result = self.base_constraint.check(value)
|
|
if result.dim() < self.reinterpreted_batch_ndims:
|
|
expected = self.base_constraint.event_dim + self.reinterpreted_batch_ndims
|
|
raise ValueError(
|
|
f"Expected value.dim() >= {expected} but got {value.dim()}"
|
|
)
|
|
result = result.reshape(
|
|
result.shape[: result.dim() - self.reinterpreted_batch_ndims] + (-1,)
|
|
)
|
|
result = result.all(-1)
|
|
return result
|
|
|
|
def __repr__(self):
|
|
return f"{self.__class__.__name__[1:]}({repr(self.base_constraint)}, {self.reinterpreted_batch_ndims})"
|
|
|
|
|
|
class _Boolean(Constraint):
|
|
"""
|
|
Constrain to the two values `{0, 1}`.
|
|
"""
|
|
|
|
is_discrete = True
|
|
|
|
def check(self, value):
|
|
return (value == 0) | (value == 1)
|
|
|
|
|
|
class _OneHot(Constraint):
|
|
"""
|
|
Constrain to one-hot vectors.
|
|
"""
|
|
|
|
is_discrete = True
|
|
event_dim = 1
|
|
|
|
def check(self, value):
|
|
is_boolean = (value == 0) | (value == 1)
|
|
is_normalized = value.sum(-1).eq(1)
|
|
return is_boolean.all(-1) & is_normalized
|
|
|
|
|
|
class _IntegerInterval(Constraint):
|
|
"""
|
|
Constrain to an integer interval `[lower_bound, upper_bound]`.
|
|
"""
|
|
|
|
is_discrete = True
|
|
|
|
def __init__(self, lower_bound, upper_bound):
|
|
self.lower_bound = lower_bound
|
|
self.upper_bound = upper_bound
|
|
super().__init__()
|
|
|
|
def check(self, value):
|
|
return (
|
|
(value % 1 == 0) & (self.lower_bound <= value) & (value <= self.upper_bound)
|
|
)
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += (
|
|
f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
|
|
)
|
|
return fmt_string
|
|
|
|
|
|
class _IntegerLessThan(Constraint):
|
|
"""
|
|
Constrain to an integer interval `(-inf, upper_bound]`.
|
|
"""
|
|
|
|
is_discrete = True
|
|
|
|
def __init__(self, upper_bound):
|
|
self.upper_bound = upper_bound
|
|
super().__init__()
|
|
|
|
def check(self, value):
|
|
return (value % 1 == 0) & (value <= self.upper_bound)
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += f"(upper_bound={self.upper_bound})"
|
|
return fmt_string
|
|
|
|
|
|
class _IntegerGreaterThan(Constraint):
|
|
"""
|
|
Constrain to an integer interval `[lower_bound, inf)`.
|
|
"""
|
|
|
|
is_discrete = True
|
|
|
|
def __init__(self, lower_bound):
|
|
self.lower_bound = lower_bound
|
|
super().__init__()
|
|
|
|
def check(self, value):
|
|
return (value % 1 == 0) & (value >= self.lower_bound)
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += f"(lower_bound={self.lower_bound})"
|
|
return fmt_string
|
|
|
|
|
|
class _Real(Constraint):
|
|
"""
|
|
Trivially constrain to the extended real line `[-inf, inf]`.
|
|
"""
|
|
|
|
def check(self, value):
|
|
return value == value # False for NANs.
|
|
|
|
|
|
class _GreaterThan(Constraint):
|
|
"""
|
|
Constrain to a real half line `(lower_bound, inf]`.
|
|
"""
|
|
|
|
def __init__(self, lower_bound):
|
|
self.lower_bound = lower_bound
|
|
super().__init__()
|
|
|
|
def check(self, value):
|
|
return self.lower_bound < value
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += f"(lower_bound={self.lower_bound})"
|
|
return fmt_string
|
|
|
|
|
|
class _GreaterThanEq(Constraint):
|
|
"""
|
|
Constrain to a real half line `[lower_bound, inf)`.
|
|
"""
|
|
|
|
def __init__(self, lower_bound):
|
|
self.lower_bound = lower_bound
|
|
super().__init__()
|
|
|
|
def check(self, value):
|
|
return self.lower_bound <= value
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += f"(lower_bound={self.lower_bound})"
|
|
return fmt_string
|
|
|
|
|
|
class _LessThan(Constraint):
|
|
"""
|
|
Constrain to a real half line `[-inf, upper_bound)`.
|
|
"""
|
|
|
|
def __init__(self, upper_bound):
|
|
self.upper_bound = upper_bound
|
|
super().__init__()
|
|
|
|
def check(self, value):
|
|
return value < self.upper_bound
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += f"(upper_bound={self.upper_bound})"
|
|
return fmt_string
|
|
|
|
|
|
class _Interval(Constraint):
|
|
"""
|
|
Constrain to a real interval `[lower_bound, upper_bound]`.
|
|
"""
|
|
|
|
def __init__(self, lower_bound, upper_bound):
|
|
self.lower_bound = lower_bound
|
|
self.upper_bound = upper_bound
|
|
super().__init__()
|
|
|
|
def check(self, value):
|
|
return (self.lower_bound <= value) & (value <= self.upper_bound)
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += (
|
|
f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
|
|
)
|
|
return fmt_string
|
|
|
|
|
|
class _HalfOpenInterval(Constraint):
|
|
"""
|
|
Constrain to a real interval `[lower_bound, upper_bound)`.
|
|
"""
|
|
|
|
def __init__(self, lower_bound, upper_bound):
|
|
self.lower_bound = lower_bound
|
|
self.upper_bound = upper_bound
|
|
super().__init__()
|
|
|
|
def check(self, value):
|
|
return (self.lower_bound <= value) & (value < self.upper_bound)
|
|
|
|
def __repr__(self):
|
|
fmt_string = self.__class__.__name__[1:]
|
|
fmt_string += (
|
|
f"(lower_bound={self.lower_bound}, upper_bound={self.upper_bound})"
|
|
)
|
|
return fmt_string
|
|
|
|
|
|
class _Simplex(Constraint):
|
|
"""
|
|
Constrain to the unit simplex in the innermost (rightmost) dimension.
|
|
Specifically: `x >= 0` and `x.sum(-1) == 1`.
|
|
"""
|
|
|
|
event_dim = 1
|
|
|
|
def check(self, value):
|
|
return torch.all(value >= 0, dim=-1) & ((value.sum(-1) - 1).abs() < 1e-6)
|
|
|
|
|
|
class _Multinomial(Constraint):
|
|
"""
|
|
Constrain to nonnegative integer values summing to at most an upper bound.
|
|
|
|
Note due to limitations of the Multinomial distribution, this currently
|
|
checks the weaker condition ``value.sum(-1) <= upper_bound``. In the future
|
|
this may be strengthened to ``value.sum(-1) == upper_bound``.
|
|
"""
|
|
|
|
is_discrete = True
|
|
event_dim = 1
|
|
|
|
def __init__(self, upper_bound):
|
|
self.upper_bound = upper_bound
|
|
|
|
def check(self, x):
|
|
return (x >= 0).all(dim=-1) & (x.sum(dim=-1) <= self.upper_bound)
|
|
|
|
|
|
class _LowerTriangular(Constraint):
|
|
"""
|
|
Constrain to lower-triangular square matrices.
|
|
"""
|
|
|
|
event_dim = 2
|
|
|
|
def check(self, value):
|
|
value_tril = value.tril()
|
|
return (value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
|
|
|
|
|
|
class _LowerCholesky(Constraint):
|
|
"""
|
|
Constrain to lower-triangular square matrices with positive diagonals.
|
|
"""
|
|
|
|
event_dim = 2
|
|
|
|
def check(self, value):
|
|
value_tril = value.tril()
|
|
lower_triangular = (
|
|
(value_tril == value).view(value.shape[:-2] + (-1,)).min(-1)[0]
|
|
)
|
|
|
|
positive_diagonal = (value.diagonal(dim1=-2, dim2=-1) > 0).min(-1)[0]
|
|
return lower_triangular & positive_diagonal
|
|
|
|
|
|
class _CorrCholesky(Constraint):
|
|
"""
|
|
Constrain to lower-triangular square matrices with positive diagonals and each
|
|
row vector being of unit length.
|
|
"""
|
|
|
|
event_dim = 2
|
|
|
|
def check(self, value):
|
|
tol = (
|
|
torch.finfo(value.dtype).eps * value.size(-1) * 10
|
|
) # 10 is an adjustable fudge factor
|
|
row_norm = torch.linalg.norm(value.detach(), dim=-1)
|
|
unit_row_norm = (row_norm - 1.0).abs().le(tol).all(dim=-1)
|
|
return _LowerCholesky().check(value) & unit_row_norm
|
|
|
|
|
|
class _Square(Constraint):
|
|
"""
|
|
Constrain to square matrices.
|
|
"""
|
|
|
|
event_dim = 2
|
|
|
|
def check(self, value):
|
|
return torch.full(
|
|
size=value.shape[:-2],
|
|
fill_value=(value.shape[-2] == value.shape[-1]),
|
|
dtype=torch.bool,
|
|
device=value.device,
|
|
)
|
|
|
|
|
|
class _Symmetric(_Square):
|
|
"""
|
|
Constrain to Symmetric square matrices.
|
|
"""
|
|
|
|
def check(self, value):
|
|
square_check = super().check(value)
|
|
if not square_check.all():
|
|
return square_check
|
|
return torch.isclose(value, value.mT, atol=1e-6).all(-2).all(-1)
|
|
|
|
|
|
class _PositiveSemidefinite(_Symmetric):
|
|
"""
|
|
Constrain to positive-semidefinite matrices.
|
|
"""
|
|
|
|
def check(self, value):
|
|
sym_check = super().check(value)
|
|
if not sym_check.all():
|
|
return sym_check
|
|
return torch.linalg.eigvalsh(value).ge(0).all(-1)
|
|
|
|
|
|
class _PositiveDefinite(_Symmetric):
|
|
"""
|
|
Constrain to positive-definite matrices.
|
|
"""
|
|
|
|
def check(self, value):
|
|
sym_check = super().check(value)
|
|
if not sym_check.all():
|
|
return sym_check
|
|
return torch.linalg.cholesky_ex(value).info.eq(0)
|
|
|
|
|
|
class _Cat(Constraint):
|
|
"""
|
|
Constraint functor that applies a sequence of constraints
|
|
`cseq` at the submatrices at dimension `dim`,
|
|
each of size `lengths[dim]`, in a way compatible with :func:`torch.cat`.
|
|
"""
|
|
|
|
def __init__(self, cseq, dim=0, lengths=None):
|
|
assert all(isinstance(c, Constraint) for c in cseq)
|
|
self.cseq = list(cseq)
|
|
if lengths is None:
|
|
lengths = [1] * len(self.cseq)
|
|
self.lengths = list(lengths)
|
|
assert len(self.lengths) == len(self.cseq)
|
|
self.dim = dim
|
|
super().__init__()
|
|
|
|
@property
|
|
def is_discrete(self):
|
|
return any(c.is_discrete for c in self.cseq)
|
|
|
|
@property
|
|
def event_dim(self):
|
|
return max(c.event_dim for c in self.cseq)
|
|
|
|
def check(self, value):
|
|
assert -value.dim() <= self.dim < value.dim()
|
|
checks = []
|
|
start = 0
|
|
for constr, length in zip(self.cseq, self.lengths):
|
|
v = value.narrow(self.dim, start, length)
|
|
checks.append(constr.check(v))
|
|
start = start + length # avoid += for jit compat
|
|
return torch.cat(checks, self.dim)
|
|
|
|
|
|
class _Stack(Constraint):
|
|
"""
|
|
Constraint functor that applies a sequence of constraints
|
|
`cseq` at the submatrices at dimension `dim`,
|
|
in a way compatible with :func:`torch.stack`.
|
|
"""
|
|
|
|
def __init__(self, cseq, dim=0):
|
|
assert all(isinstance(c, Constraint) for c in cseq)
|
|
self.cseq = list(cseq)
|
|
self.dim = dim
|
|
super().__init__()
|
|
|
|
@property
|
|
def is_discrete(self):
|
|
return any(c.is_discrete for c in self.cseq)
|
|
|
|
@property
|
|
def event_dim(self):
|
|
dim = max(c.event_dim for c in self.cseq)
|
|
if self.dim + dim < 0:
|
|
dim += 1
|
|
return dim
|
|
|
|
def check(self, value):
|
|
assert -value.dim() <= self.dim < value.dim()
|
|
vs = [value.select(self.dim, i) for i in range(value.size(self.dim))]
|
|
return torch.stack(
|
|
[constr.check(v) for v, constr in zip(vs, self.cseq)], self.dim
|
|
)
|
|
|
|
|
|
# Public interface.
|
|
dependent = _Dependent()
|
|
dependent_property = _DependentProperty
|
|
independent = _IndependentConstraint
|
|
boolean = _Boolean()
|
|
one_hot = _OneHot()
|
|
nonnegative_integer = _IntegerGreaterThan(0)
|
|
positive_integer = _IntegerGreaterThan(1)
|
|
integer_interval = _IntegerInterval
|
|
real = _Real()
|
|
real_vector = independent(real, 1)
|
|
positive = _GreaterThan(0.0)
|
|
nonnegative = _GreaterThanEq(0.0)
|
|
greater_than = _GreaterThan
|
|
greater_than_eq = _GreaterThanEq
|
|
less_than = _LessThan
|
|
multinomial = _Multinomial
|
|
unit_interval = _Interval(0.0, 1.0)
|
|
interval = _Interval
|
|
half_open_interval = _HalfOpenInterval
|
|
simplex = _Simplex()
|
|
lower_triangular = _LowerTriangular()
|
|
lower_cholesky = _LowerCholesky()
|
|
corr_cholesky = _CorrCholesky()
|
|
square = _Square()
|
|
symmetric = _Symmetric()
|
|
positive_semidefinite = _PositiveSemidefinite()
|
|
positive_definite = _PositiveDefinite()
|
|
cat = _Cat
|
|
stack = _Stack
|