Files
pytorch/torch/nn/utils/memory_format.py
Maggie Moss c855f8632e Pyrefly suppressions 7/n (#164913)
Adds suppressions to pyrefly will typecheck clean: https://github.com/pytorch/pytorch/issues/163283

Almost there!

Test plan:
dmypy restart && python3 scripts/lintrunner.py -a
pyrefly check

step 1: delete lines in the pyrefly.toml file from the project-excludes field
step 2: run pyrefly check
step 3: add suppressions, clean up unused suppressions
before: https://gist.github.com/maggiemoss/4b3bf2037014e116bc00706a16aef199

after:
 INFO 0 errors (6,884 ignored)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/164913
Approved by: https://github.com/oulgen
2025-10-08 07:27:17 +00:00

175 lines
8.1 KiB
Python

from __future__ import annotations
from typing import TypeVar
import torch
_M = TypeVar("_M", bound="torch.nn.Module")
def convert_conv2d_weight_memory_format(
module: _M, memory_format: torch.memory_format
) -> _M:
r"""Convert ``memory_format`` of ``nn.Conv2d.weight`` to ``memory_format``.
The conversion recursively applies to nested ``nn.Module``, including ``module``.
Note that it only changes the memory_format, but not the semantics of each dimensions.
This function is used to facilitate the computation to adopt NHWC kernels, which
provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0
.. note::
Calling ``model.to(memory_format=torch.channels_last)`` is more aggressive
than the utility function ``convert_conv2d_weight_memory_format``. Any
layer with 4d weight will be affected by ``model.to``, which does not
necessarily benefit from conversion to specified ``memory_format``.
One place we are confident in is that NHWC(channels_last) conversion for
convolution in cuDNN, as it is beneficial to run convolution in NHWC,
even in cases where we have to apply permutation to input tensors.
Hence our strategy here is to convert only the weight of convolution to
channels_last. This ensures that;
1. Fast convolution kernels will be used, the benefit of which could
outweigh overhead of permutation (if input is not in the same format).
2. No unnecessary permutations are applied on layers that do not benefit
from memory_format conversion.
The optimal case is that, layers between convolution layers are channels
last compatible. Input tensor would be permuted to channels last when it
encounters the first convolution layer and stay in that memory format.
Hence following convolutions will not need to permute its input tensor.
In case where a channels last incompatible layer is between convolution
layers, we need to permute the input tensor back to contiguous format
for that layer. The input tensor will go through the remaining layers in
contiguous format and be permuted to channels last when it encounters
another convolution layer. There's no point in propagating that
permutation to an earlier layer, as most layers are quite agnostic to
``memory_format``.
This claim might change when PyTorch supports fusion of permutation, as
there might have been a better spot to fuse the permutation other than
immediately before a convolution.
Args:
module (nn.Module): ``nn.Conv2d`` & ``nn.ConvTranspose2d`` or container
``nn.Module``
memory_format: user specified ``memory_format``,
e.g. ``torch.channels_last`` or ``torch.contiguous_format``
Returns:
The original module with updated ``nn.Conv2d``
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
>>> input = torch.randint(
... 1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda"
... )
>>> model = nn.Sequential(
>>> nn.Conv2d(8, 4, 3)).cuda().half()
>>> # This is identical to:
>>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last)
>>> model = nn.utils.convert_conv2d_weight_memory_format(
... model, torch.channels_last
... )
>>> out = model(input)
"""
# TODO: expand this to `_ConvNd` when channels_last support is extended
# beyond only 4d tensors.
if isinstance(module, (torch.nn.Conv2d, torch.nn.ConvTranspose2d)):
weight_data = module.weight.detach().clone(memory_format=memory_format)
module.weight.data = weight_data.resize_(
weight_data.size(), memory_format=memory_format
)
for child in module.children():
convert_conv2d_weight_memory_format(child, memory_format)
# pyrefly: ignore # bad-return
return module
def convert_conv3d_weight_memory_format(
module: _M, memory_format: torch.memory_format
) -> _M:
r"""Convert ``memory_format`` of ``nn.Conv3d.weight`` to ``memory_format``
The conversion recursively applies to nested ``nn.Module``, including ``module``.
Note that it only changes the memory_format, but not the semantics of each dimensions.
This function is used to facilitate the computation to adopt NHWC kernels, which
provides considerable speed up for fp16 data on CUDA devices with compute capability >= 7.0
.. note::
Calling ``model.to(memory_format=torch.channels_last_3d)`` is more aggressive
than the utility function ``convert_conv3d_weight_memory_format``. Any
layer with 4d weight will be affected by ``model.to``, which does not
necessarily benefit from conversion to specified ``memory_format``.
One place we are confident in is that NDHWC(channels_last_3d) conversion for
convolution in cuDNN, as it is beneficial to run convolution in NDHWC,
even in cases where we have to apply permutation to input tensors.
Hence our strategy here is to convert only the weight of convolution to
channels_last_3d. This ensures that;
1. Fast convolution kernels will be used, the benefit of which could
outweigh overhead of permutation (if input is not in the same format).
2. No unnecessary permutations are applied on layers that do not benefit
from memory_format conversion.
The optimal case is that, layers between convolution layers are channels
last compatible. Input tensor would be permuted to channels last when it
encounters the first convolution layer and stay in that memory format.
Hence following convolutions will not need to permute its input tensor.
In case where a channels last incompatible layer is between convolution
layers, we need to permute the input tensor back to contiguous format
for that layer. The input tensor will go through the remaining layers in
contiguous format and be permuted to channels last when it encounters
another convolution layer. There's no point in propagating that
permutation to an earlier layer, as most layers are quite agnostic to
``memory_format``.
This claim might change when PyTorch supports fusion of permutation, as
there might have been a better spot to fuse the permutation other than
immediately before a convolution.
Args:
module (nn.Module): ``nn.Conv3d`` & ``nn.ConvTranspose3d`` or container
``nn.Module``
memory_format: user specified ``memory_format``,
e.g. ``torch.channels_last`` or ``torch.contiguous_format``
Returns:
The original module with updated ``nn.Conv3d``
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
>>> input = torch.randint(
... 1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda"
... )
>>> model = nn.Sequential(
>>> nn.Conv3d(8, 4, 3)).cuda().half()
>>> # This is identical to:
>>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
>>> model = nn.utils.convert_conv3d_weight_memory_format(
... model, torch.channels_last_3d
... )
>>> out = model(input)
"""
# TODO: expand this to `_ConvNd` when channels_last support is extended
# beyond only 4d tensors.
if isinstance(module, (torch.nn.Conv3d, torch.nn.ConvTranspose3d)):
weight_data = module.weight.detach().clone(memory_format=memory_format)
module.weight.data = weight_data.resize_(
weight_data.size(), memory_format=memory_format
)
for child in module.children():
convert_conv3d_weight_memory_format(child, memory_format)
# pyrefly: ignore # bad-return
return module
__all__ = [
"convert_conv2d_weight_memory_format",
"convert_conv3d_weight_memory_format",
]