Revert "[BE] Simplify code interacting with get_proxy_mode/enable_tracing (#132675)"

This reverts commit c2bccfd4311fe905ff78c0977281b8e642bb10d6.

Reverted https://github.com/pytorch/pytorch/pull/132675 on behalf of https://github.com/PaliC due to We need to now revert https://github.com/pytorch/pytorch/pull/132216 in OSS and there is a dependency on this pr ([comment](https://github.com/pytorch/pytorch/pull/132674#issuecomment-2274062785))
This commit is contained in:
PyTorch MergeBot
2024-08-07 18:25:31 +00:00
parent f2ad3c89b0
commit 9d476fee53
15 changed files with 192 additions and 106 deletions

View File

@ -17,6 +17,8 @@ _export_tracepoint = HigherOrderOperator("_export_tracepoint")
@_export_tracepoint.py_impl(ProxyTorchDispatchMode)
def export_tracepoint_dispatch_mode(mode, *args, **kwargs):
if not mode.enable_tracing:
return _export_tracepoint(*args, **kwargs)
p_args, p_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, (args, kwargs))
proxy = mode.tracer.create_proxy(
"call_function", _export_tracepoint, p_args, p_kwargs

View File

@ -168,7 +168,10 @@ associative_scan_op.py_impl(DispatchKey.Autograd)(
@associative_scan_op.py_impl(ProxyTorchDispatchMode)
def associative_scan_proxy_mode(mode, combine_fn, input, dim):
return trace_associative_scan(mode, associative_scan_op, combine_fn, input, dim)
if mode.enable_tracing:
return trace_associative_scan(mode, associative_scan_op, combine_fn, input, dim)
else:
return associative_scan_op(mode, associative_scan_op, combine_fn, input, dim)
@associative_scan_op.py_impl(FakeTensorMode)

View File

@ -170,6 +170,9 @@ def auto_functionalized_proxy(
_mutable_op: OpOverload,
**kwargs: Any,
) -> Tuple[Any, Tuple[Tensor, ...]]:
if not mode.enable_tracing:
return auto_functionalized(_mutable_op, **kwargs)
with disable_proxy_modes_tracing():
out = auto_functionalized(_mutable_op, **kwargs)

View File

@ -404,7 +404,10 @@ def cond_autograd(pred, true_fn, false_fn, operands):
@cond_op.py_impl(ProxyTorchDispatchMode)
def inner(mode, pred, true_fn, false_fn, operands):
return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
if mode.enable_tracing:
return trace_cond(mode, cond_op, pred, true_fn, false_fn, operands)
else:
return cond_op(pred, true_fn, false_fn, operands)
@cond_op.py_impl(FakeTensorMode)

View File

@ -151,6 +151,9 @@ def with_effects_proxy(
*args: Tuple[Any, ...],
**kwargs: Dict[str, Any],
) -> Tuple[torch.Tensor, ...]:
if not mode.enable_tracing:
return with_effects(token, op, *args, **kwargs)
with disable_proxy_modes_tracing():
out = with_effects(token, op, *args, **kwargs)

View File

@ -315,18 +315,31 @@ def flex_attention_proxy_torch_dispatch_mode(
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor]:
assert mode is not None, "Mode should always be enabled for python fallback key"
return trace_flex_attention(
mode,
query,
key,
value,
score_mod,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
if mode.enable_tracing:
return trace_flex_attention(
mode,
query,
key,
value,
score_mod,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
else:
return flex_attention(
query,
key,
value,
score_mod,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
@flex_attention.py_functionalize_impl
@ -834,22 +847,39 @@ def flex_attention_backward_proxy_torch_dispatch_mode(
mask_mod_other_buffers: Tuple = (),
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
assert mode is not None, "Mode should always be enabled for python fallback key"
return trace_flex_attention_backward(
mode,
query,
key,
value,
out,
logsumexp,
grad_out,
fw_graph,
joint_graph,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
if mode.enable_tracing:
return trace_flex_attention_backward(
mode,
query,
key,
value,
out,
logsumexp,
grad_out,
fw_graph,
joint_graph,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
else:
return flex_attention_backward(
query,
key,
value,
out,
logsumexp,
grad_out,
fw_graph,
joint_graph,
block_mask,
scale,
kernel_options,
score_mod_other_buffers,
mask_mod_other_buffers,
)
@flex_attention_backward.py_functionalize_impl

View File

@ -218,7 +218,10 @@ def map_autograd(f, xs, pos_args):
@map_impl.py_impl(ProxyTorchDispatchMode)
def map_proxy_torch_dispatch_mode(mode, f, xs, args):
return trace_map(mode, map_impl, f, xs, args)
if mode.enable_tracing:
return trace_map(mode, map_impl, f, xs, args)
else:
return map_impl(f, xs, args)
@map_impl.py_impl(FakeTensorMode)

View File

@ -147,7 +147,10 @@ def out_dtype_proxy(
output_dtype: torch.dtype,
*args,
):
return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
if mode.enable_tracing:
return trace_out_dtype(mode, out_dtype, op, output_dtype, *args)
else:
return out_dtype(op, output_dtype, *args)
@out_dtype.py_impl(FakeTensorMode)

View File

@ -13,6 +13,8 @@ run_const_graph = HigherOrderOperator("run_const_graph")
@run_const_graph.py_impl(ProxyTorchDispatchMode)
def run_const_graph_dispatch_mode(mode, *args):
if not mode.enable_tracing:
return run_const_graph(*args)
const_gm, weights = args
p_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
assert isinstance(const_gm, torch.fx.GraphModule)

View File

@ -45,7 +45,10 @@ strict_mode_op.py_impl(DispatchKey.Autograd)(
@strict_mode_op.py_impl(ProxyTorchDispatchMode)
def inner(mode, callable, operands):
return trace_strict_mode(mode, strict_mode_op, callable, operands)
if mode.enable_tracing:
return trace_strict_mode(mode, strict_mode_op, callable, operands)
else:
return strict_mode_op(callable, operands)
def trace_strict_mode(mode, strict_mode_op, callable, operands):

View File

@ -68,39 +68,42 @@ def call_torchbind_impl(obj, method, *args, **kwargs):
@call_torchbind.py_impl(ProxyTorchDispatchMode)
def inner(mode, *args, **kwargs):
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
if mode.enable_tracing:
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, args)
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function",
call_torchbind,
proxy_args,
proxy_kwargs,
)
out = call_torchbind(*args, **kwargs)
obj, method, *rest_args = args
if isinstance(obj, torch.ScriptObject):
ns, class_name = _ns_and_class_name(
obj._type().qualified_name() # type: ignore[attr-defined]
)
log.warning(
"Tracing torchbind method %s.%s with real ScriptObject. This may"
" cause the original object being mutated. If this is not intended,"
' You can register a fake class with torch._library.register_fake_class("%s::%s").',
class_name,
method,
ns,
class_name,
out_proxy = mode.tracer.create_proxy(
"call_function",
call_torchbind,
proxy_args,
proxy_kwargs,
)
out = call_torchbind(*args, **kwargs)
ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
if "val" not in out_proxy.node.meta:
assert out is None or isinstance(
out, (int, float, bool)
), "Currently, only these constant dtypes are supported to be returned from torchbind methods."
out_proxy.node.meta["val"] = out
return ret
obj, method, *rest_args = args
if isinstance(obj, torch.ScriptObject):
ns, class_name = _ns_and_class_name(
obj._type().qualified_name() # type: ignore[attr-defined]
)
log.warning(
"Tracing torchbind method %s.%s with real ScriptObject. This may"
" cause the original object being mutated. If this is not intended,"
' You can register a fake class with torch._library.register_fake_class("%s::%s").',
class_name,
method,
ns,
class_name,
)
ret = track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
if "val" not in out_proxy.node.meta:
assert out is None or isinstance(
out, (int, float, bool)
), "Currently, only these constant dtypes are supported to be returned from torchbind methods."
out_proxy.node.meta["val"] = out
return ret
else:
return call_torchbind(*args, **kwargs)
# When tracing with fake script object, the call_torchbind op will return a fake tensor

View File

@ -590,16 +590,24 @@ def trace_triton_kernel_wrapper(proxy_mode, func_overload, node_args):
def triton_kernel_wrapper_mutation_proxy_torch_dispatch_mode(
mode, *, kernel_idx, constant_args_idx, grid, kwargs
):
trace_triton_kernel_wrapper(
mode,
triton_kernel_wrapper_mutation,
{
"kernel_idx": kernel_idx,
"constant_args_idx": constant_args_idx,
"grid": grid,
"kwargs": kwargs,
},
)
if mode.enable_tracing:
trace_triton_kernel_wrapper(
mode,
triton_kernel_wrapper_mutation,
{
"kernel_idx": kernel_idx,
"constant_args_idx": constant_args_idx,
"grid": grid,
"kwargs": kwargs,
},
)
else:
triton_kernel_wrapper_mutation(
kernel_idx=kernel_idx,
constant_args_idx=constant_args_idx,
grid=grid,
kwargs=kwargs,
)
return None
@ -683,17 +691,25 @@ def triton_kernel_wrapper_functional_fake_tensor_mode(
def triton_kernel_wrapper_functional_proxy_torch_dispatch_mode(
mode, *, kernel_idx, constant_args_idx, grid, kwargs, tensors_to_clone
):
return trace_triton_kernel_wrapper(
mode,
triton_kernel_wrapper_functional,
{
"kernel_idx": kernel_idx,
"constant_args_idx": constant_args_idx,
"grid": grid,
"kwargs": kwargs,
"tensors_to_clone": tensors_to_clone,
},
)
if mode.enable_tracing:
return trace_triton_kernel_wrapper(
mode,
triton_kernel_wrapper_functional,
{
"kernel_idx": kernel_idx,
"constant_args_idx": constant_args_idx,
"grid": grid,
"kwargs": kwargs,
"tensors_to_clone": tensors_to_clone,
},
)
else:
return triton_kernel_wrapper_functional(
kernel_idx=kernel_idx,
grid=grid,
kwargs=kwargs,
tensors_to_clone=tensors_to_clone,
)
@triton_kernel_wrapper_functional.py_functionalize_impl

View File

@ -222,9 +222,12 @@ def while_loop_tracing(mode, cond_fn, body_fn, carried_inputs, additional_inputs
out, out_proxy, constant=None, tracer=proxy_mode.tracer
)
return _trace_while_loop(
mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
)
if mode.enable_tracing:
return _trace_while_loop(
mode, while_loop_op, cond_fn, body_fn, carried_inputs, additional_inputs
)
else:
return while_loop_op(cond_fn, body_fn, carried_inputs, additional_inputs)
@while_loop_op.py_impl(FakeTensorMode)

View File

@ -178,13 +178,16 @@ def register_run_and_save_rng_state_op():
@run_and_save_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(mode, op, *args, **kwargs):
out = impl_backend_select(op, *args, **kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
if mode.enable_tracing:
out = impl_backend_select(op, *args, **kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (op, *args))
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_and_save_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
else:
return run_and_save_rng_state(op, *args, **kwargs)
return run_and_save_rng_state
@ -214,16 +217,19 @@ def register_run_with_rng_state_op():
@run_with_rng_state.py_impl(ProxyTorchDispatchMode)
def impl_proxy_dispatch_mode(mode, rng_state, op, *args, **kwargs):
# TODO: you don't need to do this, the dispatch here already disabled
# it
with disable_proxy_modes_tracing():
out = run_with_rng_state(rng_state, op, *args, **kwargs)
proxy_args = pytree.tree_map(mode.tracer.unwrap_proxy, (rng_state, op, *args))
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_with_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
if mode.enable_tracing:
with disable_proxy_modes_tracing():
out = run_with_rng_state(rng_state, op, *args, **kwargs)
proxy_args = pytree.tree_map(
mode.tracer.unwrap_proxy, (rng_state, op, *args)
)
proxy_kwargs = pytree.tree_map(mode.tracer.unwrap_proxy, kwargs)
out_proxy = mode.tracer.create_proxy(
"call_function", run_with_rng_state, proxy_args, proxy_kwargs
)
return track_tensor_tree(out, out_proxy, constant=None, tracer=mode.tracer)
else:
return run_with_rng_state(rng_state, op, *args, **kwargs)
@run_with_rng_state.py_impl(DispatchKey.BackendSelect)
def impl_backend_select(rng_state, op, *args, **kwargs):

View File

@ -7,7 +7,7 @@ import torch
import torch.distributed as dist
import torch.distributed.distributed_c10d as c10d
from torch.distributed.device_mesh import DeviceMesh
from torch.fx.experimental.proxy_tensor import get_proxy_mode
from torch.fx.experimental.proxy_tensor import get_innermost_proxy_mode
from . import _functional_collectives_impl as fun_col_impl
@ -806,7 +806,10 @@ def _are_we_tracing() -> bool:
is not None
):
return True
return get_proxy_mode() is not None
mode = get_innermost_proxy_mode()
if mode is None:
return False
return mode.tracer is not None
def _maybe_wrap_tensor(self) -> torch.Tensor: