Add check in test_cow_input to ensure COW data is never changed (#150723)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/150723
Approved by: https://github.com/Skylion007
This commit is contained in:
Kurt Mohler
2025-04-05 00:46:49 +00:00
committed by PyTorch MergeBot
parent 24aadb40fb
commit 164d2c887b

View File

@ -1825,6 +1825,7 @@ class TestCompositeCompliance(TestCase):
def check_cow_input(
arg,
arg_copy,
arg_raw,
idx_or_kw,
backward_or_forward="forward",
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_forward,
@ -1837,6 +1838,13 @@ class TestCompositeCompliance(TestCase):
) + f" during {backward_or_forward} call"
if is_strided_tensor(arg):
self.assertTrue(
torch._C._is_cow_tensor(arg_raw),
msg=(
f"{arg_name} raw input should remain COW, but it "
"unexpectedly materialized."
),
)
is_cow = torch._C._is_cow_tensor(arg)
if supports_cow_input_no_materialize and not check_ignore_materialize(
@ -1861,6 +1869,17 @@ class TestCompositeCompliance(TestCase):
"but the operation mutated its data."
),
)
else:
self.assertTrue(
torch.allclose(
arg_raw, arg_copy, rtol=0, atol=0, equal_nan=True
),
msg=(
f"{arg_name} materialized, which is allowed in this "
"case, but the COW input data was mutated, which is "
"not allowed."
),
)
for sample in samples:
args_raw = [sample.input] + list(sample.args)
@ -1901,10 +1920,10 @@ class TestCompositeCompliance(TestCase):
# Check that COW inputs remain COW after the forward op is executed
for idx, arg in enumerate(args):
check_cow_input(arg, args_copy[idx], idx)
check_cow_input(arg, args_copy[idx], args_raw[idx], idx)
for kw, arg in kwargs.items():
check_cow_input(arg, kwargs_copy[kw], kw)
check_cow_input(arg, kwargs_copy[kw], kwargs_raw[kw], kw)
# Call backward op if it is supported. This part of the test is
# based on `composite_compliance.check_backward_formula`
@ -1954,6 +1973,7 @@ class TestCompositeCompliance(TestCase):
check_cow_input(
arg,
args_copy[idx],
args_raw[idx],
idx,
backward_or_forward="backward",
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
@ -1965,6 +1985,7 @@ class TestCompositeCompliance(TestCase):
check_cow_input(
output_grad,
output_grads_copy[idx],
output_grads_raw[idx],
f"output grad {idx}",
backward_or_forward="backward",
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,