[BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/** to ruff format (#144548)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548
Approved by: https://github.com/ezyang
This commit is contained in:
Xuehai Pan
2025-06-14 00:48:12 +08:00
committed by PyTorch MergeBot
parent 3e38feb05f
commit 596b418391
65 changed files with 640 additions and 475 deletions

View File

@ -470,9 +470,9 @@ def load_torch_function_handler(cls, func, types, args=(), kwargs=None):
return cls(src._data)
return cls(src)
else:
assert isinstance(
src, cls
), f"Expected isinstance(src, {cls}) but got {type(src)}"
assert isinstance(src, cls), (
f"Expected isinstance(src, {cls}) but got {type(src)}"
)
assert (
type(dest) == torch.Tensor
or type(dest) == torch.nn.Parameter

View File

@ -1475,9 +1475,9 @@ class TestNNParametrization(NNTestCase):
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict["parametrizations.weight.0._v"]
snm.load_state_dict(non_strict_state_dict, strict=False)
non_strict_state_dict[
"weight"
] = snm.weight.detach().clone() # set W as a buffer
non_strict_state_dict["weight"] = (
snm.weight.detach().clone()
) # set W as a buffer
snm.load_state_dict(non_strict_state_dict, strict=False)
del non_strict_state_dict._metadata[
"parametrizations.weight.0"

View File

@ -107,9 +107,7 @@ class TestLRScheduler(TestCase):
[0]
+ [i + 1 for i, m in enumerate(self.milestones) if global_step >= m]
)[-1]
return [
init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr
]
return [init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr]
optimizer = SGD([torch.rand(1)], lr=1)

View File

@ -62,9 +62,9 @@ def _multistep_backprop_diff_hyperparams_fn(
kwargs: dict[str, Any],
*ignored: Any,
) -> tuple[Tensor, ...]:
assert (
kwargs["differentiable"] is True
), "Only call this test function when differentiable=True"
assert kwargs["differentiable"] is True, (
"Only call this test function when differentiable=True"
)
params = params.clone()
params.grad = grad
@ -81,9 +81,9 @@ def _multistep_backprop_diff_hyperparams_fn(
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck
if "beta1" in kwargs or "beta2" in kwargs:
# Prevent just one beta kwarg from being passed in
assert (
"beta1" in kwargs and "beta2" in kwargs
), "Both betas should be defined in kwargs"
assert "beta1" in kwargs and "beta2" in kwargs, (
"Both betas should be defined in kwargs"
)
kwargs.update({"betas": (kwargs.pop("beta1"), kwargs.pop("beta2"))})
kwargs.update(

View File

@ -41,10 +41,9 @@ USE_BLACK_FILELIST = re.compile(
"test/[a-h]*/**",
# test/[i-j]*/**
"test/[i-j]*/**",
# test/[k-n]*/**
"test/[k-n]*/**",
# test/[k-m]*/**
"test/[k-m]*/**",
# test/optim/**
"test/optim/**",
# "test/[p-z]*/**",
"test/[p-z]*/**",
# torch/**
@ -58,10 +57,9 @@ USE_BLACK_FILELIST = re.compile(
# torch/[a-c]*/**
"torch/[a-c]*/**",
# torch/d*/**
# torch/[e-n]*/**
"torch/[e-n]*/**",
# torch/[e-m]*/**
"torch/[e-m]*/**",
# torch/optim/**
"torch/optim/**",
# torch/[p-z]*/**
"torch/[p-z]*/**",
],

View File

@ -529,9 +529,9 @@ def jagged_from_tensor_and_lengths(
)
# Calculate jagged offsets
assert (
len(tensor.shape) >= 2
), "tensor must at least be 2D for the nested narrow op to work"
assert len(tensor.shape) >= 2, (
"tensor must at least be 2D for the nested narrow op to work"
)
max_seq_len = tensor.shape[1]
offset_lengths = max_seq_len * torch.arange(
0, batch_size, dtype=torch.int64, device=tensor.device

View File

@ -73,9 +73,9 @@ def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
"""
from torch._prims_common import canonicalize_dims
assert isinstance(
dims, (tuple, list)
), f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}"
assert isinstance(dims, (tuple, list)), (
f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}"
)
wrapped_dims = [
canonicalize_dims(ndim, d) for d in dims
@ -535,9 +535,9 @@ def clone_default(func, *args, **kwargs):
from .nested_tensor import jagged_from_list
# TODO: We probably want the output to have the same ragged structure / nested int.
assert (
inp._ragged_idx == 1
), "NJT with ragged_idx != 1 not supported for contiguous clone"
assert inp._ragged_idx == 1, (
"NJT with ragged_idx != 1 not supported for contiguous clone"
)
contig, _ = jagged_from_list(inp.unbind(), offsets=None)
return contig
@ -1730,8 +1730,8 @@ def native_layer_norm_default(func, *args, **kwargs):
) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
padded_normalized = (
padded_input - mean
) * padded_mask # mask elements outside of the ragged dimension size for correct variance calculation
(padded_input - mean) * padded_mask
) # mask elements outside of the ragged dimension size for correct variance calculation
variance = (
torch.sum(

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
"""This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention"""
import contextlib
from collections.abc import Iterable
from typing import Union
@ -119,6 +120,7 @@ def sdpa_kernel(
from torch.nn.functional import scaled_dot_product_attention
from torch.nn.attention import SDPBackend, sdpa_kernel
# Only enable flash attention backend
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
scaled_dot_product_attention(...)
@ -130,9 +132,9 @@ def sdpa_kernel(
This context manager can be used to select which backend to use for scaled dot product attention.
Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends.
"""
assert isinstance(
backends, (list, SDPBackend)
), "Backend must be an instance of SDPBackend or a list of SDPBackend instances"
assert isinstance(backends, (list, SDPBackend)), (
"Backend must be an instance of SDPBackend or a list of SDPBackend instances"
)
if isinstance(backends, SDPBackend):
backends = [backends]

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
"""Defines utilities for interacting with scaled_dot_product_attention"""
import math
from typing import Optional

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
"""Defines bias subclasses that work with scaled_dot_product_attention"""
from enum import auto, IntEnum
from typing import Optional
from warnings import warn
@ -101,9 +102,15 @@ class CausalBias(torch.Tensor):
# Create a lower-right causal bias
attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
q = torch.randn(
bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16
)
k = torch.randn(
bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16
)
v = torch.randn(
bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16
)
out = F.scaled_dot_product_attention(q, k, v, attn_bias)

View File

@ -182,9 +182,7 @@ class PagedAttention:
logical_block_offset = input_pos % self.page_size # [B, S]
physical_block_idx = torch.gather(
self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64)
).to(
torch.int32
) # [B, S]
).to(torch.int32) # [B, S]
addr = (physical_block_idx * self.page_size + logical_block_offset).view(
-1

View File

@ -1,6 +1,7 @@
# mypy: allow-untyped-defs
# flake8: noqa: B950
"""This module implements the user facing API for flex_attention in PyTorch."""
import functools
import inspect
import itertools
@ -293,12 +294,12 @@ class BlockMask:
assert kv_indices is not None, "kv_indices must be provided"
assert q_num_blocks is not None, "q_num_blocks must be provided"
assert q_indices is not None, "q_indices must be provided"
assert (full_kv_num_blocks is None) == (
full_kv_indices is None
), "full_kv_num_blocks and full_kv_indices must be both provided or omitted"
assert (full_q_num_blocks is None) == (
full_q_indices is None
), "full_q_num_blocks and full_q_indices must be both provided or omitted"
assert (full_kv_num_blocks is None) == (full_kv_indices is None), (
"full_kv_num_blocks and full_kv_indices must be both provided or omitted"
)
assert (full_q_num_blocks is None) == (full_q_indices is None), (
"full_q_num_blocks and full_q_indices must be both provided or omitted"
)
self.seq_lengths = seq_lengths
self.kv_num_blocks = kv_num_blocks
@ -344,9 +345,9 @@ class BlockMask:
if kv_indices.dim() < 2:
raise RuntimeError("BlockMask must have at least 2 dimensions")
assert (full_kv_num_blocks is None) == (
full_kv_indices is None
), "full_kv_num_blocks and full_kv_indices must be both provided or omitted"
assert (full_kv_num_blocks is None) == (full_kv_indices is None), (
"full_kv_num_blocks and full_kv_indices must be both provided or omitted"
)
# Generate q_num_blocks and q_indices
q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices)
@ -434,7 +435,10 @@ class BlockMask:
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(causal_mask, 4, 2, 512, 512, device="cuda")
block_mask = create_block_mask(
causal_mask, 4, 2, 512, 512, device="cuda"
)
assert block_mask.kv_num_blocks.shape == (4, 2, 4)
assert block_mask.kv_indices.shape == (4, 2, 4, 4)
@ -454,7 +458,9 @@ class BlockMask:
assert new_block_mask.kv_indices.shape == (2, 1, 4, 4)
# slicing on batch, head, and query dimension
new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)]
new_block_mask = block_mask[
0:2, 1:2, torch.tensor([1], dtype=torch.int32)
]
assert new_block_mask.kv_num_blocks.shape == (2, 1, 1)
assert new_block_mask.kv_indices.shape == (2, 1, 1, 4)
"""
@ -857,6 +863,7 @@ def create_block_mask(
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda")
query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
@ -864,9 +871,9 @@ def create_block_mask(
output = flex_attention(query, key, value, block_mask=block_mask)
"""
mod_type = _get_mod_type(mask_mod)
assert (
mod_type == _ModificationType.MASK_MOD
), f"create-block_mask requires a mask_mod function! Got {mask_mod}"
assert mod_type == _ModificationType.MASK_MOD, (
f"create-block_mask requires a mask_mod function! Got {mask_mod}"
)
if B is None:
B = 1
if H is None:
@ -962,7 +969,10 @@ def _nested_mod_func_adapter(
kv_seq_idx = q_seq_idx
else:
# cross attention case
kv_seq_idx = _build_seq_idx(kv_offsets, kv_nt._values.shape[kv_nt._ragged_idx - 1]) # type: ignore[attr-defined]
kv_seq_idx = _build_seq_idx(
kv_offsets,
kv_nt._values.shape[kv_nt._ragged_idx - 1], # type: ignore[attr-defined]
)
# Converts q_idx / kv_idx from [0, total_length) -> [0, S), where S refers
# to the sequence length for each sequence in the NJT, for use in given
@ -1039,10 +1049,14 @@ def create_nested_block_mask(
key = torch.nested.nested_tensor(..., layout=torch.jagged)
value = torch.nested.nested_tensor(..., layout=torch.jagged)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
block_mask = create_nested_block_mask(
causal_mask, 1, 1, query, _compile=True
)
output = flex_attention(query, key, value, block_mask=block_mask)
.. code-block:: python
@ -1052,11 +1066,15 @@ def create_nested_block_mask(
key = torch.nested.nested_tensor(..., layout=torch.jagged)
value = torch.nested.nested_tensor(..., layout=torch.jagged)
def causal_mask(b, h, q_idx, kv_idx):
return q_idx >= kv_idx
# cross attention case: pass both query and key/value NJTs
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, key, _compile=True)
block_mask = create_nested_block_mask(
causal_mask, 1, 1, query, key, _compile=True
)
output = flex_attention(query, key, value, block_mask=block_mask)
"""
# use same structure for kv as for q by default
@ -1381,7 +1399,13 @@ def flex_attention(
torch._dynamo.mark_static(x, -1)
out, lse = flex_attention_hop(
query, key, value, score_mod, block_mask.as_tuple(), scale, kernel_options # type: ignore[union-attr]
query,
key,
value,
score_mod,
block_mask.as_tuple(),
scale,
kernel_options, # type: ignore[union-attr]
)
if return_lse:
return out, lse * math.log(2)

View File

@ -55,9 +55,7 @@ Note:
Note:
This operator supports complex data types i.e. ``complex32, complex64, complex128``.
""".format(
**reproducibility_notes, **tf32_notes
)
""".format(**reproducibility_notes, **tf32_notes)
+ r"""
Args:
@ -106,9 +104,7 @@ Note:
Note:
This operator supports complex data types i.e. ``complex32, complex64, complex128``.
""".format(
**reproducibility_notes, **tf32_notes
)
""".format(**reproducibility_notes, **tf32_notes)
+ r"""
Args:
@ -159,9 +155,7 @@ Note:
Note:
This operator supports complex data types i.e. ``complex32, complex64, complex128``.
""".format(
**reproducibility_notes, **tf32_notes
)
""".format(**reproducibility_notes, **tf32_notes)
+ r"""
Args:
@ -208,9 +202,7 @@ See :class:`~torch.nn.ConvTranspose1d` for details and output shape.
Note:
{cudnn_reproducibility_note}
""".format(
**reproducibility_notes, **tf32_notes
)
""".format(**reproducibility_notes, **tf32_notes)
+ r"""
Args:
@ -251,9 +243,7 @@ See :class:`~torch.nn.ConvTranspose2d` for details and output shape.
Note:
{cudnn_reproducibility_note}
""".format(
**reproducibility_notes, **tf32_notes
)
""".format(**reproducibility_notes, **tf32_notes)
+ r"""
Args:
@ -296,9 +286,7 @@ See :class:`~torch.nn.ConvTranspose3d` for details and output shape.
Note:
{cudnn_reproducibility_note}
""".format(
**reproducibility_notes, **tf32_notes
)
""".format(**reproducibility_notes, **tf32_notes)
+ r"""
Args:
@ -2335,9 +2323,7 @@ Shape:
- Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)`
- Bias: :math:`(out\_features)` or :math:`()`
- Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight
""".format(
**sparse_support_notes
),
""".format(**sparse_support_notes),
)
@ -2535,13 +2521,13 @@ def embedding(
)
if padding_idx is not None:
if padding_idx > 0:
assert padding_idx < weight.size(
0
), "Padding_idx must be within num_embeddings"
assert padding_idx < weight.size(0), (
"Padding_idx must be within num_embeddings"
)
elif padding_idx < 0:
assert padding_idx >= -weight.size(
0
), "Padding_idx must be within num_embeddings"
assert padding_idx >= -weight.size(0), (
"Padding_idx must be within num_embeddings"
)
padding_idx = weight.size(0) + padding_idx
else:
padding_idx = -1
@ -5800,15 +5786,15 @@ def _in_projection(
Eq,
Ev,
), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
assert b_q is None or b_q.shape == (
Eq,
), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
assert b_k is None or b_k.shape == (
Eq,
), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
assert b_v is None or b_v.shape == (
Eq,
), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
assert b_q is None or b_q.shape == (Eq,), (
f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
)
assert b_k is None or b_k.shape == (Eq,), (
f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
)
assert b_v is None or b_v.shape == (Eq,), (
f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
)
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
@ -5914,9 +5900,7 @@ scaled_dot_product_attention = _add_docstr(
Note:
{cudnn_reproducibility_note}
""".format(
**reproducibility_notes
)
""".format(**reproducibility_notes)
+ r"""
Args:
query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
@ -6026,9 +6010,9 @@ def _mha_shape_check(
)
if attn_mask.dim() == 3:
expected_shape = (num_heads, query.shape[0], key.shape[0])
assert (
attn_mask.shape == expected_shape
), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}"
assert attn_mask.shape == expected_shape, (
f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}"
)
else:
raise AssertionError(
f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor"
@ -6289,45 +6273,45 @@ def multi_head_attention_forward(
# longer causal.
is_causal = False
assert (
embed_dim == embed_dim_to_check
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
assert embed_dim == embed_dim_to_check, (
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
)
if isinstance(embed_dim, torch.Tensor):
# embed_dim can be a tensor when JIT tracing
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
else:
head_dim = embed_dim // num_heads
assert (
head_dim * num_heads == embed_dim
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
assert head_dim * num_heads == embed_dim, (
f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
)
if use_separate_proj_weight:
# allow MHA to have different embedding dimensions when separate projection weights are used
assert (
key.shape[:2] == value.shape[:2]
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
assert key.shape[:2] == value.shape[:2], (
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
)
else:
assert (
key.shape == value.shape
), f"key shape {key.shape} does not match value shape {value.shape}"
assert key.shape == value.shape, (
f"key shape {key.shape} does not match value shape {value.shape}"
)
#
# compute in-projection
#
if not use_separate_proj_weight:
assert (
in_proj_weight is not None
), "use_separate_proj_weight is False but in_proj_weight is None"
assert in_proj_weight is not None, (
"use_separate_proj_weight is False but in_proj_weight is None"
)
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
else:
assert (
q_proj_weight is not None
), "use_separate_proj_weight is True but q_proj_weight is None"
assert (
k_proj_weight is not None
), "use_separate_proj_weight is True but k_proj_weight is None"
assert (
v_proj_weight is not None
), "use_separate_proj_weight is True but v_proj_weight is None"
assert q_proj_weight is not None, (
"use_separate_proj_weight is True but q_proj_weight is None"
)
assert k_proj_weight is not None, (
"use_separate_proj_weight is True but k_proj_weight is None"
)
assert v_proj_weight is not None, (
"use_separate_proj_weight is True but v_proj_weight is None"
)
if in_proj_bias is None:
b_q = b_k = b_v = None
else:
@ -6388,23 +6372,23 @@ def multi_head_attention_forward(
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert (
static_k.size(0) == bsz * num_heads
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
assert (
static_k.size(2) == head_dim
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
assert static_k.size(0) == bsz * num_heads, (
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
)
assert static_k.size(2) == head_dim, (
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
)
k = static_k
if static_v is None:
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
else:
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
assert (
static_v.size(0) == bsz * num_heads
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
assert (
static_v.size(2) == head_dim
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
assert static_v.size(0) == bsz * num_heads, (
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
)
assert static_v.size(2) == head_dim, (
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
)
v = static_v
# add zero attention along batch dimension (now first)
@ -6451,9 +6435,9 @@ def multi_head_attention_forward(
_B, _Nt, E = q.shape
q_scaled = q * math.sqrt(1.0 / float(E))
assert not (
is_causal and attn_mask is None
), "FIXME: is_causal not implemented for need_weights"
assert not (is_causal and attn_mask is None), (
"FIXME: is_causal not implemented for need_weights"
)
if attn_mask is not None:
attn_output_weights = torch.baddbmm(

View File

@ -168,7 +168,9 @@ def calculate_gain(
param: optional parameter for the non-linear function
Examples:
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
>>> gain = nn.init.calculate_gain(
... "leaky_relu", 0.2
... ) # leaky_relu with negative_slope=0.2
.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
"""
@ -456,7 +458,7 @@ def xavier_uniform_(
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain("relu"))
Note:
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
@ -555,7 +557,7 @@ def kaiming_uniform_(
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
>>> nn.init.kaiming_uniform_(w, mode="fan_in", nonlinearity="relu")
Note:
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
@ -620,7 +622,7 @@ def kaiming_normal_(
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
>>> nn.init.kaiming_normal_(w, mode="fan_out", nonlinearity="relu")
Note:
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming

View File

@ -1077,9 +1077,9 @@ class MultiheadAttention(Module):
self.dropout = dropout
self.batch_first = batch_first
self.head_dim = embed_dim // num_heads
assert (
self.head_dim * num_heads == self.embed_dim
), "embed_dim must be divisible by num_heads"
assert self.head_dim * num_heads == self.embed_dim, (
"embed_dim must be divisible by num_heads"
)
if not self._qkv_same_embed_dim:
self.q_proj_weight = Parameter(
@ -1276,8 +1276,10 @@ class MultiheadAttention(Module):
elif query.is_nested and (
key_padding_mask is not None or attn_mask is not None
):
why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
why_not_fast_path = (
"supplying both src_key_padding_mask and src_mask at the same time \
is not supported with NestedTensor input"
)
elif torch.is_autocast_enabled():
why_not_fast_path = "autocast is enabled"

View File

@ -18,13 +18,15 @@ _ASMoutput = namedtuple("_ASMoutput", ["output", "loss"])
class AdaptiveLogSoftmaxWithLoss(Module):
(
"""Efficient softmax approximation.
As described in
`Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin,
Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou
<https://arxiv.org/abs/1609.04309>`__.
""" r"""
"""
r"""
Adaptive softmax is an approximate strategy for training models with large
output spaces. It is most effective when the label distribution is highly
imbalanced, for example in natural language modelling, where the word
@ -104,6 +106,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
.. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law
"""
)
in_features: int
n_classes: int
@ -182,8 +185,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
if targ_dim == 1:
if input_.size(0) != target_.size(0):
raise RuntimeError(
"Input and target should have the same size "
"in the batch dimension."
"Input and target should have the same size in the batch dimension."
)
if input_.dim() != 2:
raise RuntimeError(

View File

@ -86,31 +86,30 @@ class Sequential(Module):
# for `Conv2d(20,64,5)`. Finally, the output of
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()
)
# Using Sequential with OrderedDict. This is functionally the
# same as the above code
model = nn.Sequential(OrderedDict([
('conv1', nn.Conv2d(1,20,5)),
('relu1', nn.ReLU()),
('conv2', nn.Conv2d(20,64,5)),
('relu2', nn.ReLU())
]))
model = nn.Sequential(
OrderedDict(
[
("conv1", nn.Conv2d(1, 20, 5)),
("relu1", nn.ReLU()),
("conv2", nn.Conv2d(20, 64, 5)),
("relu2", nn.ReLU()),
]
)
)
"""
_modules: dict[str, Module] # type: ignore[assignment]
@overload
def __init__(self, *args: Module) -> None:
...
def __init__(self, *args: Module) -> None: ...
@overload
def __init__(self, arg: OrderedDict[str, Module]) -> None:
...
def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
def __init__(self, *args):
super().__init__()
@ -365,12 +364,10 @@ class ModuleList(Module):
return str(idx)
@overload
def __getitem__(self, idx: slice) -> ModuleList:
...
def __getitem__(self, idx: slice) -> ModuleList: ...
@overload
def __getitem__(self, idx: int) -> Module:
...
def __getitem__(self, idx: int) -> Module: ...
@_copy_to_script_wrapper
def __getitem__(self, idx: Union[int, slice]) -> Union[Module, ModuleList]:
@ -521,14 +518,12 @@ class ModuleDict(Module):
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.choices = nn.ModuleDict({
'conv': nn.Conv2d(10, 10, 3),
'pool': nn.MaxPool2d(3)
})
self.activations = nn.ModuleDict([
['lrelu', nn.LeakyReLU()],
['prelu', nn.PReLU()]
])
self.choices = nn.ModuleDict(
{"conv": nn.Conv2d(10, 10, 3), "pool": nn.MaxPool2d(3)}
)
self.activations = nn.ModuleDict(
[["lrelu", nn.LeakyReLU()], ["prelu", nn.PReLU()]]
)
def forward(self, x, choice, act):
x = self.choices[choice](x)
@ -653,7 +648,9 @@ class ParameterList(Module):
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
self.params = nn.ParameterList(
[nn.Parameter(torch.randn(10, 10)) for i in range(10)]
)
def forward(self, x):
# ParameterList can act as an iterable, or be indexed using ints
@ -678,12 +675,10 @@ class ParameterList(Module):
return str(idx)
@overload
def __getitem__(self, idx: int) -> Any:
...
def __getitem__(self, idx: int) -> Any: ...
@overload
def __getitem__(self: T, idx: slice) -> T:
...
def __getitem__(self: T, idx: slice) -> T: ...
def __getitem__(self, idx):
if isinstance(idx, slice):
@ -805,10 +800,12 @@ class ParameterDict(Module):
class MyModule(nn.Module):
def __init__(self) -> None:
super().__init__()
self.params = nn.ParameterDict({
'left': nn.Parameter(torch.randn(5, 10)),
'right': nn.Parameter(torch.randn(5, 10))
})
self.params = nn.ParameterDict(
{
"left": nn.Parameter(torch.randn(5, 10)),
"right": nn.Parameter(torch.randn(5, 10)),
}
)
def forward(self, x, choice):
x = self.params[choice].mm(x)

View File

@ -66,8 +66,9 @@ class _ConvNd(Module):
]
__annotations__ = {"bias": Optional[torch.Tensor]}
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body]
...
def _conv_forward( # type: ignore[empty-body]
self, input: Tensor, weight: Tensor, bias: Optional[Tensor]
) -> Tensor: ...
in_channels: int
_reversed_padding_repeated_twice: list[int]
@ -187,10 +188,7 @@ class _ConvNd(Module):
init.uniform_(self.bias, -bound, bound)
def extra_repr(self):
s = (
"{in_channels}, {out_channels}, kernel_size={kernel_size}"
", stride={stride}"
)
s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}"
if self.padding != (0,) * len(self.padding):
s += ", padding={padding}"
if self.dilation != (1,) * len(self.dilation):
@ -279,9 +277,7 @@ class Conv1d(_ConvNd):
padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
""".format(
**reproducibility_notes, **convolution_notes
)
""".format(**reproducibility_notes, **convolution_notes)
+ r"""
Shape:
@ -450,9 +446,7 @@ class Conv2d(_ConvNd):
output. Default: ``True``
padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
""".format(
**reproducibility_notes, **convolution_notes
)
""".format(**reproducibility_notes, **convolution_notes)
+ r"""
Shape:
@ -619,9 +613,7 @@ class Conv3d(_ConvNd):
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
""".format(
**reproducibility_notes, **convolution_notes
)
""".format(**reproducibility_notes, **convolution_notes)
+ r"""
Shape:
@ -883,9 +875,7 @@ class ConvTranspose1d(_ConvTransposeNd):
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
""".format(
**reproducibility_notes, **convolution_notes
)
""".format(**reproducibility_notes, **convolution_notes)
+ r"""
Shape:
@ -1051,9 +1041,7 @@ class ConvTranspose2d(_ConvTransposeNd):
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
""".format(
**reproducibility_notes, **convolution_notes
)
""".format(**reproducibility_notes, **convolution_notes)
+ r"""
Shape:
@ -1249,9 +1237,7 @@ class ConvTranspose3d(_ConvTransposeNd):
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
""".format(
**reproducibility_notes, **convolution_notes
)
""".format(**reproducibility_notes, **convolution_notes)
+ r"""
Shape:

View File

@ -96,8 +96,8 @@ class Unflatten(Module):
>>> output.size()
torch.Size([2, 2, 5, 5])
>>> # With namedshape (tuple of tuples)
>>> input = torch.randn(2, 50, names=('N', 'features'))
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
>>> input = torch.randn(2, 50, names=("N", "features"))
>>> unflatten = nn.Unflatten("features", (("C", 2), ("H", 5), ("W", 5)))
>>> output = unflatten(input)
>>> output.size()
torch.Size([2, 2, 5, 5])

View File

@ -9,6 +9,7 @@ __all__ = ["Fold", "Unfold"]
class Fold(Module):
(
r"""Combines an array of sliding local blocks into a large containing tensor.
Consider a batched :attr:`input` tensor containing sliding local blocks,
@ -42,10 +43,12 @@ class Fold(Module):
* :attr:`padding` controls the amount of implicit zero-paddings on both
sides for :attr:`padding` number of points for each dimension before
reshaping.
""" """
"""
"""
* :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
""" r"""
"""
r"""
Args:
output_size (int or tuple): the shape of the spatial dimensions of the
output (i.e., ``output.sizes()[2:]``)
@ -119,6 +122,7 @@ class Fold(Module):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
)
__constants__ = ["output_size", "kernel_size", "dilation", "padding", "stride"]
output_size: _size_any_t
@ -162,6 +166,7 @@ class Fold(Module):
class Unfold(Module):
(
r"""Extracts sliding local blocks from a batched input tensor.
Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`,
@ -194,10 +199,12 @@ class Unfold(Module):
* :attr:`padding` controls the amount of implicit zero-paddings on both
sides for :attr:`padding` number of points for each dimension before
reshaping.
""" """
"""
"""
* :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
""" r"""
"""
r"""
Args:
kernel_size (int or tuple): the size of the sliding blocks
dilation (int or tuple, optional): a parameter that controls the
@ -283,6 +290,7 @@ class Unfold(Module):
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
"""
)
__constants__ = ["kernel_size", "dilation", "padding", "stride"]
kernel_size: _size_any_t

View File

@ -15,11 +15,9 @@ class _LazyProtocol(Protocol):
https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
"""
def _register_load_state_dict_pre_hook(self, hook):
...
def _register_load_state_dict_pre_hook(self, hook): ...
def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False):
...
def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False): ...
def _lazy_load_hook(
self,
@ -30,34 +28,26 @@ class _LazyProtocol(Protocol):
missing_keys,
unexpected_keys,
error_msgs,
):
...
): ...
def _get_name(self):
...
def _get_name(self): ...
def _infer_parameters(self, module, input):
...
def _infer_parameters(self, module, input): ...
@property
def _parameters(self):
...
def _parameters(self): ...
@property
def _buffers(self):
...
def _buffers(self): ...
@property
def _non_persistent_buffers_set(self):
...
def _non_persistent_buffers_set(self): ...
@property
def _load_hook(self):
...
def _load_hook(self): ...
@property
def _initialize_hook(self):
...
def _initialize_hook(self): ...
class LazyModuleMixin:

View File

@ -119,6 +119,7 @@ class L1Loss(_Loss):
>>> output = loss(input, target)
>>> output.backward()
"""
__constants__ = ["reduction"]
def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
@ -233,6 +234,7 @@ class NLLLoss(_WeightedLoss):
>>> loss = loss_fn(output, target)
>>> loss.backward()
"""
__constants__ = ["ignore_index", "reduction"]
ignore_index: int
@ -331,6 +333,7 @@ class PoissonNLLLoss(_Loss):
- Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(*)`,
the same shape as the input.
"""
__constants__ = ["log_input", "full", "eps", "reduction"]
log_input: bool
full: bool
@ -427,6 +430,7 @@ class GaussianNLLLoss(_Loss):
Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60
vol.1, doi: 10.1109/ICNN.1994.374138.
"""
__constants__ = ["full", "eps", "reduction"]
full: bool
eps: float
@ -527,6 +531,7 @@ class KLDivLoss(_Loss):
>>> log_target = F.log_softmax(torch.rand(3, 5), dim=1)
>>> output = kl_loss(input, log_target)
"""
__constants__ = ["reduction"]
def __init__(
@ -601,6 +606,7 @@ class MSELoss(_Loss):
>>> output = loss(input, target)
>>> output.backward()
"""
__constants__ = ["reduction"]
def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
@ -684,6 +690,7 @@ class BCELoss(_WeightedLoss):
>>> output = loss(m(input), target)
>>> output.backward()
"""
__constants__ = ["reduction"]
def __init__(
@ -876,6 +883,7 @@ class HingeEmbeddingLoss(_Loss):
- Target: :math:`(*)`, same shape as the input
- Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input
"""
__constants__ = ["margin", "reduction"]
margin: float
@ -950,6 +958,7 @@ class MultiLabelMarginLoss(_Loss):
tensor(0.85...)
"""
__constants__ = ["reduction"]
def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
@ -1030,6 +1039,7 @@ class SmoothL1Loss(_Loss):
- Target: :math:`(*)`, same shape as the input.
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input.
"""
__constants__ = ["reduction"]
def __init__(
@ -1092,6 +1102,7 @@ class HuberLoss(_Loss):
- Target: :math:`(*)`, same shape as the input.
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input.
"""
__constants__ = ["reduction", "delta"]
def __init__(self, reduction: str = "mean", delta: float = 1.0) -> None:
@ -1134,6 +1145,7 @@ class SoftMarginLoss(_Loss):
shape as input.
"""
__constants__ = ["reduction"]
def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
@ -1276,6 +1288,7 @@ class CrossEntropyLoss(_WeightedLoss):
>>> output = loss(input, target)
>>> output.backward()
"""
__constants__ = ["ignore_index", "reduction", "label_smoothing"]
ignore_index: int
label_smoothing: float
@ -1342,6 +1355,7 @@ class MultiLabelSoftMarginLoss(_WeightedLoss):
- Target: :math:`(N, C)`, label targets must have the same shape as the input.
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
"""
__constants__ = ["reduction"]
def __init__(
@ -1410,6 +1424,7 @@ class CosineEmbeddingLoss(_Loss):
>>> output = loss(input1, input2, target)
>>> output.backward()
"""
__constants__ = ["margin", "reduction"]
margin: float
@ -1475,6 +1490,7 @@ class MarginRankingLoss(_Loss):
>>> output = loss(input1, input2, target)
>>> output.backward()
"""
__constants__ = ["margin", "reduction"]
margin: float
@ -1554,6 +1570,7 @@ class MultiMarginLoss(_WeightedLoss):
>>> loss(x, y)
tensor(0.32...)
"""
__constants__ = ["p", "margin", "reduction"]
margin: float
p: int
@ -1657,6 +1674,7 @@ class TripletMarginLoss(_Loss):
.. _Learning shallow convolutional feature descriptors with triplet losses:
https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html
"""
__constants__ = ["margin", "p", "eps", "swap", "reduction"]
margin: float
p: float
@ -1794,6 +1812,7 @@ class TripletMarginWithDistanceLoss(_Loss):
V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses:
https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html
"""
__constants__ = ["margin", "swap", "reduction"]
margin: float
swap: bool
@ -1902,7 +1921,12 @@ class CTCLoss(_Loss):
>>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
>>>
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
>>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
>>> target_lengths = torch.randint(
... low=S_min,
... high=S,
... size=(N,),
... dtype=torch.long,
... )
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
@ -1919,7 +1943,12 @@ class CTCLoss(_Loss):
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
>>> target = torch.randint(
... low=1,
... high=C,
... size=(sum(target_lengths),),
... dtype=torch.long,
... )
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
@ -1936,7 +1965,12 @@ class CTCLoss(_Loss):
>>>
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
>>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
>>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
>>> target = torch.randint(
... low=1,
... high=C,
... size=(target_lengths,),
... dtype=torch.long,
... )
>>> ctc_loss = nn.CTCLoss()
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
>>> loss.backward()
@ -1963,6 +1997,7 @@ class CTCLoss(_Loss):
True``.
Please see the notes on :doc:`/notes/randomness` for background.
"""
__constants__ = ["blank", "reduction"]
blank: int
zero_infinity: bool

View File

@ -412,6 +412,7 @@ class Module:
import torch.nn as nn
import torch.nn.functional as F
class Model(nn.Module):
def __init__(self) -> None:
super().__init__()
@ -1230,16 +1231,13 @@ class Module:
device: Optional[DeviceLikeType] = ...,
dtype: Optional[dtype] = ...,
non_blocking: bool = ...,
) -> Self:
...
) -> Self: ...
@overload
def to(self, dtype: dtype, non_blocking: bool = ...) -> Self:
...
def to(self, dtype: dtype, non_blocking: bool = ...) -> Self: ...
@overload
def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self:
...
def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self: ...
def to(self, *args, **kwargs):
r"""Move and/or cast the parameters and buffers.
@ -1752,7 +1750,11 @@ class Module:
if recording_scopes:
# type ignore was added because at this point one knows that
# torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950
name = (
torch.jit._trace._trace_module_map[self] # type: ignore[index]
if self in torch.jit._trace._trace_module_map # type: ignore[operator]
else None
) # noqa: B950
if name:
tracing_state.push_scope(name)
else:
@ -2161,13 +2163,20 @@ class Module:
@overload
def state_dict(
self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...
) -> T_destination:
...
self,
*,
destination: T_destination,
prefix: str = ...,
keep_vars: bool = ...,
) -> T_destination: ...
@overload
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> dict[str, Any]:
...
def state_dict(
self,
*,
prefix: str = ...,
keep_vars: bool = ...,
) -> dict[str, Any]: ...
# TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows.
# Also remove the logic for arg parsing together.

View File

@ -358,6 +358,7 @@ class RMSNorm(Module):
>>> rms_norm(input)
"""
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
normalized_shape: tuple[int, ...]
eps: Optional[float]

View File

@ -253,7 +253,8 @@ class RNNBase(Module):
# alias would break the assumptions of the uniqueness check in
# Module.named_parameters().
unique_data_ptrs = {
p.data_ptr() for p in self._flat_weights # type: ignore[union-attr]
p.data_ptr() # type: ignore[union-attr]
for p in self._flat_weights
}
if len(unique_data_ptrs) != len(self._flat_weights):
return
@ -611,12 +612,10 @@ class RNN(RNNBase):
bidirectional: bool = False,
device=None,
dtype=None,
) -> None:
...
) -> None: ...
@overload
def __init__(self, *args, **kwargs):
...
def __init__(self, *args, **kwargs): ...
def __init__(self, *args, **kwargs):
if "proj_size" in kwargs:
@ -969,12 +968,10 @@ class LSTM(RNNBase):
proj_size: int = 0,
device=None,
dtype=None,
) -> None:
...
) -> None: ...
@overload
def __init__(self, *args, **kwargs):
...
def __init__(self, *args, **kwargs): ...
def __init__(self, *args, **kwargs):
super().__init__("LSTM", *args, **kwargs)
@ -1304,12 +1301,10 @@ class GRU(RNNBase):
bidirectional: bool = False,
device=None,
dtype=None,
) -> None:
...
) -> None: ...
@overload
def __init__(self, *args, **kwargs):
...
def __init__(self, *args, **kwargs): ...
def __init__(self, *args, **kwargs):
if "proj_size" in kwargs:

View File

@ -59,9 +59,11 @@ class Embedding(Module):
embedding = nn.Embedding(n, d, max_norm=1.0)
W = torch.randn((m, d), requires_grad=True)
idx = torch.tensor([1, 2])
a = embedding.weight.clone() @ W.t() # weight must be cloned for this to be differentiable
a = (
embedding.weight.clone() @ W.t()
) # weight must be cloned for this to be differentiable
b = embedding(idx) @ W.t() # modifies weight in-place
out = (a.unsqueeze(0) + b.unsqueeze(1))
out = a.unsqueeze(0) + b.unsqueeze(1)
loss = out.sigmoid().prod()
loss.backward()
@ -150,13 +152,13 @@ class Embedding(Module):
self.embedding_dim = embedding_dim
if padding_idx is not None:
if padding_idx > 0:
assert (
padding_idx < self.num_embeddings
), "Padding_idx must be within num_embeddings"
assert padding_idx < self.num_embeddings, (
"Padding_idx must be within num_embeddings"
)
elif padding_idx < 0:
assert (
padding_idx >= -self.num_embeddings
), "Padding_idx must be within num_embeddings"
assert padding_idx >= -self.num_embeddings, (
"Padding_idx must be within num_embeddings"
)
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
self.max_norm = max_norm
@ -248,9 +250,9 @@ class Embedding(Module):
>>> embedding(input)
tensor([[ 4.0000, 5.1000, 6.3000]])
"""
assert (
embeddings.dim() == 2
), "Embeddings parameter is expected to be 2-dimensional"
assert embeddings.dim() == 2, (
"Embeddings parameter is expected to be 2-dimensional"
)
rows, cols = embeddings.shape
embedding = cls(
num_embeddings=rows,
@ -391,13 +393,13 @@ class EmbeddingBag(Module):
self.scale_grad_by_freq = scale_grad_by_freq
if padding_idx is not None:
if padding_idx > 0:
assert (
padding_idx < self.num_embeddings
), "padding_idx must be within num_embeddings"
assert padding_idx < self.num_embeddings, (
"padding_idx must be within num_embeddings"
)
elif padding_idx < 0:
assert (
padding_idx >= -self.num_embeddings
), "padding_idx must be within num_embeddings"
assert padding_idx >= -self.num_embeddings, (
"padding_idx must be within num_embeddings"
)
padding_idx = self.num_embeddings + padding_idx
self.padding_idx = padding_idx
if _weight is None:
@ -526,9 +528,9 @@ class EmbeddingBag(Module):
>>> embeddingbag(input)
tensor([[ 2.5000, 3.7000, 4.6500]])
"""
assert (
embeddings.dim() == 2
), "Embeddings parameter is expected to be 2-dimensional"
assert embeddings.dim() == 2, (
"Embeddings parameter is expected to be 2-dimensional"
)
rows, cols = embeddings.shape
embeddingbag = cls(
num_embeddings=rows,

View File

@ -256,7 +256,9 @@ class Transformer(Module):
Examples:
>>> # xdoctest: +SKIP
>>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
>>> output = transformer_model(
... src, tgt, src_mask=src_mask, tgt_mask=tgt_mask
... )
"""
is_batched = src.dim() == 3
if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
@ -686,7 +688,9 @@ class TransformerEncoderLayer(Module):
>>> out = encoder_layer(src)
Alternatively, when ``batch_first`` is ``True``:
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
>>> encoder_layer = nn.TransformerEncoderLayer(
... d_model=512, nhead=8, batch_first=True
... )
>>> src = torch.rand(32, 10, 512)
>>> out = encoder_layer(src)
@ -994,7 +998,9 @@ class TransformerDecoderLayer(Module):
>>> out = decoder_layer(tgt, memory)
Alternatively, when ``batch_first`` is ``True``:
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
>>> decoder_layer = nn.TransformerDecoderLayer(
... d_model=512, nhead=8, batch_first=True
... )
>>> memory = torch.rand(32, 10, 512)
>>> tgt = torch.rand(32, 20, 512)
>>> out = decoder_layer(tgt, memory)

View File

@ -11,9 +11,9 @@ from torch.nn.parallel import comm
class Broadcast(Function):
@staticmethod
def forward(ctx, target_gpus, *inputs):
assert all(
i.device.type != "cpu" for i in inputs
), "Broadcast function not implemented for CPU tensors"
assert all(i.device.type != "cpu" for i in inputs), (
"Broadcast function not implemented for CPU tensors"
)
target_gpus = [_get_device_index(x, True) for x in target_gpus]
ctx.target_gpus = target_gpus
if len(inputs) == 0:
@ -56,9 +56,9 @@ class ReduceAddCoalesced(Function):
class Gather(Function):
@staticmethod
def forward(ctx, target_device, dim, *inputs):
assert all(
i.device.type != "cpu" for i in inputs
), "Gather function not implemented for CPU tensors"
assert all(i.device.type != "cpu" for i in inputs), (
"Gather function not implemented for CPU tensors"
)
if target_device == "cpu":
ctx.target_device = "cpu"
else:

View File

@ -759,7 +759,7 @@ class DistributedDataParallel(Module, Joinable):
"DistributedDataParallel device_ids and output_device arguments "
"only work with single-device/multiple-device GPU modules or CPU modules, "
f"but got device_ids {device_ids}, output_device {output_device}, "
f"and module parameters {({p.device for p in self._module_parameters})}.",
f"and module parameters { ({p.device for p in self._module_parameters}) }.", # noqa: E201,E202
)
self.device_ids = None

View File

@ -46,9 +46,9 @@ def parallel_apply(
element of :attr:`inputs` can either be a single object as the only argument
to a module, or a collection of positional arguments.
"""
assert len(modules) == len(
inputs
), f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}"
assert len(modules) == len(inputs), (
f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}"
)
if kwargs_tup is not None:
assert len(modules) == len(kwargs_tup)
else:
@ -88,9 +88,11 @@ def parallel_apply(
if stream is None:
stream = torch.cuda.current_stream(device)
try:
with torch.cuda.device(device), torch.cuda.stream(
stream
), torch.amp.autocast("cuda", enabled=autocast_enabled):
with (
torch.cuda.device(device),
torch.cuda.stream(stream),
torch.amp.autocast("cuda", enabled=autocast_enabled),
):
# this also avoids accidental slicing of `input` if it is a Tensor
if not isinstance(input, (list, tuple)):
input = (input,)

View File

@ -35,8 +35,7 @@ def scatter(
inputs: torch.Tensor,
target_gpus: Sequence[Union[int, torch.device]],
dim: int = ...,
) -> tuple[torch.Tensor, ...]:
...
) -> tuple[torch.Tensor, ...]: ...
@overload
@ -44,8 +43,7 @@ def scatter(
inputs: T,
target_gpus: Sequence[Union[int, torch.device]],
dim: int = ...,
) -> list[T]:
...
) -> list[T]: ...
def scatter(inputs, target_gpus, dim=0):

View File

@ -4,6 +4,7 @@ r"""QAT Dynamic Modules.
This package is in the process of being deprecated.
Please, use `torch.ao.nn.qat.dynamic` instead.
"""
from torch.nn.qat import dynamic, modules # noqa: F403
from torch.nn.qat.modules import * # noqa: F403

View File

@ -4,4 +4,5 @@ r"""QAT Dynamic Modules.
This package is in the process of being deprecated.
Please, use `torch.ao.nn.qat.dynamic` instead.
"""
from torch.nn.qat.dynamic.modules import * # noqa: F403

View File

@ -7,4 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/qat/dynamic/modules`,
while adding an import statement here.
"""
from torch.ao.nn.qat.dynamic.modules.linear import Linear

View File

@ -4,6 +4,7 @@ r"""QAT Modules.
This package is in the process of being deprecated.
Please, use `torch.ao.nn.qat.modules` instead.
"""
from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d
from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag
from torch.ao.nn.qat.modules.linear import Linear

View File

@ -7,4 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/qat/modules`,
while adding an import statement here.
"""
from torch.ao.nn.qat.modules.linear import Linear

View File

@ -7,4 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/quantizable/modules`,
while adding an import statement here.
"""
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention

View File

@ -7,4 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
while adding an import statement here.
"""
from torch.ao.nn.quantized.dynamic.modules.linear import Linear

View File

@ -314,7 +314,7 @@ def unfold3d(
Example:
>>> # xdoctest: +SKIP
>>> B, C, D, H, W = 3, 4, 5, 6, 7
>>> tensor = torch.arange(1, B * C * D * H * W + 1.).view(B, C, D, H, W)
>>> tensor = torch.arange(1, B * C * D * H * W + 1.0).view(B, C, D, H, W)
>>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape
torch.Size([3, 32, 120])
"""

View File

@ -72,8 +72,10 @@ def allow_smaller_batches(args, kwargs):
@contextmanager
def setup_rnn(use_input_variant, args, kwargs):
with batch_second(args, kwargs) if use_input_variant else allow_smaller_batches(
args, kwargs
with (
batch_second(args, kwargs)
if use_input_variant
else allow_smaller_batches(args, kwargs)
):
yield

View File

@ -49,7 +49,9 @@ def call_for_per_sample_grads(
grad_outputs by 1 / batch_size from cross batch interaction.
>>> model = nn.Linear(4, 3)
>>> batched_input = torch.randn(5, 4) # batch size of 5
>>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")(batched_input).mean()
>>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")(
... batched_input
... ).mean()
>>> res.backward()
Note::

View File

@ -150,9 +150,7 @@ def _clip_grads_with_norm_(
return
grouped_grads: dict[
tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]]
] = _group_tensors_by_device_and_dtype(
[grads]
) # type: ignore[assignment]
] = _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment]
clip_coef = max_norm / (total_norm + 1e-6)
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so

View File

@ -135,9 +135,9 @@ def fuse_linear_bn_eval(
2. the number of features in bn is 1
Otherwise, skip the folding path
"""
assert (
linear.out_features == bn.num_features or bn.num_features == 1
), "To fuse, linear.out_features == bn.num_features or bn.num_features == 1"
assert linear.out_features == bn.num_features or bn.num_features == 1, (
"To fuse, linear.out_features == bn.num_features or bn.num_features == 1"
)
assert bn.running_mean is not None and bn.running_var is not None
fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights(

View File

@ -63,12 +63,16 @@ def convert_conv2d_weight_memory_format(
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
>>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda")
>>> input = torch.randint(
... 1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda"
... )
>>> model = nn.Sequential(
>>> nn.Conv2d(8, 4, 3)).cuda().half()
>>> # This is identical to:
>>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last)
>>> model = nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last)
>>> model = nn.utils.convert_conv2d_weight_memory_format(
... model, torch.channels_last
... )
>>> out = model(input)
"""
# TODO: expand this to `_ConvNd` when channels_last support is extended
@ -137,12 +141,16 @@ def convert_conv3d_weight_memory_format(
Example:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
>>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda")
>>> input = torch.randint(
... 1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda"
... )
>>> model = nn.Sequential(
>>> nn.Conv3d(8, 4, 3)).cuda().half()
>>> # This is identical to:
>>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
>>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
>>> model = nn.utils.convert_conv3d_weight_memory_format(
... model, torch.channels_last_3d
... )
>>> out = model(input)
"""

View File

@ -46,6 +46,7 @@ def cached():
.. code-block:: python
import torch.nn.utils.parametrize as P
...
with P.cached():
output = model(inputs)
@ -536,7 +537,9 @@ def register_parametrization(
>>> s0_sqrt = S[0].sqrt().unsqueeze(-1)
>>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
>>>
>>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
>>> linear_rank_one = P.register_parametrization(
... nn.Linear(4, 4), "weight", RankOne()
... )
>>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
1

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Pruning methods."""
import numbers
from abc import ABC, abstractmethod
from collections.abc import Iterable
@ -63,9 +64,9 @@ class BasePruningMethod(ABC):
"""
# to carry out the multiplication, the mask needs to have been computed,
# so the pruning method must know what tensor it's operating on
assert (
self._tensor_name is not None
), f"Module {module} has to be pruned" # this gets set in apply()
assert self._tensor_name is not None, (
f"Module {module} has to be pruned"
) # this gets set in apply()
mask = getattr(module, self._tensor_name + "_mask")
orig = getattr(module, self._tensor_name + "_orig")
pruned_tensor = mask.to(dtype=orig.dtype) * orig
@ -109,10 +110,10 @@ class BasePruningMethod(ABC):
old_method = hook
hooks_to_remove.append(k)
found += 1
assert (
found <= 1
), f"Avoid adding multiple pruning hooks to the\
assert found <= 1, (
f"Avoid adding multiple pruning hooks to the\
same tensor {name} of module {module}. Use a PruningContainer."
)
for k in hooks_to_remove:
del module._forward_pre_hooks[k]
@ -153,9 +154,9 @@ class BasePruningMethod(ABC):
orig = getattr(module, name)
if importance_scores is not None:
assert (
importance_scores.shape == orig.shape
), f"importance_scores should have the same shape as parameter {name} of {module}"
assert importance_scores.shape == orig.shape, (
f"importance_scores should have the same shape as parameter {name} of {module}"
)
else:
importance_scores = orig
@ -222,9 +223,9 @@ class BasePruningMethod(ABC):
pruned version of tensor ``t``.
"""
if importance_scores is not None:
assert (
importance_scores.shape == t.shape
), "importance_scores should have the same shape as tensor t"
assert importance_scores.shape == t.shape, (
"importance_scores should have the same shape as tensor t"
)
else:
importance_scores = t
default_mask = default_mask if default_mask is not None else torch.ones_like(t)
@ -241,9 +242,9 @@ class BasePruningMethod(ABC):
Pruning itself is NOT undone or reversed!
"""
# before removing pruning from a tensor, it has to have been applied
assert (
self._tensor_name is not None
), f"Module {module} has to be pruned before pruning can be removed" # this gets set in apply()
assert self._tensor_name is not None, (
f"Module {module} has to be pruned before pruning can be removed"
) # this gets set in apply()
# to update module[name] to latest trained weights
weight = self.apply_mask(module) # masked weights
@ -846,7 +847,7 @@ def identity(module, name):
Examples:
>>> # xdoctest: +SKIP
>>> m = prune.identity(nn.Linear(2, 3), 'bias')
>>> m = prune.identity(nn.Linear(2, 3), "bias")
>>> print(m.bias_mask)
tensor([1., 1., 1.])
"""
@ -882,7 +883,7 @@ def random_unstructured(module, name, amount):
Examples:
>>> # xdoctest: +SKIP
>>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)
>>> m = prune.random_unstructured(nn.Linear(2, 3), "weight", amount=1)
>>> torch.sum(m.weight_mask == 0)
tensor(1)
@ -925,7 +926,7 @@ def l1_unstructured(module, name, amount, importance_scores=None):
Examples:
>>> # xdoctest: +SKIP
>>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2)
>>> m = prune.l1_unstructured(nn.Linear(2, 3), "weight", amount=0.2)
>>> m.state_dict().keys()
odict_keys(['bias', 'weight_orig', 'weight_mask'])
"""
@ -965,9 +966,7 @@ def random_structured(module, name, amount, dim):
Examples:
>>> # xdoctest: +SKIP
>>> m = prune.random_structured(
... nn.Linear(5, 3), 'weight', amount=3, dim=1
... )
>>> m = prune.random_structured(nn.Linear(5, 3), "weight", amount=3, dim=1)
>>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0))
>>> print(columns_pruned)
3
@ -1014,7 +1013,7 @@ def ln_structured(module, name, amount, n, dim, importance_scores=None):
Examples:
>>> from torch.nn.utils import prune
>>> m = prune.ln_structured(
... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf')
... nn.Conv2d(5, 3, 2), "weight", amount=0.3, dim=1, n=float("-inf")
... )
"""
LnStructured.apply(
@ -1067,13 +1066,17 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw
Examples:
>>> from torch.nn.utils import prune
>>> from collections import OrderedDict
>>> net = nn.Sequential(OrderedDict([
... ('first', nn.Linear(10, 4)),
... ('second', nn.Linear(4, 1)),
... ]))
>>> net = nn.Sequential(
... OrderedDict(
... [
... ("first", nn.Linear(10, 4)),
... ("second", nn.Linear(4, 1)),
... ]
... )
... )
>>> parameters_to_prune = (
... (net.first, 'weight'),
... (net.second, 'weight'),
... (net.first, "weight"),
... (net.second, "weight"),
... )
>>> prune.global_unstructured(
... parameters_to_prune,
@ -1165,7 +1168,7 @@ def custom_from_mask(module, name, mask):
Examples:
>>> from torch.nn.utils import prune
>>> m = prune.custom_from_mask(
... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
... nn.Linear(5, 3), name="bias", mask=torch.tensor([0, 1, 0])
... )
>>> print(m.bias_mask)
tensor([0., 1., 0.])
@ -1191,8 +1194,8 @@ def remove(module, name):
will act.
Examples:
>>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2)
>>> m = remove(m, name='weight')
>>> m = random_unstructured(nn.Linear(5, 7), name="weight", amount=0.2)
>>> m = remove(m, name="weight")
"""
for k, hook in module._forward_pre_hooks.items():
if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
@ -1223,7 +1226,7 @@ def is_pruned(module):
>>> m = nn.Linear(5, 7)
>>> print(prune.is_pruned(m))
False
>>> prune.random_unstructured(m, name='weight', amount=0.2)
>>> prune.random_unstructured(m, name="weight", amount=0.2)
>>> print(prune.is_pruned(m))
True
"""

View File

@ -105,8 +105,7 @@ class PackedSequence(PackedSequence_):
dtype: torch.dtype,
non_blocking: bool = ...,
copy: bool = ...,
) -> Self:
...
) -> Self: ...
@overload
def to(
@ -115,8 +114,7 @@ class PackedSequence(PackedSequence_):
dtype: Optional[torch.dtype] = ...,
non_blocking: bool = ...,
copy: bool = ...,
) -> Self:
...
) -> Self: ...
@overload
def to(
@ -124,8 +122,7 @@ class PackedSequence(PackedSequence_):
other: Tensor,
non_blocking: bool = ...,
copy: bool = ...,
) -> Self:
...
) -> Self: ...
def to(self, *args: Any, **kwargs: Any) -> Self:
r"""Perform dtype and/or device conversion on `self.data`.
@ -354,7 +351,9 @@ def pad_packed_sequence(
>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
>>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
>>> lens = [2, 1, 3]
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
>>> packed = pack_padded_sequence(
... seq, lens, batch_first=True, enforce_sorted=False
... )
>>> packed
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
@ -473,7 +472,10 @@ def pad_sequence(
# assuming trailing dimensions and type of all the Tensors
# in sequences are same and fetching those from sequences[0]
return torch._C._nn.pad_sequence(
sequences, batch_first, padding_value, padding_side # type: ignore[arg-type]
sequences, # type: ignore[arg-type]
batch_first,
padding_value,
padding_side, # type: ignore[arg-type]
)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
"""Spectral Normalization from https://arxiv.org/abs/1802.05957."""
from typing import Any, Optional, TypeVar
import torch

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Weight Normalization from https://arxiv.org/abs/1602.07868."""
from typing import Any, TypeVar
from typing_extensions import deprecated

