mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
This is a lot of files changed! Don't panic! Here's how it works: * Previously, we set `follow_imports = silent` for our mypy.ini configuration. Per https://mypy.readthedocs.io/en/stable/running_mypy.html#follow-imports, what this does is whenever we have an import to a module which is not listed as a file to be typechecked in mypy, we typecheck it as normal but suppress all errors that occurred in that file. * When mypy is run inside lintrunner, the list of files is precisely the files covered by the glob in lintrunner.toml, but with files in excludes excluded. * The top-level directive `# mypy: ignore-errors` instructs mypy to typecheck the file as normal, but ignore all errors. * Therefore, it should be equivalent to set `follow_imports = normal`, if we put `# mypy: ignore-errors` on all files that were previously excluded from the file list. * Having done this, we can remove the exclude list from .lintrunner.toml, since excluding a file from typechecking is baked into the files themselves. * torch/_dynamo and torch/_inductor were previously in the exclude list, because they were covered by MYPYINDUCTOR. It is not OK to mark these as `# mypy: ignore-errors` as this will impede typechecking on the alternate configuration. So they are temporarily being checked twice, but I am suppressing the errors in these files as the configurations are not quite the same. I plan to unify the configurations so this is only a temporary state. * There were some straggler type errors after these changes somehow, so I fixed them as needed. There weren't that many. In the future, to start type checking a file, just remove the ignore-errors directive from the top of the file. The codemod was done with this script authored by GPT-4: ``` import glob exclude_patterns = [ ... ] for pattern in exclude_patterns: for filepath in glob.glob(pattern, recursive=True): if filepath.endswith('.py'): with open(filepath, 'r+') as f: content = f.read() f.seek(0, 0) f.write('# mypy: ignore-errors\n\n' + content) ``` Signed-off-by: Edward Z. Yang <ezyang@meta.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/118414 Approved by: https://github.com/thiagocrepaldi, https://github.com/albanD
324 lines
16 KiB
Python
324 lines
16 KiB
Python
# mypy: ignore-errors
|
|
|
|
# Torch
|
|
import torch
|
|
import torch.cuda
|
|
import torch.jit
|
|
import torch.jit._logging
|
|
import torch.jit.frontend
|
|
import torch.jit.quantized
|
|
|
|
# Testing utils
|
|
from torch.testing._internal.common_dtype import floating_and_complex_types_and
|
|
from torch.testing._internal.common_utils import TestCase, \
|
|
freeze_rng_state, TemporaryFileName, enable_profiling_mode_for_profiling_tests, is_iterable_of_tensors
|
|
from torch.testing._internal.common_utils import enable_profiling_mode # noqa: F401
|
|
|
|
# Standard library
|
|
from itertools import chain
|
|
from typing import List, Union
|
|
from torch._C import TensorType
|
|
|
|
import io
|
|
|
|
def check_output_types(self, func, ref_outputs, args, kwargs):
|
|
graph = getattr(func, 'last_graph', None)
|
|
types = [o.type() for o in graph.outputs()]
|
|
self.assertTrue(len(types) == 1)
|
|
t = types[0]
|
|
torch._C._jit_assert_is_instance(ref_outputs, t)
|
|
|
|
# Test names in this set are only checked for a single derivative
|
|
nn_functional_single_grad = frozenset('test_nn_' + name for name in [
|
|
'pdist',
|
|
'multilabel_margin_loss',
|
|
'max_unpool3d',
|
|
'multi_margin_loss',
|
|
'binary_cross_entropy',
|
|
'binary_cross_entropy_size_average',
|
|
'ctc_loss',
|
|
'grid_sample',
|
|
])
|
|
|
|
def check_against_reference(self, func, reference_func, output_func, args, kwargs=None,
|
|
allow_unused=True, check_types=True, no_grad=False, no_gradgrad=False):
|
|
"""Verifies a function performs identically to some reference implementation.
|
|
|
|
Commonly, this is used to verify that a JIT implementation
|
|
(output_func) matches the behavior of the eager implementation
|
|
(reference_func).
|
|
"""
|
|
kwargs = kwargs if kwargs else {}
|
|
|
|
def allSum(vs):
|
|
if isinstance(vs, torch.Tensor):
|
|
vs = (vs,)
|
|
return sum((i + 1) * v.sum().abs() if v.dtype.is_complex else (i + 1) * v.sum()
|
|
for i, v in enumerate(vs)
|
|
if v is not None and v.dtype in floating_and_complex_types_and(torch.half, torch.bfloat16))
|
|
|
|
def clone_tensor(t, preserve_requires_grad):
|
|
require_grad = preserve_requires_grad and t.requires_grad
|
|
return t.detach().clone().requires_grad_(require_grad)
|
|
|
|
def clone_inputs(preserve_requires_grad: bool):
|
|
inputs: List[Union[torch.Tensor, List[torch.Tensor]]] = []
|
|
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor):
|
|
inputs.append(clone_tensor(arg, preserve_requires_grad))
|
|
elif is_iterable_of_tensors(arg):
|
|
inputs.append([clone_tensor(t, preserve_requires_grad) for t in arg])
|
|
else:
|
|
inputs.append(arg)
|
|
|
|
return inputs
|
|
|
|
# Returns tensors in args that requires_grad, including tensors in TensorList args
|
|
def get_recording_tensors(args):
|
|
recording_tensors: List[torch.Tensor] = []
|
|
|
|
for arg in args:
|
|
if isinstance(arg, torch.Tensor) and arg.requires_grad:
|
|
recording_tensors.append(arg)
|
|
elif is_iterable_of_tensors(arg):
|
|
recording_tensors.extend(filter(lambda t: t.requires_grad, arg))
|
|
|
|
return recording_tensors
|
|
|
|
# test no gradients case
|
|
nograd_inputs = clone_inputs(preserve_requires_grad=False)
|
|
outputs = self.runAndSaveRNG(reference_func, nograd_inputs, kwargs)
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
outputs_test = self.runAndSaveRNG(func, nograd_inputs, kwargs)
|
|
self.assertEqual(outputs, outputs_test)
|
|
|
|
if check_types:
|
|
check_output_types(self, func, outputs_test, nograd_inputs, kwargs)
|
|
|
|
if no_grad:
|
|
# skip grad tests
|
|
return
|
|
|
|
with enable_profiling_mode_for_profiling_tests():
|
|
# test single grad case
|
|
recording_inputs = clone_inputs(preserve_requires_grad=True)
|
|
recording_tensors = get_recording_tensors(recording_inputs)
|
|
outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
|
|
grads = torch.autograd.grad(allSum(outputs), recording_tensors,
|
|
allow_unused=allow_unused)
|
|
outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
|
|
grads_test = torch.autograd.grad(allSum(outputs_test), recording_tensors,
|
|
allow_unused=allow_unused)
|
|
self.assertEqual(outputs, outputs_test)
|
|
self.assertEqual(grads, grads_test)
|
|
# test the grad grad case
|
|
if self._testMethodName in nn_functional_single_grad or no_gradgrad:
|
|
return
|
|
|
|
outputs = output_func(self.runAndSaveRNG(reference_func, recording_inputs, kwargs))
|
|
l1 = allSum(outputs)
|
|
grads = torch.autograd.grad(l1, recording_tensors, create_graph=True,
|
|
allow_unused=allow_unused)
|
|
|
|
l2 = (allSum(grads) * l1)
|
|
grads2 = torch.autograd.grad(l2, recording_tensors, allow_unused=allow_unused)
|
|
recording_inputs = clone_inputs(preserve_requires_grad=True)
|
|
recording_tensors = get_recording_tensors(recording_inputs)
|
|
outputs_test = output_func(self.runAndSaveRNG(func, recording_inputs, kwargs))
|
|
l1_test = allSum(outputs_test)
|
|
grads_test = torch.autograd.grad(
|
|
l1_test, recording_tensors, create_graph=True, allow_unused=allow_unused)
|
|
|
|
l2_test = (allSum(grads_test) * l1_test)
|
|
grads2_test = torch.autograd.grad(l2_test, recording_tensors, allow_unused=allow_unused)
|
|
|
|
self.assertEqual(outputs, outputs_test)
|
|
self.assertEqual(grads, grads_test)
|
|
for g2, g2_test in zip(grads2, grads2_test):
|
|
if g2 is None and g2_test is None:
|
|
continue
|
|
self.assertEqual(g2, g2_test, atol=5e-4, rtol=1e-4)
|
|
|
|
class JitCommonTestCase(TestCase):
|
|
def createFunctionFromGraph(self, trace):
|
|
graph = trace if isinstance(trace, torch._C.Graph) else trace.graph()
|
|
return torch._C._create_function_from_graph("forward", graph)
|
|
|
|
def assertExportImport(self, trace, inputs):
|
|
m = self.createFunctionFromGraph(trace)
|
|
self.assertExportImportModule(m, inputs)
|
|
|
|
def assertExportImportModule(self, m, inputs):
|
|
m_import = self.getExportImportCopy(m)
|
|
a = self.runAndSaveRNG(m, inputs)
|
|
b = self.runAndSaveRNG(m_import, inputs)
|
|
self.assertEqual(a, b, "Results of original model and "
|
|
"exported/imported version of model differed")
|
|
|
|
def runAndSaveRNG(self, func, inputs, kwargs=None):
|
|
kwargs = kwargs if kwargs else {}
|
|
with freeze_rng_state():
|
|
results = func(*inputs, **kwargs)
|
|
return results
|
|
|
|
def getExportImportCopy(self, m, also_test_file=True, map_location=None):
|
|
buffer = io.BytesIO()
|
|
torch.jit.save(m, buffer)
|
|
buffer.seek(0)
|
|
imported = torch.jit.load(buffer, map_location=map_location)
|
|
|
|
if not also_test_file:
|
|
return imported
|
|
|
|
with TemporaryFileName() as fname:
|
|
torch.jit.save(imported, fname)
|
|
return torch.jit.load(fname, map_location=map_location)
|
|
|
|
def autoDiffErrorMessage(self, should_autodiff_node, nodes_not_in_diff_graph,
|
|
fusion_nodes_not_found, non_fusible_nodes_being_fused,
|
|
fusion_nodes_found, nodes_in_diff_graph):
|
|
err_msg = "\nFailure in testing nodes' autodifferentiation. "
|
|
if should_autodiff_node:
|
|
err_msg += "One or more nodes were expected to be autodiffed, " \
|
|
"but were not found in specified fusible/nonfusible " \
|
|
"DifferentiableGraph groups. \nSpecifically:"
|
|
# The node is intended to appear in a differentiable graph but doesn't
|
|
diff_nodes_missing = []
|
|
# The node is intended to appear in a differentiable graph
|
|
# outside of a fusion group but instead is in a fusion group
|
|
diff_nodes_in_fusion = []
|
|
# The node is intended to appear in a fusion group but doesn't
|
|
fusion_nodes_missing = []
|
|
# The node is intended to appear in a fusion group but instead
|
|
# is just in an outer differentiable graph
|
|
fusion_nodes_in_diff = []
|
|
for node in nodes_not_in_diff_graph:
|
|
if node in non_fusible_nodes_being_fused:
|
|
diff_nodes_in_fusion.append(node)
|
|
else:
|
|
diff_nodes_missing.append(node)
|
|
for node in fusion_nodes_not_found:
|
|
if node in nodes_in_diff_graph:
|
|
fusion_nodes_in_diff.append(node)
|
|
else:
|
|
fusion_nodes_missing.append(node)
|
|
if len(diff_nodes_missing) > 0:
|
|
err_msg += f"\n {diff_nodes_missing} were not in one of the " \
|
|
"DifferentiableGraphs when they were expected to be. " \
|
|
"Did you intend for these nodes to be autodiffed? " \
|
|
"If not, remove them from the list of nonfusible nodes."
|
|
if len(diff_nodes_in_fusion) > 0:
|
|
err_msg += f"\n {diff_nodes_in_fusion} were found in one of the FusionGroups " \
|
|
"when they were expected to be just in a DifferentiableGraph. If it was " \
|
|
"intended for these nodes to be in FusionGroups, reclassify these nodes as " \
|
|
"fusible nodes. If these nodes were not intended to be fused, your " \
|
|
"autodifferentiation logic might be wrong."
|
|
if len(fusion_nodes_missing) > 0:
|
|
err_msg += f"\n {fusion_nodes_missing} were not in one of the FusionGroups " \
|
|
"of the DifferentiableGraphs when they were expected to be. " \
|
|
"They were also not found in an outer DifferentiableGraph. Did you " \
|
|
"intend for these nodes to be autodifferentiated? If not, you should " \
|
|
"remove these nodes from the test's fusible nodes. Otherwise your " \
|
|
"autodifferentiation logic might be wrong."
|
|
if len(fusion_nodes_in_diff) > 0:
|
|
err_msg += f"\n {fusion_nodes_in_diff} were not in one of the FusionGroups " \
|
|
"of the DifferentiableGraphs when they were expected to be, " \
|
|
"instead they were found just in an outer DifferentiableGraph. " \
|
|
"Did you intend for these nodes to be fused? If not, you should " \
|
|
"move these nodes into the test's nonfusible nodes. Otherwise your " \
|
|
"autodifferentiation logic might be wrong."
|
|
else:
|
|
err_msg += "One or more nodes were not expected to be autodiffed " \
|
|
"but were found in a DifferentiableGraph or in a FusionGroup " \
|
|
"of a DifferentiableGraph. Did you intend for these nodes to be " \
|
|
"autodiffed? If so, change this test to expect autodifferentiation. " \
|
|
"\nSpecifically:"
|
|
if len(fusion_nodes_found) > 0:
|
|
err_msg += f"\n {fusion_nodes_found} were not expected to be in " \
|
|
"one of the DifferentiableGraphs, but appeared in a FusionGroup " \
|
|
"of a DifferentiableGraph. "
|
|
if len(nodes_in_diff_graph) > 0:
|
|
err_msg += f"\n {nodes_in_diff_graph} were not expected to " \
|
|
"be in one of the DifferentiableGraphs but were."
|
|
return err_msg
|
|
|
|
def assertAutodiffNode(self, graph, should_autodiff_node, nonfusible_nodes, fusible_nodes):
|
|
diff_nodes = graph.findAllNodes('prim::DifferentiableGraph')
|
|
diff_subgraphs = [node.g('Subgraph') for node in diff_nodes]
|
|
|
|
# Note: currently no tests have fusible_nodes
|
|
fusion_nodes = list(chain.from_iterable([g.findAllNodes('prim::FusionGroup') for g in diff_subgraphs]))
|
|
fusion_subgraphs = [node.g('Subgraph') for node in fusion_nodes]
|
|
|
|
# For any non-fusible node, it must show up in one of the DifferentiableGraphs.
|
|
nodes_in_diff_graph = []
|
|
nodes_not_in_diff_graph = []
|
|
non_fusible_nodes_being_fused = []
|
|
for node in nonfusible_nodes:
|
|
if any(g.findNode(node) is not None for g in diff_subgraphs):
|
|
nodes_in_diff_graph.append(node)
|
|
else:
|
|
nodes_not_in_diff_graph.append(node)
|
|
if any(g.findNode(node) is not None for g in fusion_subgraphs):
|
|
non_fusible_nodes_being_fused.append(node)
|
|
found_all_nonfusible_nodes = len(nodes_in_diff_graph) == len(nonfusible_nodes)
|
|
|
|
# For any fusible node, it must show up in one of the FusionGroups in one of the DifferentiableGraphs.
|
|
fusion_nodes_found = []
|
|
fusion_nodes_not_found = []
|
|
for node in fusible_nodes:
|
|
if any(g.findNode(node) is not None for g in fusion_subgraphs):
|
|
fusion_nodes_found.append(node)
|
|
else:
|
|
fusion_nodes_not_found.append(node)
|
|
found_all_fusible_nodes = len(fusion_nodes_found) == len(fusible_nodes)
|
|
|
|
if should_autodiff_node is not None:
|
|
err_msg = self.autoDiffErrorMessage(should_autodiff_node,
|
|
nodes_not_in_diff_graph,
|
|
fusion_nodes_not_found,
|
|
non_fusible_nodes_being_fused,
|
|
fusion_nodes_found,
|
|
nodes_in_diff_graph)
|
|
self.assertEqual(should_autodiff_node,
|
|
found_all_nonfusible_nodes and found_all_fusible_nodes, err_msg)
|
|
|
|
def checkShapeAnalysis(self, out_sizes: Union[List[int], List[List[int]]],
|
|
traced_graph, assert_propagation, constant_prop=True):
|
|
# repropagte input shapes provided by tracing,
|
|
prev_symbolic_shapes_test_enabled = torch._C._jit_symbolic_shapes_test_mode_enabled()
|
|
for enable_test_mode in [True, False]:
|
|
# here we are testing allowing/disallowing substituting in complete shapes as constants,
|
|
# disallowing constants helps stress test partial eval and substitution pipeline
|
|
torch._C._jit_set_symbolic_shapes_test_mode(enable_test_mode)
|
|
torch._C._jit_erase_non_input_shape_information(traced_graph)
|
|
if constant_prop:
|
|
torch._C._jit_pass_constant_propagation(traced_graph)
|
|
torch._C._jit_pass_propagate_shapes_on_graph(traced_graph)
|
|
# Add sizes to default tensor type to avoid checking something out of scope
|
|
# and difficulties with tracer leaving in other parts of tensor type
|
|
output = next(traced_graph.outputs()).type()
|
|
|
|
def test_type(type, actual_size):
|
|
sizes = type.symbolic_sizes()
|
|
out_type = TensorType.get().with_sizes(sizes)
|
|
actual_type = TensorType.get().with_sizes(actual_size)
|
|
|
|
# always check actual shape is a subtype of the output
|
|
self.assertTrue(actual_type.isSubtypeOf(out_type))
|
|
|
|
# and then if assertion flag is provided, check shape analysis
|
|
# is successful
|
|
if assert_propagation:
|
|
self.assertEqual(out_type.sizes(), actual_size)
|
|
|
|
if output.isSubtypeOf(torch._C.TensorType.get()):
|
|
test_type(output, out_sizes)
|
|
else:
|
|
tuple_elements = output.elements()
|
|
for i in range(len(tuple_elements)):
|
|
test_type(tuple_elements[i], out_sizes[i])
|
|
|
|
torch._C._jit_set_symbolic_shapes_test_mode(prev_symbolic_shapes_test_enabled)
|