You can not select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.

519 lines
21 KiB

import warnings
from collections import namedtuple
from typing import Any, Optional, Tuple, List, Callable, Dict
import torch
from torch.sparse._semi_structured_conversions import (
sparse_semi_structured_from_dense_cutlass,
sparse_semi_structured_to_dense_cutlass,
)
from torch.sparse._semi_structured_ops import (
fallback_dispatcher,
semi_sparse_values,
semi_sparse_indices,
semi_sparse_detach,
semi_sparse_t,
semi_sparse_view,
semi_sparse_mm,
semi_sparse_addmm,
semi_sparse_linear,
)
__all__ = [
"SparseSemiStructuredTensor",
"SparseSemiStructuredTensorCUTLASS",
"SparseSemiStructuredTensorCUSPARSELT",
"to_sparse_semi_structured",
]
_SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
"_SEMI_STRUCTURED_SPARSE_CONFIG",
"sparse_min_rows sparse_min_cols dense_min_rows dense_min_cols",
)
class SparseSemiStructuredTensor(torch.Tensor):
"""
This class implementes semi-structured sparsity as a Tensor subclass.
Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
structured sparsity.
There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
Note that as such, this class cannot be insantiated directly.
-`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
- `def from_dense()` - backend specific compression routines
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_linear)
"""
_DEFAULT_ALG_ID: int = 0
_DTYPE_SHAPE_CONSTRAINTS: Dict[torch.dtype, _SEMI_STRUCTURED_SPARSE_CONFIG]
_FORCE_CUTLASS: bool = True
_FUSE_TRANSPOSE: bool = False
_PROTOTYPE_WARNING_SHOWN: bool = False
SPARSE_DISPATCH: Dict[Callable, Callable]
packed: Optional[torch.Tensor]
meta: Optional[torch.Tensor]
packed_t: Optional[torch.Tensor]
meta_t: Optional[torch.Tensor]
threads_masks: Optional[torch.Tensor]
fuse_transpose_cusparselt: bool
alg_id_cusparselt: int
__slots__ = ["packed", "meta", "packed_t", "meta_t", "threads_masks"]
@staticmethod
def __new__( # noqa: PYI034
cls,
shape: torch.Size,
packed: Optional[torch.Tensor],
meta: Optional[torch.Tensor],
packed_t: Optional[torch.Tensor],
meta_t: Optional[torch.Tensor],
threads_masks: Optional[torch.Tensor],
fuse_transpose_cusparselt: bool = False,
alg_id_cusparselt: int = 0,
requires_grad: bool = False,
):
"""
Create a new instance of the tensor subclass from the compressed sparse representation.
We have the option to create the subclass with the compressed representations of both X and X', for training.
For inference, we only need a single representation (either X or X'), while the corresponding other set will be None.
Depending on the backend selected, certain fields will be set to None. (CUSPARSELT vs CUTLASS)
Args:
shape: The shape of the original dense tensor
packed: The compressed representation of the original dense tensor
meta: The metadata of the original dense tensor, if it is stored separately
packed_t: The compressed representation of the transposed original dense tensor
meta_t: The metadata of the transposed original dense tensor, if it is stored separately
threads_masks: The masks used by the CUTLASS backend to determine which threads should participate in the computation.
Used for pointwise ops.
fuse_transpose_cusparselt: When running with cuSPARSELt, we have the option to fuse a transposition
with a matmul, which is useful in the case of 2:4 sparse training.
alg_id_cusparselt: The algorithm id to use when using cuSPARSELT, will have effect on performance
Returns:
torch.Tensor: A torch.Tensor wrapper subclass.
Raises:
ValueError: If all of the tensor arguments are None.
"""
if not cls._PROTOTYPE_WARNING_SHOWN:
warnings.warn(
(
"The PyTorch API of SparseSemiStructuredTensor is in prototype stage "
"and will change in the near future. Please open a Github issue "
"for features requests and see our documentation on the torch.sparse "
"module for further information about the project."
),
UserWarning,
)
cls._PROTOTYPE_WARNING_SHOWN = True
# Because this only runs onces, we also load the dispatch table here as well.
# We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
# But this is useful since it allows users to overload the dispatch table for debugging / testing.
cls._load_dispatch_table()
if packed is not None:
previous_tensor = packed
elif packed_t is not None:
previous_tensor = packed_t
else:
raise ValueError("At least one of packed or packed_t must be provided")
kwargs = {
"device": previous_tensor.device,
"dtype": previous_tensor.dtype,
"layout": previous_tensor.layout,
"requires_grad": requires_grad,
}
tensor = torch.Tensor._make_wrapper_subclass(cls, shape, **kwargs) # type: ignore[attr-defined]
tensor.packed = packed
tensor.meta = meta
tensor.packed_t = packed_t
tensor.meta_t = meta_t
tensor.threads_masks = threads_masks
tensor.fuse_transpose_cusparselt = fuse_transpose_cusparselt
tensor.alg_id_cusparselt = alg_id_cusparselt
return tensor
def __repr__(self) -> str: # type: ignore[override]
assert hasattr(self, "shape")
return f"{self.__class__.__name__}(shape={self.shape})"
def __tensor_flatten__(
self,
) -> Tuple[List[str], Tuple[torch.Size, bool, int, bool]]:
inner_tensors = list(
filter(lambda x: getattr(self, x) is not None, self.__slots__)
)
tensor_meta = (
self.shape,
self.fuse_transpose_cusparselt,
self.alg_id_cusparselt,
self.requires_grad,
)
return inner_tensors, tensor_meta
@classmethod
def __tensor_unflatten__(
cls,
inner_tensors,
tensor_meta : Tuple[torch.Size, bool, int, bool],
outer_size,
outer_stride,
) -> torch.Tensor:
shape, fuse_transpose_cusparselt, alg_id_cusparselt, requires_grad = tensor_meta
return cls(
shape=shape,
packed=inner_tensors.get("packed", None),
meta=inner_tensors.get("meta", None),
packed_t=inner_tensors.get("packed_t", None),
meta_t=inner_tensors.get("meta_t", None),
threads_masks=inner_tensors.get("threads_masks", None),
fuse_transpose_cusparselt=fuse_transpose_cusparselt,
alg_id_cusparselt=alg_id_cusparselt,
requires_grad=requires_grad,
)
__torch_function__ = torch._C._disabled_torch_function_impl
@classmethod
def __torch_dispatch__(cls, func, types, args, kwargs) -> Any:
if func._overloadpacket not in cls.SPARSE_DISPATCH:
raise NotImplementedError(
f"{cls.__name__} only supports a specific set of operations, "
f"can't perform requested op ({func.__name__})"
)
return cls.SPARSE_DISPATCH[func._overloadpacket](func, types, args, kwargs)
@classmethod
def _load_dispatch_table(cls, custom_dispatch_table=None) -> None:
"""
Loads the op overload sparse dispatch table for the current class.
"""
if getattr(cls, "SPARSE_DISPATCH", None) is None:
cls.SPARSE_DISPATCH = {
torch.ops.aten.values: semi_sparse_values,
torch.ops.aten.indices: semi_sparse_indices,
torch.ops.aten.is_same_size: fallback_dispatcher,
torch.ops.aten.detach_: fallback_dispatcher,
torch.ops.aten.detach: semi_sparse_detach,
torch.ops.aten.t: semi_sparse_t,
torch.ops.aten.view: semi_sparse_view,
torch.ops.aten.mm: semi_sparse_mm,
torch.ops.aten.matmul: semi_sparse_mm,
torch.ops.aten.addmm: semi_sparse_addmm,
torch.ops.aten.linear: semi_sparse_linear,
}
if custom_dispatch_table is not None:
cls.SPARSE_DISPATCH.update(custom_dispatch_table)
@classmethod
def _validate_device_dim_dtype_shape(cls, original_tensor : torch.Tensor) -> None:
"""
Assert that the given tensor is valid for semi-structured sparse compression.
"""
# check device
if not original_tensor.is_cuda:
raise RuntimeError(
f"Error original_tensor.device= {original_tensor.device} is not supported! "
"Only CUDA tensors are currently supported."
)
# check dim
if original_tensor.dim() != 2:
raise RuntimeError(
f"Error original_tensor.dim = {original_tensor.dim()} is not supported! "
"Only 2d tensors are currently supported."
)
# check contiguous
if not original_tensor.is_contiguous():
raise RuntimeError(
"Error original_tensor is not contiguous!"
"Only contiguous tensors are currently supported."
)
# check dtype
if original_tensor.dtype not in cls._DTYPE_SHAPE_CONSTRAINTS:
raise RuntimeError(
f"Error original_tensor.dtype {original_tensor.dtype} is not a supported dtype! "
"dtype must be one of: {cls._DTYPE_SHAPE_CONSTRAINTS}"
)
# check shape
m, n = original_tensor.shape
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_rows
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[original_tensor.dtype].sparse_min_cols
if m < min_rows or m % min_rows or n < min_cols or n % min_cols:
# TODO in the future we can add in padding to support sparse dimensions that aren't perfect multiples
raise RuntimeError(
f"Error original_tensor.shape {original_tensor.shape} is not supported! "
f"Both dimensions must be larger or equal than and a multiple of ({min_rows}, {min_cols})"
)
@classmethod
def _pad_dense_input(cls, dense_input: torch.Tensor) -> torch.Tensor:
"""
Calculates padding for dense tensor and pads tensor if necessary.
If padding is not required, this function returns the original tensor.
"""
# only 2d matmul
assert dense_input.dim() == 2
# check shape
m, n = dense_input.shape
min_rows = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_rows
min_cols = cls._DTYPE_SHAPE_CONSTRAINTS[dense_input.dtype].dense_min_cols
# calculate padding
to_pad_m = -m % min_rows if m < min_rows or m % min_rows else 0
to_pad_n = -n % min_cols if n < min_cols or n % min_rows else 0
if to_pad_m or to_pad_n:
return torch.nn.functional.pad(dense_input, (0, to_pad_n, 0, to_pad_m))
else:
return dense_input
def to_dense(self):
col = self.shape[-1]
return torch.mm(self, torch.eye(col, dtype=self.dtype, device=self.device))
@classmethod
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensor":
raise NotImplementedError
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
raise NotImplementedError
def to_sparse_semi_structured(
original_tensor: torch.Tensor,
transposed: bool = False,
) -> SparseSemiStructuredTensor:
"""
This function converts a dense tensor into a sparse semi-structured tensor.
It will return a SparseSemiStructuredTensor, a subclass of torch.Tensor.
This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
We currently only support semi-structured sparse tensors for 2d CUDA tensors.
Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in
`_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
Args:
original_tensor (Tensor): the dense tensor to convert
transposed (bool, optional): deprecated arg to be removed in another release. Do not use.
Returns:
SparseSemiStructuredTensor: A sparse semi-structured tensor created from the given original_tensor
Raises:
None
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> A = torch.Tensor([0, 0, 1, 1]).tile((128, 32)).half().cuda()
tensor([[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
...,
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.],
[0., 0., 1., ..., 0., 1., 1.]], device='cuda:0', dtype=torch.float16)
>>> A_sparse = to_sparse_semi_structured(A)
SparseSemiStructuredTensor(shape=torch.Size([128, 128]))
>>> A_sparse.values()
tensor([[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
...,
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.],
[1., 1., 1., ..., 1., 1., 1.]], device='cuda:0', dtype=torch.float16),
>>> A_sparse.indices()
tensor([[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
...,
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370],
[-4370, -4370, -4370, ..., -4370, -4370, -4370]], device='cuda:0', dtype=torch.int16))
"""
if transposed:
raise DeprecationWarning(
"Setting transpose from to_sparse_semi_structured is deprecated and will be removed in a future release."
"SparseSemiStructuredTensor only support contiguous input tensors. "
)
sparse_subclass = (
torch.sparse.SparseSemiStructuredTensorCUTLASS
if SparseSemiStructuredTensor._FORCE_CUTLASS
else torch.sparse.SparseSemiStructuredTensorCUSPARSELT
)
return sparse_subclass.from_dense(original_tensor)
class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
"""
This class implements semi-structured sparsity for the CUTLASS backend.
In this implementation, the specified elements and metadata are stored seprately,
in packed and meta respectively.
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_linear
and sparse_semi_structured_from_dense for conversion to the compressed format.
"""
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 128, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 64, 8, 8),
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 4, 4),
}
@classmethod
def from_dense(
cls, original_tensor: torch.Tensor
) -> "SparseSemiStructuredTensorCUTLASS":
cls._validate_device_dim_dtype_shape(original_tensor)
(
sparse_tensor_cutlass,
meta_tensor_cutlass,
) = sparse_semi_structured_from_dense_cutlass(original_tensor)
return cls(
original_tensor.shape,
packed=sparse_tensor_cutlass,
meta=meta_tensor_cutlass,
packed_t=None,
meta_t=None,
threads_masks=None,
requires_grad=original_tensor.requires_grad,
)
def to_dense(self):
assert self.meta is not None and self.packed is not None
return (
sparse_semi_structured_to_dense_cutlass(
self.packed,
self.meta,
)
if self.meta.ndim == 2
else super().to_dense()
)
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
)
cls_name = self.__class__.__name__
if self.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
f"`{cls_name}` matmul: Broadcasting is not implemented"
)
if self.packed is None or self.meta is None:
raise NotImplementedError(
f"`{cls_name}` matmul: operation is not supported"
)
else:
res = torch._sparse_semi_structured_linear(
B.t(), self.packed, self.meta, bias=bias
).t()
return res[: self.shape[0]]
class SparseSemiStructuredTensorCUSPARSELT(SparseSemiStructuredTensor):
"""
The cuSPARSELt backend expects the specified elements and the metadata to be stored in a single tensor:
packed = [ specified elements of original tensor | metadata ]
For an original tensor of size (m, k) we expect the first m * k // 2 elements to be the kept elements
The rest of the tensor is metadata. Since there is only one tensor, we only use the packed and packed_t
attributes respectively.
cuSPARSELt also supports transposition fusion, which is necessary for performant 2:4 sparse training, as well
as specifying alg_id, a config that affects the performance of the matmul depending on matmul sizes.
"""
_DTYPE_SHAPE_CONSTRAINTS = {
torch.int8: _SEMI_STRUCTURED_SPARSE_CONFIG(32, 32, 16, 16),
torch.float16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
torch.bfloat16: _SEMI_STRUCTURED_SPARSE_CONFIG(16, 16, 8, 8),
torch.float32: _SEMI_STRUCTURED_SPARSE_CONFIG(8, 8, 4, 4),
}
@classmethod
def from_dense(cls, original_tensor : torch.Tensor) -> "SparseSemiStructuredTensorCUSPARSELT":
cls._validate_device_dim_dtype_shape(original_tensor)
return cls(
shape=original_tensor.shape,
packed=torch._cslt_compress(original_tensor),
meta=None,
packed_t=None,
meta_t=None,
threads_masks=None,
fuse_transpose_cusparselt=SparseSemiStructuredTensor._FUSE_TRANSPOSE,
alg_id_cusparselt=SparseSemiStructuredTensor._DEFAULT_ALG_ID,
requires_grad=original_tensor.requires_grad,
)
def _mm(
self,
B: torch.Tensor,
*,
bias: Optional[torch.Tensor] = None,
**kwargs
) -> torch.Tensor:
if isinstance(B, SparseSemiStructuredTensor):
raise ValueError(
"`SparseSemiStructuredTensor @ SparseSemiStructuredTensor` is not supported by the hardware"
)
if self.ndim != 2 or B.ndim != 2:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: Broadcasting is not implemented"
)
if B.dtype != self.dtype:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)}`, "
f"with A.dtype={self.dtype} and B.dtype={B.dtype}. "
"This operation is only supported when A and B have the same data type."
)
if bias is not None and bias.dtype != self.dtype:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: trying to do `A={tuple(self.shape)} @ B={tuple(B.shape)} + C`, "
"with A.dtype=B.dtype={self.dtype} and C.dtype={B.dtype}. "
"This operation is only supported when A, B and C have the same data type."
)
if self.packed is None:
raise NotImplementedError(
f"`{self.__class__.__name__}` matmul: operation is not supported"
)
else:
res = torch._cslt_sparse_mm(
self.packed,
B,
bias=bias,
transpose_result=self.fuse_transpose_cusparselt,
alg_id=self.alg_id_cusparselt,
)
return res.t() if self.fuse_transpose_cusparselt else res