diff --git a/tools/linter/adapters/pyfmt_linter.py b/tools/linter/adapters/pyfmt_linter.py index 18f7d46af21c..88f04145f899 100644 --- a/tools/linter/adapters/pyfmt_linter.py +++ b/tools/linter/adapters/pyfmt_linter.py @@ -60,7 +60,6 @@ USE_BLACK_FILELIST = re.compile( "torch/[b-c]*/**", # torch/d*/** # torch/[e-m]*/** - "torch/[e-m]*/**", # torch/optim/** # torch/[p-z]*/** "torch/[p-z]*/**", diff --git a/torch/export/__init__.py b/torch/export/__init__.py index 8382defdfdb3..13b675ead4b4 100644 --- a/torch/export/__init__.py +++ b/torch/export/__init__.py @@ -358,22 +358,24 @@ def save( import torch import io + class MyModule(torch.nn.Module): def forward(self, x): return x + 10 + ep = torch.export.export(MyModule(), (torch.randn(5),)) # Save to file - torch.export.save(ep, 'exported_program.pt2') + torch.export.save(ep, "exported_program.pt2") # Save to io.BytesIO buffer buffer = io.BytesIO() torch.export.save(ep, buffer) # Save with extra files - extra_files = {'foo.txt': b'bar'.decode('utf-8')} - torch.export.save(ep, 'exported_program.pt2', extra_files=extra_files) + extra_files = {"foo.txt": b"bar".decode("utf-8")} + torch.export.save(ep, "exported_program.pt2", extra_files=extra_files) """ if not isinstance(ep, ExportedProgram): @@ -427,18 +429,18 @@ def load( import io # Load ExportedProgram from file - ep = torch.export.load('exported_program.pt2') + ep = torch.export.load("exported_program.pt2") # Load ExportedProgram from io.BytesIO object - with open('exported_program.pt2', 'rb') as f: + with open("exported_program.pt2", "rb") as f: buffer = io.BytesIO(f.read()) buffer.seek(0) ep = torch.export.load(buffer) # Load with extra files. - extra_files = {'foo.txt': ''} # values will be replaced with data - ep = torch.export.load('exported_program.pt2', extra_files=extra_files) - print(extra_files['foo.txt']) + extra_files = {"foo.txt": ""} # values will be replaced with data + ep = torch.export.load("exported_program.pt2", extra_files=extra_files) + print(extra_files["foo.txt"]) print(ep(torch.randn(5))) """ if isinstance(f, (str, os.PathLike)): @@ -572,24 +574,29 @@ def register_dataclass( import torch from dataclasses import dataclass + @dataclass class InputDataClass: feature: torch.Tensor bias: int + @dataclass class OutputDataClass: res: torch.Tensor + torch.export.register_dataclass(InputDataClass) torch.export.register_dataclass(OutputDataClass) + class Mod(torch.nn.Module): def forward(self, x: InputDataClass) -> OutputDataClass: res = x.feature + x.bias return OutputDataClass(res=res) - ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1), )) + + ep = torch.export.export(Mod(), (InputDataClass(torch.ones(2, 2), 1),)) print(ep) """ diff --git a/torch/export/_draft_export.py b/torch/export/_draft_export.py index 2c77df8ade0d..9a9ed922c83e 100644 --- a/torch/export/_draft_export.py +++ b/torch/export/_draft_export.py @@ -43,7 +43,7 @@ def prettify_stack(stack: list[dict[str, str]], str_to_filename: dict[int, str]) continue res += f""" - File {str_to_filename[frame['filename']]}, lineno {frame['line']}, in {frame['name']}""" # type: ignore[index] + File {str_to_filename[frame["filename"]]}, lineno {frame["line"]}, in {frame["name"]}""" # type: ignore[index] res += f"\n {stack[-1]['loc']}" return res @@ -327,12 +327,12 @@ class CaptureStructuredTrace(torch._logging._internal.LazyTraceHandler): # We don't want to log all expression_created logs, only # the ones that are relevant to the # guards/propagate_real_tensor - self.expression_created_logs[ - metadata[key]["result_id"] - ] = ExpressionCreatedNode( - metadata[key]["result_id"], - metadata[key].get("argument_ids", []), - record, + self.expression_created_logs[metadata[key]["result_id"]] = ( + ExpressionCreatedNode( + metadata[key]["result_id"], + metadata[key].get("argument_ids", []), + record, + ) ) return @@ -374,10 +374,13 @@ def draft_export( capture_structured_log = CaptureStructuredTrace() - with torch._functorch.config.patch( - fake_tensor_propagate_real_tensors=True, - generate_fake_kernels_from_real_mismatches=True, - ), capture_structured_log: + with ( + torch._functorch.config.patch( + fake_tensor_propagate_real_tensors=True, + generate_fake_kernels_from_real_mismatches=True, + ), + capture_structured_log, + ): try: new_shapes = None ep = _export( @@ -424,10 +427,10 @@ def draft_export( continue elif log_name == "propagate_real_tensors_provenance": - log_contents[ - "occurrences" - ] = capture_structured_log.log_record.get_log_count( - (log_name, log_contents) + log_contents["occurrences"] = ( + capture_structured_log.log_record.get_log_count( + (log_name, log_contents) + ) ) failure_type = FailureType.DATA_DEPENDENT_ERROR diff --git a/torch/export/_swap.py b/torch/export/_swap.py index 74b564c9fccb..df003403569a 100644 --- a/torch/export/_swap.py +++ b/torch/export/_swap.py @@ -26,9 +26,9 @@ def _get_getitem_users(node: torch.fx.Node) -> set[torch.fx.Node]: if user.op == "output": continue - assert ( - user.op == "call_function" and user.target == operator.getitem - ), f"Expected getitem node as user for {node}, instead got {user}" + assert user.op == "call_function" and user.target == operator.getitem, ( + f"Expected getitem node as user for {node}, instead got {user}" + ) getitem_users.update(list(user.users.keys())) return getitem_users @@ -63,9 +63,9 @@ def _try_remove_connecting_pytrees(curr_module_node: torch.fx.Node) -> None: log.debug("Trying to remove pytrees for module call %s", curr_module_node) curr_module_users = list(curr_module_node.users.keys()) - assert ( - len(curr_module_users) == 1 - ), f"Expected only one user for module node, instead got {list(curr_module_users)}" + assert len(curr_module_users) == 1, ( + f"Expected only one user for module node, instead got {list(curr_module_users)}" + ) flatten_node = curr_module_users[0] assert ( flatten_node.op == "call_function" diff --git a/torch/export/_trace.py b/torch/export/_trace.py index d03d5f1efb8c..835a943515b7 100644 --- a/torch/export/_trace.py +++ b/torch/export/_trace.py @@ -268,9 +268,9 @@ def _extract_fake_inputs(gm, args, kwargs): if detected_fake_mode: if detected_shape_env: - assert ( - detected_shape_env is detected_fake_mode.shape_env - ), "Detected shape env does not match fake mode's shape env" + assert detected_shape_env is detected_fake_mode.shape_env, ( + "Detected shape env does not match fake mode's shape env" + ) fake_mode = detected_fake_mode elif detected_shape_env: fake_mode = FakeTensorMode(shape_env=detected_shape_env, export=True) @@ -864,13 +864,19 @@ def _export_to_aten_ir( # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, # otherwise aot_export_module will error out because it sees a mix of fake_modes. # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. - with torch.nn.utils.stateless._reparametrize_module( - mod, - fake_params_buffers, - tie_weights=True, - strict=True, - stack_weights=True, - ), grad_safe_guard, _ignore_backend_decomps(), _compiling_state_context(), custom_triton_ops_decomposition_ctx(): # type: ignore[attr-defined] + with ( + torch.nn.utils.stateless._reparametrize_module( + mod, + fake_params_buffers, + tie_weights=True, + strict=True, + stack_weights=True, + ), + grad_safe_guard, + _ignore_backend_decomps(), + _compiling_state_context(), + custom_triton_ops_decomposition_ctx(), + ): gm, graph_signature = transform(aot_export_module)( mod, fake_args, @@ -1229,9 +1235,9 @@ def _get_module_call_graph( """ gm: torch.fx.GraphModule = export_artifact.aten.gm export_graph_signature: ExportGraphSignature = export_artifact.aten.sig - module_call_specs: dict[ - str, dict[str, TreeSpec] - ] = export_artifact.module_call_specs + module_call_specs: dict[str, dict[str, TreeSpec]] = ( + export_artifact.module_call_specs + ) in_spec: TreeSpec = export_artifact.in_spec out_spec: TreeSpec = export_artifact.out_spec @@ -1365,7 +1371,8 @@ def _convert_ts_to_export_experimental(traced_callable, args, kwargs=None): ).module() elif isinstance(traced_callable, torch.ScriptMethod) and isinstance( - traced_callable.owner(), (torch._C.ScriptModule, torch.nn.Module) # type: ignore[operator] + traced_callable.owner(), # type: ignore[operator] + (torch._C.ScriptModule, torch.nn.Module), ): with patch_forward(traced_callable.owner(), traced_callable): # type: ignore[operator] return _export( @@ -1430,9 +1437,9 @@ def _strict_export( attr = getattr(gm_torch_level, node.target) # Checks if it is not a HigherOrderOp branch or a module if not isinstance(attr, torch.nn.Module): - assert ( - dynamo_fake_mode is not None - ), "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders." + assert dynamo_fake_mode is not None, ( + "Cannot find dynamo_fake_mode. This could be due to the exported graph module have no placeholders." + ) node.meta["val"] = dynamo_fake_mode.from_tensor( attr, static_shapes=True ) @@ -1749,13 +1756,17 @@ def _export_to_aten_ir_make_fx( # This _reparametrize_module makes sure inputs and module.params/buffers have the same fake_mode, # otherwise aot_export_module will error out because it sees a mix of fake_modes. # And we want aot_export_module to use the fake_tensor mode in dynamo to keep the pipeline easy to reason about. - with torch.nn.utils.stateless._reparametrize_module( - mod, - fake_params_buffers, - tie_weights=True, - strict=True, - stack_weights=True, - ), _ignore_backend_decomps(), _compiling_state_context(): # type: ignore[attr-defined] + with ( + torch.nn.utils.stateless._reparametrize_module( + mod, + fake_params_buffers, + tie_weights=True, + strict=True, + stack_weights=True, + ), + _ignore_backend_decomps(), + _compiling_state_context(), + ): gm, graph_signature = transform(_make_fx_helper)( mod, fake_args, @@ -1944,22 +1955,27 @@ def _non_strict_export( # We also need to attach dynamo configs as these will be used in HOOs that # use torch.compile, like cond dynamo_config = dataclasses.asdict(DEFAULT_EXPORT_DYNAMO_CONFIG) - dynamo_config[ - "do_not_emit_runtime_asserts" - ] = False # We want to emit runtime asserts + dynamo_config["do_not_emit_runtime_asserts"] = ( + False # We want to emit runtime asserts + ) - with fake_mode, _NonStrictTorchFunctionHandler(), tracing( - tx - ), torch._dynamo.config.patch(dynamo_config): - with _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as ( - patched_mod, - new_fake_args, - new_fake_kwargs, - new_fake_constant_attrs, - map_fake_to_real, - ), _fakify_module_inputs( - fake_args, fake_kwargs, fake_mode - ), _override_builtin_ops(): + with ( + fake_mode, + _NonStrictTorchFunctionHandler(), + tracing(tx), + torch._dynamo.config.patch(dynamo_config), + ): + with ( + _fakify_script_objects(mod, fake_args, fake_kwargs, fake_mode) as ( + patched_mod, + new_fake_args, + new_fake_kwargs, + new_fake_constant_attrs, + map_fake_to_real, + ), + _fakify_module_inputs(fake_args, fake_kwargs, fake_mode), + _override_builtin_ops(), + ): aten_export_artifact = _to_aten_func( # type: ignore[operator] patched_mod, new_fake_args, diff --git a/torch/export/dynamic_shapes.py b/torch/export/dynamic_shapes.py index a75bbdc7035a..af63551d8e46 100644 --- a/torch/export/dynamic_shapes.py +++ b/torch/export/dynamic_shapes.py @@ -666,7 +666,7 @@ class ShapesCollection: Example:: - args = ({"x": tensor_x, "others": [tensor_y, tensor_z]}) + args = {"x": tensor_x, "others": [tensor_y, tensor_z]} dim = torch.export.Dim(...) dynamic_shapes = torch.export.ShapesCollection() @@ -682,7 +682,7 @@ class ShapesCollection: Example:: - args = ({"x": tensor_x, "others": [int_x, int_y]}) + args = {"x": tensor_x, "others": [int_x, int_y]} # Wrap all ints with _IntWrapper mapped_args = pytree.tree_map_only(int, lambda a: _IntWrapper(a), args) @@ -700,18 +700,18 @@ class ShapesCollection: self._shapes = {} def __setitem__(self, t, shape): - assert isinstance( - t, (torch.Tensor, _IntWrapper) - ), f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}" + assert isinstance(t, (torch.Tensor, _IntWrapper)), ( + f"Cannot assign shape to non-tensor or non-_IntWrapper type {type(t)}" + ) # TODO(avik): check that shape is indeed a Shape t_id = id(t) if t_id in self._shapes: _shape = self._shapes[t_id] - assert ( - shape == _shape - ), f"Shapes assigned to input do not match: expected {_shape}, got {shape}" + assert shape == _shape, ( + f"Shapes assigned to input do not match: expected {_shape}, got {shape}" + ) else: self._shapes[id(t)] = shape @@ -766,7 +766,7 @@ class AdditionalInputs: Example:: - args0, kwargs0 = ... # example inputs for export + args0, kwargs0 = ... # example inputs for export # other representative inputs that the exported program will run on dynamic_shapes = torch.export.AdditionalInputs() @@ -786,9 +786,9 @@ class AdditionalInputs: """ assert type(args) is tuple, f"Representative args {args} must be a tuple" - assert ( - kwargs is None or type(kwargs) is dict - ), f"Representative kwargs {kwargs} must be None or a dict" + assert kwargs is None or type(kwargs) is dict, ( + f"Representative kwargs {kwargs} must be None or a dict" + ) self._examples.append((args, kwargs)) def dynamic_shapes(self, m, args, kwargs=None): @@ -1075,7 +1075,8 @@ def _process_dynamic_shapes( i, dim.__name__, StrictMinMaxConstraint( - vr=ValueRanges(lower=dim.value, upper=dim.value), warn_only=False # type: ignore[attr-defined] + vr=ValueRanges(lower=dim.value, upper=dim.value), # type: ignore[attr-defined] + warn_only=False, ), ) else: @@ -1085,7 +1086,8 @@ def _process_dynamic_shapes( i, dim.__name__, StrictMinMaxConstraint( - vr=ValueRanges(lower=dim.min, upper=dim.max), warn_only=False # type: ignore[attr-defined] + vr=ValueRanges(lower=dim.min, upper=dim.max), # type: ignore[attr-defined] + warn_only=False, ), ) return constraint @@ -1161,7 +1163,7 @@ def _process_dynamic_shapes( def _get_dim_name_mapping( - dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None] + dynamic_shapes: Union[dict[str, Any], tuple[Any], list[Any], None], ): name_to_dim = {} for dim in tree_flatten( diff --git a/torch/export/experimental/__init__.py b/torch/export/experimental/__init__.py index 99fc0d6995cb..a2c46108189a 100644 --- a/torch/export/experimental/__init__.py +++ b/torch/export/experimental/__init__.py @@ -137,16 +137,11 @@ class _ExportPackage: "decoder": ExportMethod( overloads={ "prefill": ExportedProgram(...), - "decode": ExportedProgram(...) + "decode": ExportedProgram(...), }, - fallbacks=[] + fallbacks=[], ), - "encoder": ExportMethod( - overloads={}, - fallbacks=[ - ExportedProgram(...) - ] - ) + "encoder": ExportMethod(overloads={}, fallbacks=[ExportedProgram(...)]), }, ) ``` @@ -212,15 +207,18 @@ class _ExportPackage: ``` package = ExportPackage() + def prefill(x, xa, kv_cache): assert x.shape[1] == 3 assert kv_cache == {} + def decode(x, xa, kv_cache): assert x.shape[1] > 1 assert len(kv_cache) > 0 return {...} # dynamic shape specs here + exporter = ( package.exporter(decoder) .define_overload("prefill", prefill) diff --git a/torch/export/exported_program.py b/torch/export/exported_program.py index 6baac896fb1f..00f775ce5221 100644 --- a/torch/export/exported_program.py +++ b/torch/export/exported_program.py @@ -272,7 +272,7 @@ def _override_composite_implicit_decomp(cia_ops_to_callable): def _split_decomp_table_to_cia_and_python_decomp( - decomp_table: dict[torch._ops.OperatorBase, Callable] + decomp_table: dict[torch._ops.OperatorBase, Callable], ) -> tuple[dict[torch._ops.OperatorBase, Callable], ...]: all_preservable_cia_ops = set(_collect_all_valid_cia_ops()) cia_ops_to_callable = {} @@ -443,9 +443,14 @@ def _decompose_and_get_gm_with_new_signature_constants( tx = TracingContext(fake_mode) - with fake_mode, _override_composite_implicit_decomp( - cia_to_decomp, - ), _enable_graph_inputs_of_type_nn_module(ep.example_inputs), tracing(tx): + with ( + fake_mode, + _override_composite_implicit_decomp( + cia_to_decomp, + ), + _enable_graph_inputs_of_type_nn_module(ep.example_inputs), + tracing(tx), + ): retracing_args_unwrapped = pytree.tree_unflatten( retracing_args, mod._in_spec ) @@ -573,9 +578,12 @@ def _decompose_and_get_gm_with_new_signature_constants( if decompose_custom_triton_ops else _disable_custom_triton_op_functional_decomposition ) - with _ignore_backend_decomps(), fake_mode, _override_composite_implicit_decomp( - cia_to_decomp - ), custom_triton_ops_decomposition_ctx(): + with ( + _ignore_backend_decomps(), + fake_mode, + _override_composite_implicit_decomp(cia_to_decomp), + custom_triton_ops_decomposition_ctx(), + ): gm, graph_signature = aot_export_module( ep.graph_module, fake_args, @@ -1514,9 +1522,9 @@ class ExportedProgram: if node.op != "placeholder": break - assert i < len( - old_signature.input_specs - ), "Number of inputs changed after transformation" + assert i < len(old_signature.input_specs), ( + "Number of inputs changed after transformation" + ) old_input_spec = old_signature.input_specs[i] arg = ( old_input_spec.arg @@ -1539,9 +1547,9 @@ class ExportedProgram: new_output_specs = [] for i, node in enumerate(output_node.args[0]): - assert i < len( - old_signature.output_specs - ), "Number of outputs changed after transformation" + assert i < len(old_signature.output_specs), ( + "Number of outputs changed after transformation" + ) old_output_spec = old_signature.output_specs[i] arg = ( old_output_spec.arg @@ -1599,9 +1607,9 @@ class ExportedProgram: # TODO: remove this @final def _validate(self): - assert ( - len(self.verifiers) > 0 - ), "ExportedProgram must have at least one verifier." + assert len(self.verifiers) > 0, ( + "ExportedProgram must have at least one verifier." + ) for v in self.verifiers: v().check(self) diff --git a/torch/export/graph_signature.py b/torch/export/graph_signature.py index d3c4e07b09c1..36e902b3838a 100644 --- a/torch/export/graph_signature.py +++ b/torch/export/graph_signature.py @@ -95,9 +95,9 @@ class InputSpec: def __post_init__(self): if self.kind == InputKind.BUFFER: - assert ( - self.persistent is not None - ), "Failed to specify persistent flag on BUFFER." + assert self.persistent is not None, ( + "Failed to specify persistent flag on BUFFER." + ) assert isinstance( self.arg, ( @@ -187,15 +187,17 @@ class ExportGraphSignature: self.my_parameter = nn.Parameter(torch.tensor(2.0)) # Define two buffers - self.register_buffer('my_buffer1', torch.tensor(3.0)) - self.register_buffer('my_buffer2', torch.tensor(4.0)) + self.register_buffer("my_buffer1", torch.tensor(3.0)) + self.register_buffer("my_buffer2", torch.tensor(4.0)) def forward(self, x1, x2): # Use the parameter, buffers, and both inputs in the forward method - output = (x1 + self.my_parameter) * self.my_buffer1 + x2 * self.my_buffer2 + output = ( + x1 + self.my_parameter + ) * self.my_buffer1 + x2 * self.my_buffer2 # Mutate one of the buffers (e.g., increment it by 1) - self.my_buffer2.add_(1.0) # In-place addition + self.my_buffer2.add_(1.0) # In-place addition return output @@ -520,9 +522,9 @@ def _make_argument_spec(node, token_names) -> ArgumentSpec: # For const outputs we just directly return this return ConstantArgument(name="", value=node) - assert ( - "val" in node.meta - ), f"{node} is not a constant or a node with a 'val' metadata field" + assert "val" in node.meta, ( + f"{node} is not a constant or a node with a 'val' metadata field" + ) val = node.meta["val"] if node.name in token_names: return TokenArgument(name=node.name) @@ -565,9 +567,21 @@ def _convert_to_export_graph_signature( user_outputs = set(graph_signature.user_outputs) buffer_mutations = graph_signature.buffers_to_mutate user_input_mutations = graph_signature.user_inputs_to_mutate - grad_params = graph_signature.backward_signature.gradients_to_parameter if is_joint else {} # type: ignore[union-attr] - grad_user_inputs = graph_signature.backward_signature.gradients_to_user_inputs if is_joint else {} # type: ignore[union-attr] - loss_output = graph_signature.backward_signature.loss_output if is_joint else None # type: ignore[union-attr] + grad_params = ( + graph_signature.backward_signature.gradients_to_parameter # type: ignore[union-attr] + if is_joint + else {} + ) + grad_user_inputs = ( + graph_signature.backward_signature.gradients_to_user_inputs # type: ignore[union-attr] + if is_joint + else {} + ) + loss_output = ( + graph_signature.backward_signature.loss_output # type: ignore[union-attr] + if is_joint + else None + ) input_tokens = graph_signature.input_tokens output_tokens = graph_signature.output_tokens diff --git a/torch/export/pt2_archive/_package.py b/torch/export/pt2_archive/_package.py index 5c9f1efa6855..e36258daf07b 100644 --- a/torch/export/pt2_archive/_package.py +++ b/torch/export/pt2_archive/_package.py @@ -155,9 +155,9 @@ class PT2ArchiveReader: def __init__(self, archive_path_or_buffer: FileLike): self.archive_file = torch._C.PyTorchFileReader(archive_path_or_buffer) # type: ignore[arg-type] - assert ( - self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE - ), "Invalid archive format" + assert self.read_string(ARCHIVE_FORMAT_PATH) == ARCHIVE_FORMAT_VALUE, ( + "Invalid archive format" + ) def __enter__(self) -> "PT2ArchiveReader": return self diff --git a/torch/export/unflatten.py b/torch/export/unflatten.py index 1cdefba579aa..54e698822b30 100644 --- a/torch/export/unflatten.py +++ b/torch/export/unflatten.py @@ -104,9 +104,9 @@ def _assign_attr( assert isinstance(from_obj, torch.Tensor) to_module.register_buffer(field, from_obj, persistent=persistent) elif attr_kind == _AttrKind.CONSTANT: - assert not isinstance( - from_obj, FakeScriptObject - ), "FakeScriptObject should only exist during tracing." + assert not isinstance(from_obj, FakeScriptObject), ( + "FakeScriptObject should only exist during tracing." + ) assert isinstance( from_obj, ( @@ -461,9 +461,9 @@ class UnflattenedModule(torch.nn.Module): # add constants that are aliased and don't appear in graph signature for const_name, const in export_module.constants.items(): if const_name not in consts_targets: - assert ( - id(const) in consts_map - ), "Constants should be either aliased or appear in graph signature" + assert id(const) in consts_map, ( + "Constants should be either aliased or appear in graph signature" + ) ph_name, _ = consts_map[id(const)][0] add_to_consts_map(id(const), ph_name, const_name) added_params_buffers.add(s.target) @@ -1041,9 +1041,9 @@ class _ModuleFrame: if arg.name in self.seen_nodes: flat_arg_node.meta = copy.copy(self.seen_nodes[arg.name].meta) - self.node_to_placeholder[ - self.seen_nodes[arg.name] - ] = flat_arg_node + self.node_to_placeholder[self.seen_nodes[arg.name]] = ( + flat_arg_node + ) with self.parent.graph.inserting_before(self.parent_call_module): input_nodes: list[Optional[torch.fx.Node]] = [] @@ -1125,8 +1125,7 @@ class _ModuleFrame: if x in self.node_to_placeholder: return self.node_to_placeholder[x] elif ( - x.op == "placeholder" - or self.module_call_graph.get(self.fqn) is None + x.op == "placeholder" or self.module_call_graph.get(self.fqn) is None # allow placeholder creation if we are not preserving module call signature ): self.add_placeholder(x) diff --git a/torch/fft/__init__.py b/torch/fft/__init__.py index 3ad1748bab1a..b48cd28bb17d 100644 --- a/torch/fft/__init__.py +++ b/torch/fft/__init__.py @@ -82,9 +82,7 @@ Example: >>> t = torch.tensor([0.+1.j, 2.+3.j, 4.+5.j, 6.+7.j]) >>> torch.fft.fft(t) tensor([12.+16.j, -8.+0.j, -4.-4.j, 0.-8.j]) -""".format( - **common_args - ), +""".format(**common_args), ) ifft = _add_docstr( @@ -125,9 +123,7 @@ Example: >>> t = torch.tensor([ 6.+0.j, -2.+2.j, -2.+0.j, -2.-2.j]) >>> torch.fft.ifft(t) tensor([0.+0.j, 1.+0.j, 2.+0.j, 3.+0.j]) -""".format( - **common_args - ), +""".format(**common_args), ) fft2 = _add_docstr( @@ -188,9 +184,7 @@ Example: >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1) >>> torch.testing.assert_close(fft2, two_ffts, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) ifft2 = _add_docstr( @@ -243,9 +237,7 @@ Example: >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1) >>> torch.testing.assert_close(ifft2, two_iffts, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) fftn = _add_docstr( @@ -305,9 +297,7 @@ Example: >>> two_ffts = torch.fft.fft(torch.fft.fft(x, dim=0), dim=1) >>> torch.testing.assert_close(fftn, two_ffts, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) ifftn = _add_docstr( @@ -359,9 +349,7 @@ Example: >>> two_iffts = torch.fft.ifft(torch.fft.ifft(x, dim=0), dim=1) >>> torch.testing.assert_close(ifftn, two_iffts, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) rfft = _add_docstr( @@ -417,9 +405,7 @@ Example: Notice that the symmetric element ``T[-1] == T[1].conj()`` is omitted. At the Nyquist frequency ``T[-2] == T[2]`` is it's own symmetric pair, and therefore must always be real-valued. -""".format( - **common_args - ), +""".format(**common_args), ) irfft = _add_docstr( @@ -496,9 +482,7 @@ Example: >>> roundtrip = torch.fft.irfft(T, t.numel()) >>> torch.testing.assert_close(roundtrip, t, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) rfft2 = _add_docstr( @@ -565,9 +549,7 @@ Example: >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0) >>> torch.testing.assert_close(rfft2, two_ffts, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) irfft2 = _add_docstr( @@ -649,9 +631,7 @@ Example: torch.Size([10, 9]) >>> torch.testing.assert_close(roundtrip, t, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) rfftn = _add_docstr( @@ -718,9 +698,7 @@ Example: >>> two_ffts = torch.fft.fft(torch.fft.rfft(t, dim=1), dim=0) >>> torch.testing.assert_close(rfftn, two_ffts, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) irfftn = _add_docstr( @@ -801,9 +779,7 @@ Example: torch.Size([10, 9]) >>> torch.testing.assert_close(roundtrip, t, check_stride=False) -""".format( - **common_args - ), +""".format(**common_args), ) hfft = _add_docstr( @@ -894,9 +870,7 @@ Example: >>> torch.fft.hfft(T[:3]) tensor([0.1250, 0.2809, 0.6250, 0.9691]) -""".format( - **common_args - ), +""".format(**common_args), ) ihfft = _add_docstr( @@ -951,9 +925,7 @@ Example: >>> torch.fft.ifft(t) tensor([ 2.0000-0.0000j, -0.5000-0.6882j, -0.5000-0.1625j, -0.5000+0.1625j, -0.5000+0.6882j]) -""".format( - **common_args - ), +""".format(**common_args), ) hfft2 = _add_docstr( @@ -1025,9 +997,7 @@ Example: >>> torch.allclose(roundtrip, T) True -""".format( - **common_args - ), +""".format(**common_args), ) ihfft2 = _add_docstr( @@ -1092,9 +1062,7 @@ Example: >>> torch.allclose(t, two_ffts) True -""".format( - **common_args - ), +""".format(**common_args), ) hfftn = _add_docstr( @@ -1187,9 +1155,7 @@ Example: >>> torch.allclose(roundtrip, T) True -""".format( - **common_args - ), +""".format(**common_args), ) ihfftn = _add_docstr( @@ -1259,9 +1225,7 @@ Example: >>> torch.allclose(ihfftn, two_iffts) True -""".format( - **common_args - ), +""".format(**common_args), ) fftfreq = _add_docstr( @@ -1310,9 +1274,7 @@ Example: >>> torch.fft.fftfreq(4) tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) -""".format( - **factory_common_args - ), +""".format(**factory_common_args), ) rfftfreq = _add_docstr( @@ -1361,9 +1323,7 @@ Example: >>> torch.fft.fftfreq(4) tensor([ 0.0000, 0.2500, -0.5000, -0.2500]) -""".format( - **factory_common_args - ), +""".format(**factory_common_args), ) fftshift = _add_docstr( diff --git a/torch/futures/__init__.py b/torch/futures/__init__.py index dcca39d06a4e..79533346187d 100644 --- a/torch/futures/__init__.py +++ b/torch/futures/__init__.py @@ -271,9 +271,9 @@ class Future(torch._C.Future, Generic[T], metaclass=_PyFutureMeta): ... ValueError: foo """ - assert isinstance( - result, Exception - ), f"{result} is of type {type(result)}, not an Exception." + assert isinstance(result, Exception), ( + f"{result} is of type {type(result)}, not an Exception." + ) def raise_error(fut_result): raise fut_result diff --git a/torch/fx/_graph_pickler.py b/torch/fx/_graph_pickler.py index e723046bf37c..97e5755d7d52 100644 --- a/torch/fx/_graph_pickler.py +++ b/torch/fx/_graph_pickler.py @@ -253,9 +253,9 @@ class _TensorPickleData: for k in MetaTensorDesc._UNSERIALIZABLE: if k in ("fake_mode", "view_func"): continue - assert ( - getattr(self.metadata, k) is None - ), f"not None: {k}: {getattr(self.metadata, k)}" + assert getattr(self.metadata, k) is None, ( + f"not None: {k}: {getattr(self.metadata, k)}" + ) def unpickle(self, unpickle_state: _UnpickleState) -> FakeTensor: # TODO: make common w/ _output_from_cache_entry() in fake_tensor.py? diff --git a/torch/fx/_symbolic_trace.py b/torch/fx/_symbolic_trace.py index 2509de1d2076..5a712ea3a1e3 100644 --- a/torch/fx/_symbolic_trace.py +++ b/torch/fx/_symbolic_trace.py @@ -755,9 +755,9 @@ class Tracer(TracerBase): self.root = root - assert hasattr( - type(root), self.traced_func_name - ), f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" + assert hasattr(type(root), self.traced_func_name), ( + f"traced_func_name={self.traced_func_name} doesn't exist in {type(root).__name__}" + ) fn = getattr(type(root), self.traced_func_name) self.root_module_name = root._get_name() @@ -1164,9 +1164,9 @@ def _maybe_revert_all_patches(): finally: if current_patcher is not None: patches_made = current_patcher.reapply_all_patches() - assert ( - patches_made == patches_removed - ), "CURRENT_PATCHER was changed during a revert_all_patches" + assert patches_made == patches_removed, ( + "CURRENT_PATCHER was changed during a revert_all_patches" + ) def _patch_wrapped_functions(patcher: _Patcher): @@ -1248,9 +1248,9 @@ def wrap(fn_or_name: Union[str, Callable]): assert not isinstance(fn_or_name, str) # to make mypy happy fn_name = fn_or_name.__name__ else: - assert isinstance( - fn_or_name, str - ), "fn_or_name must be a global function or string name" + assert isinstance(fn_or_name, str), ( + "fn_or_name must be a global function or string name" + ) fn_name = fn_or_name currentframe = inspect.currentframe() @@ -1308,7 +1308,9 @@ def symbolic_trace( return out - f = fx.symbolic_trace(f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}}) + f = fx.symbolic_trace( + f, concrete_args={"x": {"a": fx.PH, "b": fx.PH, "c": fx.PH}} + ) assert f({"a": 1, "b": 2, "c": 4}) == 7 diff --git a/torch/fx/experimental/accelerator_partitioner.py b/torch/fx/experimental/accelerator_partitioner.py index 29b8d4541b81..c29d05f511a7 100644 --- a/torch/fx/experimental/accelerator_partitioner.py +++ b/torch/fx/experimental/accelerator_partitioner.py @@ -450,9 +450,9 @@ class Partitioner: device = find_device_based_on_size(node) occupied_devices.append(device) # Update partition and its left mem size - partition_to_left_mem_bytes[ - partition - ] = device.available_mem_bytes + partition_to_left_mem_bytes[partition] = ( + device.available_mem_bytes + ) # Update available mem for the current partition partition.logical_device_ids.append(device.logical_id) else: @@ -475,9 +475,9 @@ class Partitioner: total_size_of_input_nodes = get_extra_size_of( node, partition.nodes ) - partition_to_left_mem_bytes[ - partition - ] = device.available_mem_bytes + partition_to_left_mem_bytes[partition] = ( + device.available_mem_bytes + ) partition.logical_device_ids.append(device.logical_id) partition.add_node(node) partition_to_left_mem_bytes[partition] -= total_size_of_input_nodes @@ -509,9 +509,9 @@ class Partitioner: no_device_partitions, ) = get_device_partition_stats(self.partitions, self.devices) - assert ( - len(no_device_partitions) == 0 - ), f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" + assert len(no_device_partitions) == 0, ( + f"Expect no_device_partitions has 0 device, but get {len(no_device_partitions)}" + ) # Devices that hold partitions used_devices = [d for d in self.devices if len(device_to_partitions[d]) > 0] diff --git a/torch/fx/experimental/optimization.py b/torch/fx/experimental/optimization.py index 6d2312b39d32..3e406b57a96d 100644 --- a/torch/fx/experimental/optimization.py +++ b/torch/fx/experimental/optimization.py @@ -368,12 +368,12 @@ def optimize_for_inference( supports_mkldnn = MklSupport.YES sample_parameter = next(cur_module.parameters(), None) if sample_parameter is not None: - assert ( - sample_parameter.dtype == torch.float - ), "this pass is only for torch.float modules" - assert sample_parameter.device == torch.device( - "cpu" - ), "this pass is only for CPU modules" + assert sample_parameter.dtype == torch.float, ( + "this pass is only for torch.float modules" + ) + assert sample_parameter.device == torch.device("cpu"), ( + "this pass is only for CPU modules" + ) elif node.op == "call_function": if node.target in mkldnn_supported: supports_mkldnn = MklSupport.YES diff --git a/torch/fx/experimental/proxy_tensor.py b/torch/fx/experimental/proxy_tensor.py index 6556bc1ce067..c8d8e58d952e 100644 --- a/torch/fx/experimental/proxy_tensor.py +++ b/torch/fx/experimental/proxy_tensor.py @@ -182,22 +182,19 @@ def is_sym_node(node: _HasMeta) -> bool: @overload -def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: - ... +def set_proxy_slot(obj: Tensor, tracer: _ProxyTracer, proxy: _ProxyTensor) -> None: ... @overload def set_proxy_slot( obj: _AnyScriptObjectType, tracer: _ProxyTracer, proxy: Proxy -) -> None: - ... +) -> None: ... @overload def set_proxy_slot( obj: PySymType, tracer: _ProxyTracer, proxy: _PySymProxyType -) -> None: - ... +) -> None: ... def set_proxy_slot( @@ -256,8 +253,7 @@ _PySymProxyType = Thunk[Proxy] def get_proxy_slot( obj: Tensor, tracer: _ProxyTracer, -) -> _ProxyTensor: - ... +) -> _ProxyTensor: ... @overload @@ -265,8 +261,7 @@ def get_proxy_slot( obj: Tensor, tracer: _ProxyTracer, default: U, -) -> Union[_ProxyTensor, U]: - ... +) -> Union[_ProxyTensor, U]: ... @overload @@ -275,16 +270,14 @@ def get_proxy_slot( tracer: _ProxyTracer, default: U, transform: Callable[[_ProxyTensor], R], -) -> Union[R, U]: - ... +) -> Union[R, U]: ... @overload def get_proxy_slot( obj: _AnyScriptObjectType, tracer: _ProxyTracer, -) -> Proxy: - ... +) -> Proxy: ... @overload @@ -292,8 +285,7 @@ def get_proxy_slot( obj: _AnyScriptObjectType, tracer: _ProxyTracer, default: U, -) -> Union[Proxy, U]: - ... +) -> Union[Proxy, U]: ... @overload @@ -302,16 +294,14 @@ def get_proxy_slot( tracer: _ProxyTracer, default: U, transform: Callable[[Proxy], R], -) -> Union[R, U]: - ... +) -> Union[R, U]: ... @overload def get_proxy_slot( obj: PySymType, tracer: _ProxyTracer, -) -> _PySymProxyType: - ... +) -> _PySymProxyType: ... @overload @@ -319,8 +309,7 @@ def get_proxy_slot( obj: PySymType, tracer: _ProxyTracer, default: T, -) -> Union[T, _PySymProxyType]: - ... +) -> Union[T, _PySymProxyType]: ... @overload @@ -329,8 +318,7 @@ def get_proxy_slot( tracer: _ProxyTracer, default: U, transform: Callable[[_PySymProxyType], R], -) -> Union[R, U]: - ... +) -> Union[R, U]: ... # the default argument is what to return if the slot is not set. @@ -717,22 +705,21 @@ def fetch_sym_proxy( @overload -def fetch_object_proxy(tracer: _ProxyTracer, t: Tensor) -> Union[_ProxyTensor, Tensor]: - ... +def fetch_object_proxy( + tracer: _ProxyTracer, t: Tensor +) -> Union[_ProxyTensor, Tensor]: ... @overload def fetch_object_proxy( tracer: _ProxyTracer, t: _AnyScriptObjectType -) -> Union[Proxy, _AnyScriptObjectType]: - ... +) -> Union[Proxy, _AnyScriptObjectType]: ... @overload def fetch_object_proxy( tracer: _ProxyTracer, t: PySymType -) -> Union[_PySymProxyType, PySymType]: - ... +) -> Union[_PySymProxyType, PySymType]: ... def fetch_object_proxy( @@ -815,7 +802,10 @@ def proxy_call( if func is torch.ops.aten.is_nonzero.default: with proxy_mode: - torch._check(args[0].numel() == 1, lambda: "Boolean value of Tensor with more than one value is ambiguous") # type: ignore[attr-defined] + torch._check( + args[0].numel() == 1, # type: ignore[attr-defined] + lambda: "Boolean value of Tensor with more than one value is ambiguous", + ) return (args[0] != 0).item() # type: ignore[attr-defined] tracer = proxy_mode.tracer @@ -1079,18 +1069,15 @@ class PythonKeyTracer(Tracer): return super().create_arg(a) # type: ignore[return-value] @overload - def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: - ... + def unwrap_proxy(self, e: Tensor) -> Union[Proxy, Tensor]: ... @overload - def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: - ... + def unwrap_proxy(self, e: PySymType) -> Union[Proxy, PySymType]: ... @overload def unwrap_proxy( self, e: _AnyScriptObjectType - ) -> Union[Proxy, _AnyScriptObjectType]: - ... + ) -> Union[Proxy, _AnyScriptObjectType]: ... def unwrap_proxy(self, e: T) -> object: if isinstance(e, Tensor): @@ -1608,7 +1595,10 @@ class DecompositionInterpreter(fx.Interpreter): self.mode = ProxyTorchDispatchMode(self.tracer, tracing_mode="real") def placeholder( - self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override] + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], ) -> object: out = super().placeholder(target, args, kwargs) # type: ignore[arg-type] proxy = fx.Proxy(self.new_graph.placeholder(target), self.tracer) @@ -1617,7 +1607,10 @@ class DecompositionInterpreter(fx.Interpreter): return out def get_attr( - self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override] + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], ) -> object: out = super().get_attr(target, args, kwargs) # type: ignore[arg-type] proxy = fx.Proxy(self.new_graph.get_attr(target), self.tracer) @@ -1627,7 +1620,10 @@ class DecompositionInterpreter(fx.Interpreter): # call_function, call_method, call_module get traced automatically by the outer mode. def output( - self, target: str, args: tuple[object, ...], kwargs: dict[str, object] # type: ignore[override] + self, + target: str, # type: ignore[override] + args: tuple[object, ...], + kwargs: dict[str, object], ) -> object: out = super().output(target, args, kwargs) # type: ignore[arg-type] @@ -1989,14 +1985,14 @@ class _MakefxTracer: # adding new modes in _MakefxTracer. self.fake_tensor_mode: Optional[FakeTensorMode] = None self.proxy_mode: Union[nullcontext, ProxyTorchDispatchMode] = nullcontext() - self.proxy_function_mode: Union[ - nullcontext, PreDispatchTorchFunctionMode - ] = nullcontext() + self.proxy_function_mode: Union[nullcontext, PreDispatchTorchFunctionMode] = ( + nullcontext() + ) self.fx_tracer: Optional[PythonKeyTracer] = None self.python_dispatcher_mode: Union[nullcontext, Any] = nullcontext() - self.torch_fn_metadata_mode: Union[ - nullcontext, TorchFunctionMetadataMode - ] = nullcontext() + self.torch_fn_metadata_mode: Union[nullcontext, TorchFunctionMetadataMode] = ( + nullcontext() + ) self.stack_trace = stack_trace def _checkpoint_modes(self) -> list[Any]: @@ -2071,9 +2067,9 @@ class _MakefxTracer: allow_non_fake_inputs=self._allow_non_fake_inputs, shape_env=shape_env, ) - assert ( - fake_tensor_mode.shape_env is not None - ), "shape_env should be set if tracing with 'symbolic'" + assert fake_tensor_mode.shape_env is not None, ( + "shape_env should be set if tracing with 'symbolic'" + ) self.fake_tensor_mode = fake_tensor_mode else: if not self.tracing_mode == "real": @@ -2161,9 +2157,9 @@ class _MakefxTracer: return self.fake_tensor_mode.from_tensor(x, source=source) # NB: don't match on bools elif type(x) is int and self.tracing_mode == "symbolic": - assert ( - self.fake_tensor_mode.shape_env is not None - ), "shape_env should be set if tracing with 'symbolic'" + assert self.fake_tensor_mode.shape_env is not None, ( + "shape_env should be set if tracing with 'symbolic'" + ) return self.fake_tensor_mode.shape_env.create_symintnode( self.fake_tensor_mode.shape_env.create_symbol( x, source, positive=None @@ -2176,9 +2172,9 @@ class _MakefxTracer: self.fake_tensor_mode, x ) - assert not isinstance( - x, FakeScriptObject - ), f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." + assert not isinstance(x, FakeScriptObject), ( + f"ScriptObject {x} has been fakified. Cannot wrap_fake it again." + ) return x wrap_fn_map = { @@ -2344,9 +2340,9 @@ def get_proxy_mode() -> Optional[ProxyTorchDispatchMode]: torch._C._TorchDispatchModeKey.PROXY ) mode = torch._C._get_dispatch_mode(torch._C._TorchDispatchModeKey.PROXY) - assert ( - pre_dispatch_mode is None or mode is None - ), f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}" + assert pre_dispatch_mode is None or mode is None, ( + f"pre_dispatch_mode={pre_dispatch_mode}, mode={mode}" + ) return pre_dispatch_mode or mode diff --git a/torch/fx/experimental/recording.py b/torch/fx/experimental/recording.py index bb54eba11384..c41c34158f54 100644 --- a/torch/fx/experimental/recording.py +++ b/torch/fx/experimental/recording.py @@ -460,7 +460,7 @@ def shape_env_check_state_equal(env1, env2, non_state_variable_names, map_value) # Here, we allow the value of each field to be mapped, so that we appropriately # compare the two values. def compare_vars( - map_value: Callable[[str, Any], Any] + map_value: Callable[[str, Any], Any], ) -> list[tuple[str, str, str]]: env1_set, env2_set = set(env1_vars), set(env2_vars) diff --git a/torch/fx/experimental/schema_type_annotation.py b/torch/fx/experimental/schema_type_annotation.py index 335c027c9321..b1b2f1680d64 100644 --- a/torch/fx/experimental/schema_type_annotation.py +++ b/torch/fx/experimental/schema_type_annotation.py @@ -103,7 +103,7 @@ class AnnotateTypesWithSchema(Transformer): for i, atom in enumerate(atoms): if not hasattr(module_itr, atom): raise RuntimeError( - f'Node referenced nonextent target {".".join(atoms[:i])}!' + f"Node referenced nonextent target {'.'.join(atoms[:i])}!" ) module_itr = getattr(module_itr, atom) diff --git a/torch/fx/experimental/sym_node.py b/torch/fx/experimental/sym_node.py index a37a08c8b4c1..5b8c98a1968f 100644 --- a/torch/fx/experimental/sym_node.py +++ b/torch/fx/experimental/sym_node.py @@ -149,9 +149,9 @@ class SymNode: # This is technically not TV, but this assert is expensive so # let's only do it when we're already doing expensive things computed_hint = compute_hint() - assert ( - hint == computed_hint - ), f"{hint} != {computed_hint} (for {self.expr})" + assert hint == computed_hint, ( + f"{hint} != {computed_hint} (for {self.expr})" + ) else: hint = compute_hint() self._hint = hint @@ -460,7 +460,9 @@ class SymNode: return self.float_pow(other) def is_non_overlapping_and_dense(self, sizes, strides): - return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq(to_node(self, 1)) # type: ignore[attr-defined] + return self.is_non_overlapping_and_dense_indicator(sizes, strides).eq( + to_node(self, 1) + ) # type: ignore[attr-defined] def int_(self): return self.guard_int("", 0) # NB: uses Python backtrace diff --git a/torch/fx/experimental/symbolic_shapes.py b/torch/fx/experimental/symbolic_shapes.py index 531b5c02f10c..90d14dfe17c1 100644 --- a/torch/fx/experimental/symbolic_shapes.py +++ b/torch/fx/experimental/symbolic_shapes.py @@ -182,7 +182,9 @@ CURRENT_NODE_KEY = "current_node" def log_lru_cache_stats(wrapped_f: functools._lru_cache_wrapper[object]) -> None: log.debug( - "lru_cache_stats %s: %s", wrapped_f.__name__, wrapped_f.cumulative_cache_info() # type: ignore[attr-defined] + "lru_cache_stats %s: %s", + wrapped_f.__name__, # type: ignore[attr-defined] + wrapped_f.cumulative_cache_info(), # type: ignore[attr-defined] ) @@ -497,9 +499,9 @@ def check_consistent(new: _T, old: _T) -> None: torch._check(i == j, lambda: f"{old.shape} != {new.shape} (old != new)") # NB: bool is subclass of int elif isinstance(new, scalar_types) and not isinstance(new, bool): - assert isinstance(old, scalar_types) and not isinstance( - old, bool - ), f"{old} != {new}" + assert isinstance(old, scalar_types) and not isinstance(old, bool), ( + f"{old} != {new}" + ) torch._check(old == new, lambda: f"{old} != {new} (old != new)") @@ -629,9 +631,9 @@ def rebind_unbacked( raw_u1 = new_raw_u1 if not isinstance(raw_u1, sympy.Symbol): - assert ( - not raw_u1.free_symbols - ), f"should have been constant, but got {raw_u1}" + assert not raw_u1.free_symbols, ( + f"should have been constant, but got {raw_u1}" + ) continue # The old and new could be the same if you improperly hit the memo @@ -1975,12 +1977,12 @@ class EqualityConstraint(Constraint): def _assert_symbol_context(symbolic_context: object) -> TypeGuard[SymbolicContext]: - assert isinstance( - symbolic_context, SymbolicContext - ), "Invalid symbolic_context object" - assert ( - type(symbolic_context) is not SymbolicContext - ), "Illegal usage of symbolic_context ABC" + assert isinstance(symbolic_context, SymbolicContext), ( + "Invalid symbolic_context object" + ) + assert type(symbolic_context) is not SymbolicContext, ( + "Illegal usage of symbolic_context ABC" + ) return True @@ -2519,9 +2521,9 @@ def _lru_cache( prior_version = self._version_counter prior_key = self._get_key() else: - assert ( - prior_key == self._get_key() - ), "ShapeEnv cache key changed without version being updated!" + assert prior_key == self._get_key(), ( + "ShapeEnv cache key changed without version being updated!" + ) return fn_cache(self, *args, **kwargs) @@ -2772,9 +2774,9 @@ class DynamicDimConstraintPrinter(PythonPrinter): def _print_Symbol(self, expr: sympy.Symbol) -> str: assert isinstance(expr, sympy.Symbol), str(type(expr)) - assert self.symbol_to_source.get( - expr - ), f"Unknown symbol {expr} created by constraints solver" + assert self.symbol_to_source.get(expr), ( + f"Unknown symbol {expr} created by constraints solver" + ) return self.symbol_to_source[expr][0].name() @@ -2792,9 +2794,9 @@ class DimConstraints: source_name_to_debug_name: Mapping[str, str], ) -> None: # We try to solve systems of inequalities with 1 free variable. - self._univariate_inequalities: dict[ - sympy.Symbol, set[SympyBoolean] - ] = defaultdict(set) + self._univariate_inequalities: dict[sympy.Symbol, set[SympyBoolean]] = ( + defaultdict(set) + ) # Among them, we prioritize solving for a free variable that has equalities. # NOTE: _symbols_with_equalities is always a subset of _univariate_inequalities.keys() # and removing a symbol from the former => removing it from the latter. @@ -2877,9 +2879,10 @@ class DimConstraints: # With any hint (say) s = k, we'd rewrite this to: 3*s % (s + 1) == k - 2. But, substituting, we # would then get k - 2 == s - 2, and thus s = k as the (only, constant) solution! base, divisor = args - base, divisor = self.rewrite_with_congruences( - s, base - ), self.rewrite_with_congruences(s, divisor) + base, divisor = ( + self.rewrite_with_congruences(s, base), + self.rewrite_with_congruences(s, divisor), + ) mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace( self._var_to_val ) @@ -2896,9 +2899,10 @@ class DimConstraints: # NOTE(avik): This is exactly equivalent to rewriting b // d as (b - (b % d)) / d # and eliminating b % d as above. base, divisor = args - base, divisor = self.rewrite_with_congruences( - s, base - ), self.rewrite_with_congruences(s, divisor) + base, divisor = ( + self.rewrite_with_congruences(s, base), + self.rewrite_with_congruences(s, divisor), + ) mod_reduced = base.xreplace(self._var_to_val) % divisor.xreplace( self._var_to_val ) @@ -3060,9 +3064,9 @@ class DimConstraints: (arg for arg in solution.args if isinstance(arg, sympy.Eq)), solution, ) - assert isinstance( - solution, sympy.Eq - ), f"Expected an equality constraint for {s}, got {solution}" + assert isinstance(solution, sympy.Eq), ( + f"Expected an equality constraint for {s}, got {solution}" + ) symbol, val = solution.args assert symbol == s, f"Expected a constraint on {s} instead of on {symbol}" # because this is univariate, the solution is a specialization @@ -3340,7 +3344,8 @@ class DimConstraints: "max": try_solve(sympy.Eq(expr, c["max"]), s)[1], # type: ignore[arg-type, index] } if not _check_same_range( - result, name_to_dim[mroot] # type: ignore[index, arg-type] + result, + name_to_dim[mroot], # type: ignore[index, arg-type] ): # ignore if unchanged modified_root_values[mroot] = result # type: ignore[index] break @@ -4124,9 +4129,9 @@ class ShapeEnv: if not isinstance(b, SymInt): assert a == b else: - assert isinstance( - b.node.expr, sympy.Symbol - ), "constraining non-Symbols NYI" + assert isinstance(b.node.expr, sympy.Symbol), ( + "constraining non-Symbols NYI" + ) assert b.node.shape_env is self self.replacements[b.node.expr] = sympy.Integer(a) else: @@ -4139,9 +4144,9 @@ class ShapeEnv: self.replacements[a.node.expr] = sympy.Integer(b) else: assert a.node.shape_env is b.node.shape_env - assert isinstance( - b.node.expr, sympy.Symbol - ), "constraining non-Symbols NYI" + assert isinstance(b.node.expr, sympy.Symbol), ( + "constraining non-Symbols NYI" + ) new_var = self._find(a.node.expr) self.replacements[b.node.expr] = new_var @@ -4234,9 +4239,9 @@ class ShapeEnv: # If translation validation is enabled, all arguments must have its # own FX node. - assert all( - a is not None for a in args - ), f"missing arg in FX graph ({op.__name__}): {args}" + assert all(a is not None for a in args), ( + f"missing arg in FX graph ({op.__name__}): {args}" + ) node = self.fx_node_cache[node_key] = self.graph.call_function(op, args) self.name_to_node[node.name] = node @@ -4354,9 +4359,9 @@ class ShapeEnv: source: Source, symbolic_context: SymbolicContext, ) -> list[sympy.Expr]: - assert all( - not is_symbolic(val) for val in tensor_size - ), f"Expect size to be a plain tuple of ints but got {tensor_size}" + assert all(not is_symbolic(val) for val in tensor_size), ( + f"Expect size to be a plain tuple of ints but got {tensor_size}" + ) from torch._dynamo.source import TensorProperty, TensorPropertySource _assert_symbol_context(symbolic_context) @@ -4398,7 +4403,11 @@ class ShapeEnv: source: Source, *, symbolic_context: Optional[SymbolicContext] = None, - ) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]: + ) -> tuple[ + tuple[IntLikeType, ...], + tuple[IntLikeType, ...], + IntLikeType, + ]: """ Returns a list of symbolic sizes and strides for the given tensor. We try our best to express stride in terms of the sizes, so as to not @@ -4463,9 +4472,9 @@ class ShapeEnv: ) -> IntLikeType: assert isinstance(maybe_sym, (int, torch.SymInt)) if is_symbolic(maybe_sym): - assert ( - maybe_sym.node.shape_env is not self - ), "expect the symbol is created from an shape env other than current one." + assert maybe_sym.node.shape_env is not self, ( + "expect the symbol is created from an shape env other than current one." + ) return maybe_sym.node.require_hint() return maybe_sym @@ -4481,7 +4490,11 @@ class ShapeEnv: source: Source, *, symbolic_context: Optional[SymbolicContext] = None, - ) -> tuple[tuple[IntLikeType, ...], tuple[IntLikeType, ...], IntLikeType,]: + ) -> tuple[ + tuple[IntLikeType, ...], + tuple[IntLikeType, ...], + IntLikeType, + ]: dim = len(ex_size) # Reimplement the legacy behavior @@ -5045,9 +5058,9 @@ class ShapeEnv: sloc, ) else: - self.var_to_range[ - sympy_expr - ] = self._default_unspecified_value_range() + self.var_to_range[sympy_expr] = ( + self._default_unspecified_value_range() + ) self.var_to_range_sloc[sympy_expr] = ValueRangesSLoc(sloc, sloc) # Small performance optimization: if we have a min-max constraint, @@ -5238,9 +5251,9 @@ class ShapeEnv: shape_env = replay_shape_env_events(self.events) self.check_equal(shape_env) - assert len(placeholders) == len( - sources - ), f"len({placeholders}) != len({sources})" + assert len(placeholders) == len(sources), ( + f"len({placeholders}) != len({sources})" + ) Tensorlike = (torch.Tensor, FakeTensorMeta) def _create_no_constraints_context(t: Tensor) -> StatelessSymbolicContext: @@ -5336,9 +5349,9 @@ class ShapeEnv: symbol_to_source: dict[sympy.Symbol, list[Source]] = collections.defaultdict( list ) - symbol_to_constraints: defaultdict[ - sympy.Symbol, set[Constraint] - ] = collections.defaultdict(set) + symbol_to_constraints: defaultdict[sympy.Symbol, set[Constraint]] = ( + collections.defaultdict(set) + ) constraint_violations: list[tuple[bool, str, Callable[[], str]]] = [] printers: list[_ShapeGuardPrinter] = [] @@ -6528,7 +6541,7 @@ class ShapeEnv: f"Caused by: {sloc}\n" 'For more information, run with TORCH_LOGS="dynamic"\n' "For extended logs when we create symbols, also add " - f"TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL=\"{','.join(map(str, expr.free_symbols))}\"\n" + f'TORCHDYNAMO_EXTENDED_DEBUG_CREATE_SYMBOL="{",".join(map(str, expr.free_symbols))}"\n' "If you suspect the guard was triggered from C++, add TORCHDYNAMO_EXTENDED_DEBUG_CPP=1\n" "For more debugging help, see " "https://docs.google.com/document/d/1HSuTTVvYH1pTew89Rtpeu84Ht3nQEFTYhAX3Ypa_xJs/edit?usp=sharing\n" @@ -6662,9 +6675,9 @@ class ShapeEnv: ) self._update_var_to_range(b, b_bound, self.var_to_range_sloc[a]) tgt_bound = self.bound_sympy(tgt) - assert tgt_bound.issubset( - src_bound - ), f"{tgt_bound=} not a subset of {src_bound=}" + assert tgt_bound.issubset(src_bound), ( + f"{tgt_bound=} not a subset of {src_bound=}" + ) # TODO: Should we propagate size-like-ness? # @@ -6751,9 +6764,9 @@ class ShapeEnv: for source in self.var_to_sources.get(a, []): if user_tb: self.user_specialization_stacks[source] = user_tb - self.framework_specialization_stacks[ - source - ] = CapturedTraceback.extract(cpp=True) + self.framework_specialization_stacks[source] = ( + CapturedTraceback.extract(cpp=True) + ) if config.print_specializations: self.log.warning( @@ -6820,9 +6833,9 @@ class ShapeEnv: free = list(expr.free_symbols) - assert ( - len(free) > 0 - ), f"The expression should not be static by this point: {expr}" + assert len(free) > 0, ( + f"The expression should not be static by this point: {expr}" + ) # In case of really gnarly expression, we don't blow up if len(free) > 5: return diff --git a/torch/fx/experimental/validator.py b/torch/fx/experimental/validator.py index 17a814b233c6..9d70973225db 100644 --- a/torch/fx/experimental/validator.py +++ b/torch/fx/experimental/validator.py @@ -203,9 +203,7 @@ try: return _Z3Ops.to_real(result) if cast_result_to_real else result def ceil(self, number: z3.ArithRef) -> z3.ArithRef: - return z3.If( - self.floor(number) < number, self.floor(number + 1), number - ) # type: ignore[return-value] + return z3.If(self.floor(number) < number, self.floor(number + 1), number) # type: ignore[return-value] def trunc(self, number: z3.ArithRef) -> z3.ArithRef: return z3.If(number >= 0, self.floor(number), self.ceil(number)) # type: ignore[return-value] @@ -363,9 +361,9 @@ try: return super().call_function(z3op(target, self.validator), args, kwargs) # type: ignore[arg-type] # Adds the Z3 expression corresponding to the first argument # as a validator input. - assert ( - len(args) == 1 - ), f"expected 1 argument on assertion. Got: {len(args)} " + assert len(args) == 1, ( + f"expected 1 argument on assertion. Got: {len(args)} " + ) self.validator.add_source_expr(args[0]) # type: ignore[arg-type] # Translates SymPy expressions into Z3 expressions. @@ -536,9 +534,9 @@ try: def to_z3_boolean_expr(self, e: sympy.Basic) -> z3.BoolRef: z3expr = SympyToZ3(self).run(e) - assert isinstance( - z3expr, z3.BoolRef - ), f"expected boolean expression. Got: {z3expr}" + assert isinstance(z3expr, z3.BoolRef), ( + f"expected boolean expression. Got: {z3expr}" + ) return z3expr def add_source_expr(self, e: z3.BoolRef) -> None: diff --git a/torch/fx/graph.py b/torch/fx/graph.py index 83b288196d30..6f815925e4a2 100644 --- a/torch/fx/graph.py +++ b/torch/fx/graph.py @@ -449,7 +449,7 @@ class CodeGen: # This code-path used in Python < 3.9 return origin_typename - return f'{origin_typename}[{",".join(args)}]' + return f"{origin_typename}[{','.join(args)}]" else: # Bare type, such as `typing.Tuple` with no subscript # This code-path used in Python 3.9+ @@ -573,7 +573,7 @@ class CodeGen: summary_str = parsed_stack_trace.get_summary_str() else: summary_str = "" - body.append(f'\n {dim(f"# {summary_str}")}\n') + body.append(f"\n {dim(f'# {summary_str}')}\n") elif prev_stacktrace != "": prev_stacktrace = "" no_stacktrace_msg = "# No stacktrace found for following nodes" @@ -842,7 +842,7 @@ class _PyTreeCodeGen(CodeGen): if len(has_annotation) > 0: fn_definition += "\n " + "".join(has_annotation) + "\n" fn_definition += f""" - {', '.join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" + {", ".join(without_annotation)}, = fx_pytree.tree_flatten_spec({fn_signature})""" return fn_definition def generate_output(self, output_args): @@ -1877,7 +1877,9 @@ class Graph: # through `insert_pdb`: gm.graph.on_generate_code( lambda current_trans: ( - lambda body: insert_pdb(current_trans(body) if current_trans else body) + lambda body: insert_pdb( + current_trans(body) if current_trans else body + ) ) ) @@ -1916,7 +1918,7 @@ class Graph: @contextmanager def _override_sym_repr( - override: Callable[["torch.types.PySymType"], str] + override: Callable[["torch.types.PySymType"], str], ) -> Iterator[None]: tmp = CodeGen._sym_repr try: diff --git a/torch/fx/graph_module.py b/torch/fx/graph_module.py index 1089cab1ee53..83155e7ad565 100644 --- a/torch/fx/graph_module.py +++ b/torch/fx/graph_module.py @@ -324,9 +324,9 @@ def _print_readable( colored=False, ): graph = module.graph - assert graph is not None and isinstance( - graph, torch.fx.Graph - ), "print_readable must be used on a module with a graph" + assert graph is not None and isinstance(graph, torch.fx.Graph), ( + "print_readable must be used on a module with a graph" + ) verbose_python_code = graph.python_code( root_module="self", @@ -873,9 +873,9 @@ class {module_name}(torch.nn.Module): for node in self.graph.nodes if "stack_trace" in node.meta } - dict_without_graph[ - "_graphmodule_graph_node_meta_stack_trace" - ] = node_meta_stack_trace + dict_without_graph["_graphmodule_graph_node_meta_stack_trace"] = ( + node_meta_stack_trace + ) generated_module_name = f"fx-generated._{exporter.get_unique_id()}" python_code = self.recompile() diff --git a/torch/fx/interpreter.py b/torch/fx/interpreter.py index 86648541e342..e2d2f9d7466d 100644 --- a/torch/fx/interpreter.py +++ b/torch/fx/interpreter.py @@ -51,7 +51,9 @@ class Interpreter: method equivalents). We could subclass Interpreter like so:: class NegSigmSwapInterpreter(Interpreter): - def call_function(self, target: Target, args: Tuple, kwargs: Dict) -> Any: + def call_function( + self, target: Target, args: Tuple, kwargs: Dict + ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) @@ -405,7 +407,7 @@ class Interpreter: for i, atom in enumerate(target_atoms): if not hasattr(attr_itr, atom): raise RuntimeError( - f"Node referenced nonexistent target {'.'.join(target_atoms[:i + 1])}" + f"Node referenced nonexistent target {'.'.join(target_atoms[: i + 1])}" ) attr_itr = getattr(attr_itr, atom) return attr_itr @@ -468,14 +470,20 @@ class Transformer(Interpreter): class NegSigmSwapXformer(Transformer): def call_function( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, + target: "Target", + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], ) -> Any: if target == torch.sigmoid: return torch.neg(*args, **kwargs) return super().call_function(target, args, kwargs) def call_method( - self, target: "Target", args: Tuple[Argument, ...], kwargs: Dict[str, Any] + self, + target: "Target", + args: Tuple[Argument, ...], + kwargs: Dict[str, Any], ) -> Any: if target == "neg": call_self, *args_tail = args diff --git a/torch/fx/node.py b/torch/fx/node.py index 7f51fe20201c..5b3a2abb9adf 100644 --- a/torch/fx/node.py +++ b/torch/fx/node.py @@ -514,9 +514,9 @@ class Node(_NodeBase): idx (int): The index of the element in ``self.args`` to be inserted before. arg (Argument): The new argument value to insert into ``args`` """ - assert ( - 0 <= idx <= len(self.args) - ), "insert_args index must be between 0 and len(self.args)" + assert 0 <= idx <= len(self.args), ( + "insert_args index must be between 0 and len(self.args)" + ) args_left = self.args[:idx] args_right = self.args[idx:] @@ -747,13 +747,13 @@ class Node(_NodeBase): # Check if an impure module. if self.op == "call_module": - assert ( - self.graph.owning_module is not None - ), "self.graph.owning_module not set for purity check" + assert self.graph.owning_module is not None, ( + "self.graph.owning_module not set for purity check" + ) target_mod = self.graph.owning_module.get_submodule(self.target) - assert ( - target_mod is not None - ), f"Did not find expected submodule target {self.target}" + assert target_mod is not None, ( + f"Did not find expected submodule target {self.target}" + ) return getattr(target_mod, "_is_impure", False) return False diff --git a/torch/fx/passes/net_min_base.py b/torch/fx/passes/net_min_base.py index edcb842cc892..548d2786feea 100644 --- a/torch/fx/passes/net_min_base.py +++ b/torch/fx/passes/net_min_base.py @@ -770,9 +770,9 @@ class _MinimizerBase: node_name = node.name if node_name is not None and isinstance(node_name, tuple): node_name = node_name[0] - assert node_name is not None and isinstance( - node_name, str - ), f"minimize: node_name: {node_name}" + assert node_name is not None and isinstance(node_name, str), ( + f"minimize: node_name: {node_name}" + ) report.append(f"Add node: {node_name}") diff --git a/torch/fx/passes/pass_manager.py b/torch/fx/passes/pass_manager.py index ddb1410f6840..48dfe702fedb 100644 --- a/torch/fx/passes/pass_manager.py +++ b/torch/fx/passes/pass_manager.py @@ -93,9 +93,9 @@ def loop_pass( predicate (Callable[Object, bool], optional): """ - assert (n_iter is not None) ^ ( - predicate is not None - ), "Exactly one of `n_iter`or `predicate` must be specified." + assert (n_iter is not None) ^ (predicate is not None), ( + "Exactly one of `n_iter`or `predicate` must be specified." + ) @wraps(base_pass) def new_pass(source): diff --git a/torch/fx/passes/runtime_assert.py b/torch/fx/passes/runtime_assert.py index b0479fd84b02..38c64c527aff 100644 --- a/torch/fx/passes/runtime_assert.py +++ b/torch/fx/passes/runtime_assert.py @@ -397,7 +397,9 @@ def insert_deferred_runtime_asserts( nn_module_stack=node.meta.get("nn_module_stack"), ), ): - expr_to_proxy[sym_expr] = _sympy_interp(expr_to_proxy, sym_expr) # type: ignore[arg-type] + expr_to_proxy[sym_expr] = _sympy_interp( + expr_to_proxy, sym_expr + ) # type: ignore[arg-type] # won't try DCE-ing tensor compute here hash_node = expr_to_proxy[sym_expr].node # type: ignore[arg-type] node.replace_all_uses_with(hash_node) diff --git a/torch/fx/passes/split_utils.py b/torch/fx/passes/split_utils.py index 926747b2a41f..079b1b4364bd 100644 --- a/torch/fx/passes/split_utils.py +++ b/torch/fx/passes/split_utils.py @@ -199,9 +199,9 @@ def split_by_tags( mx = max((c.order for c in upstream_components), default=0) # Expect the component for `node` has higher order then its upstream components. - assert ( - comp.order >= mx - ), f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}" + assert comp.order >= mx, ( + f"Component {comp.name} order must be >= max of its upstream components, order={comp.order} and max={mx}" + ) # Map a input of `node` to nodes in the component's graph. def remap_func(x): diff --git a/torch/fx/passes/utils/fuser_utils.py b/torch/fx/passes/utils/fuser_utils.py index 7487bc2c6631..f2524b6190d4 100644 --- a/torch/fx/passes/utils/fuser_utils.py +++ b/torch/fx/passes/utils/fuser_utils.py @@ -36,9 +36,9 @@ def topo_sort(nodes: NodeList) -> NodeList: if indegree_map[n] == 0: candidates.put(n) - assert len(nodes) == len( - sorted_nodes - ), "topological sorted nodes doesn't have same length as input nodes" + assert len(nodes) == len(sorted_nodes), ( + "topological sorted nodes doesn't have same length as input nodes" + ) return sorted_nodes @@ -127,13 +127,13 @@ def fuse_as_graphmodule( # assumption: nodes are already sorted in topo order for node in nodes: - assert ( - node.graph.owning_module is gm - ), f"{node} doesn't belong to passed in graph module {gm._get_name()}" + assert node.graph.owning_module is gm, ( + f"{node} doesn't belong to passed in graph module {gm._get_name()}" + ) assert not node._erased, f"{node} has been removed from owning graph" - assert ( - node in gm.graph._find_nodes_lookup_table - ), f"{node} is not found in graph module {gm._get_name()}" + assert node in gm.graph._find_nodes_lookup_table, ( + f"{node} is not found in graph module {gm._get_name()}" + ) # validates partition doesn't introduce dependency circles in the graph assert validate_partition(nodes), "Invalid partition, found dependency cycles" diff --git a/torch/fx/passes/utils/matcher_utils.py b/torch/fx/passes/utils/matcher_utils.py index 27d24ed29945..4f63935875d6 100644 --- a/torch/fx/passes/utils/matcher_utils.py +++ b/torch/fx/passes/utils/matcher_utils.py @@ -96,9 +96,9 @@ class SubgraphMatcher: for node in pattern.nodes: if node.op != "output": - assert ( - len(node.users) > 0 - ), "SubgraphMatcher cannot be initialized with an pattern with dead code" + assert len(node.users) > 0, ( + "SubgraphMatcher cannot be initialized with an pattern with dead code" + ) # TODO: assert pattern is a connected graph @@ -192,9 +192,9 @@ class SubgraphMatcher: return non_overlapping_matches def _match_literals(self, pn: Any, gn: Any, match: InternalMatch) -> bool: - assert not ( - isinstance(pn, Node) and isinstance(gn, Node) - ), "pn and gn cannot both be Node" + assert not (isinstance(pn, Node) and isinstance(gn, Node)), ( + "pn and gn cannot both be Node" + ) if isinstance(pn, Node) and not isinstance(gn, Node): if pn.op == "placeholder": diff --git a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py index 1fa9b721e9cc..091ec7f1f82b 100644 --- a/torch/fx/passes/utils/matcher_with_name_node_map_utils.py +++ b/torch/fx/passes/utils/matcher_with_name_node_map_utils.py @@ -18,17 +18,17 @@ def _split_to_graph_and_name_node_map( if n.op == "output": assert gm._out_spec is not None output = tree_unflatten(n.args[0], gm._out_spec) - assert isinstance( - output, tuple - ), "Expecting the pattern graph to return a tuple" - assert ( - len(output) >= 2 - ), "Expecting the pattern graph to have at least two outputs" + assert isinstance(output, tuple), ( + "Expecting the pattern graph to return a tuple" + ) + assert len(output) >= 2, ( + "Expecting the pattern graph to have at least two outputs" + ) *out, name_node_map = output flattened, out_spec = tree_flatten(out) - assert isinstance( - name_node_map, dict - ), "Expecting the input graph to have a dict output as the last element" + assert isinstance(name_node_map, dict), ( + "Expecting the input graph to have a dict output as the last element" + ) n.args = (flattened,) orig_pytree_info = gm._graph._codegen.pytree_info # type: ignore[attr-defined] gm._graph._codegen.pytree_info = _PyTreeInfo( # type: ignore[attr-defined] @@ -53,12 +53,14 @@ class SubgraphMatcherWithNameNodeMap(SubgraphMatcher): relu = F.relu(conv) return relu, {"conv": conv, "relu": relu} + def target_graph(x, weight): conv = F.conv2d(x, weight) relu = F.relu(conv) relu *= 2 return relu + pattern_gm = export_for_training(pattern, example_inputs).module() target_gm = export_for_training(target_graph, example_inputs).module() matcher = SubgraphMatcherWithNameNodeMap(pattern_gm) diff --git a/torch/fx/proxy.py b/torch/fx/proxy.py index 20bef1628bfc..9636f891102c 100644 --- a/torch/fx/proxy.py +++ b/torch/fx/proxy.py @@ -654,9 +654,9 @@ class MetaProxy(Proxy): meta_proxy = arg break - assert ( - meta_proxy is not None - ), "No MetaProxy found in arguments, but one is expected." + assert meta_proxy is not None, ( + "No MetaProxy found in arguments, but one is expected." + ) proxy = super().__torch_function__(orig_method, types, args, kwargs) with meta_proxy.fake_mode: @@ -739,14 +739,14 @@ for method in magic_methods: return tracer.create_proxy("call_function", target, args, kwargs) impl.__name__ = method - as_magic = f'__{method.strip("_")}__' + as_magic = f"__{method.strip('_')}__" setattr(Proxy, as_magic, impl) _scope(method) def _define_reflectable(orig_method_name): - method_name = f'__r{orig_method_name.strip("_")}__' + method_name = f"__r{orig_method_name.strip('_')}__" def impl(self, rhs): target = getattr(operator, orig_method_name) diff --git a/torch/fx/subgraph_rewriter.py b/torch/fx/subgraph_rewriter.py index ae6854f67887..bc1dde686699 100644 --- a/torch/fx/subgraph_rewriter.py +++ b/torch/fx/subgraph_rewriter.py @@ -307,9 +307,9 @@ def _replace_pattern( elif callable(replacement): common_replacement_graph = symbolic_trace(replacement).graph else: - assert ( - replacement_callback is not None - ), "Must provide either a replacement GraphModule or a replacement callback" + assert replacement_callback is not None, ( + "Must provide either a replacement GraphModule or a replacement callback" + ) common_replacement_graph = None # As we progressively replace nodes, we'll need to keep track of how the match results should change @@ -322,9 +322,9 @@ def _replace_pattern( match, original_graph, pattern_graph ) else: - assert ( - common_replacement_graph is not None - ), "Must provide either a replacement GraphModule or a replacement callback" + assert common_replacement_graph is not None, ( + "Must provide either a replacement GraphModule or a replacement callback" + ) replacement_graph = common_replacement_graph replacement_placeholders = [ n for n in replacement_graph.nodes if n.op == "placeholder" diff --git a/torch/jit/_builtins.py b/torch/jit/_builtins.py index fb8ac26471a9..2aa2fae3fde5 100644 --- a/torch/jit/_builtins.py +++ b/torch/jit/_builtins.py @@ -18,7 +18,15 @@ from torch.nn.modules.utils import ( _builtin_table: Optional[dict[int, str]] = None -_modules_containing_builtins = (torch, torch._C._nn, torch._C._fft, torch._C._linalg, torch._C._nested, torch._C._sparse, torch._C._special) # type: ignore[attr-defined] # noqa: B950 +_modules_containing_builtins = ( + torch, + torch._C._nn, + torch._C._fft, # type: ignore[attr-defined] + torch._C._linalg, # type: ignore[attr-defined] + torch._C._nested, # type: ignore[attr-defined] + torch._C._sparse, # type: ignore[attr-defined] + torch._C._special, # type: ignore[attr-defined] +) _builtin_ops = [ # Pairs of (function, op_name) @@ -94,7 +102,10 @@ _builtin_ops = [ (torch.autograd.grad, "aten::grad"), (torch.autograd.backward, "aten::backward"), (torch._C._infer_size, "aten::_infer_size"), - (torch.nn.functional._no_grad_embedding_renorm_, "aten::_no_grad_embedding_renorm_"), # type: ignore[attr-defined] + ( + torch.nn.functional._no_grad_embedding_renorm_, # type: ignore[attr-defined] + "aten::_no_grad_embedding_renorm_", + ), (torch.nn.functional.assert_int_or_pair, "aten::_assert_int_or_pair"), (torch.nn.init._no_grad_fill_, "aten::_no_grad_fill_"), (torch.nn.init._no_grad_normal_, "aten::_no_grad_normal_"), diff --git a/torch/jit/_decomposition_utils.py b/torch/jit/_decomposition_utils.py index 795f9da8e073..3a4b4ceff2cf 100644 --- a/torch/jit/_decomposition_utils.py +++ b/torch/jit/_decomposition_utils.py @@ -4,9 +4,9 @@ from torch._ops import OpOverload, OpOverloadPacket def _register_decomposition(op: OpOverload, graph: torch._C.Graph): - assert not isinstance( - op, OpOverloadPacket - ), f"Must pass specific op overload, not overload packet, found {op}" + assert not isinstance(op, OpOverloadPacket), ( + f"Must pass specific op overload, not overload packet, found {op}" + ) assert isinstance(op, OpOverload) torch._C._jit_register_decomposition_for_schema(op._schema, graph) diff --git a/torch/jit/_decompositions.py b/torch/jit/_decompositions.py index ba37fe5f0cac..673a04a552af 100644 --- a/torch/jit/_decompositions.py +++ b/torch/jit/_decompositions.py @@ -23,13 +23,13 @@ def check_decomposition_has_type_annotations(f): inspect_empty = inspect._empty # type: ignore[attr-defined] sig = inspect.signature(f) for param in sig.parameters.values(): - assert ( - param.annotation != inspect_empty - ), f"No signature on param {param.name} for function {f.name}" + assert param.annotation != inspect_empty, ( + f"No signature on param {param.name} for function {f.name}" + ) - assert ( - sig.return_annotation != inspect_empty - ), f"No return annotation for function {f.name}" + assert sig.return_annotation != inspect_empty, ( + f"No return annotation for function {f.name}" + ) def signatures_match(decomposition_sig, torch_op_sig): @@ -75,9 +75,9 @@ def register_decomposition( assert isinstance(aten_op, torch._ops.OpOverload) # Need unique name for jit function serialization - assert ( - f.__name__ not in function_name_set - ), f"Duplicated function name {f.__name__}" + assert f.__name__ not in function_name_set, ( + f"Duplicated function name {f.__name__}" + ) function_name_set.add(f.__name__) scripted_func = torch.jit.script(f) diff --git a/torch/jit/_recursive.py b/torch/jit/_recursive.py index d62f039263c2..e8f99080ff5c 100644 --- a/torch/jit/_recursive.py +++ b/torch/jit/_recursive.py @@ -588,9 +588,9 @@ def create_script_module_impl(nn_module, concrete_type, stubs_fn): # recursively scripting them. for name, sub_concrete_type in concrete_type.get_modules(): orig_value = getattr(nn_module, name) - assert isinstance( - orig_value, Module - ), f"Expected Module but got {type(orig_value)}" + assert isinstance(orig_value, Module), ( + f"Expected Module but got {type(orig_value)}" + ) module_type = sub_concrete_type.jit_type if isinstance(module_type, torch._C.InterfaceType): # use the interface inference rule to compile the module diff --git a/torch/jit/_script.py b/torch/jit/_script.py index 5777b047e74e..79442f57d306 100644 --- a/torch/jit/_script.py +++ b/torch/jit/_script.py @@ -318,10 +318,10 @@ class ScriptMeta(type): else: return infer_methods_to_compile(module) - self.__dict__[ - "_actual_script_module" - ] = torch.jit._recursive.create_script_module( - self, make_stubs, share_types=not added_methods_in_init + self.__dict__["_actual_script_module"] = ( + torch.jit._recursive.create_script_module( + self, make_stubs, share_types=not added_methods_in_init + ) ) # Delete the Python attributes that now shadow the ScriptModule diff --git a/torch/jit/_shape_functions.py b/torch/jit/_shape_functions.py index aa0dc2b82d54..7539b52ea885 100644 --- a/torch/jit/_shape_functions.py +++ b/torch/jit/_shape_functions.py @@ -280,15 +280,15 @@ def max_pool2d( dilation: list[int], ceil_mode: bool, ): - assert ( - len(kernel_size) == 1 or len(kernel_size) == 2 - ), "max_pool2d: kernel_size must either be a single int, or a tuple of two ints" + assert len(kernel_size) == 1 or len(kernel_size) == 2, ( + "max_pool2d: kernel_size must either be a single int, or a tuple of two ints" + ) kH = kernel_size[0] kW = kH if len(kernel_size) == 1 else kernel_size[1] - assert ( - len(stride) == 0 or len(stride) == 1 or len(stride) == 2 - ), "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" + assert len(stride) == 0 or len(stride) == 1 or len(stride) == 2, ( + "max_pool2d: stride must either be omitted, a single int, or a tuple of two ints" + ) dH = kH if len(stride) == 0 else stride[0] if len(stride) == 0: dW = kW @@ -297,15 +297,15 @@ def max_pool2d( else: dW = stride[1] - assert ( - len(padding) == 1 or len(padding) == 2 - ), "max_pool2d: padding must either be a single int, or a tuple of two ints" + assert len(padding) == 1 or len(padding) == 2, ( + "max_pool2d: padding must either be a single int, or a tuple of two ints" + ) padH = padding[0] padW = padH if len(padding) == 1 else padding[1] - assert ( - len(dilation) == 1 or len(dilation) == 2 - ), "max_pool2d: dilation must be either a single int, or a tuple of two ints" + assert len(dilation) == 1 or len(dilation) == 2, ( + "max_pool2d: dilation must be either a single int, or a tuple of two ints" + ) dilationH = dilation[0] dilationW = dilationH if len(dilation) == 1 else dilation[1] @@ -367,17 +367,17 @@ def upsample_nearest2d( assert 0, "Either output_size or scale_factors must be presented" if output_size is not None: - assert ( - scale_factors is None - ), "Must specify exactly one of output_size and scale_factors" + assert scale_factors is None, ( + "Must specify exactly one of output_size and scale_factors" + ) assert len(output_size) == 2 out.append(output_size[0]) out.append(output_size[1]) if scale_factors is not None: - assert ( - output_size is None - ), "Must specify exactly one of output_size and scale_factors" + assert output_size is None, ( + "Must specify exactly one of output_size and scale_factors" + ) assert len(scale_factors) == 2 out.append(int(input[2] * scale_factors[0])) out.append(int(input[3] * scale_factors[1])) @@ -540,9 +540,9 @@ def check_cat_shape_except_dim( assert first_dims == second_dims, "Tensors must have same number of dimensions" for dim in range(0, first_dims): if dim != dimension: - assert ( - first[dim] == second[dim] - ), "Sizes of tensors must match except in dimension" + assert first[dim] == second[dim], ( + "Sizes of tensors must match except in dimension" + ) def cat(tensors: list[list[int]], dim: int): @@ -1088,9 +1088,9 @@ def topk(self: list[int], k: int, dim: int = -1) -> tuple[list[int], list[int]]: if len(self) == 0: result: list[int] = [] else: - assert ( - k <= self[dim] - ), f"k ({k}) is too big for dimension {dim} of size {self[dim]}" + assert k <= self[dim], ( + f"k ({k}) is too big for dimension {dim} of size {self[dim]}" + ) result = _copy(self) result[dim] = k return result, result diff --git a/torch/jit/_trace.py b/torch/jit/_trace.py index eae30f415e9b..23e7db73819c 100644 --- a/torch/jit/_trace.py +++ b/torch/jit/_trace.py @@ -1205,7 +1205,10 @@ def trace_module( # Trace specific methods on a module (specified in `inputs`), constructs # a `ScriptModule` with `forward` and `weighted_kernel_sum` methods - inputs = {"forward": example_forward_input, "weighted_kernel_sum": example_weight} + inputs = { + "forward": example_forward_input, + "weighted_kernel_sum": example_weight, + } module = torch.jit.trace_module(n, inputs) """ diff --git a/torch/masked/_ops.py b/torch/masked/_ops.py index 9371052a4fb3..fb802eba1aa8 100644 --- a/torch/masked/_ops.py +++ b/torch/masked/_ops.py @@ -309,14 +309,14 @@ defined as ``prod(x[:i])``.""", operation_args, operation_kwargs = args_and_kwargs[func.__name__] arg_declarations = [ "\n ".join( - argument_declarations.get(a, f'{a.split("__", 1)[0]}: TBD.').splitlines() + argument_declarations.get(a, f"{a.split('__', 1)[0]}: TBD.").splitlines() ) for a in operation_args ] kwarg_declarations = [ "\n ".join( argument_declarations.get( - a.split("=", 1)[0], f'{a.split("__", 1)[0]}: TBD.' + a.split("=", 1)[0], f"{a.split('__', 1)[0]}: TBD." ) .format(default=a.split("=", 1)[1]) .splitlines() @@ -745,9 +745,9 @@ def _sparse_csr_segment_reduction_helper( ) -> Tensor: # Currently, while sparse CSR is always 2D with no dense dimensions keepdim must be True # FIXME: when dense dimensions are implemented for CSR tensors - assert ( - keepdim - ), "reduction operations on CSR tensors with keepdim=False is unsupported" + assert keepdim, ( + "reduction operations on CSR tensors with keepdim=False is unsupported" + ) reduce = op.__name__ valid_reductions = ["sum", "prod", "mean", "amax", "amin"] if reduce not in valid_reductions: @@ -781,9 +781,9 @@ def _sparse_csr_segment_reduction_helper( ) new_shape = [1, mask_input.size(1)] else: - assert ( - dims[0] == 1 - ), "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1." + assert dims[0] == 1, ( + "Sparse CSR tensors are 2D and only support reduction along dim 0 or 1." + ) # all intervals new_crow_indices[i] - new_crow_indices[i-1] are 1 # except for where crow_indices[i] == crow_indices[i-1] where the interval remains as 0 new_crow_indices = torch.cat( @@ -1598,9 +1598,9 @@ def _std_var( mask: Optional[Tensor], take_sqrt: Optional[bool], ) -> Tensor: - assert ( - unbiased is None or correction_opt is None - ), "Only one of unbiased and correction may be given" + assert unbiased is None or correction_opt is None, ( + "Only one of unbiased and correction may be given" + ) correction = 1.0 if unbiased is not None: correction = 1.0 if unbiased else 0.0 @@ -1636,7 +1636,11 @@ def _std_var( total = sum(x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype) else: total = sum( - x * x.conj(), dim, keepdim=keepdim, dtype=compute_dtype, mask=inmask # type: ignore[possibly-undefined] + x * x.conj(), + dim, + keepdim=keepdim, + dtype=compute_dtype, + mask=inmask, # type: ignore[possibly-undefined] ) if not keepdim: count = count.reshape(total.shape) diff --git a/torch/masked/maskedtensor/core.py b/torch/masked/maskedtensor/core.py index 46ff1eaa3c83..2e3608b3e6d3 100644 --- a/torch/masked/maskedtensor/core.py +++ b/torch/masked/maskedtensor/core.py @@ -25,7 +25,7 @@ def is_masked_tensor(obj: Any, /) -> TypeIs["MaskedTensor"]: >>> # xdoctest: +SKIP >>> from torch.masked import MaskedTensor - >>> data = torch.arange(6).reshape(2,3) + >>> data = torch.arange(6).reshape(2, 3) >>> mask = torch.tensor([[True, False, False], [True, True, False]]) >>> mt = MaskedTensor(data, mask) >>> is_masked_tensor(mt) diff --git a/torch/mps/__init__.py b/torch/mps/__init__.py index b0a62c182578..cdbf6b16ac44 100644 --- a/torch/mps/__init__.py +++ b/torch/mps/__init__.py @@ -5,6 +5,7 @@ Metal is Apple's API for programming metal GPU (graphics processor unit). Using performance can be achieved, by running work on the metal GPU(s). See https://developer.apple.com/documentation/metalperformanceshaders for more details. """ + from typing import Union import torch diff --git a/torch/mtia/__init__.py b/torch/mtia/__init__.py index a9296539d58e..4761969fc286 100644 --- a/torch/mtia/__init__.py +++ b/torch/mtia/__init__.py @@ -198,7 +198,7 @@ def snapshot() -> dict[str, Any]: def attach_out_of_memory_observer( - observer: Callable[[int, int, int, int], None] + observer: Callable[[int, int, int, int], None], ) -> None: r"""Attach an out-of-memory observer to MTIA memory allocator""" torch._C._mtia_attachOutOfMemoryObserver(observer) diff --git a/torch/multiprocessing/__init__.py b/torch/multiprocessing/__init__.py index 745c180d8c41..3e37b8c1947d 100644 --- a/torch/multiprocessing/__init__.py +++ b/torch/multiprocessing/__init__.py @@ -14,6 +14,7 @@ memory. Because of the similarity of APIs we do not document most of this package contents, and we recommend referring to very good docs of the original module. """ + import multiprocessing import sys