mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Unconditionally enable python dispatcher in AOTAutograd (#88365)
Signed-off-by: Edward Z. Yang <ezyang@fb.com> Pull Request resolved: https://github.com/pytorch/pytorch/pull/88365 Approved by: https://github.com/Chillee
This commit is contained in:
committed by
PyTorch MergeBot
parent
a689502275
commit
97d3b200ca
@ -391,24 +391,25 @@ def aot_dispatch_autograd(flat_fn, flat_args: List[Tensor], aot_config: AOTConfi
|
|||||||
disable_amp = torch._C._is_any_autocast_enabled()
|
disable_amp = torch._C._is_any_autocast_enabled()
|
||||||
|
|
||||||
if config.use_functionalize:
|
if config.use_functionalize:
|
||||||
# Trace once without decompositions, into a graph of ATen ops.
|
with enable_python_dispatcher():
|
||||||
# NB: tracing_mode is real, as it's assumed the calling context setup
|
# Trace once without decompositions, into a graph of ATen ops.
|
||||||
# fake tensor mode / symbolic shapes if that is needed
|
# NB: tracing_mode is real, as it's assumed the calling context setup
|
||||||
fx_g = make_fx(joint_forward_backward)(*joint_inputs)
|
# fake tensor mode / symbolic shapes if that is needed
|
||||||
|
fx_g = make_fx(joint_forward_backward)(*joint_inputs)
|
||||||
|
|
||||||
context = disable_autocast_manager if disable_amp else nullcontext
|
context = disable_autocast_manager if disable_amp else nullcontext
|
||||||
|
|
||||||
def fake_fn(primals, tangents):
|
def fake_fn(primals, tangents):
|
||||||
with torch.fx.traceback.override_stack_trace():
|
with torch.fx.traceback.override_stack_trace():
|
||||||
return torch.fx.Interpreter(fx_g).run(primals, tangents)
|
return torch.fx.Interpreter(fx_g).run(primals, tangents)
|
||||||
|
|
||||||
# Trace a second time, running functionalization, and THEN running decompositions.
|
# Trace a second time, running functionalization, and THEN running decompositions.
|
||||||
# functionalization only acts on ATen today, and doesn't currently handle
|
# functionalization only acts on ATen today, and doesn't currently handle
|
||||||
# view and inplace ops that come from primtorch.
|
# view and inplace ops that come from primtorch.
|
||||||
# Eventually, functionalization should support primtorch view/inplace ops,
|
# Eventually, functionalization should support primtorch view/inplace ops,
|
||||||
# which will make it ok to run decompositions before functionalization.
|
# which will make it ok to run decompositions before functionalization.
|
||||||
with context():
|
with context():
|
||||||
fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs)
|
fx_g = make_fx(functionalize(fake_fn), aot_config.decompositions)(*joint_inputs)
|
||||||
fx_g.graph.eliminate_dead_code()
|
fx_g.graph.eliminate_dead_code()
|
||||||
fx_g.recompile()
|
fx_g.recompile()
|
||||||
else:
|
else:
|
||||||
|
@ -103,7 +103,7 @@ def resolve_key(op: PyOperatorABC, k: DispatchKey): # type: ignore[valid-type]
|
|||||||
# The dispatch key itself will implicitly route to backend fallback.
|
# The dispatch key itself will implicitly route to backend fallback.
|
||||||
# This is probably not great for the pure Python implementation.
|
# This is probably not great for the pure Python implementation.
|
||||||
return k
|
return k
|
||||||
raise RuntimeError("could not find kernel")
|
raise NotImplementedError(f"could not find kernel for {op} at dispatch key {k}")
|
||||||
|
|
||||||
|
|
||||||
pyop_namespace = {}
|
pyop_namespace = {}
|
||||||
|
Reference in New Issue
Block a user