You can not select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
210 lines
5.9 KiB
210 lines
5.9 KiB
5 months ago
|
import math
|
||
|
|
||
|
import torch
|
||
|
import torch.jit
|
||
|
from torch.distributions import constraints
|
||
|
from torch.distributions.distribution import Distribution
|
||
|
from torch.distributions.utils import broadcast_all, lazy_property
|
||
|
|
||
|
__all__ = ["VonMises"]
|
||
|
|
||
|
|
||
|
def _eval_poly(y, coef):
|
||
|
coef = list(coef)
|
||
|
result = coef.pop()
|
||
|
while coef:
|
||
|
result = coef.pop() + y * result
|
||
|
return result
|
||
|
|
||
|
|
||
|
_I0_COEF_SMALL = [
|
||
|
1.0,
|
||
|
3.5156229,
|
||
|
3.0899424,
|
||
|
1.2067492,
|
||
|
0.2659732,
|
||
|
0.360768e-1,
|
||
|
0.45813e-2,
|
||
|
]
|
||
|
_I0_COEF_LARGE = [
|
||
|
0.39894228,
|
||
|
0.1328592e-1,
|
||
|
0.225319e-2,
|
||
|
-0.157565e-2,
|
||
|
0.916281e-2,
|
||
|
-0.2057706e-1,
|
||
|
0.2635537e-1,
|
||
|
-0.1647633e-1,
|
||
|
0.392377e-2,
|
||
|
]
|
||
|
_I1_COEF_SMALL = [
|
||
|
0.5,
|
||
|
0.87890594,
|
||
|
0.51498869,
|
||
|
0.15084934,
|
||
|
0.2658733e-1,
|
||
|
0.301532e-2,
|
||
|
0.32411e-3,
|
||
|
]
|
||
|
_I1_COEF_LARGE = [
|
||
|
0.39894228,
|
||
|
-0.3988024e-1,
|
||
|
-0.362018e-2,
|
||
|
0.163801e-2,
|
||
|
-0.1031555e-1,
|
||
|
0.2282967e-1,
|
||
|
-0.2895312e-1,
|
||
|
0.1787654e-1,
|
||
|
-0.420059e-2,
|
||
|
]
|
||
|
|
||
|
_COEF_SMALL = [_I0_COEF_SMALL, _I1_COEF_SMALL]
|
||
|
_COEF_LARGE = [_I0_COEF_LARGE, _I1_COEF_LARGE]
|
||
|
|
||
|
|
||
|
def _log_modified_bessel_fn(x, order=0):
|
||
|
"""
|
||
|
Returns ``log(I_order(x))`` for ``x > 0``,
|
||
|
where `order` is either 0 or 1.
|
||
|
"""
|
||
|
assert order == 0 or order == 1
|
||
|
|
||
|
# compute small solution
|
||
|
y = x / 3.75
|
||
|
y = y * y
|
||
|
small = _eval_poly(y, _COEF_SMALL[order])
|
||
|
if order == 1:
|
||
|
small = x.abs() * small
|
||
|
small = small.log()
|
||
|
|
||
|
# compute large solution
|
||
|
y = 3.75 / x
|
||
|
large = x - 0.5 * x.log() + _eval_poly(y, _COEF_LARGE[order]).log()
|
||
|
|
||
|
result = torch.where(x < 3.75, small, large)
|
||
|
return result
|
||
|
|
||
|
|
||
|
@torch.jit.script_if_tracing
|
||
|
def _rejection_sample(loc, concentration, proposal_r, x):
|
||
|
done = torch.zeros(x.shape, dtype=torch.bool, device=loc.device)
|
||
|
while not done.all():
|
||
|
u = torch.rand((3,) + x.shape, dtype=loc.dtype, device=loc.device)
|
||
|
u1, u2, u3 = u.unbind()
|
||
|
z = torch.cos(math.pi * u1)
|
||
|
f = (1 + proposal_r * z) / (proposal_r + z)
|
||
|
c = concentration * (proposal_r - f)
|
||
|
accept = ((c * (2 - c) - u2) > 0) | ((c / u2).log() + 1 - c >= 0)
|
||
|
if accept.any():
|
||
|
x = torch.where(accept, (u3 - 0.5).sign() * f.acos(), x)
|
||
|
done = done | accept
|
||
|
return (x + math.pi + loc) % (2 * math.pi) - math.pi
|
||
|
|
||
|
|
||
|
class VonMises(Distribution):
|
||
|
"""
|
||
|
A circular von Mises distribution.
|
||
|
|
||
|
This implementation uses polar coordinates. The ``loc`` and ``value`` args
|
||
|
can be any real number (to facilitate unconstrained optimization), but are
|
||
|
interpreted as angles modulo 2 pi.
|
||
|
|
||
|
Example::
|
||
|
>>> # xdoctest: +IGNORE_WANT("non-deterministic")
|
||
|
>>> m = VonMises(torch.tensor([1.0]), torch.tensor([1.0]))
|
||
|
>>> m.sample() # von Mises distributed with loc=1 and concentration=1
|
||
|
tensor([1.9777])
|
||
|
|
||
|
:param torch.Tensor loc: an angle in radians.
|
||
|
:param torch.Tensor concentration: concentration parameter
|
||
|
"""
|
||
|
|
||
|
arg_constraints = {"loc": constraints.real, "concentration": constraints.positive}
|
||
|
support = constraints.real
|
||
|
has_rsample = False
|
||
|
|
||
|
def __init__(self, loc, concentration, validate_args=None):
|
||
|
self.loc, self.concentration = broadcast_all(loc, concentration)
|
||
|
batch_shape = self.loc.shape
|
||
|
event_shape = torch.Size()
|
||
|
super().__init__(batch_shape, event_shape, validate_args)
|
||
|
|
||
|
def log_prob(self, value):
|
||
|
if self._validate_args:
|
||
|
self._validate_sample(value)
|
||
|
log_prob = self.concentration * torch.cos(value - self.loc)
|
||
|
log_prob = (
|
||
|
log_prob
|
||
|
- math.log(2 * math.pi)
|
||
|
- _log_modified_bessel_fn(self.concentration, order=0)
|
||
|
)
|
||
|
return log_prob
|
||
|
|
||
|
@lazy_property
|
||
|
def _loc(self):
|
||
|
return self.loc.to(torch.double)
|
||
|
|
||
|
@lazy_property
|
||
|
def _concentration(self):
|
||
|
return self.concentration.to(torch.double)
|
||
|
|
||
|
@lazy_property
|
||
|
def _proposal_r(self):
|
||
|
kappa = self._concentration
|
||
|
tau = 1 + (1 + 4 * kappa**2).sqrt()
|
||
|
rho = (tau - (2 * tau).sqrt()) / (2 * kappa)
|
||
|
_proposal_r = (1 + rho**2) / (2 * rho)
|
||
|
# second order Taylor expansion around 0 for small kappa
|
||
|
_proposal_r_taylor = 1 / kappa + kappa
|
||
|
return torch.where(kappa < 1e-5, _proposal_r_taylor, _proposal_r)
|
||
|
|
||
|
@torch.no_grad()
|
||
|
def sample(self, sample_shape=torch.Size()):
|
||
|
"""
|
||
|
The sampling algorithm for the von Mises distribution is based on the
|
||
|
following paper: D.J. Best and N.I. Fisher, "Efficient simulation of the
|
||
|
von Mises distribution." Applied Statistics (1979): 152-157.
|
||
|
|
||
|
Sampling is always done in double precision internally to avoid a hang
|
||
|
in _rejection_sample() for small values of the concentration, which
|
||
|
starts to happen for single precision around 1e-4 (see issue #88443).
|
||
|
"""
|
||
|
shape = self._extended_shape(sample_shape)
|
||
|
x = torch.empty(shape, dtype=self._loc.dtype, device=self.loc.device)
|
||
|
return _rejection_sample(
|
||
|
self._loc, self._concentration, self._proposal_r, x
|
||
|
).to(self.loc.dtype)
|
||
|
|
||
|
def expand(self, batch_shape):
|
||
|
try:
|
||
|
return super().expand(batch_shape)
|
||
|
except NotImplementedError:
|
||
|
validate_args = self.__dict__.get("_validate_args")
|
||
|
loc = self.loc.expand(batch_shape)
|
||
|
concentration = self.concentration.expand(batch_shape)
|
||
|
return type(self)(loc, concentration, validate_args=validate_args)
|
||
|
|
||
|
@property
|
||
|
def mean(self):
|
||
|
"""
|
||
|
The provided mean is the circular one.
|
||
|
"""
|
||
|
return self.loc
|
||
|
|
||
|
@property
|
||
|
def mode(self):
|
||
|
return self.loc
|
||
|
|
||
|
@lazy_property
|
||
|
def variance(self):
|
||
|
"""
|
||
|
The provided variance is the circular one.
|
||
|
"""
|
||
|
return (
|
||
|
1
|
||
|
- (
|
||
|
_log_modified_bessel_fn(self.concentration, order=1)
|
||
|
- _log_modified_bessel_fn(self.concentration, order=0)
|
||
|
).exp()
|
||
|
)
|