[BE] remove _SUPPORTED_OPTIM_MAP from tests (#63383)

Summary:
Pull Request resolved: https://github.com/pytorch/pytorch/pull/63383

Per title
ghstack-source-id: 135966157

Test Plan: CI

Reviewed By: SciPioneer

Differential Revision: D30358921

fbshipit-source-id: 965e054e525194b1ee55980340df275bab355c9b
This commit is contained in:
Rohan Varma
2021-08-17 17:12:32 -07:00
committed by Facebook GitHub Bot
parent 5b8862abf1
commit dcf90b797c
3 changed files with 9 additions and 24 deletions

View File

@ -5,17 +5,7 @@ import torch.nn as nn
import torch.nn.functional as F
from torch.optim import SGD, Adam, AdamW
from torch.testing._internal.common_utils import TestCase, run_tests, IS_WINDOWS
if not IS_WINDOWS:
from torch.distributed.optim.functional_sgd import _FunctionalSGD
from torch.distributed.optim.functional_adam import _FunctionalAdam
from torch.distributed.optim.functional_adamw import _FunctionalAdamW
_SUPPORTED_OPTIM_MAPPING = {
SGD: _FunctionalSGD,
Adam: _FunctionalAdam,
AdamW: _FunctionalAdamW,
}
from torch.distributed.optim import functional_optim_map
class MyModule(torch.nn.Module):
def __init__(self):
@ -39,7 +29,7 @@ class TestFunctionalOptimParity(TestCase):
optim_params = module_optim.parameters()
functional_params = module_functional.parameters()
optim = optim_cls(optim_params, *args, **kwargs)
functional_optim_cls = _SUPPORTED_OPTIM_MAPPING.get(optim_cls, None)
functional_optim_cls = functional_optim_map.get(optim_cls, None)
if not functional_optim_cls:
raise ValueError(f"Functional optimizer not implemented for {optim_cls}")
optim_functional = functional_optim_cls(