mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[map] filter none gradients and add autograd inductor tests (#160548)
Will filter the none outputs in autograd backward for other hops as follow ups Pull Request resolved: https://github.com/pytorch/pytorch/pull/160548 Approved by: https://github.com/zou3519
This commit is contained in:
committed by
PyTorch MergeBot
parent
fa75ba9303
commit
ff86509a06
@ -737,8 +737,7 @@ def forward(self, pred_1, x_1):
|
||||
getitem_1 = cond_1[0]; getitem_1 = None
|
||||
getitem_2 = cond_1[1]
|
||||
getitem_3 = cond_1[2]; getitem_3 = None
|
||||
getitem_4 = cond_1[3]; getitem_4 = None
|
||||
getitem_5 = cond_1[4]; cond_1 = getitem_5 = None
|
||||
getitem_4 = cond_1[3]; cond_1 = getitem_4 = None
|
||||
return (getitem_2,)""", # noqa: B950
|
||||
)
|
||||
|
||||
@ -854,10 +853,7 @@ def forward(self, pred_1, a_1, b_1, c_1):
|
||||
cond_1 = torch.ops.higher_order.cond(pred_1, true_graph_1, false_graph_1, (a_1, b_1, sym_size_int, sym_size_int_1, c_1, sym_size_int_2, ones_like)); pred_1 = true_graph_1 = false_graph_1 = a_1 = b_1 = sym_size_int = sym_size_int_1 = c_1 = sym_size_int_2 = ones_like = None
|
||||
getitem_1 = cond_1[0]
|
||||
getitem_2 = cond_1[1]
|
||||
getitem_3 = cond_1[2]; getitem_3 = None
|
||||
getitem_4 = cond_1[3]; getitem_4 = None
|
||||
getitem_5 = cond_1[4]; getitem_5 = None
|
||||
getitem_6 = cond_1[5]; cond_1 = getitem_6 = None
|
||||
getitem_3 = cond_1[2]; cond_1 = getitem_3 = None
|
||||
return (getitem_1, getitem_2)""", # noqa: B950
|
||||
)
|
||||
# Forward
|
||||
@ -877,7 +873,7 @@ def forward(self, arg0_1, arg1_1, arg2_1, arg3_1, arg4_1, arg5_1, arg6_1):
|
||||
clone = torch.ops.aten.clone.default(arg6_1)
|
||||
clone_1 = torch.ops.aten.clone.default(arg6_1); arg6_1 = None
|
||||
zeros_like = torch.ops.aten.zeros_like.default(arg4_1, pin_memory = False); arg4_1 = None
|
||||
return [clone, clone_1, None, None, zeros_like, None]""",
|
||||
return [clone, clone_1, zeros_like]""",
|
||||
)
|
||||
|
||||
def test_cond_autograd_pytree_input(self):
|
||||
@ -1302,15 +1298,11 @@ def forward(self, pred_1, x_1):
|
||||
|
||||
return cond_outputs, cond_inputs
|
||||
|
||||
# TODO: The compile_mode = `compile_dynamic_shape` raises the Error
|
||||
# torch._inductor.exc.LoweringException: NotImplementedError: get_size() is not
|
||||
# implemented by <class 'torch._inductor.ir.NoneAsConstantBuffer'>!
|
||||
@skipIfTorchDynamo("don't test compile on compile")
|
||||
@unittest.skipIf(not SM70OrLater, "triton")
|
||||
@unittest.skipIf(not torch.cuda.is_available(), "Test requires CUDA.")
|
||||
@parametrize("compile_mode", ["compile_dynamic_shape"])
|
||||
@parametrize("scalar", [False])
|
||||
@unittest.expectedFailure
|
||||
def test_cond_autograd_zeros_unused_branch_complex_compile_fail(
|
||||
self, compile_mode, scalar
|
||||
):
|
||||
|
@ -5,6 +5,7 @@ import unittest
|
||||
|
||||
import torch
|
||||
import torch._dynamo.testing
|
||||
import torch.utils._pytree as pytree
|
||||
from torch._higher_order_ops.associative_scan import associative_scan
|
||||
from torch._higher_order_ops.map import _fake_map
|
||||
from torch._higher_order_ops.scan import _fake_scan, scan
|
||||
@ -37,6 +38,24 @@ def prepend_counters(inputs, num_counters=1, counter_values=(0, 1, 5)):
|
||||
return _prepend_product_of_values(inputs, counter_values, num_counters)
|
||||
|
||||
|
||||
# a testing loss_fn
|
||||
def loss_fn(result) -> torch.Tensor:
|
||||
flat_results, _ = pytree.tree_flatten(result)
|
||||
total_loss = torch.tensor(
|
||||
0.0, device=flat_results[0].device if flat_results else torch.device("cpu")
|
||||
)
|
||||
|
||||
for res in flat_results:
|
||||
# Convert to float if integer tensor to avoid numerical issues
|
||||
if not res.dtype.is_floating_point:
|
||||
res = res.float()
|
||||
|
||||
# Simple robust loss: abs values + small constant to avoid inf/nan
|
||||
total_loss = total_loss + (torch.abs(res) / (1.0 + torch.abs(res))).sum()
|
||||
|
||||
return total_loss
|
||||
|
||||
|
||||
class CondModels:
|
||||
class Simple(torch.nn.Module):
|
||||
def forward(self, p, a, b):
|
||||
@ -1036,8 +1055,6 @@ class WhileLoopTests(TestCase):
|
||||
dynamic=False,
|
||||
num_counters=1,
|
||||
):
|
||||
import torch.utils._pytree as pytree
|
||||
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
||||
compiled_model = torch.compile(backend=cnt, fullgraph=True)(model)
|
||||
|
||||
@ -1566,8 +1583,6 @@ class ScanModels:
|
||||
|
||||
def forward(self, scan_op, _input, weight, bias):
|
||||
def combine_fn(carry, x):
|
||||
from torch.utils import _pytree as pytree
|
||||
|
||||
new_carry = {
|
||||
"param": carry["param"] @ x + carry["bias"],
|
||||
"bias": carry["bias"].sin(),
|
||||
@ -1977,51 +1992,88 @@ class MapTests(TestCase):
|
||||
inputs,
|
||||
device,
|
||||
dynamic=False,
|
||||
autograd=False,
|
||||
):
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
||||
compiled_model = torch.compile(backend=cnt, fullgraph=True, dynamic=dynamic)(
|
||||
model
|
||||
)
|
||||
import copy
|
||||
|
||||
inputs = [inp.to(device=device) for inp in inputs]
|
||||
model = model.to(device=device)
|
||||
model_eager = copy.deepcopy(model)
|
||||
model_compiled = copy.deepcopy(model)
|
||||
cnt = torch._dynamo.testing.CompileCounterWithBackend("inductor")
|
||||
compiled_model = torch.compile(backend=cnt, fullgraph=True, dynamic=dynamic)(
|
||||
model_compiled
|
||||
)
|
||||
|
||||
if autograd:
|
||||
pytree.tree_map_only(torch.Tensor, lambda t: t.requires_grad_(True), inputs)
|
||||
|
||||
cloned_inputs = [inp.clone() for inp in inputs]
|
||||
result = model(torch._higher_order_ops.map, *cloned_inputs)
|
||||
result_exp = model(_fake_map, *cloned_inputs)
|
||||
result_exp = model_eager(_fake_map, *cloned_inputs)
|
||||
result_compiled = compiled_model(torch._higher_order_ops.map, *cloned_inputs)
|
||||
|
||||
self.assertEqual(result, result_exp)
|
||||
self.assertEqual(result, result_compiled)
|
||||
|
||||
if autograd:
|
||||
|
||||
def loss_fn(result) -> torch.Tensor:
|
||||
flat_results, _ = pytree.tree_flatten(result)
|
||||
return sum(
|
||||
[
|
||||
torch.sqrt(torch.pow(res.sum() / res.max(), 2)).sum()
|
||||
for res in flat_results
|
||||
]
|
||||
)
|
||||
|
||||
loss_fn(result).backward()
|
||||
loss_fn(result_exp).backward()
|
||||
loss_fn(result_compiled).backward()
|
||||
|
||||
model_params = dict(model.named_parameters())
|
||||
model_eager_params = dict(model_eager.named_parameters())
|
||||
model_compiled_params = dict(model_compiled.named_parameters())
|
||||
for name, param in model_eager_params.items():
|
||||
self.assertEqual(param, model_params[name])
|
||||
self.assertEqual(param, model_compiled_params[name])
|
||||
self.assertEqual(param.grad, model_params[name].grad)
|
||||
self.assertEqual(param.grad, model_compiled_params[name].grad)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("autograd", [True, False])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_map_simple(self, device, dynamic):
|
||||
def test_map_simple(self, device, dynamic, autograd):
|
||||
self._run_test(
|
||||
model=MapModels.Simple(),
|
||||
inputs=(torch.randn(3, 4),),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("autograd", [True, False])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_map_simple_linear_with_view(self, device, dynamic):
|
||||
def test_map_simple_linear_with_view(self, device, dynamic, autograd):
|
||||
self._run_test(
|
||||
model=MapModels.SimpleWithLinearWithView(),
|
||||
inputs=(torch.randn(3, 4),),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("autograd", [True, False])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_map_pytree_in_out(self, device, dynamic):
|
||||
def test_map_pytree_in_out(self, device, dynamic, autograd):
|
||||
self._run_test(
|
||||
model=MapModels.PytreeInOut(),
|
||||
inputs=(
|
||||
@ -2031,13 +2083,15 @@ class MapTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
@requires_gpu
|
||||
@parametrize("device", ["cpu", GPU_TYPE])
|
||||
@parametrize("dynamic", [True, False])
|
||||
@parametrize("autograd", [True, False])
|
||||
@torch._dynamo.config.patch("capture_scalar_outputs", True)
|
||||
def test_map_nested_with_cond(self, device, dynamic):
|
||||
def test_map_nested_with_cond(self, device, dynamic, autograd):
|
||||
self._run_test(
|
||||
model=MapModels.NestedWithCond(),
|
||||
inputs=(
|
||||
@ -2047,6 +2101,7 @@ class MapTests(TestCase):
|
||||
),
|
||||
device=device,
|
||||
dynamic=dynamic,
|
||||
autograd=autograd,
|
||||
)
|
||||
|
||||
|
||||
|
@ -1,6 +1,7 @@
|
||||
# mypy: allow-untyped-decorators
|
||||
# mypy: allow-untyped-defs
|
||||
import contextlib
|
||||
import functools
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Callable, Optional, Union
|
||||
@ -20,6 +21,8 @@ from torch._higher_order_ops.utils import (
|
||||
_set_compilation_env,
|
||||
check_input_alias_and_mutation_return_outputs,
|
||||
create_bw_fn,
|
||||
fill_none_with_masks,
|
||||
filter_with_masks,
|
||||
materialize_as_graph,
|
||||
reenter_make_fx,
|
||||
save_tensors_and_symints_for_backward,
|
||||
@ -342,15 +345,32 @@ class CondAutogradOp(torch.autograd.Function):
|
||||
args = operands + flat_grads
|
||||
# TODO: we need to materialize the bw graphs because dynamo is unable to
|
||||
# trace through the joint function when torch.compile torch.autograd.grad.
|
||||
|
||||
grads_tensor_masks = []
|
||||
|
||||
def create_fn_remove_none(fn):
|
||||
@functools.wraps(fn)
|
||||
def wrapped(*args):
|
||||
nonlocal grads_tensor_masks
|
||||
|
||||
true_outputs = fn(*args)
|
||||
grads_tensor_masks = [
|
||||
True if isinstance(out, torch.Tensor) else False
|
||||
for out in true_outputs
|
||||
]
|
||||
return filter_with_masks(true_outputs, grads_tensor_masks)
|
||||
|
||||
return wrapped
|
||||
|
||||
true_bw_gm = materialize_as_graph(
|
||||
ctx._true_bw_fn,
|
||||
create_fn_remove_none(ctx._true_bw_fn),
|
||||
args,
|
||||
ctx._fw_include_key_set,
|
||||
ctx._fw_exclude_key_set,
|
||||
force_enable_grad=True,
|
||||
)
|
||||
false_bw_gm = materialize_as_graph(
|
||||
ctx._false_bw_fn,
|
||||
create_fn_remove_none(ctx._false_bw_fn),
|
||||
args,
|
||||
ctx._fw_include_key_set,
|
||||
ctx._fw_exclude_key_set,
|
||||
@ -362,7 +382,7 @@ class CondAutogradOp(torch.autograd.Function):
|
||||
false_bw_gm,
|
||||
args,
|
||||
)
|
||||
return None, None, None, *grads
|
||||
return None, None, None, *fill_none_with_masks(grads, grads_tensor_masks)
|
||||
|
||||
|
||||
# Note:
|
||||
|
@ -22,6 +22,8 @@ from .utils import (
|
||||
_stack_pytree,
|
||||
_unstack_pytree,
|
||||
create_bw_fn,
|
||||
fill_none_with_masks,
|
||||
filter_with_masks,
|
||||
materialize_as_graph,
|
||||
save_tensors_and_symints_for_backward,
|
||||
saved_tensors_and_symints,
|
||||
@ -154,8 +156,12 @@ class MapAutogradOp(torch.autograd.Function):
|
||||
|
||||
bw_f = create_bw_fn(ctx._f, fw_args)
|
||||
|
||||
grads_tensor_masks = []
|
||||
|
||||
# Create a wrapper around thefor the bw_f
|
||||
def bw_f_wrapper(*args):
|
||||
nonlocal grads_tensor_masks
|
||||
|
||||
# Dissect args and re-order them for the ``ctx._bw_f``
|
||||
# args provided to the wrapper are composed of [*fw_mapped_args, *flat_grads, *pos_args]
|
||||
# The content of ``bw_f_tangents`` are the upstream gradients, i.e. flat_grads
|
||||
@ -165,7 +171,11 @@ class MapAutogradOp(torch.autograd.Function):
|
||||
args, [num_mapped_args, num_grads, num_pos_args]
|
||||
)
|
||||
bw_f_primals = *fw_m_args, *pos_args
|
||||
return bw_f(*bw_f_primals, *bw_f_tangents)
|
||||
gradients = bw_f(*bw_f_primals, *bw_f_tangents)
|
||||
grads_tensor_masks = [
|
||||
True if isinstance(out, torch.Tensor) else out for out in gradients
|
||||
]
|
||||
return filter_with_masks(gradients, grads_tensor_masks)
|
||||
|
||||
def construct_args_single_step_bw():
|
||||
unwrapped_mapped_xs = pytree.tree_map(_from_fun, fw_mapped_args)
|
||||
@ -194,7 +204,7 @@ class MapAutogradOp(torch.autograd.Function):
|
||||
|
||||
grads = map_impl(fn_bw_gm, fw_mapped_args + flat_grads, pos_args)
|
||||
|
||||
return None, None, *grads
|
||||
return None, None, *fill_none_with_masks(grads, grads_tensor_masks)
|
||||
|
||||
|
||||
def trace_map(proxy_mode, func_overload, f, xs, pos_args):
|
||||
|
@ -1217,3 +1217,13 @@ def _has_gen_schema(op: HigherOrderOperator):
|
||||
return hasattr(type(op), method) and getattr(type(op), method) is not getattr(
|
||||
HigherOrderOperator, method
|
||||
)
|
||||
|
||||
|
||||
def filter_with_masks(data: list[Optional[torch.Tensor]], masks: list[bool]):
|
||||
assert len(data) == len(masks)
|
||||
return [item for item, keep in zip(data, masks) if keep]
|
||||
|
||||
|
||||
def fill_none_with_masks(data: list[Optional[torch.Tensor]], masks: list[bool]):
|
||||
data_iter = iter(data)
|
||||
return [next(data_iter) if kept else None for kept in masks]
|
||||
|
Reference in New Issue
Block a user