View File

@ -347,9 +347,9 @@ def _single_tensor_adafactor(
maximize: bool,
has_complex: bool,
):
assert (
grad_scale is None and found_inf is None
), "Grad scaling should occur outside of optimizer.step()"
assert grad_scale is None and found_inf is None, (
"Grad scaling should occur outside of optimizer.step()"
)
if torch.jit.is_scripting():
# this assert is due to JIT being dumb and not realizing that the ops below
@ -381,9 +381,9 @@ def _single_tensor_adafactor(
param.mul_(1 - lr * weight_decay)
if grad.dim() > 1:
assert (
row_var is not None and col_var is not None
), "row_var and col_var should be defined when grad is multidimensional"
assert row_var is not None and col_var is not None, (
"row_var and col_var should be defined when grad is multidimensional"
)
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
row_mean = (
torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1))
@ -397,9 +397,9 @@ def _single_tensor_adafactor(
var_estimate = row_var @ col_var
var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1))
else:
assert (
variance is not None
), "variance should be defined when grad is a vector"
assert variance is not None, (
"variance should be defined when grad is a vector"
)
grad_squared = grad * grad
variance.lerp_(grad_squared, one_minus_beta2_t)
# avoid writing into variance during update
@ -472,9 +472,9 @@ def _multi_tensor_adafactor(
if len(params) == 0:
return
assert (
grad_scale is None and found_inf is None
), "Grad scaling should occur outside of optimizer.step()"
assert grad_scale is None and found_inf is None, (
"Grad scaling should occur outside of optimizer.step()"
)
lr = _to_scalar(lr)
@ -495,9 +495,9 @@ def _multi_tensor_adafactor(
device_grads = cast(list[Tensor], device_grads_)
device_state_steps = cast(list[Tensor], device_state_steps_)
if eps1 is None:
assert (
dtype is not None
), "dtype is needed to compute eps1 when eps1 is unset"
assert dtype is not None, (
"dtype is needed to compute eps1 when eps1 is unset"
)
eps1 = torch.finfo(dtype).eps
if TYPE_CHECKING:
@ -537,9 +537,9 @@ def _multi_tensor_adafactor(
if is_multidim:
device_row_vars = cast(list[Tensor], device_row_vars_)
device_col_vars = cast(list[Tensor], device_col_vars_)
assert (
device_row_vars[0] is not None and device_col_vars[0] is not None
), "row_var and col_var should be defined when grad is multidimensional"
assert device_row_vars[0] is not None and device_col_vars[0] is not None, (
"row_var and col_var should be defined when grad is multidimensional"
)
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
row_means = [
torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads
@ -570,9 +570,9 @@ def _multi_tensor_adafactor(
del row_var_means
else:
device_variances = cast(list[Tensor], device_variances_)
assert (
device_variances[0] is not None
), "variance should be defined when grad is a vector"
assert device_variances[0] is not None, (
"variance should be defined when grad is a vector"
)
grads_squared = torch._foreach_mul(device_grads, device_grads)
torch._foreach_lerp_(device_variances, grads_squared, one_minus_beta2_ts)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Functional interface."""
import math
from torch import Tensor

View File

@ -5,6 +5,7 @@ Most commonly used methods are already supported, and the interface is general
enough, so that more sophisticated ones can be also easily integrated in the
future.
"""
from functools import partialmethod
from torch import optim

View File

@ -267,7 +267,9 @@ def _single_tensor_adadelta(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
if not torch.jit.is_scripting():
lr = _to_scalar(lr)
@ -326,7 +328,9 @@ def _multi_tensor_adadelta(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
if len(params) == 0:
return

View File

@ -398,7 +398,9 @@ def _single_tensor_adam(
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
# update step
step_t += 1
@ -433,7 +435,9 @@ def _single_tensor_adam(
# cast to workaround https://github.com/pytorch/pytorch/issues/140601
key = (device, dtype)
if key not in beta1_dict:
beta1_dict[key] = beta1.to(device=device, dtype=dtype, non_blocking=True) # type: ignore[union-attr]
beta1_dict[key] = beta1.to( # type: ignore[union-attr]
device=device, dtype=dtype, non_blocking=True
)
device_beta1: Union[float, Tensor] = beta1_dict[key]
else:
@ -593,7 +597,9 @@ def _multi_tensor_adam(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
assert grad_scale is None and found_inf is None
@ -769,7 +775,10 @@ def _multi_tensor_adam(
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
torch._foreach_add_(exp_avg_sq_sqrt, eps)
torch._foreach_addcdiv_(
device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size # type: ignore[arg-type]
device_params,
device_exp_avgs,
exp_avg_sq_sqrt,
step_size, # type: ignore[arg-type]
)

View File

@ -256,7 +256,9 @@ def _single_tensor_adamax(
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
# update step
step_t += 1
@ -331,7 +333,9 @@ def _multi_tensor_adamax(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
lr = _to_scalar(lr)

View File

@ -305,7 +305,9 @@ def _multi_tensor_asgd(
p.device.type == mu.device.type == eta.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, mu, eta, step in zip(params, mus, etas, state_steps)
), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
)
lr = _to_scalar(lr)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Learning Rate Scheduler."""
from __future__ import annotations
import math
@ -827,7 +828,11 @@ class SequentialLR(LRScheduler):
>>> # lr = 0.0405 if epoch == 22
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
>>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[20])
>>> scheduler = SequentialLR(
... optimizer,
... schedulers=[scheduler1, scheduler2],
... milestones=[20],
... )
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
@ -1271,7 +1276,7 @@ class ReduceLROnPlateau(LRScheduler):
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
>>> scheduler = ReduceLROnPlateau(optimizer, "min")
>>> for epoch in range(10):
>>> train(...)
>>> val_loss = validate(...)
@ -1502,7 +1507,12 @@ class CyclicLR(LRScheduler):
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1, step_size_up=10)
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(
... optimizer,
... base_lr=0.01,
... max_lr=0.1,
... step_size_up=10,
... )
>>> data_loader = torch.utils.data.DataLoader(...)
>>> for epoch in range(10):
>>> for batch in data_loader:
@ -1729,7 +1739,9 @@ class CosineAnnealingWarmRestarts(LRScheduler):
Example:
>>> # xdoctest: +SKIP
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20)
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
... optimizer, T_0=20
... )
>>> for epoch in range(100):
>>> train(...)
>>> validate(...)
@ -1936,7 +1948,9 @@ class OneCycleLR(LRScheduler):
>>> # xdoctest: +SKIP
>>> data_loader = torch.utils.data.DataLoader(...)
>>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(
... optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10
... )
>>> for epoch in range(10):
>>> for batch in data_loader:
>>> train_batch(...)
@ -2141,8 +2155,6 @@ class OneCycleLR(LRScheduler):
if self.use_beta1:
group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined]
else:
group[
"momentum"
] = computed_momentum # type: ignore[possibly-undefined]
group["momentum"] = computed_momentum # type: ignore[possibly-undefined]
return lrs

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for the NAdam algorithm."""
from typing import cast, Optional, Union
import torch
@ -408,7 +409,11 @@ def _multi_tensor_nadam(
p.device.type == mp.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, mp, step in zip(params, mu_products, state_steps)
), f"If capturable=True, params, mu_products, and state_steps must be on supported devices: {capturable_supported_devices}."
), (
"If capturable=True, "
"params, mu_products, and state_steps must be on supported devices: "
f"{capturable_supported_devices}."
)
lr = _to_scalar(lr)
@ -576,10 +581,16 @@ def _multi_tensor_nadam(
)
torch._foreach_addcdiv_(
grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads # type: ignore[arg-type]
grouped_params,
grouped_grads,
exp_avg_sq_sqrt,
step_size_grads, # type: ignore[arg-type]
)
torch._foreach_addcdiv_(
grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg # type: ignore[arg-type]
grouped_params,
grouped_exp_avgs,
exp_avg_sq_sqrt,
step_size_expavg, # type: ignore[arg-type]
)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
"""Base optimizer."""
import functools
import warnings
from collections import defaultdict, OrderedDict
@ -103,7 +104,7 @@ def _stack_if_compiling(x):
def _disable_dynamo_if_unsupported(
single_tensor_fn: Optional[Callable[..., object]] = None
single_tensor_fn: Optional[Callable[..., object]] = None,
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
# workaround for torchscript BC
# it requires all called functions to be in the
@ -349,15 +350,24 @@ class Optimizer:
options (used when a parameter group doesn't specify them).
"""
OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[tuple[Args, Kwargs]]] # type: ignore[misc]
OptimizerPreHook: TypeAlias = Callable[
[Self, Args, Kwargs], # type: ignore[misc]
Optional[tuple[Args, Kwargs]],
]
OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc]
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
_optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
_optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
_optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
_optimizer_state_dict_post_hooks: (
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
)
_optimizer_load_state_dict_pre_hooks: (
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
)
_optimizer_load_state_dict_post_hooks: (
'OrderedDict[int, Callable[["Optimizer"], None]]'
)
def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: D107
torch._C._log_api_usage_once("python.optimizer")
@ -847,7 +857,9 @@ class Optimizer:
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
self._optimizer_load_state_dict_post_hooks[handle.id] = hook
if prepend:
self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
self._optimizer_load_state_dict_post_hooks.move_to_end(
handle.id, last=False
) # type: ignore[attr-defined]
return handle
@torch._disable_dynamo
@ -877,12 +889,25 @@ class Optimizer:
>>> # xdoctest: +SKIP
>>> model = torch.nn.Linear(10, 10)
>>> optim = torch.optim.SGD(model.parameters(), lr=3e-4)
>>> scheduler1 = torch.optim.lr_scheduler.LinearLR(optim, start_factor=0.1, end_factor=1, total_iters=20)
>>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=80, eta_min=3e-5)
>>> lr = torch.optim.lr_scheduler.SequentialLR(optim, schedulers=[scheduler1, scheduler2], milestones=[20])
>>> lr.load_state_dict(torch.load('./save_seq.pt'))
>>> scheduler1 = torch.optim.lr_scheduler.LinearLR(
... optim,
... start_factor=0.1,
... end_factor=1,
... total_iters=20,
... )
>>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(
... optim,
... T_max=80,
... eta_min=3e-5,
... )
>>> lr = torch.optim.lr_scheduler.SequentialLR(
... optim,
... schedulers=[scheduler1, scheduler2],
... milestones=[20],
... )
>>> lr.load_state_dict(torch.load("./save_seq.pt"))
>>> # now load the optimizer checkpoint after loading the LRScheduler
>>> optim.load_state_dict(torch.load('./save_optim.pt'))
>>> optim.load_state_dict(torch.load("./save_optim.pt"))
"""
# shallow copy, to be consistent with module API
@ -933,7 +958,10 @@ class Optimizer:
for k, v in value.items()
}
elif isinstance(value, Iterable):
return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg]
return type(value)(
_cast(param, v, param_id=param_id, param_groups=param_groups)
for v in value
) # type: ignore[call-arg]
else:
return value
@ -1021,12 +1049,10 @@ class Optimizer:
torch._foreach_zero_(grads)
@overload
def step(self, closure: None = None) -> None:
...
def step(self, closure: None = None) -> None: ...
@overload
def step(self, closure: Callable[[], float]) -> float:
...
def step(self, closure: Callable[[], float]) -> float: ...
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
r"""Perform a single optimization step to update parameter.

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for the RAdam algorithm."""
from typing import cast, Optional, Union
import torch
@ -285,7 +286,9 @@ def _single_tensor_radam(
assert (
param.device.type == step_t.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
if torch.is_complex(param):
param = torch.view_as_real(param)
@ -386,7 +389,9 @@ def _multi_tensor_radam(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
lr = _to_scalar(lr)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for the RMSprop algorithm."""
from typing import cast, Optional, Union
import torch
@ -292,7 +293,9 @@ def _single_tensor_rmsprop(
assert (
param.device.type == step.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
grad = grads[i]
grad = grad if not maximize else -grad
@ -366,7 +369,9 @@ def _multi_tensor_rmsprop(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
lr = _to_scalar(lr)

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for the Resilient backpropagation."""
from typing import cast, Optional, Union
import torch
@ -248,7 +249,9 @@ def _single_tensor_rprop(
assert (
param.device.type == step.device.type
and param.device.type in capturable_supported_devices
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
step += 1
@ -315,7 +318,9 @@ def _multi_tensor_rprop(
p.device.type == step.device.type
and p.device.type in capturable_supported_devices
for p, step in zip(params, state_steps)
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
), (
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, prevs, step_sizes, state_steps] # type: ignore[list-item]

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for Stochastic Gradient Descent optimizer."""
from typing import cast, Optional, Union
import torch
@ -397,7 +398,8 @@ def _multi_tensor_sgd(
lr = _to_scalar(lr)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, momentum_buffer_list], with_indices=True # type: ignore[list-item]
[params, grads, momentum_buffer_list], # type: ignore[list-item]
with_indices=True,
)
for (
device_params_,
@ -502,7 +504,8 @@ def _fused_sgd(
for i, g in enumerate(grads):
momentum_buffer_list[i] = torch.empty_like(g)
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
[params, grads, momentum_buffer_list], with_indices=False # type: ignore[list-item]
[params, grads, momentum_buffer_list], # type: ignore[list-item]
with_indices=False,
)
for (device, _), (
(device_params_, device_grads_, device_momentum_buffer_list),

View File

@ -37,9 +37,9 @@ class SparseAdam(Optimizer):
sparse_params = []
complex_params = []
for index, param_group in enumerate(self.param_groups):
assert isinstance(
param_group, dict
), f"param_groups must be a list of dicts, but got {type(param_group)}"
assert isinstance(param_group, dict), (
f"param_groups must be a list of dicts, but got {type(param_group)}"
)
# given param group, convert given params to a list first before iterating
for d_index, d_param in enumerate(param_group["params"]):
if d_param.is_sparse:

View File

@ -1,5 +1,6 @@
# mypy: allow-untyped-defs
r"""Implementation for Stochastic Weight Averaging implementation."""
import itertools
import math
import warnings
@ -225,9 +226,9 @@ class AveragedModel(Module):
use_buffers=False,
): # noqa: D107
super().__init__()
assert (
avg_fn is None or multi_avg_fn is None
), "Only one of avg_fn and multi_avg_fn should be provided"
assert avg_fn is None or multi_avg_fn is None, (
"Only one of avg_fn and multi_avg_fn should be provided"
)
self.module = deepcopy(model)
if device is not None:
self.module = self.module.to(device)
@ -274,7 +275,9 @@ class AveragedModel(Module):
) in grouped_tensors.items():
if self.multi_avg_fn:
self.multi_avg_fn(
self_params, model_params, self.n_averaged.to(device) # type: ignore[arg-type]
self_params, # type: ignore[arg-type]
model_params, # type: ignore[arg-type]
self.n_averaged.to(device),
)
elif (
device is not None