mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[BE] remove _SUPPORTED_OPTIM_MAP from tests (#63383)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/63383 Per title ghstack-source-id: 135966157 Test Plan: CI Reviewed By: SciPioneer Differential Revision: D30358921 fbshipit-source-id: 965e054e525194b1ee55980340df275bab355c9b
This commit is contained in:
committed by
Facebook GitHub Bot
parent
5b8862abf1
commit
dcf90b797c
@ -5,17 +5,7 @@ import torch.nn as nn
|
||||
import torch.nn.functional as F
|
||||
from torch.optim import SGD, Adam, AdamW
|
||||
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
|
||||
|
||||
if not IS_WINDOWS:
|
||||
from torch.distributed.optim.functional_sgd import _FunctionalSGD
|
||||
from torch.distributed.optim.functional_adam import _FunctionalAdam
|
||||
from torch.distributed.optim.functional_adamw import _FunctionalAdamW
|
||||
_SUPPORTED_OPTIM_MAPPING = {
|
||||
SGD: _FunctionalSGD,
|
||||
Adam: _FunctionalAdam,
|
||||
AdamW: _FunctionalAdamW,
|
||||
}
|
||||
|
||||
from torch.distributed.optim import functional_optim_map
|
||||
|
||||
class MyModule(torch.nn.Module):
|
||||
def __init__(self):
|
||||
@ -39,7 +29,7 @@ class TestFunctionalOptimParity(TestCase):
|
||||
optim_params = module_optim.parameters()
|
||||
functional_params = module_functional.parameters()
|
||||
optim = optim_cls(optim_params, *args, **kwargs)
|
||||
functional_optim_cls = _SUPPORTED_OPTIM_MAPPING.get(optim_cls, None)
|
||||
functional_optim_cls = functional_optim_map.get(optim_cls, None)
|
||||
if not functional_optim_cls:
|
||||
raise ValueError(f"Functional optimizer not implemented for {optim_cls}")
|
||||
optim_functional = functional_optim_cls(
|
||||
|
Reference in New Issue
Block a user