mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[ca] suggest to disable compiled autograd for trace-time NotImplementedErrors (#156509)
Example: ```python File "/home/xmfan/core/a/pytorch/torch/autograd/graph.py", line 829, in _engine_run_backward return Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ NotImplementedError: TorchDispatchMode not yet implemented for compiled autograd. You can disable compiled autograd for this operation by: 1. Relocating the unsupported autograd call outside the compiled region. 2. Wrapping the unsupported autograd call within a scope that disables compiled autograd. 3. Configuring the specific compilation unit to disable compiled autograd. 4. Globally disabling compiled autograd at the application's initialization. ``` No duplicate error messages for python side trace-time errors ```python ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ File "/home/xmfan/core/a/pytorch/torch/_dynamo/compiled_autograd.py", line 344, in begin_capture raise NotImplementedError( NotImplementedError: Found tensor of type <class 'torch.nn.utils._expanded_weights.expanded_weights_impl.ExpandedWeight'>, which is not supported by FakeTensorMode. You can turn off compiled autograd by either: 1. Moving the unsupported autograd call outside of the torch.compile'd region. 2. Wrapping the unsupported autograd call in the torch._dynamo.compiled_autograd._disable() context manager. 3. Setting torch._dynamo.config.compiled_autograd=False for the torch.compile call containing the unsupported autograd call. 4. Setting torch._dynamo.config.compiled_autograd=False at the start of the program. ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/156509 Approved by: https://github.com/jansel ghstack dependencies: #156374
This commit is contained in:
committed by
PyTorch MergeBot
parent
f1968a5e76
commit
5f2f343e1e
@ -594,8 +594,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
// Implementations in subclasses should call args.collect() with all node
|
||||
// attrs. These functions are only called durring backward.
|
||||
virtual void compiled_args(CompiledNodeArgs& args) const {
|
||||
throw std::runtime_error(
|
||||
std::string("compiled_args not implemented: ") + name());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, std::string("compiled_args not implemented: ") + name());
|
||||
}
|
||||
|
||||
// Used by compiled autograd to call apply() with different saved tensors
|
||||
@ -604,8 +604,8 @@ struct TORCH_API Node : std::enable_shared_from_this<Node> {
|
||||
virtual variable_list apply_with_saved(
|
||||
const variable_list& inputs,
|
||||
SwapSavedVariables& saved) {
|
||||
throw std::runtime_error(
|
||||
std::string("apply_with_saved not implemented: ") + name());
|
||||
TORCH_CHECK_NOT_IMPLEMENTED(
|
||||
false, std::string("apply_with_saved not implemented: ") + name());
|
||||
}
|
||||
|
||||
// If this node is the AOTBackward node produced by torch.compile.
|
||||
|
Reference in New Issue
Block a user