[Graph Partition] remove weak dep from partition_input_names (#152863)

Graph partition analyzes read_writes to get partition input names. However, weak dep is fake dependency and is not actually read or written. So we should not include weak dep in graph partition input names.

The following test failure is fixed by removing weak dependency from partition_input_names:
`PYTORCH_TEST_WITH_INDUCTOR=1 python test/test_torch.py TestTorchDeviceTypeCUDA.test_params_invalidated_with_grads_invalidated_between_unscale_and_step_Adam_cuda_float32`

Pull Request resolved: https://github.com/pytorch/pytorch/pull/152863
Approved by: https://github.com/eellison
This commit is contained in:
Boyuan Feng
2025-05-09 17:20:00 +00:00
committed by PyTorch MergeBot
parent 286de0d601
commit ffda46e3be
2 changed files with 28 additions and 7 deletions

View File

@ -6011,12 +6011,7 @@ else:
# Make sure that the parameters become nonsense when scaled gradients are finite
# but they get invalidated before `optimizer.step`, after `GradScaler.unscale_`
@onlyNativeDeviceTypes
@optims(
[optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]],
dtypes=[torch.float32]
)
def test_params_invalidated_with_grads_invalidated_between_unscale_and_step(self, device, dtype, optim_info):
def _test_params_invalidated_with_grads_invalidated_between_unscale_and_step(self, device, dtype, optim_info):
optimizer_ctor = optim_info.optim_cls
all_optim_inputs = _get_optim_inputs_including_global_cliquey_kwargs(
device, dtype, optim_info, skip=("differentiable",))
@ -6044,6 +6039,23 @@ else:
self.assertTrue(all((p.isnan().any() or p.isinf().any()) for p in model.parameters()))
@onlyNativeDeviceTypes
@optims(
[optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]],
dtypes=[torch.float32]
)
def test_params_invalidated_with_grads_invalidated_between_unscale_and_step(self, device, dtype, optim_info):
self._test_params_invalidated_with_grads_invalidated_between_unscale_and_step(device, dtype, optim_info)
@onlyNativeDeviceTypes
@optims(
[optim for optim in optim_db if optim.optim_cls in [torch.optim.AdamW, torch.optim.Adam, torch.optim.SGD]],
dtypes=[torch.float32]
)
@torch._inductor.config.patch("graph_partition", True)
def test_params_invalidated_with_grads_invalidated_and_graph_partition(self, device, dtype, optim_info):
self._test_params_invalidated_with_grads_invalidated_between_unscale_and_step(device, dtype, optim_info)
@onlyNativeDeviceTypes
def test_grad_scale_will_not_overflow(self, device):
device = torch.device(device)