You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

212 lines
7.0 KiB

5 months ago
"""This file exports ONNX ops for opset 17.
Note [ONNX Operators that are added/updated in opset 17]
~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~~
https://github.com/onnx/onnx/blob/main/docs/Changelog.md#version-17-of-the-default-onnx-operator-set
New operators:
BlackmanWindow
DFT
HammingWindow
HannWindow
LayerNormalization
MelWeightMatrix
STFT
SequenceMap
"""
import functools
from typing import Optional, Sequence
import torch
from torch import _C
from torch.onnx import _type_utils, errors, symbolic_helper
from torch.onnx._internal import _beartype, jit_utils, registration
# EDITING THIS FILE? READ THIS FIRST!
# see Note [Edit Symbolic Files] in README.md
__all__ = ["layer_norm", "stft"]
_onnx_symbolic = functools.partial(registration.onnx_symbolic, opset=17)
@_onnx_symbolic("aten::layer_norm")
@symbolic_helper.parse_args("v", "is", "v", "v", "f", "none")
def layer_norm(
g: jit_utils.GraphContext,
input: _C.Value,
normalized_shape: Sequence[int],
weight: _C.Value,
bias: _C.Value,
eps: float,
cudnn_enable: bool,
):
# normalized_shape: input shape from an expected input of size
# axis: The first normalization dimension.
# layer_norm normalizes on the last D dimensions,
# where D is the size of normalized_shape
axis = -len(normalized_shape)
scalar_type = _type_utils.JitScalarType.from_value(
input, _type_utils.JitScalarType.FLOAT
)
dtype = scalar_type.dtype()
if symbolic_helper._is_none(weight):
weight_value = torch.ones(normalized_shape, dtype=dtype)
weight = g.op("Constant", value_t=weight_value)
if symbolic_helper._is_none(bias):
bias_value = torch.zeros(normalized_shape, dtype=dtype)
bias = g.op("Constant", value_t=bias_value)
return g.op(
"LayerNormalization",
input,
weight,
bias,
epsilon_f=eps,
axis_i=axis,
)
def _compute_edge_sizes(n_fft, window_size):
"""Helper function to compute the sizes of the edges (left and right)
of a given window centered within an FFT size."""
left = (n_fft - window_size) // 2
right = n_fft - left - window_size
return left, right
@_onnx_symbolic("aten::stft")
@symbolic_helper.parse_args("v", "i", "i", "i", "v", "b", "b", "b")
@_beartype.beartype
def stft(
g: jit_utils.GraphContext,
input: _C.Value,
n_fft: int,
hop_length: Optional[int] = None,
win_length: Optional[int] = None,
window: Optional[_C.Value] = None,
normalized: bool = False,
onesided: Optional[bool] = True,
return_complex: Optional[bool] = False,
) -> _C.Value:
"""Associates `torch.stft` with the `STFT` ONNX operator.
Note that torch.stft calls _VF.stft, without centering or padding options.
Hence, this function does not contain these two arguments.
See torch.stft source code for more info.
Args:
g: Graph to write the ONNX representation into
input: Input tensor for the transformation
n_fft: FFT size
hop_length: Size of the hop. Defaults to `floot(n_fft // 4)`
win_length: Size of the analysis window. Defaults to `n_fft`
window: Analysis window. Defaults to a window of all ones
normalized: Whether to return a normalized STFT
onesided: Whether to return only half (+1) of the results, given the
symmetry of the STFT
return_complex: Whether to return the complex value (Note: Must be
`False` or `None`)
Returns:
op: Operator for torch.stft associated with STFT (ONNX)
"""
# Checks
if return_complex:
raise errors.SymbolicValueError(
msg="STFT does not currently support complex types", value=input
)
# Get STFT sizes
frame_step_value = hop_length if hop_length is not None else n_fft // 4
frame_step_const = g.op(
"Constant", value_t=torch.tensor(frame_step_value, dtype=torch.int64)
)
frame_length_const = g.op(
"Constant", value_t=torch.tensor(n_fft, dtype=torch.int64)
)
# Pre-process input if needed
signal = input
signal_rank = symbolic_helper._get_tensor_rank(signal)
if signal_rank == 1:
# Add batch dimension
signal = g.op(
"Unsqueeze",
signal,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
elif signal_rank > 2:
raise errors.SymbolicValueError(
msg="STFT can only take inputs of 1 [signal] or 2 [batch, signal] dimensions. "
f"Current rank of signal is {signal_rank}, please reduce it.",
value=input,
)
# Get window and make sure it's the same size as `win_length` or `n_fft`
n_win = symbolic_helper._get_tensor_dim_size(window, dim=0)
if n_win is not None:
win_length_default = win_length if win_length else n_fft
assert n_win == win_length_default, (
"Analysis window size must equal `win_length` or `n_fft`. "
f"Please, set `win_length` or `n_fft` to match `window` size ({n_win})",
)
# Center window around zeros if needed (required by ONNX's STFT)
if n_win < n_fft:
left, right = _compute_edge_sizes(n_fft, n_win)
left_win = g.op("Constant", value_t=torch.zeros(left))
right_win = g.op("Constant", value_t=torch.zeros(right))
window = g.op("Concat", left_win, window, right_win, axis_i=0)
# Create window, if needed
if symbolic_helper._is_none(window):
if win_length:
if win_length > n_fft:
raise errors.SymbolicValueError(
msg="The analysis window can't be longer than the size of the FFT. "
f"Please set `win_length` ({win_length}) to `n_fft` ({n_fft}) or less.",
value=input,
)
# Center window, if needed
left, right = _compute_edge_sizes(n_fft, win_length)
torch_window = torch.hstack(
(torch.zeros(left), torch.ones(win_length), torch.zeros(right))
)
else:
# Rectangle window
torch_window = torch.ones(n_fft)
assert torch_window.shape[0] == n_fft
window = g.op("Constant", value_t=torch_window)
window = g.op(
"Cast", window, to_i=_type_utils.JitScalarType.from_value(signal).onnx_type()
)
# Run STFT
result = g.op(
"STFT",
signal,
frame_step_const,
window,
frame_length_const,
onesided_i=1 if onesided is None or onesided else 0,
)
# Transpose to mimic torch.stft's behavior
result = g.op("Transpose", result, perm_i=[0, 2, 1, 3])
# Remove batch dimension, if needed
if signal_rank == 1:
result = g.op(
"Squeeze",
result,
g.op("Constant", value_t=torch.tensor([0], dtype=torch.int64)),
)
# Normalize, if needed
if normalized:
sqrt_nfft = torch.sqrt(torch.tensor(n_fft, dtype=signal.type().dtype()))
result = g.op("Div", result, g.op("Constant", value_t=sqrt_nfft))
return result