[BE][6/16] fix typos in torch/ (#156316)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/156316
Approved by: https://github.com/albanD
ghstack dependencies: #156313, #156314, #156315
This commit is contained in:
Xuehai Pan
2025-06-22 22:22:32 +08:00
committed by PyTorch MergeBot
parent 4ccc0381de
commit cec2977ed2
32 changed files with 58 additions and 59 deletions

View File

@ -331,11 +331,11 @@ def _compute_compressed_swizzled_bitmask(dense):
# we first need to split into the 8x8 tiles
bitmask_8x8_chunks = int_bitmask.unfold(0, 8, 8).unfold(1, 8, 8)
# then we unfold again to get our indivdual 4x4 tiles
# then we unfold again to get our individual 4x4 tiles
bitmask_4x4_chunks = bitmask_8x8_chunks.unfold(2, 4, 4).unfold(3, 4, 4)
# Each 4x4 bitmask defines two 8-bit integers, which encode the sparsity pattern
# of that tile. Note that the least siginificant bit is stored first.
# of that tile. Note that the least significant bit is stored first.
# [1 1 0 0]
# [1 1 0 0] -> 0011 0011 -> 51
# [0 0 1 1] 1100 1100 204
@ -346,7 +346,7 @@ def _compute_compressed_swizzled_bitmask(dense):
*bitmask_4x4_chunks.shape[:2], 4, 2, 8
)
# to convert from binary representaiton, we can do a matmul with powers of two
# to convert from binary representation, we can do a matmul with powers of two
powers_of_two = 2 ** torch.arange(8, dtype=torch.float, device="cuda")
# To run on GPU: cast to float to do matmul and then cast back
compressed_swizzled_bitmask = (

View File

@ -179,7 +179,7 @@ def semi_sparse_scaled_mm(func, types, args=(), kwargs=None) -> torch.Tensor:
assert A.dtype == torch.float8_e4m3fn
assert B.dtype == torch.float8_e4m3fn
# only cuSPARSELt supports float8_e4m3fn currentl
# only cuSPARSELt supports float8_e4m3fn currently
assert isinstance(A, torch.sparse.SparseSemiStructuredTensorCUSPARSELT)
assert A.packed is not None
# Currently we only support per-tensor scaling, with float32 scales

View File

@ -333,7 +333,7 @@ def scatter_mm(blocks, others, indices_data, *, accumulators=None):
this property enables defining swizzle operators via
rearrangements of ``r_offsets`` items..
Auxilary functions are provided for pre-computing
Auxiliary functions are provided for pre-computing
:attr:`indices_data`. For example,
:func:`bsr_scatter_mm_indices_data` is used to define indices data
for matrix multiplication of BSR and strided tensors.
@ -836,7 +836,7 @@ def bsr_dense_addmm_meta(
class TensorAsKey:
"""A light-weight wrapper of a tensor that enables storing tensors as
keys with efficient memory reference based comparision as an
keys with efficient memory reference based comparison as an
approximation to data equality based keys.
Motivation: the hash value of a torch tensor is tensor instance

View File

@ -9,7 +9,7 @@ performance of operations several times. For example, for large tensor
shapes, the usage of a bsr tensor as mat1 argument in addmm-based
operations typically outperforms the corresponding operation with
strided-only inputs when the blocked representation of a tensor
provides a better alignement with memory access than what the strided
provides a better alignment with memory access than what the strided
representation would provide.
Pre-computed kernel parameters
@ -57,7 +57,7 @@ Computing optimal kernel parameters
If the approximations listed above are unacceptable, e.g. when one
seeks a maximal performance possible, the optimal kernel parameters
for a particular GPU can be computed by simply running this script in
the pytorch developement tree::
the pytorch development tree::
cd /path/to/pytorch
python setup.py develop
@ -91,7 +91,7 @@ torch.nn.functional.linear will benefit from using the computed
optimal set of kernel parameters.
Note that running tune_bsr_dense_addmm can take several minutes. So,
use it wisely, e.g. by implementing persisten storage of optimized
use it wisely, e.g. by implementing persistent storage of optimized
kernel parameters. See the source code of get_meta and
tune_bsr_dense_addmm to learn how to register a custom set of optimal
kernel parameters for addmm-based operations.
@ -852,7 +852,7 @@ def main(op="scatter_mm", force=False, dtype=torch.float16, verbose=True):
if 0:
# Check performance dependence on sparsity and apply
# adjustments when differences are noticable (more than 10%).
# adjustments when differences are noticeable (more than 10%).
#
# When using NVIDIA A100 GPU, the performance dependence on
# sparsity is insignificant (0 % ... 10 %) for majority of

View File

@ -37,7 +37,7 @@ _SEMI_STRUCTURED_SPARSE_CONFIG = namedtuple(
class SparseSemiStructuredTensor(torch.Tensor):
"""
This class implementes semi-structured sparsity as a Tensor subclass.
This class implements semi-structured sparsity as a Tensor subclass.
Semi-structured sparsity describes a sparsity pattern where n in every 2n elements are sparse,
depending on the datatype. It is also referred to as 2:4 sparsity or fine-grained
@ -46,11 +46,11 @@ class SparseSemiStructuredTensor(torch.Tensor):
There are two backends available for semi_structred sparsity, either cuSPARSELt or CUTLASS.
This class is meant to serve as a base class for both implementations. SparseSemiStructuredCUTLASS
and SparseSemiStructuredCUSPARSELT both inherit from this class and define three backend-specific items.
Note that as such, this class cannot be insantiated directly.
Note that as such, this class cannot be instantiated directly.
-`_DTYPE_SHAPE_CONSTRAINTS` - A dictionary holding backend specific dense/sparse min shape constraints
- `def from_dense()` - backend specific compression routines
- `def _mm()` - backend specifc mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
- `def _mm()` - backend specific mm op (either torch._cslt_sparse_mm or torch._sparse_semi_structured_(mm|addmm))
"""
_DEFAULT_ALG_ID: int = 0
@ -123,7 +123,7 @@ class SparseSemiStructuredTensor(torch.Tensor):
)
cls._PROTOTYPE_WARNING_SHOWN = True
# Because this only runs onces, we also load the dispatch table here as well.
# Because this only runs once, we also load the dispatch table here as well.
# We can't define the dispatch table explicitly because of torch.ops import errors, so we do this instead
# But this is useful since it allows users to overload the dispatch table for debugging / testing.
cls._load_dispatch_table()
@ -325,7 +325,7 @@ def to_sparse_semi_structured(
This function will check to ensure the dense tensor has the right dtype, size, dims, and device.
We currently only support semi-structured sparse tensors for 2d CUDA tensors.
Additionally, your tensor must be a positive multiple of the mininum sparse block size, given in
Additionally, your tensor must be a positive multiple of the minimum sparse block size, given in
`_DTYPE_TO_SHAPE_CONSTRAINTS` for each dtype (float32, float16, bfloat16, int8).
Args:
@ -388,7 +388,7 @@ class SparseSemiStructuredTensorCUTLASS(SparseSemiStructuredTensor):
This class implements semi-structured sparsity for the CUTLASS backend.
In this implementation, the specified elements and metadata are stored seprately,
In this implementation, the specified elements and metadata are stored separately,
in packed and meta respectively.
When _FORCE_CUTLASS is set, or when cuSPARSELt is not available, this subclass calls into _sparse_semi_structured_(mm|addmm) and