mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE][PYFMT] migrate PYFMT for {torch,test}/{nn,optim}/**
to ruff format
(#144548)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/144548 Approved by: https://github.com/ezyang
This commit is contained in:
committed by
PyTorch MergeBot
parent
3e38feb05f
commit
596b418391
@ -470,9 +470,9 @@ def load_torch_function_handler(cls, func, types, args=(), kwargs=None):
|
||||
return cls(src._data)
|
||||
return cls(src)
|
||||
else:
|
||||
assert isinstance(
|
||||
src, cls
|
||||
), f"Expected isinstance(src, {cls}) but got {type(src)}"
|
||||
assert isinstance(src, cls), (
|
||||
f"Expected isinstance(src, {cls}) but got {type(src)}"
|
||||
)
|
||||
assert (
|
||||
type(dest) == torch.Tensor
|
||||
or type(dest) == torch.nn.Parameter
|
||||
|
@ -1475,9 +1475,9 @@ class TestNNParametrization(NNTestCase):
|
||||
snm.load_state_dict(non_strict_state_dict, strict=False)
|
||||
del non_strict_state_dict["parametrizations.weight.0._v"]
|
||||
snm.load_state_dict(non_strict_state_dict, strict=False)
|
||||
non_strict_state_dict[
|
||||
"weight"
|
||||
] = snm.weight.detach().clone() # set W as a buffer
|
||||
non_strict_state_dict["weight"] = (
|
||||
snm.weight.detach().clone()
|
||||
) # set W as a buffer
|
||||
snm.load_state_dict(non_strict_state_dict, strict=False)
|
||||
del non_strict_state_dict._metadata[
|
||||
"parametrizations.weight.0"
|
||||
|
@ -107,9 +107,7 @@ class TestLRScheduler(TestCase):
|
||||
[0]
|
||||
+ [i + 1 for i, m in enumerate(self.milestones) if global_step >= m]
|
||||
)[-1]
|
||||
return [
|
||||
init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr
|
||||
]
|
||||
return [init_lr * (self.gamma**gamma_power) for init_lr in self.init_lr]
|
||||
|
||||
optimizer = SGD([torch.rand(1)], lr=1)
|
||||
|
||||
|
@ -62,9 +62,9 @@ def _multistep_backprop_diff_hyperparams_fn(
|
||||
kwargs: dict[str, Any],
|
||||
*ignored: Any,
|
||||
) -> tuple[Tensor, ...]:
|
||||
assert (
|
||||
kwargs["differentiable"] is True
|
||||
), "Only call this test function when differentiable=True"
|
||||
assert kwargs["differentiable"] is True, (
|
||||
"Only call this test function when differentiable=True"
|
||||
)
|
||||
|
||||
params = params.clone()
|
||||
params.grad = grad
|
||||
@ -81,9 +81,9 @@ def _multistep_backprop_diff_hyperparams_fn(
|
||||
# so they're passed in as Tensors (not a tuple) and recognized by gradcheck
|
||||
if "beta1" in kwargs or "beta2" in kwargs:
|
||||
# Prevent just one beta kwarg from being passed in
|
||||
assert (
|
||||
"beta1" in kwargs and "beta2" in kwargs
|
||||
), "Both betas should be defined in kwargs"
|
||||
assert "beta1" in kwargs and "beta2" in kwargs, (
|
||||
"Both betas should be defined in kwargs"
|
||||
)
|
||||
kwargs.update({"betas": (kwargs.pop("beta1"), kwargs.pop("beta2"))})
|
||||
|
||||
kwargs.update(
|
||||
|
@ -41,10 +41,9 @@ USE_BLACK_FILELIST = re.compile(
|
||||
"test/[a-h]*/**",
|
||||
# test/[i-j]*/**
|
||||
"test/[i-j]*/**",
|
||||
# test/[k-n]*/**
|
||||
"test/[k-n]*/**",
|
||||
# test/[k-m]*/**
|
||||
"test/[k-m]*/**",
|
||||
# test/optim/**
|
||||
"test/optim/**",
|
||||
# "test/[p-z]*/**",
|
||||
"test/[p-z]*/**",
|
||||
# torch/**
|
||||
@ -58,10 +57,9 @@ USE_BLACK_FILELIST = re.compile(
|
||||
# torch/[a-c]*/**
|
||||
"torch/[a-c]*/**",
|
||||
# torch/d*/**
|
||||
# torch/[e-n]*/**
|
||||
"torch/[e-n]*/**",
|
||||
# torch/[e-m]*/**
|
||||
"torch/[e-m]*/**",
|
||||
# torch/optim/**
|
||||
"torch/optim/**",
|
||||
# torch/[p-z]*/**
|
||||
"torch/[p-z]*/**",
|
||||
],
|
||||
|
@ -529,9 +529,9 @@ def jagged_from_tensor_and_lengths(
|
||||
)
|
||||
|
||||
# Calculate jagged offsets
|
||||
assert (
|
||||
len(tensor.shape) >= 2
|
||||
), "tensor must at least be 2D for the nested narrow op to work"
|
||||
assert len(tensor.shape) >= 2, (
|
||||
"tensor must at least be 2D for the nested narrow op to work"
|
||||
)
|
||||
max_seq_len = tensor.shape[1]
|
||||
offset_lengths = max_seq_len * torch.arange(
|
||||
0, batch_size, dtype=torch.int64, device=tensor.device
|
||||
|
@ -73,9 +73,9 @@ def _wrap_jagged_dims(ndim, dims, op_name, ragged_idx=1):
|
||||
"""
|
||||
from torch._prims_common import canonicalize_dims
|
||||
|
||||
assert isinstance(
|
||||
dims, (tuple, list)
|
||||
), f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}"
|
||||
assert isinstance(dims, (tuple, list)), (
|
||||
f"_wrap_jagged_dims(): cannot iterate over dimensions of type {type(dims)}"
|
||||
)
|
||||
|
||||
wrapped_dims = [
|
||||
canonicalize_dims(ndim, d) for d in dims
|
||||
@ -535,9 +535,9 @@ def clone_default(func, *args, **kwargs):
|
||||
from .nested_tensor import jagged_from_list
|
||||
|
||||
# TODO: We probably want the output to have the same ragged structure / nested int.
|
||||
assert (
|
||||
inp._ragged_idx == 1
|
||||
), "NJT with ragged_idx != 1 not supported for contiguous clone"
|
||||
assert inp._ragged_idx == 1, (
|
||||
"NJT with ragged_idx != 1 not supported for contiguous clone"
|
||||
)
|
||||
contig, _ = jagged_from_list(inp.unbind(), offsets=None)
|
||||
return contig
|
||||
|
||||
@ -1730,8 +1730,8 @@ def native_layer_norm_default(func, *args, **kwargs):
|
||||
) # a sum over (1, 2) ensures layer norm, whereas a sum over (1) would be an instance norm
|
||||
|
||||
padded_normalized = (
|
||||
padded_input - mean
|
||||
) * padded_mask # mask elements outside of the ragged dimension size for correct variance calculation
|
||||
(padded_input - mean) * padded_mask
|
||||
) # mask elements outside of the ragged dimension size for correct variance calculation
|
||||
|
||||
variance = (
|
||||
torch.sum(
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""This module contains functions and classes that alter the behavior of torch.nn.functional.scaled_dot_product_attention"""
|
||||
|
||||
import contextlib
|
||||
from collections.abc import Iterable
|
||||
from typing import Union
|
||||
@ -119,6 +120,7 @@ def sdpa_kernel(
|
||||
|
||||
from torch.nn.functional import scaled_dot_product_attention
|
||||
from torch.nn.attention import SDPBackend, sdpa_kernel
|
||||
|
||||
# Only enable flash attention backend
|
||||
with sdpa_kernel(SDPBackend.FLASH_ATTENTION):
|
||||
scaled_dot_product_attention(...)
|
||||
@ -130,9 +132,9 @@ def sdpa_kernel(
|
||||
This context manager can be used to select which backend to use for scaled dot product attention.
|
||||
Upon exiting the context manager, the previous state of the flags will be restored, enabling all backends.
|
||||
"""
|
||||
assert isinstance(
|
||||
backends, (list, SDPBackend)
|
||||
), "Backend must be an instance of SDPBackend or a list of SDPBackend instances"
|
||||
assert isinstance(backends, (list, SDPBackend)), (
|
||||
"Backend must be an instance of SDPBackend or a list of SDPBackend instances"
|
||||
)
|
||||
|
||||
if isinstance(backends, SDPBackend):
|
||||
backends = [backends]
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Defines utilities for interacting with scaled_dot_product_attention"""
|
||||
|
||||
import math
|
||||
from typing import Optional
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Defines bias subclasses that work with scaled_dot_product_attention"""
|
||||
|
||||
from enum import auto, IntEnum
|
||||
from typing import Optional
|
||||
from warnings import warn
|
||||
@ -101,9 +102,15 @@ class CausalBias(torch.Tensor):
|
||||
# Create a lower-right causal bias
|
||||
attn_bias = causal_lower_right(seqlen_q, seqlen_kv)
|
||||
|
||||
q = torch.randn(bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16)
|
||||
k = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
|
||||
v = torch.randn(bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16)
|
||||
q = torch.randn(
|
||||
bsz, num_heads, seqlen_q, head_dim, device="cuda", dtype=torch.float16
|
||||
)
|
||||
k = torch.randn(
|
||||
bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16
|
||||
)
|
||||
v = torch.randn(
|
||||
bsz, num_heads, seqlen_kv, head_dim, device="cuda", dtype=torch.float16
|
||||
)
|
||||
|
||||
out = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
||||
|
||||
|
@ -182,9 +182,7 @@ class PagedAttention:
|
||||
logical_block_offset = input_pos % self.page_size # [B, S]
|
||||
physical_block_idx = torch.gather(
|
||||
self.page_table[batch_idx], 1, logical_block_idx.to(torch.int64)
|
||||
).to(
|
||||
torch.int32
|
||||
) # [B, S]
|
||||
).to(torch.int32) # [B, S]
|
||||
|
||||
addr = (physical_block_idx * self.page_size + logical_block_offset).view(
|
||||
-1
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-defs
|
||||
# flake8: noqa: B950
|
||||
"""This module implements the user facing API for flex_attention in PyTorch."""
|
||||
|
||||
import functools
|
||||
import inspect
|
||||
import itertools
|
||||
@ -293,12 +294,12 @@ class BlockMask:
|
||||
assert kv_indices is not None, "kv_indices must be provided"
|
||||
assert q_num_blocks is not None, "q_num_blocks must be provided"
|
||||
assert q_indices is not None, "q_indices must be provided"
|
||||
assert (full_kv_num_blocks is None) == (
|
||||
full_kv_indices is None
|
||||
), "full_kv_num_blocks and full_kv_indices must be both provided or omitted"
|
||||
assert (full_q_num_blocks is None) == (
|
||||
full_q_indices is None
|
||||
), "full_q_num_blocks and full_q_indices must be both provided or omitted"
|
||||
assert (full_kv_num_blocks is None) == (full_kv_indices is None), (
|
||||
"full_kv_num_blocks and full_kv_indices must be both provided or omitted"
|
||||
)
|
||||
assert (full_q_num_blocks is None) == (full_q_indices is None), (
|
||||
"full_q_num_blocks and full_q_indices must be both provided or omitted"
|
||||
)
|
||||
|
||||
self.seq_lengths = seq_lengths
|
||||
self.kv_num_blocks = kv_num_blocks
|
||||
@ -344,9 +345,9 @@ class BlockMask:
|
||||
if kv_indices.dim() < 2:
|
||||
raise RuntimeError("BlockMask must have at least 2 dimensions")
|
||||
|
||||
assert (full_kv_num_blocks is None) == (
|
||||
full_kv_indices is None
|
||||
), "full_kv_num_blocks and full_kv_indices must be both provided or omitted"
|
||||
assert (full_kv_num_blocks is None) == (full_kv_indices is None), (
|
||||
"full_kv_num_blocks and full_kv_indices must be both provided or omitted"
|
||||
)
|
||||
|
||||
# Generate q_num_blocks and q_indices
|
||||
q_num_blocks, q_indices = _transpose_ordered(kv_num_blocks, kv_indices)
|
||||
@ -434,7 +435,10 @@ class BlockMask:
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 4, 2, 512, 512, device="cuda")
|
||||
|
||||
block_mask = create_block_mask(
|
||||
causal_mask, 4, 2, 512, 512, device="cuda"
|
||||
)
|
||||
assert block_mask.kv_num_blocks.shape == (4, 2, 4)
|
||||
assert block_mask.kv_indices.shape == (4, 2, 4, 4)
|
||||
|
||||
@ -454,7 +458,9 @@ class BlockMask:
|
||||
assert new_block_mask.kv_indices.shape == (2, 1, 4, 4)
|
||||
|
||||
# slicing on batch, head, and query dimension
|
||||
new_block_mask = block_mask[0:2, 1:2, torch.tensor([1], dtype=torch.int32)]
|
||||
new_block_mask = block_mask[
|
||||
0:2, 1:2, torch.tensor([1], dtype=torch.int32)
|
||||
]
|
||||
assert new_block_mask.kv_num_blocks.shape == (2, 1, 1)
|
||||
assert new_block_mask.kv_indices.shape == (2, 1, 1, 4)
|
||||
"""
|
||||
@ -857,6 +863,7 @@ def create_block_mask(
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
block_mask = create_block_mask(causal_mask, 1, 1, 8192, 8192, device="cuda")
|
||||
query = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
|
||||
key = torch.randn(1, 1, 8192, 64, device="cuda", dtype=torch.float16)
|
||||
@ -864,9 +871,9 @@ def create_block_mask(
|
||||
output = flex_attention(query, key, value, block_mask=block_mask)
|
||||
"""
|
||||
mod_type = _get_mod_type(mask_mod)
|
||||
assert (
|
||||
mod_type == _ModificationType.MASK_MOD
|
||||
), f"create-block_mask requires a mask_mod function! Got {mask_mod}"
|
||||
assert mod_type == _ModificationType.MASK_MOD, (
|
||||
f"create-block_mask requires a mask_mod function! Got {mask_mod}"
|
||||
)
|
||||
if B is None:
|
||||
B = 1
|
||||
if H is None:
|
||||
@ -962,7 +969,10 @@ def _nested_mod_func_adapter(
|
||||
kv_seq_idx = q_seq_idx
|
||||
else:
|
||||
# cross attention case
|
||||
kv_seq_idx = _build_seq_idx(kv_offsets, kv_nt._values.shape[kv_nt._ragged_idx - 1]) # type: ignore[attr-defined]
|
||||
kv_seq_idx = _build_seq_idx(
|
||||
kv_offsets,
|
||||
kv_nt._values.shape[kv_nt._ragged_idx - 1], # type: ignore[attr-defined]
|
||||
)
|
||||
|
||||
# Converts q_idx / kv_idx from [0, total_length) -> [0, S), where S refers
|
||||
# to the sequence length for each sequence in the NJT, for use in given
|
||||
@ -1039,10 +1049,14 @@ def create_nested_block_mask(
|
||||
key = torch.nested.nested_tensor(..., layout=torch.jagged)
|
||||
value = torch.nested.nested_tensor(..., layout=torch.jagged)
|
||||
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, _compile=True)
|
||||
|
||||
block_mask = create_nested_block_mask(
|
||||
causal_mask, 1, 1, query, _compile=True
|
||||
)
|
||||
output = flex_attention(query, key, value, block_mask=block_mask)
|
||||
|
||||
.. code-block:: python
|
||||
@ -1052,11 +1066,15 @@ def create_nested_block_mask(
|
||||
key = torch.nested.nested_tensor(..., layout=torch.jagged)
|
||||
value = torch.nested.nested_tensor(..., layout=torch.jagged)
|
||||
|
||||
|
||||
def causal_mask(b, h, q_idx, kv_idx):
|
||||
return q_idx >= kv_idx
|
||||
|
||||
|
||||
# cross attention case: pass both query and key/value NJTs
|
||||
block_mask = create_nested_block_mask(causal_mask, 1, 1, query, key, _compile=True)
|
||||
block_mask = create_nested_block_mask(
|
||||
causal_mask, 1, 1, query, key, _compile=True
|
||||
)
|
||||
output = flex_attention(query, key, value, block_mask=block_mask)
|
||||
"""
|
||||
# use same structure for kv as for q by default
|
||||
@ -1381,7 +1399,13 @@ def flex_attention(
|
||||
torch._dynamo.mark_static(x, -1)
|
||||
|
||||
out, lse = flex_attention_hop(
|
||||
query, key, value, score_mod, block_mask.as_tuple(), scale, kernel_options # type: ignore[union-attr]
|
||||
query,
|
||||
key,
|
||||
value,
|
||||
score_mod,
|
||||
block_mask.as_tuple(),
|
||||
scale,
|
||||
kernel_options, # type: ignore[union-attr]
|
||||
)
|
||||
if return_lse:
|
||||
return out, lse * math.log(2)
|
||||
|
@ -55,9 +55,7 @@ Note:
|
||||
|
||||
Note:
|
||||
This operator supports complex data types i.e. ``complex32, complex64, complex128``.
|
||||
""".format(
|
||||
**reproducibility_notes, **tf32_notes
|
||||
)
|
||||
""".format(**reproducibility_notes, **tf32_notes)
|
||||
+ r"""
|
||||
|
||||
Args:
|
||||
@ -106,9 +104,7 @@ Note:
|
||||
|
||||
Note:
|
||||
This operator supports complex data types i.e. ``complex32, complex64, complex128``.
|
||||
""".format(
|
||||
**reproducibility_notes, **tf32_notes
|
||||
)
|
||||
""".format(**reproducibility_notes, **tf32_notes)
|
||||
+ r"""
|
||||
|
||||
Args:
|
||||
@ -159,9 +155,7 @@ Note:
|
||||
|
||||
Note:
|
||||
This operator supports complex data types i.e. ``complex32, complex64, complex128``.
|
||||
""".format(
|
||||
**reproducibility_notes, **tf32_notes
|
||||
)
|
||||
""".format(**reproducibility_notes, **tf32_notes)
|
||||
+ r"""
|
||||
|
||||
Args:
|
||||
@ -208,9 +202,7 @@ See :class:`~torch.nn.ConvTranspose1d` for details and output shape.
|
||||
|
||||
Note:
|
||||
{cudnn_reproducibility_note}
|
||||
""".format(
|
||||
**reproducibility_notes, **tf32_notes
|
||||
)
|
||||
""".format(**reproducibility_notes, **tf32_notes)
|
||||
+ r"""
|
||||
|
||||
Args:
|
||||
@ -251,9 +243,7 @@ See :class:`~torch.nn.ConvTranspose2d` for details and output shape.
|
||||
|
||||
Note:
|
||||
{cudnn_reproducibility_note}
|
||||
""".format(
|
||||
**reproducibility_notes, **tf32_notes
|
||||
)
|
||||
""".format(**reproducibility_notes, **tf32_notes)
|
||||
+ r"""
|
||||
|
||||
Args:
|
||||
@ -296,9 +286,7 @@ See :class:`~torch.nn.ConvTranspose3d` for details and output shape.
|
||||
|
||||
Note:
|
||||
{cudnn_reproducibility_note}
|
||||
""".format(
|
||||
**reproducibility_notes, **tf32_notes
|
||||
)
|
||||
""".format(**reproducibility_notes, **tf32_notes)
|
||||
+ r"""
|
||||
|
||||
Args:
|
||||
@ -2335,9 +2323,7 @@ Shape:
|
||||
- Weight: :math:`(out\_features, in\_features)` or :math:`(in\_features)`
|
||||
- Bias: :math:`(out\_features)` or :math:`()`
|
||||
- Output: :math:`(*, out\_features)` or :math:`(*)`, based on the shape of the weight
|
||||
""".format(
|
||||
**sparse_support_notes
|
||||
),
|
||||
""".format(**sparse_support_notes),
|
||||
)
|
||||
|
||||
|
||||
@ -2535,13 +2521,13 @@ def embedding(
|
||||
)
|
||||
if padding_idx is not None:
|
||||
if padding_idx > 0:
|
||||
assert padding_idx < weight.size(
|
||||
0
|
||||
), "Padding_idx must be within num_embeddings"
|
||||
assert padding_idx < weight.size(0), (
|
||||
"Padding_idx must be within num_embeddings"
|
||||
)
|
||||
elif padding_idx < 0:
|
||||
assert padding_idx >= -weight.size(
|
||||
0
|
||||
), "Padding_idx must be within num_embeddings"
|
||||
assert padding_idx >= -weight.size(0), (
|
||||
"Padding_idx must be within num_embeddings"
|
||||
)
|
||||
padding_idx = weight.size(0) + padding_idx
|
||||
else:
|
||||
padding_idx = -1
|
||||
@ -5800,15 +5786,15 @@ def _in_projection(
|
||||
Eq,
|
||||
Ev,
|
||||
), f"expecting value weights shape of {(Eq, Ev)}, but got {w_v.shape}"
|
||||
assert b_q is None or b_q.shape == (
|
||||
Eq,
|
||||
), f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
|
||||
assert b_k is None or b_k.shape == (
|
||||
Eq,
|
||||
), f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
|
||||
assert b_v is None or b_v.shape == (
|
||||
Eq,
|
||||
), f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
|
||||
assert b_q is None or b_q.shape == (Eq,), (
|
||||
f"expecting query bias shape of {(Eq,)}, but got {b_q.shape}"
|
||||
)
|
||||
assert b_k is None or b_k.shape == (Eq,), (
|
||||
f"expecting key bias shape of {(Eq,)}, but got {b_k.shape}"
|
||||
)
|
||||
assert b_v is None or b_v.shape == (Eq,), (
|
||||
f"expecting value bias shape of {(Eq,)}, but got {b_v.shape}"
|
||||
)
|
||||
return linear(q, w_q, b_q), linear(k, w_k, b_k), linear(v, w_v, b_v)
|
||||
|
||||
|
||||
@ -5914,9 +5900,7 @@ scaled_dot_product_attention = _add_docstr(
|
||||
Note:
|
||||
|
||||
{cudnn_reproducibility_note}
|
||||
""".format(
|
||||
**reproducibility_notes
|
||||
)
|
||||
""".format(**reproducibility_notes)
|
||||
+ r"""
|
||||
Args:
|
||||
query (Tensor): Query tensor; shape :math:`(N, ..., Hq, L, E)`.
|
||||
@ -6026,9 +6010,9 @@ def _mha_shape_check(
|
||||
)
|
||||
if attn_mask.dim() == 3:
|
||||
expected_shape = (num_heads, query.shape[0], key.shape[0])
|
||||
assert (
|
||||
attn_mask.shape == expected_shape
|
||||
), f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}"
|
||||
assert attn_mask.shape == expected_shape, (
|
||||
f"Expected `attn_mask` shape to be {expected_shape} but got {attn_mask.shape}"
|
||||
)
|
||||
else:
|
||||
raise AssertionError(
|
||||
f"query should be unbatched 2D or batched 3D tensor but received {query.dim()}-D query tensor"
|
||||
@ -6289,45 +6273,45 @@ def multi_head_attention_forward(
|
||||
# longer causal.
|
||||
is_causal = False
|
||||
|
||||
assert (
|
||||
embed_dim == embed_dim_to_check
|
||||
), f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||
assert embed_dim == embed_dim_to_check, (
|
||||
f"was expecting embedding dimension of {embed_dim_to_check}, but got {embed_dim}"
|
||||
)
|
||||
if isinstance(embed_dim, torch.Tensor):
|
||||
# embed_dim can be a tensor when JIT tracing
|
||||
head_dim = embed_dim.div(num_heads, rounding_mode="trunc")
|
||||
else:
|
||||
head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
head_dim * num_heads == embed_dim
|
||||
), f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||
assert head_dim * num_heads == embed_dim, (
|
||||
f"embed_dim {embed_dim} not divisible by num_heads {num_heads}"
|
||||
)
|
||||
if use_separate_proj_weight:
|
||||
# allow MHA to have different embedding dimensions when separate projection weights are used
|
||||
assert (
|
||||
key.shape[:2] == value.shape[:2]
|
||||
), f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||
assert key.shape[:2] == value.shape[:2], (
|
||||
f"key's sequence and batch dims {key.shape[:2]} do not match value's {value.shape[:2]}"
|
||||
)
|
||||
else:
|
||||
assert (
|
||||
key.shape == value.shape
|
||||
), f"key shape {key.shape} does not match value shape {value.shape}"
|
||||
assert key.shape == value.shape, (
|
||||
f"key shape {key.shape} does not match value shape {value.shape}"
|
||||
)
|
||||
|
||||
#
|
||||
# compute in-projection
|
||||
#
|
||||
if not use_separate_proj_weight:
|
||||
assert (
|
||||
in_proj_weight is not None
|
||||
), "use_separate_proj_weight is False but in_proj_weight is None"
|
||||
assert in_proj_weight is not None, (
|
||||
"use_separate_proj_weight is False but in_proj_weight is None"
|
||||
)
|
||||
q, k, v = _in_projection_packed(query, key, value, in_proj_weight, in_proj_bias)
|
||||
else:
|
||||
assert (
|
||||
q_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but q_proj_weight is None"
|
||||
assert (
|
||||
k_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but k_proj_weight is None"
|
||||
assert (
|
||||
v_proj_weight is not None
|
||||
), "use_separate_proj_weight is True but v_proj_weight is None"
|
||||
assert q_proj_weight is not None, (
|
||||
"use_separate_proj_weight is True but q_proj_weight is None"
|
||||
)
|
||||
assert k_proj_weight is not None, (
|
||||
"use_separate_proj_weight is True but k_proj_weight is None"
|
||||
)
|
||||
assert v_proj_weight is not None, (
|
||||
"use_separate_proj_weight is True but v_proj_weight is None"
|
||||
)
|
||||
if in_proj_bias is None:
|
||||
b_q = b_k = b_v = None
|
||||
else:
|
||||
@ -6388,23 +6372,23 @@ def multi_head_attention_forward(
|
||||
k = k.view(k.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
assert (
|
||||
static_k.size(0) == bsz * num_heads
|
||||
), f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||
assert (
|
||||
static_k.size(2) == head_dim
|
||||
), f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||
assert static_k.size(0) == bsz * num_heads, (
|
||||
f"expecting static_k.size(0) of {bsz * num_heads}, but got {static_k.size(0)}"
|
||||
)
|
||||
assert static_k.size(2) == head_dim, (
|
||||
f"expecting static_k.size(2) of {head_dim}, but got {static_k.size(2)}"
|
||||
)
|
||||
k = static_k
|
||||
if static_v is None:
|
||||
v = v.view(v.shape[0], bsz * num_heads, head_dim).transpose(0, 1)
|
||||
else:
|
||||
# TODO finish disentangling control flow so we don't do in-projections when statics are passed
|
||||
assert (
|
||||
static_v.size(0) == bsz * num_heads
|
||||
), f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||
assert (
|
||||
static_v.size(2) == head_dim
|
||||
), f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||
assert static_v.size(0) == bsz * num_heads, (
|
||||
f"expecting static_v.size(0) of {bsz * num_heads}, but got {static_v.size(0)}"
|
||||
)
|
||||
assert static_v.size(2) == head_dim, (
|
||||
f"expecting static_v.size(2) of {head_dim}, but got {static_v.size(2)}"
|
||||
)
|
||||
v = static_v
|
||||
|
||||
# add zero attention along batch dimension (now first)
|
||||
@ -6451,9 +6435,9 @@ def multi_head_attention_forward(
|
||||
_B, _Nt, E = q.shape
|
||||
q_scaled = q * math.sqrt(1.0 / float(E))
|
||||
|
||||
assert not (
|
||||
is_causal and attn_mask is None
|
||||
), "FIXME: is_causal not implemented for need_weights"
|
||||
assert not (is_causal and attn_mask is None), (
|
||||
"FIXME: is_causal not implemented for need_weights"
|
||||
)
|
||||
|
||||
if attn_mask is not None:
|
||||
attn_output_weights = torch.baddbmm(
|
||||
|
@ -168,7 +168,9 @@ def calculate_gain(
|
||||
param: optional parameter for the non-linear function
|
||||
|
||||
Examples:
|
||||
>>> gain = nn.init.calculate_gain('leaky_relu', 0.2) # leaky_relu with negative_slope=0.2
|
||||
>>> gain = nn.init.calculate_gain(
|
||||
... "leaky_relu", 0.2
|
||||
... ) # leaky_relu with negative_slope=0.2
|
||||
|
||||
.. _Self-Normalizing Neural Networks: https://papers.nips.cc/paper/2017/hash/5d44ee6f2c3f71b73125876103c8f6c4-Abstract.html
|
||||
"""
|
||||
@ -456,7 +458,7 @@ def xavier_uniform_(
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain('relu'))
|
||||
>>> nn.init.xavier_uniform_(w, gain=nn.init.calculate_gain("relu"))
|
||||
|
||||
Note:
|
||||
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
|
||||
@ -555,7 +557,7 @@ def kaiming_uniform_(
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.kaiming_uniform_(w, mode='fan_in', nonlinearity='relu')
|
||||
>>> nn.init.kaiming_uniform_(w, mode="fan_in", nonlinearity="relu")
|
||||
|
||||
Note:
|
||||
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
|
||||
@ -620,7 +622,7 @@ def kaiming_normal_(
|
||||
|
||||
Examples:
|
||||
>>> w = torch.empty(3, 5)
|
||||
>>> nn.init.kaiming_normal_(w, mode='fan_out', nonlinearity='relu')
|
||||
>>> nn.init.kaiming_normal_(w, mode="fan_out", nonlinearity="relu")
|
||||
|
||||
Note:
|
||||
Be aware that ``fan_in`` and ``fan_out`` are calculated assuming
|
||||
|
@ -1077,9 +1077,9 @@ class MultiheadAttention(Module):
|
||||
self.dropout = dropout
|
||||
self.batch_first = batch_first
|
||||
self.head_dim = embed_dim // num_heads
|
||||
assert (
|
||||
self.head_dim * num_heads == self.embed_dim
|
||||
), "embed_dim must be divisible by num_heads"
|
||||
assert self.head_dim * num_heads == self.embed_dim, (
|
||||
"embed_dim must be divisible by num_heads"
|
||||
)
|
||||
|
||||
if not self._qkv_same_embed_dim:
|
||||
self.q_proj_weight = Parameter(
|
||||
@ -1276,8 +1276,10 @@ class MultiheadAttention(Module):
|
||||
elif query.is_nested and (
|
||||
key_padding_mask is not None or attn_mask is not None
|
||||
):
|
||||
why_not_fast_path = "supplying both src_key_padding_mask and src_mask at the same time \
|
||||
why_not_fast_path = (
|
||||
"supplying both src_key_padding_mask and src_mask at the same time \
|
||||
is not supported with NestedTensor input"
|
||||
)
|
||||
elif torch.is_autocast_enabled():
|
||||
why_not_fast_path = "autocast is enabled"
|
||||
|
||||
|
@ -18,13 +18,15 @@ _ASMoutput = namedtuple("_ASMoutput", ["output", "loss"])
|
||||
|
||||
|
||||
class AdaptiveLogSoftmaxWithLoss(Module):
|
||||
(
|
||||
"""Efficient softmax approximation.
|
||||
|
||||
As described in
|
||||
`Efficient softmax approximation for GPUs by Edouard Grave, Armand Joulin,
|
||||
Moustapha Ciss\u00e9, David Grangier, and Herv\u00e9 J\u00e9gou
|
||||
<https://arxiv.org/abs/1609.04309>`__.
|
||||
""" r"""
|
||||
"""
|
||||
r"""
|
||||
Adaptive softmax is an approximate strategy for training models with large
|
||||
output spaces. It is most effective when the label distribution is highly
|
||||
imbalanced, for example in natural language modelling, where the word
|
||||
@ -104,6 +106,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
|
||||
|
||||
.. _Zipf's law: https://en.wikipedia.org/wiki/Zipf%27s_law
|
||||
"""
|
||||
)
|
||||
|
||||
in_features: int
|
||||
n_classes: int
|
||||
@ -182,8 +185,7 @@ class AdaptiveLogSoftmaxWithLoss(Module):
|
||||
if targ_dim == 1:
|
||||
if input_.size(0) != target_.size(0):
|
||||
raise RuntimeError(
|
||||
"Input and target should have the same size "
|
||||
"in the batch dimension."
|
||||
"Input and target should have the same size in the batch dimension."
|
||||
)
|
||||
if input_.dim() != 2:
|
||||
raise RuntimeError(
|
||||
|
@ -86,31 +86,30 @@ class Sequential(Module):
|
||||
# for `Conv2d(20,64,5)`. Finally, the output of
|
||||
# `Conv2d(20,64,5)` will be used as input to the second `ReLU`
|
||||
model = nn.Sequential(
|
||||
nn.Conv2d(1,20,5),
|
||||
nn.ReLU(),
|
||||
nn.Conv2d(20,64,5),
|
||||
nn.ReLU()
|
||||
nn.Conv2d(1, 20, 5), nn.ReLU(), nn.Conv2d(20, 64, 5), nn.ReLU()
|
||||
)
|
||||
|
||||
# Using Sequential with OrderedDict. This is functionally the
|
||||
# same as the above code
|
||||
model = nn.Sequential(OrderedDict([
|
||||
('conv1', nn.Conv2d(1,20,5)),
|
||||
('relu1', nn.ReLU()),
|
||||
('conv2', nn.Conv2d(20,64,5)),
|
||||
('relu2', nn.ReLU())
|
||||
]))
|
||||
model = nn.Sequential(
|
||||
OrderedDict(
|
||||
[
|
||||
("conv1", nn.Conv2d(1, 20, 5)),
|
||||
("relu1", nn.ReLU()),
|
||||
("conv2", nn.Conv2d(20, 64, 5)),
|
||||
("relu2", nn.ReLU()),
|
||||
]
|
||||
)
|
||||
)
|
||||
"""
|
||||
|
||||
_modules: dict[str, Module] # type: ignore[assignment]
|
||||
|
||||
@overload
|
||||
def __init__(self, *args: Module) -> None:
|
||||
...
|
||||
def __init__(self, *args: Module) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(self, arg: OrderedDict[str, Module]) -> None:
|
||||
...
|
||||
def __init__(self, arg: OrderedDict[str, Module]) -> None: ...
|
||||
|
||||
def __init__(self, *args):
|
||||
super().__init__()
|
||||
@ -365,12 +364,10 @@ class ModuleList(Module):
|
||||
return str(idx)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, idx: slice) -> ModuleList:
|
||||
...
|
||||
def __getitem__(self, idx: slice) -> ModuleList: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self, idx: int) -> Module:
|
||||
...
|
||||
def __getitem__(self, idx: int) -> Module: ...
|
||||
|
||||
@_copy_to_script_wrapper
|
||||
def __getitem__(self, idx: Union[int, slice]) -> Union[Module, ModuleList]:
|
||||
@ -521,14 +518,12 @@ class ModuleDict(Module):
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.choices = nn.ModuleDict({
|
||||
'conv': nn.Conv2d(10, 10, 3),
|
||||
'pool': nn.MaxPool2d(3)
|
||||
})
|
||||
self.activations = nn.ModuleDict([
|
||||
['lrelu', nn.LeakyReLU()],
|
||||
['prelu', nn.PReLU()]
|
||||
])
|
||||
self.choices = nn.ModuleDict(
|
||||
{"conv": nn.Conv2d(10, 10, 3), "pool": nn.MaxPool2d(3)}
|
||||
)
|
||||
self.activations = nn.ModuleDict(
|
||||
[["lrelu", nn.LeakyReLU()], ["prelu", nn.PReLU()]]
|
||||
)
|
||||
|
||||
def forward(self, x, choice, act):
|
||||
x = self.choices[choice](x)
|
||||
@ -653,7 +648,9 @@ class ParameterList(Module):
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.params = nn.ParameterList([nn.Parameter(torch.randn(10, 10)) for i in range(10)])
|
||||
self.params = nn.ParameterList(
|
||||
[nn.Parameter(torch.randn(10, 10)) for i in range(10)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
# ParameterList can act as an iterable, or be indexed using ints
|
||||
@ -678,12 +675,10 @@ class ParameterList(Module):
|
||||
return str(idx)
|
||||
|
||||
@overload
|
||||
def __getitem__(self, idx: int) -> Any:
|
||||
...
|
||||
def __getitem__(self, idx: int) -> Any: ...
|
||||
|
||||
@overload
|
||||
def __getitem__(self: T, idx: slice) -> T:
|
||||
...
|
||||
def __getitem__(self: T, idx: slice) -> T: ...
|
||||
|
||||
def __getitem__(self, idx):
|
||||
if isinstance(idx, slice):
|
||||
@ -805,10 +800,12 @@ class ParameterDict(Module):
|
||||
class MyModule(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
self.params = nn.ParameterDict({
|
||||
'left': nn.Parameter(torch.randn(5, 10)),
|
||||
'right': nn.Parameter(torch.randn(5, 10))
|
||||
})
|
||||
self.params = nn.ParameterDict(
|
||||
{
|
||||
"left": nn.Parameter(torch.randn(5, 10)),
|
||||
"right": nn.Parameter(torch.randn(5, 10)),
|
||||
}
|
||||
)
|
||||
|
||||
def forward(self, x, choice):
|
||||
x = self.params[choice].mm(x)
|
||||
|
@ -66,8 +66,9 @@ class _ConvNd(Module):
|
||||
]
|
||||
__annotations__ = {"bias": Optional[torch.Tensor]}
|
||||
|
||||
def _conv_forward(self, input: Tensor, weight: Tensor, bias: Optional[Tensor]) -> Tensor: # type: ignore[empty-body]
|
||||
...
|
||||
def _conv_forward( # type: ignore[empty-body]
|
||||
self, input: Tensor, weight: Tensor, bias: Optional[Tensor]
|
||||
) -> Tensor: ...
|
||||
|
||||
in_channels: int
|
||||
_reversed_padding_repeated_twice: list[int]
|
||||
@ -187,10 +188,7 @@ class _ConvNd(Module):
|
||||
init.uniform_(self.bias, -bound, bound)
|
||||
|
||||
def extra_repr(self):
|
||||
s = (
|
||||
"{in_channels}, {out_channels}, kernel_size={kernel_size}"
|
||||
", stride={stride}"
|
||||
)
|
||||
s = "{in_channels}, {out_channels}, kernel_size={kernel_size}, stride={stride}"
|
||||
if self.padding != (0,) * len(self.padding):
|
||||
s += ", padding={padding}"
|
||||
if self.dilation != (1,) * len(self.dilation):
|
||||
@ -279,9 +277,7 @@ class Conv1d(_ConvNd):
|
||||
padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
|
||||
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
|
||||
|
||||
""".format(
|
||||
**reproducibility_notes, **convolution_notes
|
||||
)
|
||||
""".format(**reproducibility_notes, **convolution_notes)
|
||||
+ r"""
|
||||
|
||||
Shape:
|
||||
@ -450,9 +446,7 @@ class Conv2d(_ConvNd):
|
||||
output. Default: ``True``
|
||||
padding_mode (str, optional): ``'zeros'``, ``'reflect'``,
|
||||
``'replicate'`` or ``'circular'``. Default: ``'zeros'``
|
||||
""".format(
|
||||
**reproducibility_notes, **convolution_notes
|
||||
)
|
||||
""".format(**reproducibility_notes, **convolution_notes)
|
||||
+ r"""
|
||||
|
||||
Shape:
|
||||
@ -619,9 +613,7 @@ class Conv3d(_ConvNd):
|
||||
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
||||
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
|
||||
padding_mode (str, optional): ``'zeros'``, ``'reflect'``, ``'replicate'`` or ``'circular'``. Default: ``'zeros'``
|
||||
""".format(
|
||||
**reproducibility_notes, **convolution_notes
|
||||
)
|
||||
""".format(**reproducibility_notes, **convolution_notes)
|
||||
+ r"""
|
||||
|
||||
Shape:
|
||||
@ -883,9 +875,7 @@ class ConvTranspose1d(_ConvTransposeNd):
|
||||
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
||||
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
|
||||
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
||||
""".format(
|
||||
**reproducibility_notes, **convolution_notes
|
||||
)
|
||||
""".format(**reproducibility_notes, **convolution_notes)
|
||||
+ r"""
|
||||
|
||||
Shape:
|
||||
@ -1051,9 +1041,7 @@ class ConvTranspose2d(_ConvTransposeNd):
|
||||
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
||||
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
|
||||
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
||||
""".format(
|
||||
**reproducibility_notes, **convolution_notes
|
||||
)
|
||||
""".format(**reproducibility_notes, **convolution_notes)
|
||||
+ r"""
|
||||
|
||||
Shape:
|
||||
@ -1249,9 +1237,7 @@ class ConvTranspose3d(_ConvTransposeNd):
|
||||
groups (int, optional): Number of blocked connections from input channels to output channels. Default: 1
|
||||
bias (bool, optional): If ``True``, adds a learnable bias to the output. Default: ``True``
|
||||
dilation (int or tuple, optional): Spacing between kernel elements. Default: 1
|
||||
""".format(
|
||||
**reproducibility_notes, **convolution_notes
|
||||
)
|
||||
""".format(**reproducibility_notes, **convolution_notes)
|
||||
+ r"""
|
||||
|
||||
Shape:
|
||||
|
@ -96,8 +96,8 @@ class Unflatten(Module):
|
||||
>>> output.size()
|
||||
torch.Size([2, 2, 5, 5])
|
||||
>>> # With namedshape (tuple of tuples)
|
||||
>>> input = torch.randn(2, 50, names=('N', 'features'))
|
||||
>>> unflatten = nn.Unflatten('features', (('C', 2), ('H', 5), ('W', 5)))
|
||||
>>> input = torch.randn(2, 50, names=("N", "features"))
|
||||
>>> unflatten = nn.Unflatten("features", (("C", 2), ("H", 5), ("W", 5)))
|
||||
>>> output = unflatten(input)
|
||||
>>> output.size()
|
||||
torch.Size([2, 2, 5, 5])
|
||||
|
@ -9,6 +9,7 @@ __all__ = ["Fold", "Unfold"]
|
||||
|
||||
|
||||
class Fold(Module):
|
||||
(
|
||||
r"""Combines an array of sliding local blocks into a large containing tensor.
|
||||
|
||||
Consider a batched :attr:`input` tensor containing sliding local blocks,
|
||||
@ -42,10 +43,12 @@ class Fold(Module):
|
||||
* :attr:`padding` controls the amount of implicit zero-paddings on both
|
||||
sides for :attr:`padding` number of points for each dimension before
|
||||
reshaping.
|
||||
""" """
|
||||
"""
|
||||
"""
|
||||
* :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
|
||||
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
||||
""" r"""
|
||||
"""
|
||||
r"""
|
||||
Args:
|
||||
output_size (int or tuple): the shape of the spatial dimensions of the
|
||||
output (i.e., ``output.sizes()[2:]``)
|
||||
@ -119,6 +122,7 @@ class Fold(Module):
|
||||
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
__constants__ = ["output_size", "kernel_size", "dilation", "padding", "stride"]
|
||||
output_size: _size_any_t
|
||||
@ -162,6 +166,7 @@ class Fold(Module):
|
||||
|
||||
|
||||
class Unfold(Module):
|
||||
(
|
||||
r"""Extracts sliding local blocks from a batched input tensor.
|
||||
|
||||
Consider a batched :attr:`input` tensor of shape :math:`(N, C, *)`,
|
||||
@ -194,10 +199,12 @@ class Unfold(Module):
|
||||
* :attr:`padding` controls the amount of implicit zero-paddings on both
|
||||
sides for :attr:`padding` number of points for each dimension before
|
||||
reshaping.
|
||||
""" """
|
||||
"""
|
||||
"""
|
||||
* :attr:`dilation` controls the spacing between the kernel points; also known as the \u00e0 trous algorithm.
|
||||
It is harder to describe, but this `link`_ has a nice visualization of what :attr:`dilation` does.
|
||||
""" r"""
|
||||
"""
|
||||
r"""
|
||||
Args:
|
||||
kernel_size (int or tuple): the size of the sliding blocks
|
||||
dilation (int or tuple, optional): a parameter that controls the
|
||||
@ -283,6 +290,7 @@ class Unfold(Module):
|
||||
https://github.com/vdumoulin/conv_arithmetic/blob/master/README.md
|
||||
|
||||
"""
|
||||
)
|
||||
|
||||
__constants__ = ["kernel_size", "dilation", "padding", "stride"]
|
||||
kernel_size: _size_any_t
|
||||
|
@ -15,11 +15,9 @@ class _LazyProtocol(Protocol):
|
||||
https://mypy.readthedocs.io/en/latest/more_types.html#mixin-classes
|
||||
"""
|
||||
|
||||
def _register_load_state_dict_pre_hook(self, hook):
|
||||
...
|
||||
def _register_load_state_dict_pre_hook(self, hook): ...
|
||||
|
||||
def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False):
|
||||
...
|
||||
def register_forward_pre_hook(self, hook, *, prepend=False, with_kwargs=False): ...
|
||||
|
||||
def _lazy_load_hook(
|
||||
self,
|
||||
@ -30,34 +28,26 @@ class _LazyProtocol(Protocol):
|
||||
missing_keys,
|
||||
unexpected_keys,
|
||||
error_msgs,
|
||||
):
|
||||
...
|
||||
): ...
|
||||
|
||||
def _get_name(self):
|
||||
...
|
||||
def _get_name(self): ...
|
||||
|
||||
def _infer_parameters(self, module, input):
|
||||
...
|
||||
def _infer_parameters(self, module, input): ...
|
||||
|
||||
@property
|
||||
def _parameters(self):
|
||||
...
|
||||
def _parameters(self): ...
|
||||
|
||||
@property
|
||||
def _buffers(self):
|
||||
...
|
||||
def _buffers(self): ...
|
||||
|
||||
@property
|
||||
def _non_persistent_buffers_set(self):
|
||||
...
|
||||
def _non_persistent_buffers_set(self): ...
|
||||
|
||||
@property
|
||||
def _load_hook(self):
|
||||
...
|
||||
def _load_hook(self): ...
|
||||
|
||||
@property
|
||||
def _initialize_hook(self):
|
||||
...
|
||||
def _initialize_hook(self): ...
|
||||
|
||||
|
||||
class LazyModuleMixin:
|
||||
|
@ -119,6 +119,7 @@ class L1Loss(_Loss):
|
||||
>>> output = loss(input, target)
|
||||
>>> output.backward()
|
||||
"""
|
||||
|
||||
__constants__ = ["reduction"]
|
||||
|
||||
def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
|
||||
@ -233,6 +234,7 @@ class NLLLoss(_WeightedLoss):
|
||||
>>> loss = loss_fn(output, target)
|
||||
>>> loss.backward()
|
||||
"""
|
||||
|
||||
__constants__ = ["ignore_index", "reduction"]
|
||||
ignore_index: int
|
||||
|
||||
@ -331,6 +333,7 @@ class PoissonNLLLoss(_Loss):
|
||||
- Output: scalar by default. If :attr:`reduction` is ``'none'``, then :math:`(*)`,
|
||||
the same shape as the input.
|
||||
"""
|
||||
|
||||
__constants__ = ["log_input", "full", "eps", "reduction"]
|
||||
log_input: bool
|
||||
full: bool
|
||||
@ -427,6 +430,7 @@ class GaussianNLLLoss(_Loss):
|
||||
Conference on Neural Networks (ICNN'94), Orlando, FL, USA, 1994, pp. 55-60
|
||||
vol.1, doi: 10.1109/ICNN.1994.374138.
|
||||
"""
|
||||
|
||||
__constants__ = ["full", "eps", "reduction"]
|
||||
full: bool
|
||||
eps: float
|
||||
@ -527,6 +531,7 @@ class KLDivLoss(_Loss):
|
||||
>>> log_target = F.log_softmax(torch.rand(3, 5), dim=1)
|
||||
>>> output = kl_loss(input, log_target)
|
||||
"""
|
||||
|
||||
__constants__ = ["reduction"]
|
||||
|
||||
def __init__(
|
||||
@ -601,6 +606,7 @@ class MSELoss(_Loss):
|
||||
>>> output = loss(input, target)
|
||||
>>> output.backward()
|
||||
"""
|
||||
|
||||
__constants__ = ["reduction"]
|
||||
|
||||
def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
|
||||
@ -684,6 +690,7 @@ class BCELoss(_WeightedLoss):
|
||||
>>> output = loss(m(input), target)
|
||||
>>> output.backward()
|
||||
"""
|
||||
|
||||
__constants__ = ["reduction"]
|
||||
|
||||
def __init__(
|
||||
@ -876,6 +883,7 @@ class HingeEmbeddingLoss(_Loss):
|
||||
- Target: :math:`(*)`, same shape as the input
|
||||
- Output: scalar. If :attr:`reduction` is ``'none'``, then same shape as the input
|
||||
"""
|
||||
|
||||
__constants__ = ["margin", "reduction"]
|
||||
margin: float
|
||||
|
||||
@ -950,6 +958,7 @@ class MultiLabelMarginLoss(_Loss):
|
||||
tensor(0.85...)
|
||||
|
||||
"""
|
||||
|
||||
__constants__ = ["reduction"]
|
||||
|
||||
def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
|
||||
@ -1030,6 +1039,7 @@ class SmoothL1Loss(_Loss):
|
||||
- Target: :math:`(*)`, same shape as the input.
|
||||
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input.
|
||||
"""
|
||||
|
||||
__constants__ = ["reduction"]
|
||||
|
||||
def __init__(
|
||||
@ -1092,6 +1102,7 @@ class HuberLoss(_Loss):
|
||||
- Target: :math:`(*)`, same shape as the input.
|
||||
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(*)`, same shape as the input.
|
||||
"""
|
||||
|
||||
__constants__ = ["reduction", "delta"]
|
||||
|
||||
def __init__(self, reduction: str = "mean", delta: float = 1.0) -> None:
|
||||
@ -1134,6 +1145,7 @@ class SoftMarginLoss(_Loss):
|
||||
shape as input.
|
||||
|
||||
"""
|
||||
|
||||
__constants__ = ["reduction"]
|
||||
|
||||
def __init__(self, size_average=None, reduce=None, reduction: str = "mean") -> None:
|
||||
@ -1276,6 +1288,7 @@ class CrossEntropyLoss(_WeightedLoss):
|
||||
>>> output = loss(input, target)
|
||||
>>> output.backward()
|
||||
"""
|
||||
|
||||
__constants__ = ["ignore_index", "reduction", "label_smoothing"]
|
||||
ignore_index: int
|
||||
label_smoothing: float
|
||||
@ -1342,6 +1355,7 @@ class MultiLabelSoftMarginLoss(_WeightedLoss):
|
||||
- Target: :math:`(N, C)`, label targets must have the same shape as the input.
|
||||
- Output: scalar. If :attr:`reduction` is ``'none'``, then :math:`(N)`.
|
||||
"""
|
||||
|
||||
__constants__ = ["reduction"]
|
||||
|
||||
def __init__(
|
||||
@ -1410,6 +1424,7 @@ class CosineEmbeddingLoss(_Loss):
|
||||
>>> output = loss(input1, input2, target)
|
||||
>>> output.backward()
|
||||
"""
|
||||
|
||||
__constants__ = ["margin", "reduction"]
|
||||
margin: float
|
||||
|
||||
@ -1475,6 +1490,7 @@ class MarginRankingLoss(_Loss):
|
||||
>>> output = loss(input1, input2, target)
|
||||
>>> output.backward()
|
||||
"""
|
||||
|
||||
__constants__ = ["margin", "reduction"]
|
||||
margin: float
|
||||
|
||||
@ -1554,6 +1570,7 @@ class MultiMarginLoss(_WeightedLoss):
|
||||
>>> loss(x, y)
|
||||
tensor(0.32...)
|
||||
"""
|
||||
|
||||
__constants__ = ["p", "margin", "reduction"]
|
||||
margin: float
|
||||
p: int
|
||||
@ -1657,6 +1674,7 @@ class TripletMarginLoss(_Loss):
|
||||
.. _Learning shallow convolutional feature descriptors with triplet losses:
|
||||
https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html
|
||||
"""
|
||||
|
||||
__constants__ = ["margin", "p", "eps", "swap", "reduction"]
|
||||
margin: float
|
||||
p: float
|
||||
@ -1794,6 +1812,7 @@ class TripletMarginWithDistanceLoss(_Loss):
|
||||
V. Balntas, et al.: Learning shallow convolutional feature descriptors with triplet losses:
|
||||
https://bmva-archive.org.uk/bmvc/2016/papers/paper119/index.html
|
||||
"""
|
||||
|
||||
__constants__ = ["margin", "swap", "reduction"]
|
||||
margin: float
|
||||
swap: bool
|
||||
@ -1902,7 +1921,12 @@ class CTCLoss(_Loss):
|
||||
>>> target = torch.randint(low=1, high=C, size=(N, S), dtype=torch.long)
|
||||
>>>
|
||||
>>> input_lengths = torch.full(size=(N,), fill_value=T, dtype=torch.long)
|
||||
>>> target_lengths = torch.randint(low=S_min, high=S, size=(N,), dtype=torch.long)
|
||||
>>> target_lengths = torch.randint(
|
||||
... low=S_min,
|
||||
... high=S,
|
||||
... size=(N,),
|
||||
... dtype=torch.long,
|
||||
... )
|
||||
>>> ctc_loss = nn.CTCLoss()
|
||||
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
|
||||
>>> loss.backward()
|
||||
@ -1919,7 +1943,12 @@ class CTCLoss(_Loss):
|
||||
>>>
|
||||
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
|
||||
>>> target_lengths = torch.randint(low=1, high=T, size=(N,), dtype=torch.long)
|
||||
>>> target = torch.randint(low=1, high=C, size=(sum(target_lengths),), dtype=torch.long)
|
||||
>>> target = torch.randint(
|
||||
... low=1,
|
||||
... high=C,
|
||||
... size=(sum(target_lengths),),
|
||||
... dtype=torch.long,
|
||||
... )
|
||||
>>> ctc_loss = nn.CTCLoss()
|
||||
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
|
||||
>>> loss.backward()
|
||||
@ -1936,7 +1965,12 @@ class CTCLoss(_Loss):
|
||||
>>>
|
||||
>>> # Initialize random batch of targets (0 = blank, 1:C = classes)
|
||||
>>> target_lengths = torch.randint(low=1, high=T, size=(), dtype=torch.long)
|
||||
>>> target = torch.randint(low=1, high=C, size=(target_lengths,), dtype=torch.long)
|
||||
>>> target = torch.randint(
|
||||
... low=1,
|
||||
... high=C,
|
||||
... size=(target_lengths,),
|
||||
... dtype=torch.long,
|
||||
... )
|
||||
>>> ctc_loss = nn.CTCLoss()
|
||||
>>> loss = ctc_loss(input, target, input_lengths, target_lengths)
|
||||
>>> loss.backward()
|
||||
@ -1963,6 +1997,7 @@ class CTCLoss(_Loss):
|
||||
True``.
|
||||
Please see the notes on :doc:`/notes/randomness` for background.
|
||||
"""
|
||||
|
||||
__constants__ = ["blank", "reduction"]
|
||||
blank: int
|
||||
zero_infinity: bool
|
||||
|
@ -412,6 +412,7 @@ class Module:
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
|
||||
|
||||
class Model(nn.Module):
|
||||
def __init__(self) -> None:
|
||||
super().__init__()
|
||||
@ -1230,16 +1231,13 @@ class Module:
|
||||
device: Optional[DeviceLikeType] = ...,
|
||||
dtype: Optional[dtype] = ...,
|
||||
non_blocking: bool = ...,
|
||||
) -> Self:
|
||||
...
|
||||
) -> Self: ...
|
||||
|
||||
@overload
|
||||
def to(self, dtype: dtype, non_blocking: bool = ...) -> Self:
|
||||
...
|
||||
def to(self, dtype: dtype, non_blocking: bool = ...) -> Self: ...
|
||||
|
||||
@overload
|
||||
def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self:
|
||||
...
|
||||
def to(self, tensor: Tensor, non_blocking: bool = ...) -> Self: ...
|
||||
|
||||
def to(self, *args, **kwargs):
|
||||
r"""Move and/or cast the parameters and buffers.
|
||||
@ -1752,7 +1750,11 @@ class Module:
|
||||
if recording_scopes:
|
||||
# type ignore was added because at this point one knows that
|
||||
# torch.jit._trace._trace_module_map is not Optional and has type Dict[Any, Any]
|
||||
name = torch.jit._trace._trace_module_map[self] if self in torch.jit._trace._trace_module_map else None # type: ignore[index, operator] # noqa: B950
|
||||
name = (
|
||||
torch.jit._trace._trace_module_map[self] # type: ignore[index]
|
||||
if self in torch.jit._trace._trace_module_map # type: ignore[operator]
|
||||
else None
|
||||
) # noqa: B950
|
||||
if name:
|
||||
tracing_state.push_scope(name)
|
||||
else:
|
||||
@ -2161,13 +2163,20 @@ class Module:
|
||||
|
||||
@overload
|
||||
def state_dict(
|
||||
self, *, destination: T_destination, prefix: str = ..., keep_vars: bool = ...
|
||||
) -> T_destination:
|
||||
...
|
||||
self,
|
||||
*,
|
||||
destination: T_destination,
|
||||
prefix: str = ...,
|
||||
keep_vars: bool = ...,
|
||||
) -> T_destination: ...
|
||||
|
||||
@overload
|
||||
def state_dict(self, *, prefix: str = ..., keep_vars: bool = ...) -> dict[str, Any]:
|
||||
...
|
||||
def state_dict(
|
||||
self,
|
||||
*,
|
||||
prefix: str = ...,
|
||||
keep_vars: bool = ...,
|
||||
) -> dict[str, Any]: ...
|
||||
|
||||
# TODO: Change `*args` to `*` and remove the corresponding warning in docs when BC allows.
|
||||
# Also remove the logic for arg parsing together.
|
||||
|
@ -358,6 +358,7 @@ class RMSNorm(Module):
|
||||
>>> rms_norm(input)
|
||||
|
||||
"""
|
||||
|
||||
__constants__ = ["normalized_shape", "eps", "elementwise_affine"]
|
||||
normalized_shape: tuple[int, ...]
|
||||
eps: Optional[float]
|
||||
|
@ -253,7 +253,8 @@ class RNNBase(Module):
|
||||
# alias would break the assumptions of the uniqueness check in
|
||||
# Module.named_parameters().
|
||||
unique_data_ptrs = {
|
||||
p.data_ptr() for p in self._flat_weights # type: ignore[union-attr]
|
||||
p.data_ptr() # type: ignore[union-attr]
|
||||
for p in self._flat_weights
|
||||
}
|
||||
if len(unique_data_ptrs) != len(self._flat_weights):
|
||||
return
|
||||
@ -611,12 +612,10 @@ class RNN(RNNBase):
|
||||
bidirectional: bool = False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(self, *args, **kwargs):
|
||||
...
|
||||
def __init__(self, *args, **kwargs): ...
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "proj_size" in kwargs:
|
||||
@ -969,12 +968,10 @@ class LSTM(RNNBase):
|
||||
proj_size: int = 0,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(self, *args, **kwargs):
|
||||
...
|
||||
def __init__(self, *args, **kwargs): ...
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__("LSTM", *args, **kwargs)
|
||||
@ -1304,12 +1301,10 @@ class GRU(RNNBase):
|
||||
bidirectional: bool = False,
|
||||
device=None,
|
||||
dtype=None,
|
||||
) -> None:
|
||||
...
|
||||
) -> None: ...
|
||||
|
||||
@overload
|
||||
def __init__(self, *args, **kwargs):
|
||||
...
|
||||
def __init__(self, *args, **kwargs): ...
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
if "proj_size" in kwargs:
|
||||
|
@ -59,9 +59,11 @@ class Embedding(Module):
|
||||
embedding = nn.Embedding(n, d, max_norm=1.0)
|
||||
W = torch.randn((m, d), requires_grad=True)
|
||||
idx = torch.tensor([1, 2])
|
||||
a = embedding.weight.clone() @ W.t() # weight must be cloned for this to be differentiable
|
||||
a = (
|
||||
embedding.weight.clone() @ W.t()
|
||||
) # weight must be cloned for this to be differentiable
|
||||
b = embedding(idx) @ W.t() # modifies weight in-place
|
||||
out = (a.unsqueeze(0) + b.unsqueeze(1))
|
||||
out = a.unsqueeze(0) + b.unsqueeze(1)
|
||||
loss = out.sigmoid().prod()
|
||||
loss.backward()
|
||||
|
||||
@ -150,13 +152,13 @@ class Embedding(Module):
|
||||
self.embedding_dim = embedding_dim
|
||||
if padding_idx is not None:
|
||||
if padding_idx > 0:
|
||||
assert (
|
||||
padding_idx < self.num_embeddings
|
||||
), "Padding_idx must be within num_embeddings"
|
||||
assert padding_idx < self.num_embeddings, (
|
||||
"Padding_idx must be within num_embeddings"
|
||||
)
|
||||
elif padding_idx < 0:
|
||||
assert (
|
||||
padding_idx >= -self.num_embeddings
|
||||
), "Padding_idx must be within num_embeddings"
|
||||
assert padding_idx >= -self.num_embeddings, (
|
||||
"Padding_idx must be within num_embeddings"
|
||||
)
|
||||
padding_idx = self.num_embeddings + padding_idx
|
||||
self.padding_idx = padding_idx
|
||||
self.max_norm = max_norm
|
||||
@ -248,9 +250,9 @@ class Embedding(Module):
|
||||
>>> embedding(input)
|
||||
tensor([[ 4.0000, 5.1000, 6.3000]])
|
||||
"""
|
||||
assert (
|
||||
embeddings.dim() == 2
|
||||
), "Embeddings parameter is expected to be 2-dimensional"
|
||||
assert embeddings.dim() == 2, (
|
||||
"Embeddings parameter is expected to be 2-dimensional"
|
||||
)
|
||||
rows, cols = embeddings.shape
|
||||
embedding = cls(
|
||||
num_embeddings=rows,
|
||||
@ -391,13 +393,13 @@ class EmbeddingBag(Module):
|
||||
self.scale_grad_by_freq = scale_grad_by_freq
|
||||
if padding_idx is not None:
|
||||
if padding_idx > 0:
|
||||
assert (
|
||||
padding_idx < self.num_embeddings
|
||||
), "padding_idx must be within num_embeddings"
|
||||
assert padding_idx < self.num_embeddings, (
|
||||
"padding_idx must be within num_embeddings"
|
||||
)
|
||||
elif padding_idx < 0:
|
||||
assert (
|
||||
padding_idx >= -self.num_embeddings
|
||||
), "padding_idx must be within num_embeddings"
|
||||
assert padding_idx >= -self.num_embeddings, (
|
||||
"padding_idx must be within num_embeddings"
|
||||
)
|
||||
padding_idx = self.num_embeddings + padding_idx
|
||||
self.padding_idx = padding_idx
|
||||
if _weight is None:
|
||||
@ -526,9 +528,9 @@ class EmbeddingBag(Module):
|
||||
>>> embeddingbag(input)
|
||||
tensor([[ 2.5000, 3.7000, 4.6500]])
|
||||
"""
|
||||
assert (
|
||||
embeddings.dim() == 2
|
||||
), "Embeddings parameter is expected to be 2-dimensional"
|
||||
assert embeddings.dim() == 2, (
|
||||
"Embeddings parameter is expected to be 2-dimensional"
|
||||
)
|
||||
rows, cols = embeddings.shape
|
||||
embeddingbag = cls(
|
||||
num_embeddings=rows,
|
||||
|
@ -256,7 +256,9 @@ class Transformer(Module):
|
||||
|
||||
Examples:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> output = transformer_model(src, tgt, src_mask=src_mask, tgt_mask=tgt_mask)
|
||||
>>> output = transformer_model(
|
||||
... src, tgt, src_mask=src_mask, tgt_mask=tgt_mask
|
||||
... )
|
||||
"""
|
||||
is_batched = src.dim() == 3
|
||||
if not self.batch_first and src.size(1) != tgt.size(1) and is_batched:
|
||||
@ -686,7 +688,9 @@ class TransformerEncoderLayer(Module):
|
||||
>>> out = encoder_layer(src)
|
||||
|
||||
Alternatively, when ``batch_first`` is ``True``:
|
||||
>>> encoder_layer = nn.TransformerEncoderLayer(d_model=512, nhead=8, batch_first=True)
|
||||
>>> encoder_layer = nn.TransformerEncoderLayer(
|
||||
... d_model=512, nhead=8, batch_first=True
|
||||
... )
|
||||
>>> src = torch.rand(32, 10, 512)
|
||||
>>> out = encoder_layer(src)
|
||||
|
||||
@ -994,7 +998,9 @@ class TransformerDecoderLayer(Module):
|
||||
>>> out = decoder_layer(tgt, memory)
|
||||
|
||||
Alternatively, when ``batch_first`` is ``True``:
|
||||
>>> decoder_layer = nn.TransformerDecoderLayer(d_model=512, nhead=8, batch_first=True)
|
||||
>>> decoder_layer = nn.TransformerDecoderLayer(
|
||||
... d_model=512, nhead=8, batch_first=True
|
||||
... )
|
||||
>>> memory = torch.rand(32, 10, 512)
|
||||
>>> tgt = torch.rand(32, 20, 512)
|
||||
>>> out = decoder_layer(tgt, memory)
|
||||
|
@ -11,9 +11,9 @@ from torch.nn.parallel import comm
|
||||
class Broadcast(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, target_gpus, *inputs):
|
||||
assert all(
|
||||
i.device.type != "cpu" for i in inputs
|
||||
), "Broadcast function not implemented for CPU tensors"
|
||||
assert all(i.device.type != "cpu" for i in inputs), (
|
||||
"Broadcast function not implemented for CPU tensors"
|
||||
)
|
||||
target_gpus = [_get_device_index(x, True) for x in target_gpus]
|
||||
ctx.target_gpus = target_gpus
|
||||
if len(inputs) == 0:
|
||||
@ -56,9 +56,9 @@ class ReduceAddCoalesced(Function):
|
||||
class Gather(Function):
|
||||
@staticmethod
|
||||
def forward(ctx, target_device, dim, *inputs):
|
||||
assert all(
|
||||
i.device.type != "cpu" for i in inputs
|
||||
), "Gather function not implemented for CPU tensors"
|
||||
assert all(i.device.type != "cpu" for i in inputs), (
|
||||
"Gather function not implemented for CPU tensors"
|
||||
)
|
||||
if target_device == "cpu":
|
||||
ctx.target_device = "cpu"
|
||||
else:
|
||||
|
@ -759,7 +759,7 @@ class DistributedDataParallel(Module, Joinable):
|
||||
"DistributedDataParallel device_ids and output_device arguments "
|
||||
"only work with single-device/multiple-device GPU modules or CPU modules, "
|
||||
f"but got device_ids {device_ids}, output_device {output_device}, "
|
||||
f"and module parameters {({p.device for p in self._module_parameters})}.",
|
||||
f"and module parameters { ({p.device for p in self._module_parameters}) }.", # noqa: E201,E202
|
||||
)
|
||||
|
||||
self.device_ids = None
|
||||
|
@ -46,9 +46,9 @@ def parallel_apply(
|
||||
element of :attr:`inputs` can either be a single object as the only argument
|
||||
to a module, or a collection of positional arguments.
|
||||
"""
|
||||
assert len(modules) == len(
|
||||
inputs
|
||||
), f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}"
|
||||
assert len(modules) == len(inputs), (
|
||||
f"The number of modules {len(modules)} is not equal to the number of inputs {len(inputs)}"
|
||||
)
|
||||
if kwargs_tup is not None:
|
||||
assert len(modules) == len(kwargs_tup)
|
||||
else:
|
||||
@ -88,9 +88,11 @@ def parallel_apply(
|
||||
if stream is None:
|
||||
stream = torch.cuda.current_stream(device)
|
||||
try:
|
||||
with torch.cuda.device(device), torch.cuda.stream(
|
||||
stream
|
||||
), torch.amp.autocast("cuda", enabled=autocast_enabled):
|
||||
with (
|
||||
torch.cuda.device(device),
|
||||
torch.cuda.stream(stream),
|
||||
torch.amp.autocast("cuda", enabled=autocast_enabled),
|
||||
):
|
||||
# this also avoids accidental slicing of `input` if it is a Tensor
|
||||
if not isinstance(input, (list, tuple)):
|
||||
input = (input,)
|
||||
|
@ -35,8 +35,7 @@ def scatter(
|
||||
inputs: torch.Tensor,
|
||||
target_gpus: Sequence[Union[int, torch.device]],
|
||||
dim: int = ...,
|
||||
) -> tuple[torch.Tensor, ...]:
|
||||
...
|
||||
) -> tuple[torch.Tensor, ...]: ...
|
||||
|
||||
|
||||
@overload
|
||||
@ -44,8 +43,7 @@ def scatter(
|
||||
inputs: T,
|
||||
target_gpus: Sequence[Union[int, torch.device]],
|
||||
dim: int = ...,
|
||||
) -> list[T]:
|
||||
...
|
||||
) -> list[T]: ...
|
||||
|
||||
|
||||
def scatter(inputs, target_gpus, dim=0):
|
||||
|
@ -4,6 +4,7 @@ r"""QAT Dynamic Modules.
|
||||
This package is in the process of being deprecated.
|
||||
Please, use `torch.ao.nn.qat.dynamic` instead.
|
||||
"""
|
||||
|
||||
from torch.nn.qat import dynamic, modules # noqa: F403
|
||||
from torch.nn.qat.modules import * # noqa: F403
|
||||
|
||||
|
@ -4,4 +4,5 @@ r"""QAT Dynamic Modules.
|
||||
This package is in the process of being deprecated.
|
||||
Please, use `torch.ao.nn.qat.dynamic` instead.
|
||||
"""
|
||||
|
||||
from torch.nn.qat.dynamic.modules import * # noqa: F403
|
||||
|
@ -7,4 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
|
||||
appropriate file under the `torch/ao/nn/qat/dynamic/modules`,
|
||||
while adding an import statement here.
|
||||
"""
|
||||
|
||||
from torch.ao.nn.qat.dynamic.modules.linear import Linear
|
||||
|
@ -4,6 +4,7 @@ r"""QAT Modules.
|
||||
This package is in the process of being deprecated.
|
||||
Please, use `torch.ao.nn.qat.modules` instead.
|
||||
"""
|
||||
|
||||
from torch.ao.nn.qat.modules.conv import Conv1d, Conv2d, Conv3d
|
||||
from torch.ao.nn.qat.modules.embedding_ops import Embedding, EmbeddingBag
|
||||
from torch.ao.nn.qat.modules.linear import Linear
|
||||
|
@ -7,4 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
|
||||
appropriate file under the `torch/ao/nn/qat/modules`,
|
||||
while adding an import statement here.
|
||||
"""
|
||||
|
||||
from torch.ao.nn.qat.modules.linear import Linear
|
||||
|
@ -7,4 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
|
||||
appropriate file under the `torch/ao/nn/quantizable/modules`,
|
||||
while adding an import statement here.
|
||||
"""
|
||||
|
||||
from torch.ao.nn.quantizable.modules.activation import MultiheadAttention
|
||||
|
@ -7,4 +7,5 @@ If you are adding a new entry/functionality, please, add it to the
|
||||
appropriate file under the `torch/ao/nn/quantized/dynamic/modules`,
|
||||
while adding an import statement here.
|
||||
"""
|
||||
|
||||
from torch.ao.nn.quantized.dynamic.modules.linear import Linear
|
||||
|
@ -314,7 +314,7 @@ def unfold3d(
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> B, C, D, H, W = 3, 4, 5, 6, 7
|
||||
>>> tensor = torch.arange(1, B * C * D * H * W + 1.).view(B, C, D, H, W)
|
||||
>>> tensor = torch.arange(1, B * C * D * H * W + 1.0).view(B, C, D, H, W)
|
||||
>>> unfold3d(tensor, kernel_size=2, padding=0, stride=1).shape
|
||||
torch.Size([3, 32, 120])
|
||||
"""
|
||||
|
@ -72,8 +72,10 @@ def allow_smaller_batches(args, kwargs):
|
||||
|
||||
@contextmanager
|
||||
def setup_rnn(use_input_variant, args, kwargs):
|
||||
with batch_second(args, kwargs) if use_input_variant else allow_smaller_batches(
|
||||
args, kwargs
|
||||
with (
|
||||
batch_second(args, kwargs)
|
||||
if use_input_variant
|
||||
else allow_smaller_batches(args, kwargs)
|
||||
):
|
||||
yield
|
||||
|
||||
|
@ -49,7 +49,9 @@ def call_for_per_sample_grads(
|
||||
grad_outputs by 1 / batch_size from cross batch interaction.
|
||||
>>> model = nn.Linear(4, 3)
|
||||
>>> batched_input = torch.randn(5, 4) # batch size of 5
|
||||
>>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")(batched_input).mean()
|
||||
>>> res = call_for_per_sample_grads(model, 5, loss_reduction="mean")(
|
||||
... batched_input
|
||||
... ).mean()
|
||||
>>> res.backward()
|
||||
|
||||
Note::
|
||||
|
@ -150,9 +150,7 @@ def _clip_grads_with_norm_(
|
||||
return
|
||||
grouped_grads: dict[
|
||||
tuple[torch.device, torch.dtype], tuple[list[list[Tensor]], list[int]]
|
||||
] = _group_tensors_by_device_and_dtype(
|
||||
[grads]
|
||||
) # type: ignore[assignment]
|
||||
] = _group_tensors_by_device_and_dtype([grads]) # type: ignore[assignment]
|
||||
|
||||
clip_coef = max_norm / (total_norm + 1e-6)
|
||||
# Note: multiplying by the clamped coef is redundant when the coef is clamped to 1, but doing so
|
||||
|
@ -135,9 +135,9 @@ def fuse_linear_bn_eval(
|
||||
2. the number of features in bn is 1
|
||||
Otherwise, skip the folding path
|
||||
"""
|
||||
assert (
|
||||
linear.out_features == bn.num_features or bn.num_features == 1
|
||||
), "To fuse, linear.out_features == bn.num_features or bn.num_features == 1"
|
||||
assert linear.out_features == bn.num_features or bn.num_features == 1, (
|
||||
"To fuse, linear.out_features == bn.num_features or bn.num_features == 1"
|
||||
)
|
||||
|
||||
assert bn.running_mean is not None and bn.running_var is not None
|
||||
fused_linear.weight, fused_linear.bias = fuse_linear_bn_weights(
|
||||
|
@ -63,12 +63,16 @@ def convert_conv2d_weight_memory_format(
|
||||
Example:
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
|
||||
>>> input = torch.randint(1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda")
|
||||
>>> input = torch.randint(
|
||||
... 1, 10, (2, 8, 4, 4), dtype=torch.float16, device="cuda"
|
||||
... )
|
||||
>>> model = nn.Sequential(
|
||||
>>> nn.Conv2d(8, 4, 3)).cuda().half()
|
||||
>>> # This is identical to:
|
||||
>>> # nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last)
|
||||
>>> model = nn.utils.convert_conv2d_weight_memory_format(model, torch.channels_last)
|
||||
>>> model = nn.utils.convert_conv2d_weight_memory_format(
|
||||
... model, torch.channels_last
|
||||
... )
|
||||
>>> out = model(input)
|
||||
"""
|
||||
# TODO: expand this to `_ConvNd` when channels_last support is extended
|
||||
@ -137,12 +141,16 @@ def convert_conv3d_weight_memory_format(
|
||||
Example:
|
||||
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_CUDA)
|
||||
>>> # xdoctest: +REQUIRES(env:CUBLAS_WORKSPACE_CONFIG)
|
||||
>>> input = torch.randint(1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda")
|
||||
>>> input = torch.randint(
|
||||
... 1, 10, (2, 8, 4, 4, 4), dtype=torch.float16, device="cuda"
|
||||
... )
|
||||
>>> model = nn.Sequential(
|
||||
>>> nn.Conv3d(8, 4, 3)).cuda().half()
|
||||
>>> # This is identical to:
|
||||
>>> # nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
|
||||
>>> model = nn.utils.convert_conv3d_weight_memory_format(model, torch.channels_last_3d)
|
||||
>>> model = nn.utils.convert_conv3d_weight_memory_format(
|
||||
... model, torch.channels_last_3d
|
||||
... )
|
||||
>>> out = model(input)
|
||||
"""
|
||||
|
||||
|
@ -46,6 +46,7 @@ def cached():
|
||||
.. code-block:: python
|
||||
|
||||
import torch.nn.utils.parametrize as P
|
||||
|
||||
...
|
||||
with P.cached():
|
||||
output = model(inputs)
|
||||
@ -536,7 +537,9 @@ def register_parametrization(
|
||||
>>> s0_sqrt = S[0].sqrt().unsqueeze(-1)
|
||||
>>> return U[..., :, 0] * s0_sqrt, Vh[..., 0, :] * s0_sqrt
|
||||
>>>
|
||||
>>> linear_rank_one = P.register_parametrization(nn.Linear(4, 4), "weight", RankOne())
|
||||
>>> linear_rank_one = P.register_parametrization(
|
||||
... nn.Linear(4, 4), "weight", RankOne()
|
||||
... )
|
||||
>>> print(torch.linalg.matrix_rank(linear_rank_one.weight).item())
|
||||
1
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Pruning methods."""
|
||||
|
||||
import numbers
|
||||
from abc import ABC, abstractmethod
|
||||
from collections.abc import Iterable
|
||||
@ -63,9 +64,9 @@ class BasePruningMethod(ABC):
|
||||
"""
|
||||
# to carry out the multiplication, the mask needs to have been computed,
|
||||
# so the pruning method must know what tensor it's operating on
|
||||
assert (
|
||||
self._tensor_name is not None
|
||||
), f"Module {module} has to be pruned" # this gets set in apply()
|
||||
assert self._tensor_name is not None, (
|
||||
f"Module {module} has to be pruned"
|
||||
) # this gets set in apply()
|
||||
mask = getattr(module, self._tensor_name + "_mask")
|
||||
orig = getattr(module, self._tensor_name + "_orig")
|
||||
pruned_tensor = mask.to(dtype=orig.dtype) * orig
|
||||
@ -109,10 +110,10 @@ class BasePruningMethod(ABC):
|
||||
old_method = hook
|
||||
hooks_to_remove.append(k)
|
||||
found += 1
|
||||
assert (
|
||||
found <= 1
|
||||
), f"Avoid adding multiple pruning hooks to the\
|
||||
assert found <= 1, (
|
||||
f"Avoid adding multiple pruning hooks to the\
|
||||
same tensor {name} of module {module}. Use a PruningContainer."
|
||||
)
|
||||
|
||||
for k in hooks_to_remove:
|
||||
del module._forward_pre_hooks[k]
|
||||
@ -153,9 +154,9 @@ class BasePruningMethod(ABC):
|
||||
|
||||
orig = getattr(module, name)
|
||||
if importance_scores is not None:
|
||||
assert (
|
||||
importance_scores.shape == orig.shape
|
||||
), f"importance_scores should have the same shape as parameter {name} of {module}"
|
||||
assert importance_scores.shape == orig.shape, (
|
||||
f"importance_scores should have the same shape as parameter {name} of {module}"
|
||||
)
|
||||
else:
|
||||
importance_scores = orig
|
||||
|
||||
@ -222,9 +223,9 @@ class BasePruningMethod(ABC):
|
||||
pruned version of tensor ``t``.
|
||||
"""
|
||||
if importance_scores is not None:
|
||||
assert (
|
||||
importance_scores.shape == t.shape
|
||||
), "importance_scores should have the same shape as tensor t"
|
||||
assert importance_scores.shape == t.shape, (
|
||||
"importance_scores should have the same shape as tensor t"
|
||||
)
|
||||
else:
|
||||
importance_scores = t
|
||||
default_mask = default_mask if default_mask is not None else torch.ones_like(t)
|
||||
@ -241,9 +242,9 @@ class BasePruningMethod(ABC):
|
||||
Pruning itself is NOT undone or reversed!
|
||||
"""
|
||||
# before removing pruning from a tensor, it has to have been applied
|
||||
assert (
|
||||
self._tensor_name is not None
|
||||
), f"Module {module} has to be pruned before pruning can be removed" # this gets set in apply()
|
||||
assert self._tensor_name is not None, (
|
||||
f"Module {module} has to be pruned before pruning can be removed"
|
||||
) # this gets set in apply()
|
||||
|
||||
# to update module[name] to latest trained weights
|
||||
weight = self.apply_mask(module) # masked weights
|
||||
@ -846,7 +847,7 @@ def identity(module, name):
|
||||
|
||||
Examples:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> m = prune.identity(nn.Linear(2, 3), 'bias')
|
||||
>>> m = prune.identity(nn.Linear(2, 3), "bias")
|
||||
>>> print(m.bias_mask)
|
||||
tensor([1., 1., 1.])
|
||||
"""
|
||||
@ -882,7 +883,7 @@ def random_unstructured(module, name, amount):
|
||||
|
||||
Examples:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> m = prune.random_unstructured(nn.Linear(2, 3), 'weight', amount=1)
|
||||
>>> m = prune.random_unstructured(nn.Linear(2, 3), "weight", amount=1)
|
||||
>>> torch.sum(m.weight_mask == 0)
|
||||
tensor(1)
|
||||
|
||||
@ -925,7 +926,7 @@ def l1_unstructured(module, name, amount, importance_scores=None):
|
||||
|
||||
Examples:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> m = prune.l1_unstructured(nn.Linear(2, 3), 'weight', amount=0.2)
|
||||
>>> m = prune.l1_unstructured(nn.Linear(2, 3), "weight", amount=0.2)
|
||||
>>> m.state_dict().keys()
|
||||
odict_keys(['bias', 'weight_orig', 'weight_mask'])
|
||||
"""
|
||||
@ -965,9 +966,7 @@ def random_structured(module, name, amount, dim):
|
||||
|
||||
Examples:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> m = prune.random_structured(
|
||||
... nn.Linear(5, 3), 'weight', amount=3, dim=1
|
||||
... )
|
||||
>>> m = prune.random_structured(nn.Linear(5, 3), "weight", amount=3, dim=1)
|
||||
>>> columns_pruned = int(sum(torch.sum(m.weight, dim=0) == 0))
|
||||
>>> print(columns_pruned)
|
||||
3
|
||||
@ -1014,7 +1013,7 @@ def ln_structured(module, name, amount, n, dim, importance_scores=None):
|
||||
Examples:
|
||||
>>> from torch.nn.utils import prune
|
||||
>>> m = prune.ln_structured(
|
||||
... nn.Conv2d(5, 3, 2), 'weight', amount=0.3, dim=1, n=float('-inf')
|
||||
... nn.Conv2d(5, 3, 2), "weight", amount=0.3, dim=1, n=float("-inf")
|
||||
... )
|
||||
"""
|
||||
LnStructured.apply(
|
||||
@ -1067,13 +1066,17 @@ def global_unstructured(parameters, pruning_method, importance_scores=None, **kw
|
||||
Examples:
|
||||
>>> from torch.nn.utils import prune
|
||||
>>> from collections import OrderedDict
|
||||
>>> net = nn.Sequential(OrderedDict([
|
||||
... ('first', nn.Linear(10, 4)),
|
||||
... ('second', nn.Linear(4, 1)),
|
||||
... ]))
|
||||
>>> net = nn.Sequential(
|
||||
... OrderedDict(
|
||||
... [
|
||||
... ("first", nn.Linear(10, 4)),
|
||||
... ("second", nn.Linear(4, 1)),
|
||||
... ]
|
||||
... )
|
||||
... )
|
||||
>>> parameters_to_prune = (
|
||||
... (net.first, 'weight'),
|
||||
... (net.second, 'weight'),
|
||||
... (net.first, "weight"),
|
||||
... (net.second, "weight"),
|
||||
... )
|
||||
>>> prune.global_unstructured(
|
||||
... parameters_to_prune,
|
||||
@ -1165,7 +1168,7 @@ def custom_from_mask(module, name, mask):
|
||||
Examples:
|
||||
>>> from torch.nn.utils import prune
|
||||
>>> m = prune.custom_from_mask(
|
||||
... nn.Linear(5, 3), name='bias', mask=torch.tensor([0, 1, 0])
|
||||
... nn.Linear(5, 3), name="bias", mask=torch.tensor([0, 1, 0])
|
||||
... )
|
||||
>>> print(m.bias_mask)
|
||||
tensor([0., 1., 0.])
|
||||
@ -1191,8 +1194,8 @@ def remove(module, name):
|
||||
will act.
|
||||
|
||||
Examples:
|
||||
>>> m = random_unstructured(nn.Linear(5, 7), name='weight', amount=0.2)
|
||||
>>> m = remove(m, name='weight')
|
||||
>>> m = random_unstructured(nn.Linear(5, 7), name="weight", amount=0.2)
|
||||
>>> m = remove(m, name="weight")
|
||||
"""
|
||||
for k, hook in module._forward_pre_hooks.items():
|
||||
if isinstance(hook, BasePruningMethod) and hook._tensor_name == name:
|
||||
@ -1223,7 +1226,7 @@ def is_pruned(module):
|
||||
>>> m = nn.Linear(5, 7)
|
||||
>>> print(prune.is_pruned(m))
|
||||
False
|
||||
>>> prune.random_unstructured(m, name='weight', amount=0.2)
|
||||
>>> prune.random_unstructured(m, name="weight", amount=0.2)
|
||||
>>> print(prune.is_pruned(m))
|
||||
True
|
||||
"""
|
||||
|
@ -105,8 +105,7 @@ class PackedSequence(PackedSequence_):
|
||||
dtype: torch.dtype,
|
||||
non_blocking: bool = ...,
|
||||
copy: bool = ...,
|
||||
) -> Self:
|
||||
...
|
||||
) -> Self: ...
|
||||
|
||||
@overload
|
||||
def to(
|
||||
@ -115,8 +114,7 @@ class PackedSequence(PackedSequence_):
|
||||
dtype: Optional[torch.dtype] = ...,
|
||||
non_blocking: bool = ...,
|
||||
copy: bool = ...,
|
||||
) -> Self:
|
||||
...
|
||||
) -> Self: ...
|
||||
|
||||
@overload
|
||||
def to(
|
||||
@ -124,8 +122,7 @@ class PackedSequence(PackedSequence_):
|
||||
other: Tensor,
|
||||
non_blocking: bool = ...,
|
||||
copy: bool = ...,
|
||||
) -> Self:
|
||||
...
|
||||
) -> Self: ...
|
||||
|
||||
def to(self, *args: Any, **kwargs: Any) -> Self:
|
||||
r"""Perform dtype and/or device conversion on `self.data`.
|
||||
@ -354,7 +351,9 @@ def pad_packed_sequence(
|
||||
>>> from torch.nn.utils.rnn import pack_padded_sequence, pad_packed_sequence
|
||||
>>> seq = torch.tensor([[1, 2, 0], [3, 0, 0], [4, 5, 6]])
|
||||
>>> lens = [2, 1, 3]
|
||||
>>> packed = pack_padded_sequence(seq, lens, batch_first=True, enforce_sorted=False)
|
||||
>>> packed = pack_padded_sequence(
|
||||
... seq, lens, batch_first=True, enforce_sorted=False
|
||||
... )
|
||||
>>> packed
|
||||
PackedSequence(data=tensor([4, 1, 3, 5, 2, 6]), batch_sizes=tensor([3, 2, 1]),
|
||||
sorted_indices=tensor([2, 0, 1]), unsorted_indices=tensor([1, 2, 0]))
|
||||
@ -473,7 +472,10 @@ def pad_sequence(
|
||||
# assuming trailing dimensions and type of all the Tensors
|
||||
# in sequences are same and fetching those from sequences[0]
|
||||
return torch._C._nn.pad_sequence(
|
||||
sequences, batch_first, padding_value, padding_side # type: ignore[arg-type]
|
||||
sequences, # type: ignore[arg-type]
|
||||
batch_first,
|
||||
padding_value,
|
||||
padding_side, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Spectral Normalization from https://arxiv.org/abs/1802.05957."""
|
||||
|
||||
from typing import Any, Optional, TypeVar
|
||||
|
||||
import torch
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Weight Normalization from https://arxiv.org/abs/1602.07868."""
|
||||
|
||||
from typing import Any, TypeVar
|
||||
from typing_extensions import deprecated
|
||||
|
||||
|
@ -347,9 +347,9 @@ def _single_tensor_adafactor(
|
||||
maximize: bool,
|
||||
has_complex: bool,
|
||||
):
|
||||
assert (
|
||||
grad_scale is None and found_inf is None
|
||||
), "Grad scaling should occur outside of optimizer.step()"
|
||||
assert grad_scale is None and found_inf is None, (
|
||||
"Grad scaling should occur outside of optimizer.step()"
|
||||
)
|
||||
|
||||
if torch.jit.is_scripting():
|
||||
# this assert is due to JIT being dumb and not realizing that the ops below
|
||||
@ -381,9 +381,9 @@ def _single_tensor_adafactor(
|
||||
param.mul_(1 - lr * weight_decay)
|
||||
|
||||
if grad.dim() > 1:
|
||||
assert (
|
||||
row_var is not None and col_var is not None
|
||||
), "row_var and col_var should be defined when grad is multidimensional"
|
||||
assert row_var is not None and col_var is not None, (
|
||||
"row_var and col_var should be defined when grad is multidimensional"
|
||||
)
|
||||
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
|
||||
row_mean = (
|
||||
torch.norm(grad, dim=-1, keepdim=True).square_().div_(grad.size(-1))
|
||||
@ -397,9 +397,9 @@ def _single_tensor_adafactor(
|
||||
var_estimate = row_var @ col_var
|
||||
var_estimate.div_(row_var.mean(dim=-2, keepdim=True).clamp_(min=eps1))
|
||||
else:
|
||||
assert (
|
||||
variance is not None
|
||||
), "variance should be defined when grad is a vector"
|
||||
assert variance is not None, (
|
||||
"variance should be defined when grad is a vector"
|
||||
)
|
||||
grad_squared = grad * grad
|
||||
variance.lerp_(grad_squared, one_minus_beta2_t)
|
||||
# avoid writing into variance during update
|
||||
@ -472,9 +472,9 @@ def _multi_tensor_adafactor(
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
||||
assert (
|
||||
grad_scale is None and found_inf is None
|
||||
), "Grad scaling should occur outside of optimizer.step()"
|
||||
assert grad_scale is None and found_inf is None, (
|
||||
"Grad scaling should occur outside of optimizer.step()"
|
||||
)
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
@ -495,9 +495,9 @@ def _multi_tensor_adafactor(
|
||||
device_grads = cast(list[Tensor], device_grads_)
|
||||
device_state_steps = cast(list[Tensor], device_state_steps_)
|
||||
if eps1 is None:
|
||||
assert (
|
||||
dtype is not None
|
||||
), "dtype is needed to compute eps1 when eps1 is unset"
|
||||
assert dtype is not None, (
|
||||
"dtype is needed to compute eps1 when eps1 is unset"
|
||||
)
|
||||
eps1 = torch.finfo(dtype).eps
|
||||
|
||||
if TYPE_CHECKING:
|
||||
@ -537,9 +537,9 @@ def _multi_tensor_adafactor(
|
||||
if is_multidim:
|
||||
device_row_vars = cast(list[Tensor], device_row_vars_)
|
||||
device_col_vars = cast(list[Tensor], device_col_vars_)
|
||||
assert (
|
||||
device_row_vars[0] is not None and device_col_vars[0] is not None
|
||||
), "row_var and col_var should be defined when grad is multidimensional"
|
||||
assert device_row_vars[0] is not None and device_col_vars[0] is not None, (
|
||||
"row_var and col_var should be defined when grad is multidimensional"
|
||||
)
|
||||
# same as (g * g).mean(dim=-1) w/o materializing an intermediate size g
|
||||
row_means = [
|
||||
torch.norm(grad, dim=-1, keepdim=True) for grad in device_grads
|
||||
@ -570,9 +570,9 @@ def _multi_tensor_adafactor(
|
||||
del row_var_means
|
||||
else:
|
||||
device_variances = cast(list[Tensor], device_variances_)
|
||||
assert (
|
||||
device_variances[0] is not None
|
||||
), "variance should be defined when grad is a vector"
|
||||
assert device_variances[0] is not None, (
|
||||
"variance should be defined when grad is a vector"
|
||||
)
|
||||
|
||||
grads_squared = torch._foreach_mul(device_grads, device_grads)
|
||||
torch._foreach_lerp_(device_variances, grads_squared, one_minus_beta2_ts)
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Functional interface."""
|
||||
|
||||
import math
|
||||
|
||||
from torch import Tensor
|
||||
|
@ -5,6 +5,7 @@ Most commonly used methods are already supported, and the interface is general
|
||||
enough, so that more sophisticated ones can be also easily integrated in the
|
||||
future.
|
||||
"""
|
||||
|
||||
from functools import partialmethod
|
||||
|
||||
from torch import optim
|
||||
|
@ -267,7 +267,9 @@ def _single_tensor_adadelta(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
if not torch.jit.is_scripting():
|
||||
lr = _to_scalar(lr)
|
||||
@ -326,7 +328,9 @@ def _multi_tensor_adadelta(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
if len(params) == 0:
|
||||
return
|
||||
|
@ -398,7 +398,9 @@ def _single_tensor_adam(
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
@ -433,7 +435,9 @@ def _single_tensor_adam(
|
||||
# cast to workaround https://github.com/pytorch/pytorch/issues/140601
|
||||
key = (device, dtype)
|
||||
if key not in beta1_dict:
|
||||
beta1_dict[key] = beta1.to(device=device, dtype=dtype, non_blocking=True) # type: ignore[union-attr]
|
||||
beta1_dict[key] = beta1.to( # type: ignore[union-attr]
|
||||
device=device, dtype=dtype, non_blocking=True
|
||||
)
|
||||
|
||||
device_beta1: Union[float, Tensor] = beta1_dict[key]
|
||||
else:
|
||||
@ -593,7 +597,9 @@ def _multi_tensor_adam(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
assert grad_scale is None and found_inf is None
|
||||
|
||||
@ -769,7 +775,10 @@ def _multi_tensor_adam(
|
||||
torch._foreach_div_(exp_avg_sq_sqrt, bias_correction2_sqrt)
|
||||
torch._foreach_add_(exp_avg_sq_sqrt, eps)
|
||||
torch._foreach_addcdiv_(
|
||||
device_params, device_exp_avgs, exp_avg_sq_sqrt, step_size # type: ignore[arg-type]
|
||||
device_params,
|
||||
device_exp_avgs,
|
||||
exp_avg_sq_sqrt,
|
||||
step_size, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
|
@ -256,7 +256,9 @@ def _single_tensor_adamax(
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
# update step
|
||||
step_t += 1
|
||||
@ -331,7 +333,9 @@ def _multi_tensor_adamax(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
|
@ -305,7 +305,9 @@ def _multi_tensor_asgd(
|
||||
p.device.type == mu.device.type == eta.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, mu, eta, step in zip(params, mus, etas, state_steps)
|
||||
), f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params, mus, etas, and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Learning Rate Scheduler."""
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import math
|
||||
@ -827,7 +828,11 @@ class SequentialLR(LRScheduler):
|
||||
>>> # lr = 0.0405 if epoch == 22
|
||||
>>> scheduler1 = ConstantLR(optimizer, factor=0.1, total_iters=20)
|
||||
>>> scheduler2 = ExponentialLR(optimizer, gamma=0.9)
|
||||
>>> scheduler = SequentialLR(optimizer, schedulers=[scheduler1, scheduler2], milestones=[20])
|
||||
>>> scheduler = SequentialLR(
|
||||
... optimizer,
|
||||
... schedulers=[scheduler1, scheduler2],
|
||||
... milestones=[20],
|
||||
... )
|
||||
>>> for epoch in range(100):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
@ -1271,7 +1276,7 @@ class ReduceLROnPlateau(LRScheduler):
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||
>>> scheduler = ReduceLROnPlateau(optimizer, 'min')
|
||||
>>> scheduler = ReduceLROnPlateau(optimizer, "min")
|
||||
>>> for epoch in range(10):
|
||||
>>> train(...)
|
||||
>>> val_loss = validate(...)
|
||||
@ -1502,7 +1507,12 @@ class CyclicLR(LRScheduler):
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
|
||||
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(optimizer, base_lr=0.01, max_lr=0.1, step_size_up=10)
|
||||
>>> scheduler = torch.optim.lr_scheduler.CyclicLR(
|
||||
... optimizer,
|
||||
... base_lr=0.01,
|
||||
... max_lr=0.1,
|
||||
... step_size_up=10,
|
||||
... )
|
||||
>>> data_loader = torch.utils.data.DataLoader(...)
|
||||
>>> for epoch in range(10):
|
||||
>>> for batch in data_loader:
|
||||
@ -1729,7 +1739,9 @@ class CosineAnnealingWarmRestarts(LRScheduler):
|
||||
Example:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=0.05)
|
||||
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(optimizer, T_0=20)
|
||||
>>> scheduler = torch.optim.lr_scheduler.CosineAnnealingWarmRestarts(
|
||||
... optimizer, T_0=20
|
||||
... )
|
||||
>>> for epoch in range(100):
|
||||
>>> train(...)
|
||||
>>> validate(...)
|
||||
@ -1936,7 +1948,9 @@ class OneCycleLR(LRScheduler):
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> data_loader = torch.utils.data.DataLoader(...)
|
||||
>>> optimizer = torch.optim.SGD(model.parameters(), lr=1e-4, momentum=0.9)
|
||||
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10)
|
||||
>>> scheduler = torch.optim.lr_scheduler.OneCycleLR(
|
||||
... optimizer, max_lr=0.01, steps_per_epoch=len(data_loader), epochs=10
|
||||
... )
|
||||
>>> for epoch in range(10):
|
||||
>>> for batch in data_loader:
|
||||
>>> train_batch(...)
|
||||
@ -2141,8 +2155,6 @@ class OneCycleLR(LRScheduler):
|
||||
if self.use_beta1:
|
||||
group["betas"] = (computed_momentum, *group["betas"][1:]) # type: ignore[possibly-undefined]
|
||||
else:
|
||||
group[
|
||||
"momentum"
|
||||
] = computed_momentum # type: ignore[possibly-undefined]
|
||||
group["momentum"] = computed_momentum # type: ignore[possibly-undefined]
|
||||
|
||||
return lrs
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for the NAdam algorithm."""
|
||||
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -408,7 +409,11 @@ def _multi_tensor_nadam(
|
||||
p.device.type == mp.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, mp, step in zip(params, mu_products, state_steps)
|
||||
), f"If capturable=True, params, mu_products, and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
"If capturable=True, "
|
||||
"params, mu_products, and state_steps must be on supported devices: "
|
||||
f"{capturable_supported_devices}."
|
||||
)
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
@ -576,10 +581,16 @@ def _multi_tensor_nadam(
|
||||
)
|
||||
|
||||
torch._foreach_addcdiv_(
|
||||
grouped_params, grouped_grads, exp_avg_sq_sqrt, step_size_grads # type: ignore[arg-type]
|
||||
grouped_params,
|
||||
grouped_grads,
|
||||
exp_avg_sq_sqrt,
|
||||
step_size_grads, # type: ignore[arg-type]
|
||||
)
|
||||
torch._foreach_addcdiv_(
|
||||
grouped_params, grouped_exp_avgs, exp_avg_sq_sqrt, step_size_expavg # type: ignore[arg-type]
|
||||
grouped_params,
|
||||
grouped_exp_avgs,
|
||||
exp_avg_sq_sqrt,
|
||||
step_size_expavg, # type: ignore[arg-type]
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
"""Base optimizer."""
|
||||
|
||||
import functools
|
||||
import warnings
|
||||
from collections import defaultdict, OrderedDict
|
||||
@ -103,7 +104,7 @@ def _stack_if_compiling(x):
|
||||
|
||||
|
||||
def _disable_dynamo_if_unsupported(
|
||||
single_tensor_fn: Optional[Callable[..., object]] = None
|
||||
single_tensor_fn: Optional[Callable[..., object]] = None,
|
||||
) -> Callable[[Callable[_P, _T]], Callable[_P, _T]]:
|
||||
# workaround for torchscript BC
|
||||
# it requires all called functions to be in the
|
||||
@ -349,15 +350,24 @@ class Optimizer:
|
||||
options (used when a parameter group doesn't specify them).
|
||||
"""
|
||||
|
||||
OptimizerPreHook: TypeAlias = Callable[[Self, Args, Kwargs], Optional[tuple[Args, Kwargs]]] # type: ignore[misc]
|
||||
OptimizerPreHook: TypeAlias = Callable[
|
||||
[Self, Args, Kwargs], # type: ignore[misc]
|
||||
Optional[tuple[Args, Kwargs]],
|
||||
]
|
||||
OptimizerPostHook: TypeAlias = Callable[[Self, Args, Kwargs], None] # type: ignore[misc]
|
||||
|
||||
_optimizer_step_pre_hooks: dict[int, OptimizerPreHook]
|
||||
_optimizer_step_post_hooks: dict[int, OptimizerPostHook]
|
||||
_optimizer_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
_optimizer_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
_optimizer_load_state_dict_pre_hooks: 'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
_optimizer_load_state_dict_post_hooks: 'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
_optimizer_state_dict_post_hooks: (
|
||||
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
)
|
||||
_optimizer_load_state_dict_pre_hooks: (
|
||||
'OrderedDict[int, Callable[["Optimizer", StateDict], Optional[StateDict]]]'
|
||||
)
|
||||
_optimizer_load_state_dict_post_hooks: (
|
||||
'OrderedDict[int, Callable[["Optimizer"], None]]'
|
||||
)
|
||||
|
||||
def __init__(self, params: ParamsT, defaults: dict[str, Any]) -> None: # noqa: D107
|
||||
torch._C._log_api_usage_once("python.optimizer")
|
||||
@ -847,7 +857,9 @@ class Optimizer:
|
||||
handle = hooks.RemovableHandle(self._optimizer_load_state_dict_post_hooks)
|
||||
self._optimizer_load_state_dict_post_hooks[handle.id] = hook
|
||||
if prepend:
|
||||
self._optimizer_load_state_dict_post_hooks.move_to_end(handle.id, last=False) # type: ignore[attr-defined]
|
||||
self._optimizer_load_state_dict_post_hooks.move_to_end(
|
||||
handle.id, last=False
|
||||
) # type: ignore[attr-defined]
|
||||
return handle
|
||||
|
||||
@torch._disable_dynamo
|
||||
@ -877,12 +889,25 @@ class Optimizer:
|
||||
>>> # xdoctest: +SKIP
|
||||
>>> model = torch.nn.Linear(10, 10)
|
||||
>>> optim = torch.optim.SGD(model.parameters(), lr=3e-4)
|
||||
>>> scheduler1 = torch.optim.lr_scheduler.LinearLR(optim, start_factor=0.1, end_factor=1, total_iters=20)
|
||||
>>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(optim, T_max=80, eta_min=3e-5)
|
||||
>>> lr = torch.optim.lr_scheduler.SequentialLR(optim, schedulers=[scheduler1, scheduler2], milestones=[20])
|
||||
>>> lr.load_state_dict(torch.load('./save_seq.pt'))
|
||||
>>> scheduler1 = torch.optim.lr_scheduler.LinearLR(
|
||||
... optim,
|
||||
... start_factor=0.1,
|
||||
... end_factor=1,
|
||||
... total_iters=20,
|
||||
... )
|
||||
>>> scheduler2 = torch.optim.lr_scheduler.CosineAnnealingLR(
|
||||
... optim,
|
||||
... T_max=80,
|
||||
... eta_min=3e-5,
|
||||
... )
|
||||
>>> lr = torch.optim.lr_scheduler.SequentialLR(
|
||||
... optim,
|
||||
... schedulers=[scheduler1, scheduler2],
|
||||
... milestones=[20],
|
||||
... )
|
||||
>>> lr.load_state_dict(torch.load("./save_seq.pt"))
|
||||
>>> # now load the optimizer checkpoint after loading the LRScheduler
|
||||
>>> optim.load_state_dict(torch.load('./save_optim.pt'))
|
||||
>>> optim.load_state_dict(torch.load("./save_optim.pt"))
|
||||
|
||||
"""
|
||||
# shallow copy, to be consistent with module API
|
||||
@ -933,7 +958,10 @@ class Optimizer:
|
||||
for k, v in value.items()
|
||||
}
|
||||
elif isinstance(value, Iterable):
|
||||
return type(value)(_cast(param, v, param_id=param_id, param_groups=param_groups) for v in value) # type: ignore[call-arg]
|
||||
return type(value)(
|
||||
_cast(param, v, param_id=param_id, param_groups=param_groups)
|
||||
for v in value
|
||||
) # type: ignore[call-arg]
|
||||
else:
|
||||
return value
|
||||
|
||||
@ -1021,12 +1049,10 @@ class Optimizer:
|
||||
torch._foreach_zero_(grads)
|
||||
|
||||
@overload
|
||||
def step(self, closure: None = None) -> None:
|
||||
...
|
||||
def step(self, closure: None = None) -> None: ...
|
||||
|
||||
@overload
|
||||
def step(self, closure: Callable[[], float]) -> float:
|
||||
...
|
||||
def step(self, closure: Callable[[], float]) -> float: ...
|
||||
|
||||
def step(self, closure: Optional[Callable[[], float]] = None) -> Optional[float]:
|
||||
r"""Perform a single optimization step to update parameter.
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for the RAdam algorithm."""
|
||||
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -285,7 +286,9 @@ def _single_tensor_radam(
|
||||
assert (
|
||||
param.device.type == step_t.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
if torch.is_complex(param):
|
||||
param = torch.view_as_real(param)
|
||||
@ -386,7 +389,9 @@ def _multi_tensor_radam(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for the RMSprop algorithm."""
|
||||
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -292,7 +293,9 @@ def _single_tensor_rmsprop(
|
||||
assert (
|
||||
param.device.type == step.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
grad = grads[i]
|
||||
grad = grad if not maximize else -grad
|
||||
@ -366,7 +369,9 @@ def _multi_tensor_rmsprop(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for the Resilient backpropagation."""
|
||||
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -248,7 +249,9 @@ def _single_tensor_rprop(
|
||||
assert (
|
||||
param.device.type == step.device.type
|
||||
and param.device.type in capturable_supported_devices
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
step += 1
|
||||
|
||||
@ -315,7 +318,9 @@ def _multi_tensor_rprop(
|
||||
p.device.type == step.device.type
|
||||
and p.device.type in capturable_supported_devices
|
||||
for p, step in zip(params, state_steps)
|
||||
), f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
), (
|
||||
f"If capturable=True, params and state_steps must be on supported devices: {capturable_supported_devices}."
|
||||
)
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||
[params, grads, prevs, step_sizes, state_steps] # type: ignore[list-item]
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for Stochastic Gradient Descent optimizer."""
|
||||
|
||||
from typing import cast, Optional, Union
|
||||
|
||||
import torch
|
||||
@ -397,7 +398,8 @@ def _multi_tensor_sgd(
|
||||
lr = _to_scalar(lr)
|
||||
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||
[params, grads, momentum_buffer_list], with_indices=True # type: ignore[list-item]
|
||||
[params, grads, momentum_buffer_list], # type: ignore[list-item]
|
||||
with_indices=True,
|
||||
)
|
||||
for (
|
||||
device_params_,
|
||||
@ -502,7 +504,8 @@ def _fused_sgd(
|
||||
for i, g in enumerate(grads):
|
||||
momentum_buffer_list[i] = torch.empty_like(g)
|
||||
grouped_tensors = Optimizer._group_tensors_by_device_and_dtype(
|
||||
[params, grads, momentum_buffer_list], with_indices=False # type: ignore[list-item]
|
||||
[params, grads, momentum_buffer_list], # type: ignore[list-item]
|
||||
with_indices=False,
|
||||
)
|
||||
for (device, _), (
|
||||
(device_params_, device_grads_, device_momentum_buffer_list),
|
||||
|
@ -37,9 +37,9 @@ class SparseAdam(Optimizer):
|
||||
sparse_params = []
|
||||
complex_params = []
|
||||
for index, param_group in enumerate(self.param_groups):
|
||||
assert isinstance(
|
||||
param_group, dict
|
||||
), f"param_groups must be a list of dicts, but got {type(param_group)}"
|
||||
assert isinstance(param_group, dict), (
|
||||
f"param_groups must be a list of dicts, but got {type(param_group)}"
|
||||
)
|
||||
# given param group, convert given params to a list first before iterating
|
||||
for d_index, d_param in enumerate(param_group["params"]):
|
||||
if d_param.is_sparse:
|
||||
|
@ -1,5 +1,6 @@
|
||||
# mypy: allow-untyped-defs
|
||||
r"""Implementation for Stochastic Weight Averaging implementation."""
|
||||
|
||||
import itertools
|
||||
import math
|
||||
import warnings
|
||||
@ -225,9 +226,9 @@ class AveragedModel(Module):
|
||||
use_buffers=False,
|
||||
): # noqa: D107
|
||||
super().__init__()
|
||||
assert (
|
||||
avg_fn is None or multi_avg_fn is None
|
||||
), "Only one of avg_fn and multi_avg_fn should be provided"
|
||||
assert avg_fn is None or multi_avg_fn is None, (
|
||||
"Only one of avg_fn and multi_avg_fn should be provided"
|
||||
)
|
||||
self.module = deepcopy(model)
|
||||
if device is not None:
|
||||
self.module = self.module.to(device)
|
||||
@ -274,7 +275,9 @@ class AveragedModel(Module):
|
||||
) in grouped_tensors.items():
|
||||
if self.multi_avg_fn:
|
||||
self.multi_avg_fn(
|
||||
self_params, model_params, self.n_averaged.to(device) # type: ignore[arg-type]
|
||||
self_params, # type: ignore[arg-type]
|
||||
model_params, # type: ignore[arg-type]
|
||||
self.n_averaged.to(device),
|
||||
)
|
||||
elif (
|
||||
device is not None
|
||||
|
Reference in New Issue
Block a user