diff --git a/test/ao/sparsity/test_data_sparsifier.py b/test/ao/sparsity/test_data_sparsifier.py index efd75b057aa6..999438215743 100644 --- a/test/ao/sparsity/test_data_sparsifier.py +++ b/test/ao/sparsity/test_data_sparsifier.py @@ -494,9 +494,7 @@ class TestBaseDataSparsifier(_BaseDataSparsiferTestCase): ( emb1, emb2, - ) = nn.Embedding( - 10, 3 - ), nn.Embedding(20, 3) + ) = nn.Embedding(10, 3), nn.Embedding(20, 3) emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3) emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3) @@ -627,9 +625,7 @@ class TestNormDataSparsifiers(_NormDataSparsifierTestCase): ( emb1, emb2, - ) = nn.Embedding( - 10, 3 - ), nn.Embedding(20, 3) + ) = nn.Embedding(10, 3), nn.Embedding(20, 3) emb1_bag, emb2_bag = nn.EmbeddingBag(10, 3), nn.EmbeddingBag(20, 3) emb3, emb3_bag = nn.Embedding(15, 3), nn.EmbeddingBag(20, 3) diff --git a/test/ao/sparsity/test_kernels.py b/test/ao/sparsity/test_kernels.py index be2f3b8cd982..1ffdca5fd343 100644 --- a/test/ao/sparsity/test_kernels.py +++ b/test/ao/sparsity/test_kernels.py @@ -218,12 +218,12 @@ def _sparse_layer_test_helper( qmodule_to_check = fqn_to_module(qmodel, fqn_to_check) # check that the modules were converted as expected - assert isinstance( - sqmodule_to_check, sqmodule_expected_converted_class - ), "Convert failed" - assert isinstance( - qmodule_to_check, qmodule_expected_converted_class - ), "Mapping failed" + assert isinstance(sqmodule_to_check, sqmodule_expected_converted_class), ( + "Convert failed" + ) + assert isinstance(qmodule_to_check, qmodule_expected_converted_class), ( + "Mapping failed" + ) row_block_size, col_block_size = sqmodel.linear._packed_params._weight_bias()[ 2: diff --git a/test/ao/sparsity/test_structured_sparsifier.py b/test/ao/sparsity/test_structured_sparsifier.py index 21a951ddec20..c62cc3d30539 100644 --- a/test/ao/sparsity/test_structured_sparsifier.py +++ b/test/ao/sparsity/test_structured_sparsifier.py @@ -1055,9 +1055,9 @@ class TestFPGMPruner(TestCase): mask1 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-1] mask2 = pruner.groups[0]["module"].parametrizations.weight[0].mask[-2] # Check if either of the least-norm filters is not pruned - assert ( - mask1.item() is not False or mask2.item() is not False - ), "Do not prune all least-norm filters" + assert mask1.item() is not False or mask2.item() is not False, ( + "Do not prune all least-norm filters" + ) # fusion step pruned_model = pruner.prune() diff --git a/test/benchmark_utils/test_benchmark_utils.py b/test/benchmark_utils/test_benchmark_utils.py index 969a1584b68c..1d8d8e7e3594 100644 --- a/test/benchmark_utils/test_benchmark_utils.py +++ b/test/benchmark_utils/test_benchmark_utils.py @@ -66,9 +66,10 @@ def generate_callgrind_artifacts() -> None: json.dump(artifacts, f, indent=4) -def load_callgrind_artifacts() -> ( - tuple[benchmark_utils.CallgrindStats, benchmark_utils.CallgrindStats] -): +def load_callgrind_artifacts() -> tuple[ + benchmark_utils.CallgrindStats, + benchmark_utils.CallgrindStats, +]: """Hermetic artifact to unit test Callgrind wrapper. In addition to collecting counts, this wrapper provides some facilities for diff --git a/test/cpp_api_parity/functional_impl_check.py b/test/cpp_api_parity/functional_impl_check.py index b4272a2df1bd..34b9ac158127 100644 --- a/test/cpp_api_parity/functional_impl_check.py +++ b/test/cpp_api_parity/functional_impl_check.py @@ -158,7 +158,8 @@ def compute_functional_name(test_params_dict): return test_params_dict["cpp_function_call"].split("(")[0].replace("F::", "") else: raise RuntimeError( - f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950 + "`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n" + f"{pprint.pformat(test_params_dict)}" ) @@ -179,7 +180,8 @@ def compute_cpp_function_call(test_params_dict, arg_dict, functional_name): ) else: raise RuntimeError( - f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}" # noqa: B950 + "`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n" + f"{pprint.pformat(test_params_dict)}" ) @@ -217,7 +219,8 @@ def write_test_to_test_class( or "cpp_function_call" in test_params_dict ), ( "To enable C++ API parity test, " - f"`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n{pprint.pformat(test_params_dict)}. \n" # noqa: B950 + "`cpp_options_args` or `cpp_function_call` entry must be present in test params dict:\n" + f"{pprint.pformat(test_params_dict)}. \n" "If you are interested in adding the C++ API parity test, please see:\n" "NOTE [How to check NN module / functional API parity between Python and C++ frontends]. \n" "If not, please add `test_cpp_api_parity=False` to the test params dict and file an issue about this." @@ -233,14 +236,16 @@ def write_test_to_test_class( functional_name = compute_functional_name(test_params_dict) - assert hasattr( - torch.nn.functional, functional_name - ), f"`torch.nn.functional` doesn't have function `{functional_name}`. (Discovered while processing\n{pprint.pformat(test_params_dict)}.)" # noqa: B950 + assert hasattr(torch.nn.functional, functional_name), ( + f"`torch.nn.functional` doesn't have function `{functional_name}`. " + f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)" + ) functional_full_name = "F::" + functional_name assert functional_full_name in parity_table["torch::nn::functional"], ( - f"Please add `{functional_full_name}` entry to `torch::nn::functional` section of `test/cpp_api_parity/parity-tracker.md`. " + f"Please add `{functional_full_name}` entry to `torch::nn::functional` " + "section of `test/cpp_api_parity/parity-tracker.md`. " f"(Discovered while processing\n{pprint.pformat(test_params_dict)}.)" ) diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py index 8d98387cf5f2..d4c49bd28d45 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_aten_impl.py @@ -78,9 +78,9 @@ def _kernel_fallback(op, *args, **kwargs): elif op._schema.is_mutable or op is torch.ops.aten._copy_from.default: # Only handle inplace ops returning their first arg assert len(args) >= 1, f"Inplace {op} needs at least one arg" - assert ( - len(op._schema.returns) == 1 - ), f"NYI Inplace {op} with more than one return" + assert len(op._schema.returns) == 1, ( + f"NYI Inplace {op} with more than one return" + ) op_name = op.overloadpacket._qualified_op_name real_res = args[0] elif any(r.alias_info is not None for r in op._schema.returns): diff --git a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py index 80194b38aaeb..0f54f2ec4df0 100644 --- a/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py +++ b/test/cpp_extensions/open_registration_extension/pytorch_openreg/_meta_parser.py @@ -67,7 +67,7 @@ def prepare_for_sending(args, kwargs): def convert(obj): if type(obj) not in VALID_QUEUE_TYPES_IN: raise RuntimeError( - f"Cannot send object of type {type(obj)} " "over openreg device pipe." + f"Cannot send object of type {type(obj)} over openreg device pipe." ) if isinstance(obj, torch.Tensor): @@ -82,8 +82,7 @@ def receive_after_sending(allocator, args, kwargs): def convert(obj): if type(obj) not in VALID_QUEUE_TYPES_OUT: raise RuntimeError( - f"Received invalid object of type {type(obj)} " - "over openreg device pipe." + f"Received invalid object of type {type(obj)} over openreg device pipe." ) if isinstance(obj, OpenRegTensorMeta): diff --git a/test/distributed/_composable/fsdp/test_fully_shard_comm.py b/test/distributed/_composable/fsdp/test_fully_shard_comm.py index cd37107e5a9f..dcc34b5489aa 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_comm.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_comm.py @@ -561,8 +561,9 @@ class TestFullyShardPrefetch(FSDPTest): FSDPParamGroup.post_backward, events ) # Check the order for normal 1 forward, 1 backward, 1 optimizer step - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): for iter_idx in range(3): loss = model(inp) @@ -617,8 +618,9 @@ class TestFullyShardPrefetch(FSDPTest): FSDPParamGroup.post_backward, events ) # Check the order for multiple forwards before 1 backward - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): loss1 = model(inp) loss2 = model(inp) @@ -703,8 +705,9 @@ class TestFullyShardPrefetch(FSDPTest): post_backward_with_record = self._get_post_backward_with_record( FSDPParamGroup.post_backward, events ) - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): loss1, loss2 = model(inp) expected_events = [ @@ -794,9 +797,11 @@ class TestFullyShardPrefetch(FSDPTest): ("reshard", "", TrainingState.POST_BACKWARD), ("post_backward", "", TrainingState.POST_BACKWARD), ] - with patch_unshard(unshard_with_record), patch_reshard( - reshard_with_record - ), patch_post_backward(post_backward_with_record): + with ( + patch_unshard(unshard_with_record), + patch_reshard(reshard_with_record), + patch_post_backward(post_backward_with_record), + ): set_forward_prefetch(model, num_to_prefetch=1) loss = model(inp) expected_forward_events = [ @@ -882,9 +887,11 @@ class TestFullyShardPrefetch(FSDPTest): ("reshard", "layers.3", TrainingState.FORWARD), ("reshard", "", TrainingState.FORWARD), ] - with patch_unshard(unshard_with_record), patch_reshard( - reshard_with_record - ), patch_post_backward(post_backward_with_record): + with ( + patch_unshard(unshard_with_record), + patch_reshard(reshard_with_record), + patch_post_backward(post_backward_with_record), + ): set_backward_prefetch(model, num_to_prefetch=1) loss = model(inp) self.assertEqual(events, expected_forward_events) @@ -967,8 +974,9 @@ class TestFullyShardPrefetch(FSDPTest): (2, model_args.max_seq_len), device=device_type.type, ) - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): for _ in range(3): loss = model(inp) @@ -1046,8 +1054,9 @@ class TestFullyShardPrefetch(FSDPTest): FSDPParamGroup.post_backward, events ) inp = torch.randn((2, 16), device=device_type.type) - with patch_unshard(unshard_with_record), patch_post_backward( - post_backward_with_record + with ( + patch_unshard(unshard_with_record), + patch_post_backward(post_backward_with_record), ): for _ in range(3): loss = model(inp) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_compile.py b/test/distributed/_composable/fsdp/test_fully_shard_compile.py index 02a3377babf4..c376aa0e1aa0 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_compile.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_compile.py @@ -222,9 +222,7 @@ class TestFullyShardCompile(FSDPTest): ): unsharded_param_graph_inputs.add(node.args[0]) assert len(unsharded_param_graph_inputs) > 0 - assert len(unsharded_param_graph_inputs) == len( - list(model.parameters()) - ), """\ + assert len(unsharded_param_graph_inputs) == len(list(model.parameters())), """\ Expected all model parameters to be wrapped by FSDP2 and have their unsharded version as graph input, but it's not true! """ @@ -237,7 +235,7 @@ have their unsharded version as graph input, but it's not true! no_aliased_unsharded_params_in_graph_inputs = False err_msg += f"""\n Found aliased unsharded param in graph inputs: {aliased_graph_inputs}, -val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, +val.shape: {[node.meta["val"].shape for node in aliased_graph_inputs]}, """ self.assertTrue(no_aliased_unsharded_params_in_graph_inputs, err_msg) @@ -466,10 +464,9 @@ val.shape: {[node.meta['val'].shape for node in aliased_graph_inputs]}, @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") def test_compiled_autograd_ctx(self): self.skipTestForOldSm() - with torch._dynamo.config.patch( - skip_fsdp_hooks=False, - ), torch._functorch.config.patch( - recompute_views=True, + with ( + torch._dynamo.config.patch(skip_fsdp_hooks=False), + torch._functorch.config.patch(recompute_views=True), ): inputs = torch.randn(8, 8) model = torch.nn.Linear(8, 8) @@ -567,24 +564,28 @@ Unsupported Tensor.backward() call torch._dynamo.reset() torch._dynamo.compiled_autograd.reset() - with torch._dynamo.config.patch( - compiled_autograd=True, - compiled_autograd_kwargs_override={ - "fullgraph": True, - }, - inline_inbuilt_nn_modules=True, - skip_fsdp_hooks=False, - ), torch._functorch.config.patch( - enable_autograd_cache=False, - recompute_views=True, - ), torch._inductor.config.patch( - force_disable_caches=True, - reorder_for_compute_comm_overlap=True, - reorder_for_compute_comm_overlap_passes=[ - "sink_waits", - "raise_comms", - "reorder_compute_for_overlap", - ], + with ( + torch._dynamo.config.patch( + compiled_autograd=True, + compiled_autograd_kwargs_override={ + "fullgraph": True, + }, + inline_inbuilt_nn_modules=True, + skip_fsdp_hooks=False, + ), + torch._functorch.config.patch( + enable_autograd_cache=False, + recompute_views=True, + ), + torch._inductor.config.patch( + force_disable_caches=True, + reorder_for_compute_comm_overlap=True, + reorder_for_compute_comm_overlap_passes=[ + "sink_waits", + "raise_comms", + "reorder_compute_for_overlap", + ], + ), ): losses_compiled = test_compiled() losses_eager = test_eager() @@ -741,20 +742,21 @@ Unsupported Tensor.backward() call def _test_nested_fully_shard_backend_inductor_fullgraph_True(self): self.skipTestForOldSm() for fwd_fullgraph in [True]: - with self._reinplace_all_gather_with_optional_checks( - fwd_fullgraph - ), torch._inductor.config.patch( - post_grad_custom_post_pass=( - functools.partial( - self._check_fsdp_copy_and_resize_ops_count_in_graph, - fwd_copy_count=0, - fwd_resize_count=0, - bwd_copy_count=0, - bwd_resize_count=0, + with ( + self._reinplace_all_gather_with_optional_checks(fwd_fullgraph), + torch._inductor.config.patch( + post_grad_custom_post_pass=( + functools.partial( + self._check_fsdp_copy_and_resize_ops_count_in_graph, + fwd_copy_count=0, + fwd_resize_count=0, + bwd_copy_count=0, + bwd_resize_count=0, + ) + if fwd_fullgraph + else None ) - if fwd_fullgraph - else None - ) + ), ): _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( @@ -943,9 +945,10 @@ Unsupported Tensor.backward() call for fwd_fullgraph, all_requires_grad in itertools.product( [True], [True, False] ): - with self._maybe_add_graph_break_to_sdpa( - fwd_fullgraph - ), self._reinplace_all_gather_with_optional_checks(fwd_fullgraph): + with ( + self._maybe_add_graph_break_to_sdpa(fwd_fullgraph), + self._reinplace_all_gather_with_optional_checks(fwd_fullgraph), + ): self._test_traceable_fsdp( *self._create_transformer_factory_fns( all_requires_grad=all_requires_grad @@ -982,23 +985,24 @@ Unsupported Tensor.backward() call log.warning( f"fwd_fullgraph={fwd_fullgraph}, all_requires_grad={all_requires_grad}, activation_checkpoint={activation_checkpoint}" # noqa: G004, G001, B950 ) - with self._reinplace_all_gather_with_optional_checks( - fwd_fullgraph - ), torch._inductor.config.patch( - post_grad_custom_post_pass=( - functools.partial( - self._check_fsdp_copy_and_resize_ops_count_in_graph, - # NOTE: For the root unsharded params, we don't reshard after forward since for training, - # the parameters would be freed and all-gathered immediately. Hence we still have - # their resize and copy ops in the graph. - fwd_copy_count=4, - fwd_resize_count=4, - bwd_copy_count=0, - bwd_resize_count=4, + with ( + self._reinplace_all_gather_with_optional_checks(fwd_fullgraph), + torch._inductor.config.patch( + post_grad_custom_post_pass=( + functools.partial( + self._check_fsdp_copy_and_resize_ops_count_in_graph, + # NOTE: For the root unsharded params, we don't reshard after forward since for training, + # the parameters would be freed and all-gathered immediately. Hence we still have + # their resize and copy ops in the graph. + fwd_copy_count=4, + fwd_resize_count=4, + bwd_copy_count=0, + bwd_resize_count=4, + ) + if fwd_fullgraph + else None ) - if fwd_fullgraph - else None - ) + ), ): _, triton_codes = run_and_get_code( lambda: self._test_traceable_fsdp( diff --git a/test/distributed/_composable/fsdp/test_fully_shard_extensions.py b/test/distributed/_composable/fsdp/test_fully_shard_extensions.py index f8888d12fc9a..0b25e09b3def 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_extensions.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_extensions.py @@ -385,9 +385,9 @@ class TestFullyShardAllGatherExtensionsMultiThread( only some ranks may require padding, in which case only those ranks will error out and the all-gather will timeout. """ - assert ( - self.world_size >= 2 - ), f"Assumes world size of at least 2 but got {self.world_size=}" + assert self.world_size >= 2, ( + f"Assumes world size of at least 2 but got {self.world_size=}" + ) model = MLP(dim=3, dim_multiplier=3) for module in model.modules(): for param_name, param in module.named_parameters(recurse=False): diff --git a/test/distributed/_composable/fsdp/test_fully_shard_frozen.py b/test/distributed/_composable/fsdp/test_fully_shard_frozen.py index 467b63563b82..f56c5e76c122 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_frozen.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_frozen.py @@ -115,9 +115,10 @@ class TestFullyShardFrozen(FSDPTest): torch.manual_seed(42 + self.rank + 1) device = device_type - with patch_reduce_scatter( - reduce_scatter - ), patch_register_post_backward_hook_backward(backward_with_count): + with ( + patch_reduce_scatter(reduce_scatter), + patch_register_post_backward_hook_backward(backward_with_count), + ): for iter_idx in range(10): inp = torch.randn((8, lin_dim), device=device) losses: list[torch.Tensor] = [] diff --git a/test/distributed/_composable/fsdp/test_fully_shard_init.py b/test/distributed/_composable/fsdp/test_fully_shard_init.py index 6f5326dab5a7..714145f8b976 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_init.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_init.py @@ -910,9 +910,9 @@ class TestFullyShardProcessGroupInit(FSDPTestMultiThread): @skip_if_lt_x_gpu(1) def test_2d_process_group_init(self): shard_mesh_dim_size = 2 - assert ( - self.world_size % shard_mesh_dim_size == 0 - ), f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}" + assert self.world_size % shard_mesh_dim_size == 0, ( + f"Expects {self.world_size} to be divisible by {shard_mesh_dim_size}" + ) replicate_mesh_dim_size = self.world_size // shard_mesh_dim_size mesh_dim_names = ("replicate", "shard") ref_mesh = init_device_mesh( diff --git a/test/distributed/_composable/fsdp/test_fully_shard_memory.py b/test/distributed/_composable/fsdp/test_fully_shard_memory.py index c3b8f04688ef..44d05ade98f7 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_memory.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_memory.py @@ -54,9 +54,9 @@ class TestFullyShardMemory(FSDPTest): ) ): return # skip since not a common use case - assert ( - self.world_size == 2 - ), f"Requires world size of 2 since some values are hard coded: {self.world_size}" + assert self.world_size == 2, ( + f"Requires world size of 2 since some values are hard coded: {self.world_size}" + ) torch.manual_seed(42) # Pre-run a linear forward (gemm and bias) and backward (gemm) to # allocate the cuBLAS workspaces before measuring the memory usage diff --git a/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py b/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py index 4e5bf9465b45..06881442b748 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_mixed_precision.py @@ -284,9 +284,7 @@ class TestFullyShardMixedPrecisionTraining(FSDPTest): ) # bf16 reduction param.grad = funcol.all_gather_tensor( sharded_grad, gather_dim=0, group=group - ).to( - param.dtype - ) # upcast to fp32 + ).to(param.dtype) # upcast to fp32 ref_optim.step() # fp32 optimizer step self.assertEqual(fsdp_loss, ref_loss) diff --git a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py index c9653d06adea..e8d52f70e0f4 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_overlap.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_overlap.py @@ -139,8 +139,9 @@ class TestFullyShardOverlap(FSDPTest): dist.reduce_scatter_tensor(dummy_rs_output, dummy_rs_input) def fwd_bwd(): - with patch_all_gather(delayed_all_gather), patch_reduce_scatter( - delayed_reduce_scatter + with ( + patch_all_gather(delayed_all_gather), + patch_reduce_scatter(delayed_reduce_scatter), ): loss = model(inp).sum() loss.backward() diff --git a/test/distributed/_composable/fsdp/test_fully_shard_training.py b/test/distributed/_composable/fsdp/test_fully_shard_training.py index b3a32575fc0c..96b2a8b4dfd4 100644 --- a/test/distributed/_composable/fsdp/test_fully_shard_training.py +++ b/test/distributed/_composable/fsdp/test_fully_shard_training.py @@ -74,12 +74,12 @@ class TestFullyShardForwardInputs(FSDPTestMultiThread): # Check that FSDP moved the inputs to GPU, including recursing # into the tuple data structure assert x.device == device, f"Expects {device} but got {x.device}" - assert ( - ys[0].device == device - ), f"Expects {device} but got {ys[0].device}" - assert ( - ys[1].device == device - ), f"Expects {device} but got {ys[1].device}" + assert ys[0].device == device, ( + f"Expects {device} but got {ys[0].device}" + ) + assert ys[1].device == device, ( + f"Expects {device} but got {ys[1].device}" + ) y = ys[0] + ys[1] return x + y + 1 diff --git a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py index ec85a668d74f..89a893037c3b 100644 --- a/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py +++ b/test/distributed/algorithms/ddp_comm_hooks/test_ddp_hooks.py @@ -234,8 +234,8 @@ class DistributedDataParallelCommHookTest(MultiProcessTestCase): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/algorithms/test_join.py b/test/distributed/algorithms/test_join.py index 60982d29cc62..8fd613a47d77 100644 --- a/test/distributed/algorithms/test_join.py +++ b/test/distributed/algorithms/test_join.py @@ -250,9 +250,11 @@ class TestJoin(MultiProcessTestCase): else "Detected at least one rank that exhausted inputs. " "Throwing across all ranks." ) - with self.assertRaisesRegex( - RuntimeError, expected_msg - ) if throw_on_early_termination else contextlib.nullcontext(): + with ( + self.assertRaisesRegex(RuntimeError, expected_msg) + if throw_on_early_termination + else contextlib.nullcontext() + ): with Join( allreducers, enable=enable, diff --git a/test/distributed/checkpoint/test_state_dict.py b/test/distributed/checkpoint/test_state_dict.py index b0a8ae3f58c9..9c4f6fb005a3 100644 --- a/test/distributed/checkpoint/test_state_dict.py +++ b/test/distributed/checkpoint/test_state_dict.py @@ -677,9 +677,7 @@ class TestStateDict(DTensorTestBase, VerifyStateDictMixin): fully_shard(layer) fully_shard(model) optim = torch.optim.Adam(model.parameters(), lr=1e-2) - torch.optim.lr_scheduler.LambdaLR( - optim, lr_lambda=[lambda epoch: 0.95**epoch] - ) + torch.optim.lr_scheduler.LambdaLR(optim, lr_lambda=[lambda epoch: 0.95**epoch]) opt_state_dict = ptd_state_dict.get_optimizer_state_dict( model, optim, diff --git a/test/distributed/checkpoint/test_state_dict_stager.py b/test/distributed/checkpoint/test_state_dict_stager.py index 4b721bc4d19b..57f3c014e88d 100644 --- a/test/distributed/checkpoint/test_state_dict_stager.py +++ b/test/distributed/checkpoint/test_state_dict_stager.py @@ -228,9 +228,9 @@ class TestStateDictStager(TestCase): # Validate tensor count and bytes expected_storage_cnt = 2 - assert ( - num_storages == expected_storage_cnt - ), f"Expected {expected_storage_cnt} storages, got {num_storages}" + assert num_storages == expected_storage_cnt, ( + f"Expected {expected_storage_cnt} storages, got {num_storages}" + ) # Calculate expected bytes # Note: Only unique storages are counted in the byte count @@ -239,9 +239,9 @@ class TestStateDictStager(TestCase): + tensor3.numel() # tensor1 and tensor2 share storage * tensor3.element_size() # tensor3 and its narrow view share storage ) - assert ( - num_bytes == expected_bytes - ), f"Expected {expected_bytes} bytes, got {num_bytes}" + assert num_bytes == expected_bytes, ( + f"Expected {expected_bytes} bytes, got {num_bytes}" + ) # Verify that the CPU state dict is equivalent to the original CUDA state dict result, error = compare_state_dicts(state_dict, cpu_state_dict) assert result, f"State dicts are not equivalent: {error}" @@ -301,9 +301,9 @@ class TestStateDictStager(TestCase): # Verify the first result is correct result, error = compare_state_dicts(state_dict, cpu_state_dict1) - assert ( - result - ), f"First state dict is not equivalent to original: {error}" + assert result, ( + f"First state dict is not equivalent to original: {error}" + ) # Modify the original tensors tensor1.fill_(0) @@ -317,14 +317,14 @@ class TestStateDictStager(TestCase): # Verify that the second CPU state dict is equivalent to the modified original state dict result, error = compare_state_dicts(state_dict, cpu_state_dict2) - assert ( - result - ), f"Second state dict is not equivalent to modified original: {error}" + assert result, ( + f"Second state dict is not equivalent to modified original: {error}" + ) # Verify that the number of cached storages hasn't changed - assert ( - num_storages1 == num_storages2 - ), f"Storage count changed: {num_storages1} vs {num_storages2}" + assert num_storages1 == num_storages2, ( + f"Storage count changed: {num_storages1} vs {num_storages2}" + ) # Verify that the tensors in the second state dict have the same storage pointers as the first assert ( @@ -347,12 +347,12 @@ class TestStateDictStager(TestCase): cpu_state_dict3 = stager.stage(state_dict) # Verify that the third CPU state dict reflects the updated values - assert torch.all( - cpu_state_dict3["tensor1"] == 42.0 - ), "Updated values should be reflected in the cached state dict" - assert torch.all( - cpu_state_dict3["tensor2"] == 42.0 - ), "Updated values should be reflected in the cached state dict" + assert torch.all(cpu_state_dict3["tensor1"] == 42.0), ( + "Updated values should be reflected in the cached state dict" + ) + assert torch.all(cpu_state_dict3["tensor2"] == 42.0), ( + "Updated values should be reflected in the cached state dict" + ) @requires_cuda def test_tensor_attrs(self): @@ -381,24 +381,24 @@ class TestStateDictStager(TestCase): cpu_state_dict = stager.stage(state_dict) # Verify that tensor attributes are preserved - assert hasattr( - cpu_state_dict["tensor1"], "a" - ), "Tensor attribute 'a' was not preserved" - assert ( - cpu_state_dict["tensor1"].a == 42 - ), "Tensor attribute 'a' has incorrect value" - assert hasattr( - cpu_state_dict["tensor1"], "b" - ), "Tensor attribute 'b' was not preserved" - assert ( - cpu_state_dict["tensor1"].b == 43 - ), "Tensor attribute 'b' has incorrect value" - assert hasattr( - cpu_state_dict["recursive"]["tensor3"], "c" - ), "Tensor attribute 'c' was not preserved" - assert ( - cpu_state_dict["recursive"]["tensor3"].c == 44 - ), "Tensor attribute 'c' has incorrect value" + assert hasattr(cpu_state_dict["tensor1"], "a"), ( + "Tensor attribute 'a' was not preserved" + ) + assert cpu_state_dict["tensor1"].a == 42, ( + "Tensor attribute 'a' has incorrect value" + ) + assert hasattr(cpu_state_dict["tensor1"], "b"), ( + "Tensor attribute 'b' was not preserved" + ) + assert cpu_state_dict["tensor1"].b == 43, ( + "Tensor attribute 'b' has incorrect value" + ) + assert hasattr(cpu_state_dict["recursive"]["tensor3"], "c"), ( + "Tensor attribute 'c' was not preserved" + ) + assert cpu_state_dict["recursive"]["tensor3"].c == 44, ( + "Tensor attribute 'c' has incorrect value" + ) @requires_cuda def test_different_dtypes(self): diff --git a/test/distributed/elastic/multiprocessing/api_test.py b/test/distributed/elastic/multiprocessing/api_test.py index a6acc177ec81..6e0f273a7c8e 100644 --- a/test/distributed/elastic/multiprocessing/api_test.py +++ b/test/distributed/elastic/multiprocessing/api_test.py @@ -500,11 +500,13 @@ if not (TEST_WITH_DEV_DBG_ASAN or IS_WINDOWS or IS_MACOS): logs_specs=DefaultLogsSpecs(log_dir=self.log_dir()), ) - with mock.patch.object( - mpc, "_is_done", return_value=True - ), mock.patch.object(mpc, "_pc"), mock.patch.object( - mpc._pc, "join", side_effect=[True, False, False, True] - ) as mock_join: + with ( + mock.patch.object(mpc, "_is_done", return_value=True), + mock.patch.object(mpc, "_pc"), + mock.patch.object( + mpc._pc, "join", side_effect=[True, False, False, True] + ) as mock_join, + ): mpc._poll() self.assertEqual(4, mock_join.call_count) diff --git a/test/distributed/fsdp/test_distributed_checkpoint.py b/test/distributed/fsdp/test_distributed_checkpoint.py index 42111efc8922..ac34246ee643 100644 --- a/test/distributed/fsdp/test_distributed_checkpoint.py +++ b/test/distributed/fsdp/test_distributed_checkpoint.py @@ -56,32 +56,36 @@ class TestDistributedCheckpoint(FSDPTest): torch.manual_seed(200) new_model = wrap(SkipModel(double_nest=True)) - with FullyShardedDataParallel.summon_full_params( - model - ), FullyShardedDataParallel.summon_full_params(new_model): + with ( + FullyShardedDataParallel.summon_full_params(model), + FullyShardedDataParallel.summon_full_params(new_model), + ): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertNotEqual(params, new_params) writer = FileSystemWriter(self.temp_dir) reader = FileSystemReader(self.temp_dir) - with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( - new_model, state_dict_type + with ( + FSDP.state_dict_type(model, state_dict_type), + FSDP.state_dict_type(new_model, state_dict_type), ): state_dict = model.state_dict() save(state_dict, writer) - with FSDP.state_dict_type(model, state_dict_type), FSDP.state_dict_type( - new_model, state_dict_type + with ( + FSDP.state_dict_type(model, state_dict_type), + FSDP.state_dict_type(new_model, state_dict_type), ): state_dict = new_model.state_dict() load(state_dict, reader) new_model.load_state_dict(state_dict) - with FullyShardedDataParallel.summon_full_params( - model - ), FullyShardedDataParallel.summon_full_params(new_model): + with ( + FullyShardedDataParallel.summon_full_params(model), + FullyShardedDataParallel.summon_full_params(new_model), + ): params = list(model.parameters()) new_params = list(new_model.parameters()) self.assertEqual(params, new_params) diff --git a/test/distributed/fsdp/test_fsdp_comm.py b/test/distributed/fsdp/test_fsdp_comm.py index aedeb688977d..42fa03162301 100644 --- a/test/distributed/fsdp/test_fsdp_comm.py +++ b/test/distributed/fsdp/test_fsdp_comm.py @@ -242,11 +242,10 @@ class TestCommunication(FSDPTest): # and if `use_no_sync=False`, we only run `num_iters` iterations # outside `no_sync()` num_iters = 3 - with patch( - "torch.distributed.all_gather_into_tensor" - ) as mock_all_gather, patch( - "torch.distributed.reduce_scatter_tensor" - ) as mock_reduce_scatter: + with ( + patch("torch.distributed.all_gather_into_tensor") as mock_all_gather, + patch("torch.distributed.reduce_scatter_tensor") as mock_reduce_scatter, + ): def reset_mocks(): mock_all_gather.reset_mock() diff --git a/test/distributed/fsdp/test_fsdp_core.py b/test/distributed/fsdp/test_fsdp_core.py index 5f8b88bb6e59..d6ee32c1f2e3 100644 --- a/test/distributed/fsdp/test_fsdp_core.py +++ b/test/distributed/fsdp/test_fsdp_core.py @@ -379,12 +379,15 @@ class TestHooks(FSDPTest): register_pre_backward_hooks_call_count += 1 return orig_register_pre_backward_hooks(*args, **kwargs) - with mock.patch( - "torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks", - _register_pre_backward_hooks_with_count, - ), mock.patch( - "torch.distributed.fsdp._runtime_utils._register_post_backward_hook" - ) as register_post_bwd_mock: + with ( + mock.patch( + "torch.distributed.fsdp._runtime_utils._register_pre_backward_hooks", + _register_pre_backward_hooks_with_count, + ), + mock.patch( + "torch.distributed.fsdp._runtime_utils._register_post_backward_hook" + ) as register_post_bwd_mock, + ): self.assertEqual(register_pre_backward_hooks_call_count, 0) self.assertFalse(register_post_bwd_mock.called) fsdp_model(*input) diff --git a/test/distributed/fsdp/test_fsdp_grad_acc.py b/test/distributed/fsdp/test_fsdp_grad_acc.py index 1e51938a033f..b674b408462c 100644 --- a/test/distributed/fsdp/test_fsdp_grad_acc.py +++ b/test/distributed/fsdp/test_fsdp_grad_acc.py @@ -152,9 +152,9 @@ class TestGradAcc(FSDPTest): batches.append(tuple(permute_tensor(t) for t in batch)) for batch1, batch2 in itertools.combinations(batches, r=2): for t1, t2 in zip(batch1, batch2): - assert not torch.all( - t1 == t2 - ), "Check the test to make sure that batches are distinct" + assert not torch.all(t1 == t2), ( + "Check the test to make sure that batches are distinct" + ) # Concatenate the batches along the given batch dimension concat_batch: tuple[torch.Tensor, ...] = tuple( diff --git a/test/distributed/fsdp/test_fsdp_hybrid_shard.py b/test/distributed/fsdp/test_fsdp_hybrid_shard.py index dc9b54be2dd7..70c415ae1fe7 100644 --- a/test/distributed/fsdp/test_fsdp_hybrid_shard.py +++ b/test/distributed/fsdp/test_fsdp_hybrid_shard.py @@ -121,8 +121,9 @@ class TestFSDPHybridShard(FSDPTest): def test_hsdp_save_load_state_dict(self): model = MyModel().cuda() num_node_devices = torch.cuda.device_count() - shard_rank_lists = list(range(0, num_node_devices // 2)), list( - range(num_node_devices // 2, num_node_devices) + shard_rank_lists = ( + list(range(0, num_node_devices // 2)), + list(range(num_node_devices // 2, num_node_devices)), ) shard_groups = ( dist.new_group(shard_rank_lists[0]), @@ -171,8 +172,9 @@ class TestFSDPHybridShard(FSDPTest): def test_hsdp_sync_module_state(self): model = MyModel().cuda() num_node_devices = torch.cuda.device_count() - shard_rank_lists = list(range(0, num_node_devices // 2)), list( - range(num_node_devices // 2, num_node_devices) + shard_rank_lists = ( + list(range(0, num_node_devices // 2)), + list(range(num_node_devices // 2, num_node_devices)), ) shard_groups = ( dist.new_group(shard_rank_lists[0]), @@ -310,8 +312,9 @@ class TestFSDPHybridShard(FSDPTest): cntr = Counter() patched_allreduce = partial(patched_collective, orig_ar, cntr) patched_reduce_scatter = partial(patched_collective, orig_rs, cntr) - with patch_allreduce(patched_allreduce), patch_reduce_scatter( - patched_reduce_scatter + with ( + patch_allreduce(patched_allreduce), + patch_reduce_scatter(patched_reduce_scatter), ): inp = hsdp_model.get_input(device=torch.cuda.current_device()) out = hsdp_model(inp[0], inp[1]) @@ -355,9 +358,9 @@ class TestFSDPHybridShard(FSDPTest): use_orig_params, hsdp_process_groups=hsdp_pgs, ) - assert ( - hsdp_model._inter_node_pg.size() > 1 - ), "HSDP model initialized without replication" + assert hsdp_model._inter_node_pg.size() > 1, ( + "HSDP model initialized without replication" + ) fsdp_optim = torch.optim.Adam(fsdp_model.parameters(), lr=1e-2) hsdp_optim = torch.optim.Adam(hsdp_model.parameters(), lr=1e-2) torch.manual_seed(global_pg.rank() + 1) diff --git a/test/distributed/fsdp/test_fsdp_mixed_precision.py b/test/distributed/fsdp/test_fsdp_mixed_precision.py index bb54f1c2d2c9..dee38d040346 100644 --- a/test/distributed/fsdp/test_fsdp_mixed_precision.py +++ b/test/distributed/fsdp/test_fsdp_mixed_precision.py @@ -766,9 +766,9 @@ class TestFSDPMixedPrecisionSharded(TestFSDPMixedPrecision): if expect_use_full_prec_in_eval: assert x.dtype == torch.float32, f"Expected fp32, got {x.dtype}" else: - assert ( - x.dtype == low_prec_dtype - ), f"Expected {low_prec_dtype}, got {x.dtype}" + assert x.dtype == low_prec_dtype, ( + f"Expected {low_prec_dtype}, got {x.dtype}" + ) return self.a(x) mp_config = MixedPrecision( diff --git a/test/distributed/fsdp/test_fsdp_tp_integration.py b/test/distributed/fsdp/test_fsdp_tp_integration.py index 326157ec9e41..2cc3858e1269 100644 --- a/test/distributed/fsdp/test_fsdp_tp_integration.py +++ b/test/distributed/fsdp/test_fsdp_tp_integration.py @@ -91,9 +91,9 @@ class TestTPFSDPIntegration(FSDPTest): tensor_parallel_size: int, ) -> tuple[dict[str, int], dict[str, tuple[torch.Size, int]]]: """ """ - assert ( - type(model) is SimpleModel - ), "Expects a `SimpleModel` since the sharding cases on the model definition" + assert type(model) is SimpleModel, ( + "Expects a `SimpleModel` since the sharding cases on the model definition" + ) param_name_to_numel = OrderedDict() param_name_to_sharding_info = OrderedDict() for param_name, param in model.named_parameters(): diff --git a/test/distributed/fsdp/test_fsdp_use_orig_params.py b/test/distributed/fsdp/test_fsdp_use_orig_params.py index a0e1d0a50cc0..7efe6ec6661c 100644 --- a/test/distributed/fsdp/test_fsdp_use_orig_params.py +++ b/test/distributed/fsdp/test_fsdp_use_orig_params.py @@ -654,9 +654,12 @@ class TestFSDPUseOrigParamsUnshardReshard(FSDPTest): losses1 = [] losses2 = [] losses = [] - for _model, _optim in (fsdp_model, optim), ( - fsdp_model_orig_params, - optim_orig_params, + for _model, _optim in ( + (fsdp_model, optim), + ( + fsdp_model_orig_params, + optim_orig_params, + ), ): _optim.zero_grad() loss1 = _model(*inp1) @@ -1166,9 +1169,9 @@ class TestFSDPUseOrigParamsFQNs(FSDPTest): clean_tensor_name(tup[0]) for tup in self.named_parameters() ] params = [tup[1] for tup in self.named_parameters()] - assert ( - param_shapes[0] is not None and param_shapes[1] is not None - ), "`param_sizes` should be set" + assert param_shapes[0] is not None and param_shapes[1] is not None, ( + "`param_sizes` should be set" + ) assert_equal_fn( param_names, [ diff --git a/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py b/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py index 691c43ddb542..f3ab4090e8dc 100755 --- a/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py +++ b/test/distributed/launcher/bin/test_script_is_torchelastic_launched.py @@ -19,6 +19,7 @@ The script itself is not a test case hence no assertions are made in this script see: - test/distributed/launcher/run_test.py#test_is_torchelastic_launched() - test/distributed/launcher/run_test.py#test_is_not_torchelastic_launched() """ + import argparse import torch.distributed as dist diff --git a/test/distributed/optim/test_zero_redundancy_optimizer.py b/test/distributed/optim/test_zero_redundancy_optimizer.py index 0f0ee84cee3d..603f671546a5 100644 --- a/test/distributed/optim/test_zero_redundancy_optimizer.py +++ b/test/distributed/optim/test_zero_redundancy_optimizer.py @@ -292,9 +292,9 @@ class TestZeroRedundancyOptimizerSingleRank(TestZeroRedundancyOptimizer): betas=BETAS, eps=EPS, ) - assert ( - len(o.param_groups) == 2 - ), f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}" + assert len(o.param_groups) == 2, ( + f"Expected 2 ZeRO param groups, but got {len(o.param_groups)}" + ) assert len(o.optim.param_groups) == 2, ( "Expected 2 local optimizer param groups, but got " f"{len(o.optim.param_groups)}" @@ -713,9 +713,9 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): LR = 1e-3 MOMENTUM = 0.99 REFERENCE_RANK = 0 - assert ( - REFERENCE_RANK in subgroup_ranks - ), "Reference rank must be in the new process group" + assert REFERENCE_RANK in subgroup_ranks, ( + "Reference rank must be in the new process group" + ) loss_fn = torch.nn.L1Loss().to(device) def check(optimizer): @@ -1165,22 +1165,28 @@ class TestZeroRedundancyOptimizerDistributed(TestZeroRedundancyOptimizer): # Increased tolerances are needed to pass when using TF32 # See: https://github.com/pytorch/pytorch/issues/67764 - torch.testing.assert_close( - local_loss.cpu(), - ddp_loss.cpu(), - rtol=1e-03, - atol=1e-08, - ), "Losses differ between local optimizer and ZeRO" + ( + torch.testing.assert_close( + local_loss.cpu(), + ddp_loss.cpu(), + rtol=1e-03, + atol=1e-08, + ), + "Losses differ between local optimizer and ZeRO", + ) for local_p, ddp_p in zip( local_model.parameters(), ddp_model.parameters() ): - torch.testing.assert_close( - local_p.cpu(), - ddp_p.cpu(), - rtol=1e-03, - atol=1e-04, - ), "Models differ after a step" + ( + torch.testing.assert_close( + local_p.cpu(), + ddp_p.cpu(), + rtol=1e-03, + atol=1e-04, + ), + "Models differ after a step", + ) @skipIfHpu @skip_if_lt_x_gpu(4) diff --git a/test/distributed/pipelining/test_pipe.py b/test/distributed/pipelining/test_pipe.py index 3e02c4de3c93..8ddb5634811c 100644 --- a/test/distributed/pipelining/test_pipe.py +++ b/test/distributed/pipelining/test_pipe.py @@ -89,9 +89,9 @@ class PipeTests(TestCase): mb_args=(x, y), ) - assert ( - pipe.num_stages == EXPECTED_N_STAGES[ModelClass] - ), f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}" + assert pipe.num_stages == EXPECTED_N_STAGES[ModelClass], ( + f"nstages = {pipe.num_stages}, expect {EXPECTED_N_STAGES[ModelClass]}" + ) ref_out = mod(x, y) out = pipe(x, y)[0] @@ -109,9 +109,7 @@ class PipeTests(TestCase): new_names.update(stage_fqns) if CHECK_FQN_SET_EQUALITY: - assert ( - old_names == new_names - ), f""" + assert old_names == new_names, f""" old names {old_names} new names {new_names} """ diff --git a/test/distributed/pipelining/test_unflatten.py b/test/distributed/pipelining/test_unflatten.py index 5fb30b5e1d17..ae1e684d7c22 100644 --- a/test/distributed/pipelining/test_unflatten.py +++ b/test/distributed/pipelining/test_unflatten.py @@ -60,9 +60,9 @@ class UnflattenTests(TestCase): for stage_idx in range(pipe.num_stages): stage_mod = pipe.get_stage_module(stage_idx) for param_name, _ in stage_mod.named_parameters(): - assert ( - param_name in orig_state_dict - ), f"{param_name} not in original state dict" + assert param_name in orig_state_dict, ( + f"{param_name} not in original state dict" + ) print("Param qualname test passed") # Check equivalence diff --git a/test/distributed/rpc/test_share_memory.py b/test/distributed/rpc/test_share_memory.py index bda98b1df949..97273981d082 100644 --- a/test/distributed/rpc/test_share_memory.py +++ b/test/distributed/rpc/test_share_memory.py @@ -45,9 +45,9 @@ class ShareMemoryRPCPickler(_InternalRPCPickler): for t in torch._tensor_classes: self._dispatch_table[t] = TorchMpReductions.reduce_tensor self._dispatch_table[torch.Tensor] = TorchMpReductions.reduce_tensor - self._dispatch_table[ - torch.nn.parameter.Parameter - ] = TorchMpReductions.reduce_tensor + self._dispatch_table[torch.nn.parameter.Parameter] = ( + TorchMpReductions.reduce_tensor + ) def worker_loop(a): diff --git a/test/distributed/tensor/test_dtensor_compile.py b/test/distributed/tensor/test_dtensor_compile.py index c0859b5925ff..23114f87f46a 100644 --- a/test/distributed/tensor/test_dtensor_compile.py +++ b/test/distributed/tensor/test_dtensor_compile.py @@ -872,9 +872,7 @@ def forward(self, primals_1): "buf0 = torch.ops._c10d_functional.all_gather_into_tensor.default(primal" ).check("torch.ops._c10d_functional.wait_tensor.default(buf0").check( "extern_kernels.mm(buf0," - ).run( - code - ) + ).run(code) @unittest.skipIf(not HAS_GPU, "Inductor+gpu needs triton and recent GPU arch") @skip_if_lt_x_gpu(1) diff --git a/test/distributed/tensor/test_math_ops.py b/test/distributed/tensor/test_math_ops.py index 0ce1206ae1bd..48f92c4ecd74 100644 --- a/test/distributed/tensor/test_math_ops.py +++ b/test/distributed/tensor/test_math_ops.py @@ -441,9 +441,11 @@ class DistMathOpsTest(DTensorTestBase): out_req_grad: bool subtest_fails = {} - valid_filter = lambda cfg: not ( # noqa: E731 - cfg.ln_req_grad and not cfg.elementwise_affine - ) and any(cfg[2:]) + valid_filter = ( # noqa: E731 + lambda cfg: ( + not (cfg.ln_req_grad and not cfg.elementwise_affine) and any(cfg[2:]) + ) + ) subtest_cfgs = list( filter( valid_filter, @@ -566,9 +568,9 @@ class DistMathOpsTest(DTensorTestBase): except Exception as e: subtest_fails[subtest_cfg] = e # if any subtest fails, provide the failed subtests and report the overall failure - assert ( - not subtest_fails - ), f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" + assert not subtest_fails, ( + f"{len(subtest_fails)}/{len(subtest_cfgs)} subtests failed: {pformat(subtest_fails)}" + ) @with_comms def test_topk(self): diff --git a/test/distributed/tensor/test_xla_integration.py b/test/distributed/tensor/test_xla_integration.py index 179b5bc796c8..3fbfcffbd76c 100644 --- a/test/distributed/tensor/test_xla_integration.py +++ b/test/distributed/tensor/test_xla_integration.py @@ -26,7 +26,9 @@ def with_xla(func: Callable) -> Callable: @wraps(func) # pyre-ignore[6] def wrapper( - self, *args: tuple[object], **kwargs: dict[str, Any] # type: ignore[misc] + self, + *args: tuple[object], + **kwargs: dict[str, Any], # type: ignore[misc] ) -> None: # TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag. os.environ["XLA_USE_SPMD"] = "1" diff --git a/test/distributed/test_c10d_common.py b/test/distributed/test_c10d_common.py index 8590f25a351c..efac131e6c38 100644 --- a/test/distributed/test_c10d_common.py +++ b/test/distributed/test_c10d_common.py @@ -2234,8 +2234,8 @@ class LocalRankTest(MultiProcessTestCase): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_c10d_functional_native.py b/test/distributed/test_c10d_functional_native.py index ebdf2c0dcdcb..e49fb2b1036c 100644 --- a/test/distributed/test_c10d_functional_native.py +++ b/test/distributed/test_c10d_functional_native.py @@ -796,13 +796,11 @@ class CompileTest(TestCase): .check("buf6 = empty") # Expect in-place with inductor allocated buf .check( - "torch.ops._c10d_functional.all_reduce_coalesced_" - ".default([buf0, buf1]" + "torch.ops._c10d_functional.all_reduce_coalesced_.default([buf0, buf1]" ) # Expect no in-place with graph input (buf5, buf6 are clones) .check( - "torch.ops._c10d_functional.all_reduce_coalesced_" - ".default([buf5, buf6]" + "torch.ops._c10d_functional.all_reduce_coalesced_.default([buf5, buf6]" ) .check("torch.ops._c10d_functional.wait_tensor.default(buf0") .check("torch.ops._c10d_functional.wait_tensor.default(buf1") diff --git a/test/distributed/test_c10d_gloo.py b/test/distributed/test_c10d_gloo.py index b770d52c01ec..96ad01b95b18 100644 --- a/test/distributed/test_c10d_gloo.py +++ b/test/distributed/test_c10d_gloo.py @@ -2705,8 +2705,8 @@ class LargeCommTest(test_c10d_common.AbstractLargeCommTest, MultiProcessTestCase if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_c10d_nccl.py b/test/distributed/test_c10d_nccl.py index 6240f1315183..c02e968e23fb 100644 --- a/test/distributed/test_c10d_nccl.py +++ b/test/distributed/test_c10d_nccl.py @@ -2869,9 +2869,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase): self.assertTrue(t.is_alive()) if prev_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = prev_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + prev_nccl_async_error_handling + ) @requires_nccl() @skip_if_lt_x_gpu(3) @@ -2931,9 +2931,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase): os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = "0" self._test_barrier_error() if prev_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = prev_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + prev_nccl_async_error_handling + ) @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @@ -2984,9 +2984,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase): process_group.abort() if prev_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = prev_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + prev_nccl_async_error_handling + ) @requires_nccl() @requires_nccl_version((2, 4, 0), "Need NCCL 2.4+ for error checking") @@ -3065,9 +3065,9 @@ class NcclErrorHandlingTest(MultiProcessTestCase): os.remove(new_file_name) if prev_nccl_async_error_handling is not None: - os.environ[ - "TORCH_NCCL_ASYNC_ERROR_HANDLING" - ] = prev_nccl_async_error_handling + os.environ["TORCH_NCCL_ASYNC_ERROR_HANDLING"] = ( + prev_nccl_async_error_handling + ) def _run_invalid_nccl_blocking_wait_env(self, val): os.environ["TORCH_NCCL_BLOCKING_WAIT"] = val @@ -3360,9 +3360,7 @@ class CommTest(test_c10d_common.AbstractCommTest, MultiProcessTestCase): self.assertEqual(_get_intra_node_comm_usage_counter(), 3) # Verify that IntraNodeComm is not used beyond 10MB - t = torch.full( - (10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16 - ).cuda() + t = torch.full((10 * 1024**2 // 2 + 1,), self.rank, dtype=torch.bfloat16).cuda() c10d.all_reduce(t, c10d.ReduceOp.SUM) self.assertTrue(t.eq(expect).all()) self.assertEqual(_get_intra_node_comm_usage_counter(), 3) @@ -4249,9 +4247,9 @@ class SparseCollective(MultiProcessTestCase): class NCCLTraceTestBase(MultiProcessTestCase): def setUp(self): super().setUp() - os.environ[ - "TORCH_NCCL_ENABLE_TIMING" - ] = "0" # see 'timing_enabled' parametrized tests + os.environ["TORCH_NCCL_ENABLE_TIMING"] = ( + "0" # see 'timing_enabled' parametrized tests + ) os.environ["TORCH_NCCL_TRACE_BUFFER_SIZE"] = "1000" os.environ["TORCH_NCCL_DUMP_ON_TIMEOUT"] = "1" self.tempdir = tempfile.TemporaryDirectory() @@ -5331,8 +5329,8 @@ class ProcessGroupNCCLLargerScaleTest(MultiProcessTestCase): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_c10d_ucc.py b/test/distributed/test_c10d_ucc.py index 5e7af710c8a9..e3a4764d594f 100644 --- a/test/distributed/test_c10d_ucc.py +++ b/test/distributed/test_c10d_ucc.py @@ -1090,8 +1090,8 @@ class UccProcessGroupWithDispatchedCollectivesTests( if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_collective_utils.py b/test/distributed/test_collective_utils.py index 2999f318d690..a150a55f77be 100644 --- a/test/distributed/test_collective_utils.py +++ b/test/distributed/test_collective_utils.py @@ -90,9 +90,9 @@ class TestCollectiveUtils(MultiProcessTestCase): res = all_gather(data_or_fn=func, pg=pg) func.assert_called_once() - assert res == list( - range(self.world_size) - ), f"Expect res to be list of 0 through {self.world_size} (got {res})" + assert res == list(range(self.world_size)), ( + f"Expect res to be list of 0 through {self.world_size} (got {res})" + ) def test_all_gather_result_no_pg(self) -> None: """ diff --git a/test/distributed/test_control_collectives.py b/test/distributed/test_control_collectives.py index 594c028ae9d4..8e48735c7779 100644 --- a/test/distributed/test_control_collectives.py +++ b/test/distributed/test_control_collectives.py @@ -207,8 +207,8 @@ class TestCollectives(TestCase): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_device_mesh.py b/test/distributed/test_device_mesh.py index 7ad4c33de431..06502943934f 100644 --- a/test/distributed/test_device_mesh.py +++ b/test/distributed/test_device_mesh.py @@ -490,8 +490,9 @@ class DeviceMeshTestNDim(DTensorTestBase): # Create shard groups (e.g. (0, 1, 2, 3), (4, 5, 6, 7)) # and assign the correct shard group to each rank - shard_rank_lists = list(range(0, self.world_size // 2)), list( - range(self.world_size // 2, self.world_size) + shard_rank_lists = ( + list(range(0, self.world_size // 2)), + list(range(self.world_size // 2, self.world_size)), ) shard_groups = ( new_group(shard_rank_lists[0]), diff --git a/test/distributed/test_dynamo_distributed.py b/test/distributed/test_dynamo_distributed.py index d43c0c3e3e0c..8446282c84ff 100644 --- a/test/distributed/test_dynamo_distributed.py +++ b/test/distributed/test_dynamo_distributed.py @@ -1830,9 +1830,7 @@ class TestSingleProc(DynamoDistributedSingleProcTestCase): f"""{expected_guard_source} "L['self']._modules['net']" TYPE_MATCH""" ).check( f"""{expected_guard_source} "L['self']._modules['net']._modules['0']" TYPE_MATCH""" - ).run( - GUARDS_FILE.getvalue() - ) + ).run(GUARDS_FILE.getvalue()) self.assertTrue(same(correct_outputs, outputs)) diff --git a/test/distributed/test_inductor_collectives.py b/test/distributed/test_inductor_collectives.py index 6c33a6031d28..77dd871a5520 100644 --- a/test/distributed/test_inductor_collectives.py +++ b/test/distributed/test_inductor_collectives.py @@ -519,12 +519,13 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): out = a2a / a2a.sum(dim=0) return out - with _dynamo_dist_per_rank_init( - self.rank, self.world_size - ), torch._dynamo.config.patch( - dynamic_shapes=True, - capture_dynamic_output_shape_ops=True, - capture_scalar_outputs=True, + with ( + _dynamo_dist_per_rank_init(self.rank, self.world_size), + torch._dynamo.config.patch( + dynamic_shapes=True, + capture_dynamic_output_shape_ops=True, + capture_scalar_outputs=True, + ), ): row = self.world_size * (self.rank + 1) * (self.world_size + 1) / 2 input_split_sizes_tensor = torch.tensor( @@ -680,15 +681,15 @@ class TestCollectivesMultiProc(DynamoDistributedMultiProcTestCase): return torch.ops.custom_ns.foo(a2a) - with _dynamo_dist_per_rank_init( - self.rank, self.world_size - ), torch._dynamo.config.patch( - dynamic_shapes=True, - capture_dynamic_output_shape_ops=True, - capture_scalar_outputs=True, - ), torch.library._scoped_library( - "custom_ns", "FRAGMENT" - ) as lib: + with ( + _dynamo_dist_per_rank_init(self.rank, self.world_size), + torch._dynamo.config.patch( + dynamic_shapes=True, + capture_dynamic_output_shape_ops=True, + capture_scalar_outputs=True, + ), + torch.library._scoped_library("custom_ns", "FRAGMENT") as lib, + ): lib.define( "alltoall_autograd(Tensor input, SymInt[]? output_split_sizes, SymInt[]? input_split_sizes, str tag, int[] ranks, int group_size) -> Tensor" # noqa: B950 ) diff --git a/test/distributed/test_pg_wrapper.py b/test/distributed/test_pg_wrapper.py index d7e59f1c90a7..4c96d4b564d6 100644 --- a/test/distributed/test_pg_wrapper.py +++ b/test/distributed/test_pg_wrapper.py @@ -464,8 +464,8 @@ class ProcessGroupGlooWrapperTest(AbstractProcessGroupWrapperTest): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_pg_wrapper must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_pg_wrapper must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/distributed/test_store.py b/test/distributed/test_store.py index 25a554942c82..e9abb1d90717 100644 --- a/test/distributed/test_store.py +++ b/test/distributed/test_store.py @@ -372,12 +372,8 @@ class TCPStoreTest(TestCase, StoreTestBase): # Use noqa to silence flake8. # Need to store in an unused variable here to ensure the first # object is not destroyed before the second object is created. - store1 = dist.TCPStore( - addr, port, 1, True, use_libuv=self._use_libuv - ) # noqa: F841 - store2 = dist.TCPStore( - addr, port, 1, True, use_libuv=self._use_libuv - ) # noqa: F841 + store1 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841 + store2 = dist.TCPStore(addr, port, 1, True, use_libuv=self._use_libuv) # noqa: F841 self.assertEqual(store1.libuvBackend, self._use_libuv) self.assertEqual(store2.libuvBackend, self._use_libuv) @@ -767,7 +763,7 @@ class RendezvousFileTest(TestCase): def test_nominal(self): with tempfile.NamedTemporaryFile(delete=False) as file: - url = f'file:///{file.name.replace(os.path.sep, "/")}?world_size=2' + url = f"file:///{file.name.replace(os.path.sep, '/')}?world_size=2" gen0 = dist.rendezvous(url + "&rank=0") store0, rank0, size0 = next(gen0) self.assertEqual(0, rank0) @@ -1178,8 +1174,8 @@ class TestClientProtocol(TestCase): if __name__ == "__main__": - assert ( - not torch.cuda._initialized - ), "test_distributed must not have initialized CUDA context on main process" + assert not torch.cuda._initialized, ( + "test_distributed must not have initialized CUDA context on main process" + ) run_tests() diff --git a/test/dynamo/test_activation_checkpointing.py b/test/dynamo/test_activation_checkpointing.py index c1ab329a137d..6699c973052b 100644 --- a/test/dynamo/test_activation_checkpointing.py +++ b/test/dynamo/test_activation_checkpointing.py @@ -81,9 +81,9 @@ def count_ops( for node in gm.graph.nodes: if match_rng_op(node, op) or node.target == op: actual_count += 1 - assert ( - actual_count >= freq_ge - ), f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}." + assert actual_count >= freq_ge, ( + f"In graph {gm}, expected {op} to have occurred at least {freq_ge} times in the graph, but got {actual_count}." + ) return gm diff --git a/test/dynamo/test_autograd_function.py b/test/dynamo/test_autograd_function.py index 80aa5c1025fe..978a23f80946 100644 --- a/test/dynamo/test_autograd_function.py +++ b/test/dynamo/test_autograd_function.py @@ -455,9 +455,9 @@ class AutogradFunctionTests(torch._dynamo.test_case.TestCase): # Modify gradient using .data (Dangerous: Breaks autograd tracking!) modified_grad = grad_output.clone() - modified_grad.data[ - input_tensor.data < 0 - ] = 0 # Zero-out gradients for negative inputs + modified_grad.data[input_tensor.data < 0] = ( + 0 # Zero-out gradients for negative inputs + ) return modified_grad * 3 diff --git a/test/dynamo/test_base_hop.py b/test/dynamo/test_base_hop.py index b185a1a13339..18cdf78c61f2 100644 --- a/test/dynamo/test_base_hop.py +++ b/test/dynamo/test_base_hop.py @@ -195,10 +195,13 @@ class GraphModule(torch.nn.Module): def f(x, y): return invoke_quant_test(inner, [x, y], scheme="nf4") - with mock.patch( - "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", - True, - ), torch.no_grad(): + with ( + mock.patch( + "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", + True, + ), + torch.no_grad(), + ): torch.compile(f, backend=bk, fullgraph=True)(x.clone(), y) self.assertEqual(len(bk.graphs), 1) @@ -319,10 +322,13 @@ class GraphModule(torch.nn.Module): x = torch.randn(3, 3, requires_grad=False) x_clone = x.clone() y = torch.randn(3, 3, requires_grad=True) - with mock.patch( - "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", - True, - ), torch.no_grad(): + with ( + mock.patch( + "torch._dynamo.variables.higher_order_ops.BaseHOPVariable.supports_input_mutation", + True, + ), + torch.no_grad(), + ): compiled_out = torch.compile(f, backend=backend, fullgraph=True)(x, y) self.assertEqual(x, x_clone + 1) self.assertEqual(compiled_out, x_clone + y + 1) diff --git a/test/dynamo/test_bytecode_utils.py b/test/dynamo/test_bytecode_utils.py index fa906a2ac162..b91b8156ec18 100644 --- a/test/dynamo/test_bytecode_utils.py +++ b/test/dynamo/test_bytecode_utils.py @@ -53,8 +53,8 @@ class BytecodeTests(torch._dynamo.test_case.TestCase): fn_str = f"""\ def fn(): foo.bar(1, 2, 3) -{str(chr(10)).join(' ' * 4 + 'x' + str(i) + ' = 1' for i in range(1 << 9))} - l = [{' '.join('x' + str(i) + ',' for i in range(1 << 9))}] +{str(chr(10)).join(" " * 4 + "x" + str(i) + " = 1" for i in range(1 << 9))} + l = [{" ".join("x" + str(i) + "," for i in range(1 << 9))}] """ locals = {} exec(fn_str, {}, locals) diff --git a/test/dynamo/test_callback.py b/test/dynamo/test_callback.py index 86bd692e8d5e..1d221f635538 100644 --- a/test/dynamo/test_callback.py +++ b/test/dynamo/test_callback.py @@ -26,9 +26,10 @@ class CallbackTests(TestCase): def test_callbacks_with_duplicate_prevention(self) -> None: trigger = CallbackTrigger.DYNAMO compile_id = CompileId(0, 0) - with callback_handler.install_callbacks( - trigger, compile_id - ), callback_handler.install_callbacks(trigger, compile_id): + with ( + callback_handler.install_callbacks(trigger, compile_id), + callback_handler.install_callbacks(trigger, compile_id), + ): self._on_compile_start.assert_called_once() self._on_compile_end.assert_called_once() diff --git a/test/dynamo/test_ctx_manager.py b/test/dynamo/test_ctx_manager.py index d0216ed59038..3f0edd939a56 100644 --- a/test/dynamo/test_ctx_manager.py +++ b/test/dynamo/test_ctx_manager.py @@ -1252,10 +1252,13 @@ class CtxManagerTests(torch._dynamo.test_case.TestCase): def f(x, y): return x + y - x, y = torch.ones( - 1, - ), torch.zeros( - 1, + x, y = ( + torch.ones( + 1, + ), + torch.zeros( + 1, + ), ) return f(x, y) @@ -1289,10 +1292,13 @@ class GraphModule(torch.nn.Module): def f(x, y): return x + y - x, y = torch.ones( - 1, - ), torch.zeros( - 1, + x, y = ( + torch.ones( + 1, + ), + torch.zeros( + 1, + ), ) return f(x, y) @@ -1335,10 +1341,13 @@ class GraphModule(torch.nn.Module): return inner_fn(x, y) + x - x, y = torch.ones( - 1, - ), torch.zeros( - 1, + x, y = ( + torch.ones( + 1, + ), + torch.zeros( + 1, + ), ) return f(x, y) diff --git a/test/dynamo/test_exceptions.py b/test/dynamo/test_exceptions.py index 07bb57603260..f3a2a1d7c772 100644 --- a/test/dynamo/test_exceptions.py +++ b/test/dynamo/test_exceptions.py @@ -19,7 +19,7 @@ from torch.testing._internal.common_utils import ( class CustomException(Exception): - ... + pass class CustomExceptionMeta(type): @@ -28,7 +28,7 @@ class CustomExceptionMeta(type): class CustomExceptionWithInstanceCheck(Exception, metaclass=CustomExceptionMeta): - ... + pass class CustomExceptionWithArgs(Exception): @@ -358,7 +358,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): def test_raise_custom_exception(self): class Exc(Exception): - ... + pass @torch.compile(backend="eager", fullgraph=True) def fn(t): @@ -375,7 +375,7 @@ class ExceptionTests(torch._dynamo.test_case.TestCase): def test_raise_custom_exception_with_args(self): class Exc(Exception): - ... + pass @torch.compile(backend="eager", fullgraph=True) def fn(t): diff --git a/test/dynamo/test_guard_serialization.py b/test/dynamo/test_guard_serialization.py index 6fa40064beb1..9e93f3048ea8 100644 --- a/test/dynamo/test_guard_serialization.py +++ b/test/dynamo/test_guard_serialization.py @@ -324,9 +324,13 @@ class TestGuardSerialization(torch._inductor.test_case.TestCase): distributed_state=None, package=None, ) - with compile_context(CompileContext(CompileId(0, 0))), tracing( - tracer.output.tracing_context - ), tracer.set_current_tx(), get_metrics_context(), dynamo_timed(""): + with ( + compile_context(CompileContext(CompileId(0, 0))), + tracing(tracer.output.tracing_context), + tracer.set_current_tx(), + get_metrics_context(), + dynamo_timed(""), + ): tracer.run() check_fn_manager = CheckFunctionManager( diff --git a/test/dynamo/test_misc.py b/test/dynamo/test_misc.py index ced119969ace..50ede0b54656 100644 --- a/test/dynamo/test_misc.py +++ b/test/dynamo/test_misc.py @@ -1092,9 +1092,7 @@ not ___dict_contains('bbbbbbbb', G['sys'].modules) not ___dict_contains('cccccccc', G['sys'].modules) str(L['x'].device) == 'cpu' str(L['x'].dtype) == 'torch.float32' -utils_device.CURRENT_DEVICE == None""".split( - "\n" - ): +utils_device.CURRENT_DEVICE == None""".split("\n"): self.assertIn( line, guard_code_str, @@ -2806,7 +2804,7 @@ utils_device.CURRENT_DEVICE == None""".split( "int", np.intp, np.int32, - np.uint8 + np.uint8, # np.dtype('int') # XXX: as above ] @@ -5527,9 +5525,9 @@ utils_device.CURRENT_DEVICE == None""".split( def forward(self, idx, targets=None): b, t = idx.size() - assert ( - t <= self.block_size - ), "Cannot forward, model block size is exhausted." + assert t <= self.block_size, ( + "Cannot forward, model block size is exhausted." + ) # forward the GPT model token_embeddings = self.tok_emb( @@ -6075,15 +6073,17 @@ utils_device.CURRENT_DEVICE == None""".split( def count_graph_break_msgs(msgs): return sum("Graph break in user code" in msg for msg in msgs) - with self.assertLogs( - logger="torch._dynamo", level=logging.DEBUG - ) as log, torch._dynamo.config.patch(verbose=True): + with ( + self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log, + torch._dynamo.config.patch(verbose=True), + ): f1(torch.randn(10), torch.randn(10)) self.assertGreater(count_graph_break_msgs(log.output), 1) - with self.assertLogs( - logger="torch._dynamo", level=logging.DEBUG - ) as log, torch._dynamo.config.patch(verbose=False): + with ( + self.assertLogs(logger="torch._dynamo", level=logging.DEBUG) as log, + torch._dynamo.config.patch(verbose=False), + ): g1(torch.randn(10), torch.randn(10)) self.assertEqual(count_graph_break_msgs(log.output), 1) @@ -8235,8 +8235,9 @@ utils_device.CURRENT_DEVICE == None""".split( def f(a): return h(a) - with warnings.catch_warnings(record=True) as w, self.assertRaises( - torch._dynamo.exc.BackendCompilerFailed + with ( + warnings.catch_warnings(record=True) as w, + self.assertRaises(torch._dynamo.exc.BackendCompilerFailed), ): f(torch.randn(2, 2, requires_grad=True)) @@ -8429,8 +8430,7 @@ utils_device.CURRENT_DEVICE == None""".split( def test_torch_compile_ctx_on_forward_and_training_step(self): class MyModel(torch.nn.Module): - def forward(self): - ... + def forward(self): ... def training_step(self): self() diff --git a/test/dynamo/test_modules.py b/test/dynamo/test_modules.py index 9a8fe50bc8ec..b6cb548647aa 100644 --- a/test/dynamo/test_modules.py +++ b/test/dynamo/test_modules.py @@ -2094,11 +2094,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): mod = MockModule() # Each submod is compiled separately and has a different nn module # guard. Ensure that recompilation logic is handle correctly. - with unittest.mock.patch( - "torch._dynamo.config.error_on_recompile", True - ), unittest.mock.patch( - "torch._dynamo.config.recompile_limit", - recompile_limit, + with ( + unittest.mock.patch("torch._dynamo.config.error_on_recompile", True), + unittest.mock.patch( + "torch._dynamo.config.recompile_limit", + recompile_limit, + ), ): x = torch.randn(*size, requires_grad=True) mod(x) @@ -2160,11 +2161,12 @@ class OptimizedModuleTest(torch._dynamo.test_case.TestCase): mod = MockModule() # Each submod is compiled separately and has a different nn module # guard. Ensure that recompilation logic is handle correctly. - with unittest.mock.patch( - "torch._dynamo.config.error_on_recompile", True - ), unittest.mock.patch( - "torch._dynamo.config.recompile_limit", - recompile_limit, + with ( + unittest.mock.patch("torch._dynamo.config.error_on_recompile", True), + unittest.mock.patch( + "torch._dynamo.config.recompile_limit", + recompile_limit, + ), ): x = torch.randn(*size, requires_grad=True) mod(x) diff --git a/test/dynamo/test_optimizers.py b/test/dynamo/test_optimizers.py index 614baec1e3dc..e74ebc225871 100644 --- a/test/dynamo/test_optimizers.py +++ b/test/dynamo/test_optimizers.py @@ -3,6 +3,7 @@ PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes with test_adam in OptimizerTests) """ + import functools import torch diff --git a/test/dynamo/test_recompile_ux.py b/test/dynamo/test_recompile_ux.py index 4507d3394620..e69c23c95243 100644 --- a/test/dynamo/test_recompile_ux.py +++ b/test/dynamo/test_recompile_ux.py @@ -238,9 +238,7 @@ class RecompileUxTests(torch._dynamo.test_case.TestCase): tensor 'x' size mismatch at index 0. expected 11, actual 12 tensor 'x' size mismatch at index 0. expected 10, actual 12 tensor 'x' size mismatch at index 0. expected 9, actual 12 -tensor 'x' size mismatch at index 0. expected 8, actual 12""".split( - "\n" - ): +tensor 'x' size mismatch at index 0. expected 8, actual 12""".split("\n"): self.assertIn( line, failure_str, @@ -276,9 +274,7 @@ tensor 'x' size mismatch at index 0. expected 8, actual 12""".split( opt_f([7, 8]) for line in """\ -len(x) == 3""".split( - "\n" - ): +len(x) == 3""".split("\n"): self.assertIn(line, filter_reasons()) failure_reasons.clear() @@ -286,9 +282,7 @@ len(x) == 3""".split( for line in """\ len(x) == 2 -len(x) == 3""".split( - "\n" - ): +len(x) == 3""".split("\n"): self.assertIn(line, filter_reasons()) @torch._dynamo.config.patch(recompile_limit=1) diff --git a/test/dynamo/test_repros.py b/test/dynamo/test_repros.py index 49fcacd33428..c9aa02a44ad8 100644 --- a/test/dynamo/test_repros.py +++ b/test/dynamo/test_repros.py @@ -179,9 +179,9 @@ def shapes_to_tensor(x, device=None): if torch.jit.is_scripting(): return torch.as_tensor(x, device=device) if torch.jit.is_tracing(): - assert all( - isinstance(t, torch.Tensor) for t in x - ), "Shape should be tensor during tracing!" + assert all(isinstance(t, torch.Tensor) for t in x), ( + "Shape should be tensor during tracing!" + ) # as_tensor should not be used in tracing because it records a constant ret = torch.stack(x) if ret.device != device: # avoid recording a hard-coded device if not necessary @@ -480,9 +480,9 @@ class PartialT5(torch.nn.Module): real_seq_length = seq_length if past_key_value is not None: - assert ( - len(past_key_value) == 2 - ), f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" + assert len(past_key_value) == 2, ( + f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states" + ) real_seq_length += ( past_key_value[0].shape[2] if query_length is None else query_length ) @@ -4877,9 +4877,9 @@ def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor): with warnings.catch_warnings(record=True): data_len = len(value) if len(self._fields): - assert ( - len(self) == data_len - ), f"Adding a field of length {data_len} to a Instances of length {len(self)}" + assert len(self) == data_len, ( + f"Adding a field of length {data_len} to a Instances of length {len(self)}" + ) self._fields[name] = value def get(self, name: str) -> Any: diff --git a/test/dynamo/test_subclasses.py b/test/dynamo/test_subclasses.py index b741c6b5b9c4..c4e0fdadeedc 100644 --- a/test/dynamo/test_subclasses.py +++ b/test/dynamo/test_subclasses.py @@ -108,9 +108,10 @@ def get_view_test_cases(): for requires_grad_1, requires_grad_2 in itertools.product( [True, False], repeat=2 ): - yield partial( - mk_leaf, base_is_nt, requires_grad_1, requires_grad_2 - ), f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}" + yield ( + partial(mk_leaf, base_is_nt, requires_grad_1, requires_grad_2), + f"{prefix}_leaf_{requires_grad_1}_{requires_grad_2}", + ) # (3) obscure case: # view is not a leaf (implies requires_grad True) @@ -118,9 +119,10 @@ def get_view_test_cases(): yield partial(mk_obscure, base_is_nt), f"{prefix}_obscure" # Subclass -> Dense - yield lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[ - 0 - ].clone(), "subclass_dense" + yield ( + lambda: get_jagged_tensor(((2, 3, 4), 3), None, requires_grad=True)[0].clone(), + "subclass_dense", + ) # Dense -> Subclass -> Dense -> Subclass def mk_dense_subclass_dense_subclass(): diff --git a/test/dynamo/test_trace_rules.py b/test/dynamo/test_trace_rules.py index 90aa18caee48..0125b06c64bd 100644 --- a/test/dynamo/test_trace_rules.py +++ b/test/dynamo/test_trace_rules.py @@ -151,9 +151,9 @@ def gen_allowed_objs_and_ids(record=False, c_binding_only=True) -> AllowedObject types.WrapperDescriptorType, ), ) or is_special_functions(obj): - torch_name_rule_map[ - f"{module.__name__}.{name}" - ] = TorchInGraphFunctionVariable + torch_name_rule_map[f"{module.__name__}.{name}"] = ( + TorchInGraphFunctionVariable + ) if c_binding_only: if not hasattr(obj, "__code__"): c_binding_in_graph_functions.add(obj) @@ -398,12 +398,15 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase): ) self.assertTrue("torch._dynamo" not in torch._dynamo.trace_rules.MOD_INLINELIST) - with unittest.mock.patch( - "torch._dynamo.trace_rules.torch_name_rule_map", - _torch_name_rule_map, - ), unittest.mock.patch( - "torch._dynamo.trace_rules.get_torch_obj_rule_map", - torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache + with ( + unittest.mock.patch( + "torch._dynamo.trace_rules.torch_name_rule_map", + _torch_name_rule_map, + ), + unittest.mock.patch( + "torch._dynamo.trace_rules.get_torch_obj_rule_map", + torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, # bypass functools.lru_cache + ), ): x = torch.rand(3) opt_fn = torch.compile(backend="eager", fullgraph=True)(fn) @@ -419,9 +422,9 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase): _manual_torch_name_rule_map = manual_torch_name_rule_map.copy() # Force inline `mod.func` by setting trace rule. - _manual_torch_name_rule_map[ - f"{mod.__name__}.{func.__name__}" - ] = UserFunctionVariable + _manual_torch_name_rule_map[f"{mod.__name__}.{func.__name__}"] = ( + UserFunctionVariable + ) _torch_name_rule_map = [ _manual_torch_name_rule_map, @@ -429,12 +432,15 @@ class TraceRuleTests(torch._dynamo.test_case.TestCase): torch_non_c_binding_in_graph_functions, ] - with unittest.mock.patch( - "torch._dynamo.trace_rules.torch_name_rule_map", - _torch_name_rule_map, - ), unittest.mock.patch( - "torch._dynamo.trace_rules.get_torch_obj_rule_map", - torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, + with ( + unittest.mock.patch( + "torch._dynamo.trace_rules.torch_name_rule_map", + _torch_name_rule_map, + ), + unittest.mock.patch( + "torch._dynamo.trace_rules.get_torch_obj_rule_map", + torch._dynamo.trace_rules.get_torch_obj_rule_map.__wrapped__, + ), ): # First adding the module to SKIP_DIRS so that it will be skipped by default. torch._dynamo.trace_rules.add(mod.__name__) diff --git a/test/dynamo/test_utils.py b/test/dynamo/test_utils.py index 007c56e6a26e..c9ab3b781887 100644 --- a/test/dynamo/test_utils.py +++ b/test/dynamo/test_utils.py @@ -593,9 +593,10 @@ class TestDynamoTimed(TestCase): ) compilation_events = [] - with dynamo_config.patch({"automatic_dynamic_shapes": False}), mock.patch( - "torch._dynamo.utils.log_compilation_event" - ) as log_event: + with ( + dynamo_config.patch({"automatic_dynamic_shapes": False}), + mock.patch("torch._dynamo.utils.log_compilation_event") as log_event, + ): @torch.compile() def f(x): diff --git a/test/export/test_draft_export.py b/test/export/test_draft_export.py index 9f2dd833b3b6..6cf819958fcc 100644 --- a/test/export/test_draft_export.py +++ b/test/export/test_draft_export.py @@ -181,9 +181,12 @@ class TestDraftExport(TestCase): self.assertEqual(len(report.op_profiles), 1) self.assertEqual(len(report.op_profiles["mylib.foo8.default"]), 1) - with torch._library.fake_profile.unsafe_generate_fake_kernels( - report.op_profiles - ), FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()): + with ( + torch._library.fake_profile.unsafe_generate_fake_kernels( + report.op_profiles + ), + FakeTensorMode(allow_non_fake_inputs=True, shape_env=ShapeEnv()), + ): torch.ops.mylib.foo8(*new_inp) # Existing registration has been updated to match the new diff --git a/test/export/test_export.py b/test/export/test_export.py index f6ba272bc91e..dda4ddacd2eb 100755 --- a/test/export/test_export.py +++ b/test/export/test_export.py @@ -12834,12 +12834,14 @@ def forward(self, x, y): "y": [Dim("dy")], # y & z incorrect, export is supposed to fail. "z": [Dim("dz")], # suggested fix should be to match these up. } - with self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize. - torch._dynamo.exc.UserError, - r".*Constraints violated(.*\n)*" - r"Suggested fixes:(.*\n)*" - r".*dz = dy(.*\n)*", - ) as msg: + with ( + self.assertRaisesRegex( # if disable=True, suggested fixes should not specialize. + torch._dynamo.exc.UserError, + r".*Constraints violated(.*\n)*" + r"Suggested fixes:(.*\n)*" + r".*dz = dy(.*\n)*", + ) as msg + ): export( Foo(), inputs, @@ -13675,8 +13677,7 @@ def forward(self, x): """Make sure the metadata is kept after exported program run_decompositions.""" @torch.library.custom_op("mylib::add", mutates_args=()) - def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: - ... + def add(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: ... @torch.library.register_fake("mylib::add") def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor: diff --git a/test/export/test_serialize.py b/test/export/test_serialize.py index f28611f04ba7..75a30ccf3da9 100644 --- a/test/export/test_serialize.py +++ b/test/export/test_serialize.py @@ -947,10 +947,12 @@ class TestDeserialize(TestCase): ep = torch.export.export(M(), (torch.ones(3, 3), None, torch.ones(3, 3))) serialized_program = ExportedProgramSerializer(None, 2).serialize(ep) - serialized_program.exported_program.graph_module.signature.input_specs[ - 1 - ] = schema.InputSpec.create( - user_input=schema.UserInputSpec(arg=schema.Argument.create(as_none=True)) + serialized_program.exported_program.graph_module.signature.input_specs[1] = ( + schema.InputSpec.create( + user_input=schema.UserInputSpec( + arg=schema.Argument.create(as_none=True) + ) + ) ) ep = ExportedProgramDeserializer(None).deserialize( serialized_program.exported_program, {}, {}, {} diff --git a/test/export/test_swap.py b/test/export/test_swap.py index 8833c3c94ae7..d9b2269dc324 100644 --- a/test/export/test_swap.py +++ b/test/export/test_swap.py @@ -22,7 +22,9 @@ from torch.testing._internal.common_utils import IS_WINDOWS, run_tests, TestCase {"strict": False}, {"strict": True}, ], - class_name_func=lambda cls, _, params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}", + class_name_func=lambda cls, + _, + params: f"{cls.__name__}_{'strict' if params['strict'] else 'nonstrict'}", ) class TestSwap(TestCase): def test_unflatten_preserve_signature(self): diff --git a/test/forward_backward_compatibility/check_forward_backward_compatibility.py b/test/forward_backward_compatibility/check_forward_backward_compatibility.py index 3fe6b66039ca..d6cf2df4343f 100644 --- a/test/forward_backward_compatibility/check_forward_backward_compatibility.py +++ b/test/forward_backward_compatibility/check_forward_backward_compatibility.py @@ -348,8 +348,7 @@ def check_fc(existing_schemas): "\n\t".join(str(s) for s in matching_new_schemas), ) log.warning( - "Refer to following reasons for failure " - "to find FC schema:\n[\n%s\n]", + "Refer to following reasons for failure to find FC schema:\n[\n%s\n]", "\n\t".join(str(r) for r in possible_failure_reasons), ) broken_ops.append(str(existing_schema)) diff --git a/test/functorch/common_utils.py b/test/functorch/common_utils.py index 72a41dad777f..4fa17b89f19e 100644 --- a/test/functorch/common_utils.py +++ b/test/functorch/common_utils.py @@ -523,15 +523,15 @@ def decorateForModules(decorator, module_classes, device_type=None, dtypes=None) dtypes=dtypes, ): name_parts = fn.__qualname__.split(".") - assert ( - len(name_parts) == 2 - ), "Decorator only applies to a test function of a test class" + assert len(name_parts) == 2, ( + "Decorator only applies to a test function of a test class" + ) test_case_name, base_test_name = name_parts for module_cls in module_classes: matching_module_infos = [m for m in module_db if m.module_cls == module_cls] - assert ( - len(matching_module_infos) == 1 - ), f"Couldn't find single ModuleInfo for {module_cls}" + assert len(matching_module_infos) == 1, ( + f"Couldn't find single ModuleInfo for {module_cls}" + ) module_info = matching_module_infos[0] decorators = list(module_info.decorators) new_decorator = DecorateInfo( diff --git a/test/functorch/test_ac_knapsack.py b/test/functorch/test_ac_knapsack.py index f0a3c3916e6b..751a4c4d2185 100644 --- a/test/functorch/test_ac_knapsack.py +++ b/test/functorch/test_ac_knapsack.py @@ -124,9 +124,7 @@ class TestGraphInfoProvider(TestCase): ) def test_recomputable_node_only_graph_with_larger_graph_context(self): - recomputable_node_only_graph_with_larger_graph_context = ( - self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context - ) + recomputable_node_only_graph_with_larger_graph_context = self.graph_info_provider.recomputable_node_only_graph_with_larger_graph_context # noqa: B950 expected_nodes = self.all_recomputable_banned_nodes # node1 does not have an indirect path to node5 because of node2 # node2 has an indirect path to node5 diff --git a/test/functorch/test_aotdispatch.py b/test/functorch/test_aotdispatch.py index 1edd3845df92..9bd326304fa2 100644 --- a/test/functorch/test_aotdispatch.py +++ b/test/functorch/test_aotdispatch.py @@ -2568,8 +2568,9 @@ def forward(self, primals_1, primals_2): def fn(x: torch.Tensor, x1: torch.Tensor) -> torch.Tensor: return torch.ops._test._clone_create_graph(x, x1) - inp_x, inp_x1 = torch.randn(3, requires_grad=True), torch.randn( - 3, requires_grad=True + inp_x, inp_x1 = ( + torch.randn(3, requires_grad=True), + torch.randn(3, requires_grad=True), ) ref_x, ref_x1 = inp_x.clone(), inp_x1.clone() @@ -5283,11 +5284,12 @@ def forward(self, arg0_1): mod = TestMod(fn) inp = torch.randn(2) - with patch( - "functorch.compile.config.functionalize_rng_ops", True - ), self.assertRaisesRegex( - RuntimeError, - "Functionalized RNG is not currently supported in the aot_export", + with ( + patch("functorch.compile.config.functionalize_rng_ops", True), + self.assertRaisesRegex( + RuntimeError, + "Functionalized RNG is not currently supported in the aot_export", + ), ): aot_export_joint_simple(fn, [mod.p, inp], trace_joint=False) aot_export_joint_simple(fn, [mod.p, inp], trace_joint=True) diff --git a/test/functorch/test_eager_transforms.py b/test/functorch/test_eager_transforms.py index 1c53063b7a7c..bd8abbc3ea85 100644 --- a/test/functorch/test_eager_transforms.py +++ b/test/functorch/test_eager_transforms.py @@ -4324,9 +4324,7 @@ class TestExamplesCorrectness(TestCase): def lennard_jones_force(r): """Get magnitude of LJ force""" - return -epsilon * ( - (-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7) - ) + return -epsilon * ((-12 * sigma**12 / r**13) + (6 * sigma**6 / r**7)) r = torch.linspace(0.5, 2 * sigma, steps=100, requires_grad=True, device=device) drs = torch.outer(r, torch.tensor([1.0, 0, 0], device=device)) @@ -4495,8 +4493,9 @@ class TestExamplesCorrectness(TestCase): # This example mimics what a user might do when trying to find the optimal learning rate. They would # want to run a bunch of models with the same behavior (including the same dropout!) and have them # each run with different learning rates. Specifically, this is an example of using same randomness with vmap - points, labels = torch.randn(100, 2, 2, 2, 2, device=device), torch.randint( - 0, 2, (100,), device=device + points, labels = ( + torch.randn(100, 2, 2, 2, 2, device=device), + torch.randint(0, 2, (100,), device=device), ) class MLPClassifier(nn.Module): diff --git a/test/functorch/test_memory_efficient_fusion.py b/test/functorch/test_memory_efficient_fusion.py index 7bf263431ad0..4926781d7f65 100644 --- a/test/functorch/test_memory_efficient_fusion.py +++ b/test/functorch/test_memory_efficient_fusion.py @@ -208,33 +208,33 @@ def check(f, t, delta, check_val=True, graph_input=False): old_num_nodes = len(fx_g.graph.nodes) new_num_nodes = len(new_graph.nodes) if delta == -1: - assert ( - old_num_nodes >= new_num_nodes - ), f"number of nodes increased {old_num_nodes}, {new_num_nodes}" + assert old_num_nodes >= new_num_nodes, ( + f"number of nodes increased {old_num_nodes}, {new_num_nodes}" + ) else: - assert ( - old_num_nodes == new_num_nodes + delta - ), f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}" + assert old_num_nodes == new_num_nodes + delta, ( + f"number of nodes not the same {old_num_nodes - delta}, {new_num_nodes}\n {fx_g.graph} \n {new_graph}" + ) # a second pass should not reduce more nodes pass_2_graph = fx_graph_cse(new_graph) pass_2_num_nodes = len(pass_2_graph.nodes) - assert ( - pass_2_num_nodes == new_num_nodes - ), f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}" + assert pass_2_num_nodes == new_num_nodes, ( + f"second pass graph has less node {pass_2_num_nodes}, {new_num_nodes}\n {new_graph} \n {pass_2_graph}" + ) # check correctness if check_val: true_result = fx_g(t) our_result = new_g(t) if true_result is None: # both return None - assert ( - our_result is None - ), f"true result is None, CSE result is {our_result}" + assert our_result is None, ( + f"true result is None, CSE result is {our_result}" + ) else: # results returned are the same - assert torch.all( - true_result == our_result - ), f"results are different {true_result}, {our_result}" # check results are the same + assert torch.all(true_result == our_result), ( + f"results are different {true_result}, {our_result}" + ) # check results are the same class NoChangeTestCase(TestCase): diff --git a/test/functorch/test_ops.py b/test/functorch/test_ops.py index 8a0bf6ad40f5..cef00f83eb72 100644 --- a/test/functorch/test_ops.py +++ b/test/functorch/test_ops.py @@ -2154,9 +2154,9 @@ class TestOperators(TestCase): else: weight = torch.randn(weight_shape, device=device) target = torch.randint(0, C, target_shape, device=device) - target[ - 0 - ] = 1 # since we're ignoring index 0, at least one element must be non-zero + target[0] = ( + 1 # since we're ignoring index 0, at least one element must be non-zero + ) fn = functools.partial( torch.nn.functional.nll_loss, target=target, weight=weight, **kwargs diff --git a/test/functorch/test_parsing.py b/test/functorch/test_parsing.py index 46c9b340c594..8183755ebd4d 100644 --- a/test/functorch/test_parsing.py +++ b/test/functorch/test_parsing.py @@ -24,6 +24,7 @@ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ + from typing import Any from unittest import mock @@ -107,7 +108,7 @@ class TestParsedExpression(TestCase): ParsedExpression("(a) ((b c) (d ...))") # invalid identifiers - ParsedExpression("camelCase under_scored cApiTaLs \u00DF ...") + ParsedExpression("camelCase under_scored cApiTaLs \u00df ...") with self.assertRaises(ValueError): ParsedExpression("1a") with self.assertRaises(ValueError): diff --git a/test/functorch/test_rearrange.py b/test/functorch/test_rearrange.py index d5f55d7e7a3b..b3c8f7753687 100644 --- a/test/functorch/test_rearrange.py +++ b/test/functorch/test_rearrange.py @@ -25,7 +25,6 @@ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE SOFTWARE. """ - import numpy as np import torch diff --git a/test/functorch/test_vmap.py b/test/functorch/test_vmap.py index dc4f239ca2d4..1222e8905978 100644 --- a/test/functorch/test_vmap.py +++ b/test/functorch/test_vmap.py @@ -1532,7 +1532,7 @@ class TestVmapOperators(Namespace.TestVmapBase): self._test_unary(op, getter, "cpu") # test in-place - method = getattr(Tensor, f'{op.__name__ + "_"}') + method = getattr(Tensor, f"{op.__name__ + '_'}") self._test_unary(method, getter, "cpu", check_propagates_grad=False) def test_clone(self): diff --git a/test/fx/quantization.py b/test/fx/quantization.py index 3daa4da479ec..33550702ca6c 100644 --- a/test/fx/quantization.py +++ b/test/fx/quantization.py @@ -2,6 +2,7 @@ r""" **This file is EXPERIMENTAL and is mostly used for testing purposes! Do not rely on it for anything!** """ + import operator import sys diff --git a/test/fx/test_cse_pass.py b/test/fx/test_cse_pass.py index 06ecd2e14280..74eb2ca3af42 100644 --- a/test/fx/test_cse_pass.py +++ b/test/fx/test_cse_pass.py @@ -46,9 +46,9 @@ def check(self, f, t, delta, check_val=True, graph_input=False, P=None): old_num_nodes = len(fx_g.graph.nodes) new_num_nodes = len(new_graph.nodes) - assert ( - new_num_nodes < old_num_nodes - ) == modified, "modified should be True if the number of nodes decrease" + assert (new_num_nodes < old_num_nodes) == modified, ( + "modified should be True if the number of nodes decrease" + ) if delta == -1: self.assertTrue( diff --git a/test/fx/test_z3_gradual_types.py b/test/fx/test_z3_gradual_types.py index 70430e03c3a5..9b1a3878ed6a 100644 --- a/test/fx/test_z3_gradual_types.py +++ b/test/fx/test_z3_gradual_types.py @@ -1783,8 +1783,9 @@ class TestSingleOperation(unittest.TestCase): self.assertEqual(s.check(), z3.sat) add_result = z3.Const(3, tensor_type) - broadcast_res1, broadcast_res2 = z3.Const(4, tensor_type), z3.Const( - 5, tensor_type + broadcast_res1, broadcast_res2 = ( + z3.Const(4, tensor_type), + z3.Const(5, tensor_type), ) # print(s.model()) diff --git a/test/higher_order_ops/test_invoke_subgraph.py b/test/higher_order_ops/test_invoke_subgraph.py index 0fee2cf0953b..f7efc393697f 100644 --- a/test/higher_order_ops/test_invoke_subgraph.py +++ b/test/higher_order_ops/test_invoke_subgraph.py @@ -251,10 +251,13 @@ class TestInvokeSubgraphCompile(TestCase): x_clone = x.detach().clone().requires_grad_(True) y_clone = y.detach().clone().requires_grad_(True) backend = EagerAndRecordGraphs() - with mock.patch( - "torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation", - True, - ), torch.no_grad(): + with ( + mock.patch( + "torch._dynamo.variables.higher_order_ops.InvokeSubgraphHigherOrderVariable.supports_input_mutation", + True, + ), + torch.no_grad(), + ): res = torch.compile(fn, backend=backend, fullgraph=True)( mod, x_clone, y_clone ) @@ -2399,7 +2402,9 @@ class GraphModule(torch.nn.Module): {"strict": False}, {"strict": True}, ], - class_name_func=lambda cls, _, params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}", + class_name_func=lambda cls, + _, + params: f"{cls.__name__}{'Strict' if params['strict'] else 'Nonstrict'}", ) class TestInvokeSubgraphExport(TestCase): def test_simple_func(self): diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index 88f04145f899..590445341580 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -38,7 +38,6 @@ USE_BLACK_FILELIST = re.compile( # torchgen/** # test/** # test/[a-h]*/** - "test/[a-h]*/**", # test/[i-j]*/** "test/[i-j]*/**", # test/[k-m]*/**