mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-21 05:34:18 +08:00
Add check in test_cow_input
to ensure COW data is never changed (#150723)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/150723 Approved by: https://github.com/Skylion007
This commit is contained in:
committed by
PyTorch MergeBot
parent
24aadb40fb
commit
164d2c887b
@ -1825,6 +1825,7 @@ class TestCompositeCompliance(TestCase):
|
||||
def check_cow_input(
|
||||
arg,
|
||||
arg_copy,
|
||||
arg_raw,
|
||||
idx_or_kw,
|
||||
backward_or_forward="forward",
|
||||
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_forward,
|
||||
@ -1837,6 +1838,13 @@ class TestCompositeCompliance(TestCase):
|
||||
) + f" during {backward_or_forward} call"
|
||||
|
||||
if is_strided_tensor(arg):
|
||||
self.assertTrue(
|
||||
torch._C._is_cow_tensor(arg_raw),
|
||||
msg=(
|
||||
f"{arg_name} raw input should remain COW, but it "
|
||||
"unexpectedly materialized."
|
||||
),
|
||||
)
|
||||
is_cow = torch._C._is_cow_tensor(arg)
|
||||
|
||||
if supports_cow_input_no_materialize and not check_ignore_materialize(
|
||||
@ -1861,6 +1869,17 @@ class TestCompositeCompliance(TestCase):
|
||||
"but the operation mutated its data."
|
||||
),
|
||||
)
|
||||
else:
|
||||
self.assertTrue(
|
||||
torch.allclose(
|
||||
arg_raw, arg_copy, rtol=0, atol=0, equal_nan=True
|
||||
),
|
||||
msg=(
|
||||
f"{arg_name} materialized, which is allowed in this "
|
||||
"case, but the COW input data was mutated, which is "
|
||||
"not allowed."
|
||||
),
|
||||
)
|
||||
|
||||
for sample in samples:
|
||||
args_raw = [sample.input] + list(sample.args)
|
||||
@ -1901,10 +1920,10 @@ class TestCompositeCompliance(TestCase):
|
||||
|
||||
# Check that COW inputs remain COW after the forward op is executed
|
||||
for idx, arg in enumerate(args):
|
||||
check_cow_input(arg, args_copy[idx], idx)
|
||||
check_cow_input(arg, args_copy[idx], args_raw[idx], idx)
|
||||
|
||||
for kw, arg in kwargs.items():
|
||||
check_cow_input(arg, kwargs_copy[kw], kw)
|
||||
check_cow_input(arg, kwargs_copy[kw], kwargs_raw[kw], kw)
|
||||
|
||||
# Call backward op if it is supported. This part of the test is
|
||||
# based on `composite_compliance.check_backward_formula`
|
||||
@ -1954,6 +1973,7 @@ class TestCompositeCompliance(TestCase):
|
||||
check_cow_input(
|
||||
arg,
|
||||
args_copy[idx],
|
||||
args_raw[idx],
|
||||
idx,
|
||||
backward_or_forward="backward",
|
||||
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
|
||||
@ -1965,6 +1985,7 @@ class TestCompositeCompliance(TestCase):
|
||||
check_cow_input(
|
||||
output_grad,
|
||||
output_grads_copy[idx],
|
||||
output_grads_raw[idx],
|
||||
f"output grad {idx}",
|
||||
backward_or_forward="backward",
|
||||
supports_cow_input_no_materialize=op.supports_cow_input_no_materialize_backward,
|
||||
|
Reference in New Issue
Block a user