mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Rename NewCriterionTest to CriterionTest. (#44056)
Summary: Pull Request resolved: https://github.com/pytorch/pytorch/pull/44056 Test Plan: Imported from OSS Reviewed By: zou3519 Differential Revision: D23482573 Pulled By: gchanan fbshipit-source-id: dde0f1624330dc85f48e5a0b9d98fb55fdb72f68
This commit is contained in:
committed by
Facebook GitHub Bot
parent
7d95eb8633
commit
5973b44d9e
@ -21,7 +21,7 @@ TorchNNModuleTestParams = namedtuple(
|
||||
# Unique identifier for this module config (e.g. "BCELoss_weights_cuda")
|
||||
'module_variant_name',
|
||||
|
||||
# An instance of an NN test class (e.g. `NewCriterionTest`) which stores
|
||||
# An instance of an NN test class (e.g. `CriterionTest`) which stores
|
||||
# necessary information (e.g. input / target / extra_args) for running the Python test
|
||||
'test_instance',
|
||||
|
||||
@ -184,7 +184,7 @@ def move_cpp_tensors_to_device(cpp_tensor_stmts, device):
|
||||
return ['{}.to("{}")'.format(tensor_stmt, device) for tensor_stmt in cpp_tensor_stmts]
|
||||
|
||||
def is_criterion_test(test_instance):
|
||||
return isinstance(test_instance, common_nn.NewCriterionTest)
|
||||
return isinstance(test_instance, common_nn.CriterionTest)
|
||||
|
||||
# This function computes the following:
|
||||
# - What variable declaration statements should show up in the C++ parity test function
|
||||
|
@ -30,8 +30,8 @@ for test_params_dicts, test_instance_class in [
|
||||
(sample_functional.functional_tests, common_nn.NewModuleTest),
|
||||
(common_nn.module_tests, common_nn.ModuleTest),
|
||||
(common_nn.new_module_tests, common_nn.NewModuleTest),
|
||||
(common_nn.criterion_tests, common_nn.NewCriterionTest),
|
||||
(common_nn.new_criterion_tests, common_nn.NewCriterionTest),
|
||||
(common_nn.criterion_tests, common_nn.CriterionTest),
|
||||
(common_nn.new_criterion_tests, common_nn.CriterionTest),
|
||||
]:
|
||||
for test_params_dict in test_params_dicts:
|
||||
if test_params_dict.get('test_cpp_api_parity', True):
|
||||
|
@ -38,7 +38,7 @@ from torch.testing._internal.common_utils import freeze_rng_state, run_tests, Te
|
||||
get_function_arglist, load_tests, repeat_test_for_types, ALL_TENSORTYPES, \
|
||||
ALL_TENSORTYPES2, TemporaryFileName, TEST_WITH_UBSAN, IS_PPC
|
||||
from torch.testing._internal.common_cuda import TEST_CUDA, TEST_MULTIGPU, TEST_CUDNN, TEST_CUDNN_VERSION
|
||||
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, NewCriterionTest, \
|
||||
from torch.testing._internal.common_nn import NNTestCase, NewModuleTest, CriterionTest, \
|
||||
module_tests, criterion_tests, new_criterion_tests, loss_reference_fns, \
|
||||
ctcloss_reference, new_module_tests
|
||||
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes, \
|
||||
@ -8742,7 +8742,7 @@ for test_params in module_tests + new_module_tests:
|
||||
for test_params in criterion_tests + new_criterion_tests:
|
||||
name = test_params.pop('module_name')
|
||||
test_params['constructor'] = getattr(nn, name)
|
||||
test = NewCriterionTest(**test_params)
|
||||
test = CriterionTest(**test_params)
|
||||
decorator = test_params.pop('decorator', None)
|
||||
add_test(test, decorator)
|
||||
if 'check_sum_reduction' in test_params:
|
||||
@ -8757,7 +8757,7 @@ for test_params in criterion_tests + new_criterion_tests:
|
||||
return sum_reduction_constructor
|
||||
|
||||
test_params['constructor'] = gen_sum_reduction_constructor(test_params['constructor'])
|
||||
test = NewCriterionTest(**test_params)
|
||||
test = CriterionTest(**test_params)
|
||||
add_test(test, decorator)
|
||||
|
||||
|
||||
|
@ -5032,7 +5032,7 @@ class NewModuleTest(InputVariableMixin, ModuleTest):
|
||||
return self._get_arg('constructor_args', False)
|
||||
|
||||
|
||||
class NewCriterionTest(InputVariableMixin, TestBase):
|
||||
class CriterionTest(InputVariableMixin, TestBase):
|
||||
# TODO: check that criterions don't ignore grad_output
|
||||
|
||||
_required_arg_names = TestBase._required_arg_names.union({'target'})
|
||||
|
Reference in New Issue
Block a user