mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Enable UFMT on test/test_fake_tensor.py
, test/test_flop_counter.py
and some files (#125747)
Part of: #123062 Ran lintrunner on: - test/test_fake_tensor.py - test/test_flop_counter.py - test/test_function_schema.py - test/test_functional_autograd_benchmark.py - test/test_functional_optim.py - test/test_functionalization_of_rng_ops.py Detail: ```bash $ lintrunner -a --take UFMT --all-files ok No lint issues. Successfully applied all patches. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/125747 Approved by: https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
187aeaeabf
commit
ba3cd6e463
@ -1,15 +1,16 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
|
||||
from typing import List, Optional, Tuple
|
||||
import unittest
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
import torch
|
||||
import torch.distributed
|
||||
import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch import Tensor
|
||||
from torch.optim import SGD, Adam, AdamW
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests
|
||||
from torch.optim import Adam, AdamW, SGD
|
||||
from torch.testing._internal.common_utils import run_tests, TestCase
|
||||
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -21,6 +22,7 @@ class MyModule(torch.nn.Module):
|
||||
def forward(self, t1):
|
||||
return self.lin2(F.relu(self.lin1(t1)))
|
||||
|
||||
|
||||
# dummy class to showcase custom optimizer registration with functional wrapper
|
||||
class MyDummyFnOptimizer:
|
||||
def __init__(
|
||||
@ -32,7 +34,6 @@ class MyDummyFnOptimizer:
|
||||
weight_decay: float = 0.0,
|
||||
_allow_empty_param_list: bool = False,
|
||||
):
|
||||
|
||||
if not 0.0 <= lr:
|
||||
raise ValueError(f"Invalid learning rate: {lr}")
|
||||
if not 0.0 <= eps:
|
||||
@ -58,17 +59,26 @@ class MyDummyFnOptimizer:
|
||||
def step_param(self, param: Tensor, grad: Optional[Tensor]):
|
||||
# call the custom optimizer step_param implementation
|
||||
with torch.no_grad():
|
||||
raise RuntimeError("MyDummyFnOptimizer does not support step_param() as of now")
|
||||
raise RuntimeError(
|
||||
"MyDummyFnOptimizer does not support step_param() as of now"
|
||||
)
|
||||
|
||||
def step(self, gradients: List[Optional[Tensor]]):
|
||||
# call the custom optimizer step implementation
|
||||
with torch.no_grad():
|
||||
raise RuntimeError("MyDummyFnOptimizer does not support step() as of now")
|
||||
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed.optim.utils import functional_optim_map, register_functional_optim
|
||||
|
||||
@unittest.skipIf(not torch.distributed.is_available(), "These are testing distributed functions")
|
||||
if torch.distributed.is_available():
|
||||
from torch.distributed.optim.utils import (
|
||||
functional_optim_map,
|
||||
register_functional_optim,
|
||||
)
|
||||
|
||||
|
||||
@unittest.skipIf(
|
||||
not torch.distributed.is_available(), "These are testing distributed functions"
|
||||
)
|
||||
class TestFunctionalOptimParity(TestCase):
|
||||
def _validate_parameters(self, params_1, params_2):
|
||||
for p1, p2 in zip(params_1, params_2):
|
||||
|
Reference in New Issue
Block a user