diff --git a/docs/source/autograd.rst b/docs/source/autograd.rst index f3899fd0a60d..046ee42717bf 100644 --- a/docs/source/autograd.rst +++ b/docs/source/autograd.rst @@ -33,6 +33,9 @@ for detailed steps on how to use this API. forward_ad.dual_level forward_ad.make_dual forward_ad.unpack_dual + forward_ad.enter_dual_level + forward_ad.exit_dual_level + forward_ad.UnpackedDualTensor .. _functional-api: @@ -209,6 +212,27 @@ When creating a new :class:`Function`, the following methods are available to `c function.FunctionCtx.save_for_backward function.FunctionCtx.set_materialize_grads +Custom Function utilities +^^^^^^^^^^^^^^^^^^^^^^^^^ +Decorator for backward method. + +.. autosummary:: + :toctree: generated + :nosignatures: + + function.once_differentiable + +Base custom :class:`Function` used to build PyTorch utilities + +.. autosummary:: + :toctree: generated + :nosignatures: + + function.BackwardCFunction + function.InplaceFunction + function.NestedIOFunction + + .. _grad-check: Numerical gradient checking @@ -224,6 +248,7 @@ Numerical gradient checking gradcheck gradgradcheck + GradcheckError .. Just to reset the base path for the rest of this file .. currentmodule:: torch.autograd @@ -249,6 +274,14 @@ and vtune profiler based using profiler.profile.key_averages profiler.profile.self_cpu_time_total profiler.profile.total_average + profiler.parse_nvprof_trace + profiler.EnforceUnique + profiler.KinetoStepTracker + profiler.record_function + profiler_util.Interval + profiler_util.Kernel + profiler_util.MemRecordsAcc + profiler_util.StringTable .. autoclass:: torch.autograd.profiler.emit_nvtx .. autoclass:: torch.autograd.profiler.emit_itt @@ -260,13 +293,20 @@ and vtune profiler based using profiler.load_nvprof -Anomaly detection +Debugging and anomaly detection ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ .. autoclass:: detect_anomaly .. autoclass:: set_detect_anomaly +.. autosummary:: + :toctree: generated + :nosignatures: + + grad_mode.set_multithreading_enabled + + Autograd graph ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ @@ -286,6 +326,7 @@ enabled and at least one of the inputs required gradients), or ``None`` otherwis graph.Node.next_functions graph.Node.register_hook graph.Node.register_prehook + graph.increment_version Some operations need intermediary results to be saved during the forward pass in order to execute the backward pass. diff --git a/docs/source/conf.py b/docs/source/conf.py index 3ece6fcc65a4..662757dd7155 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -91,9 +91,6 @@ templates_path = ["_templates"] coverage_ignore_functions = [ # torch "typename", - # torch.autograd - "register_py_tensor_class_for_device", - "variable", # torch.cuda "check_error", "cudart", @@ -390,20 +387,6 @@ coverage_ignore_functions = [ "weight_dtype", "weight_is_quantized", "weight_is_statically_quantized", - # torch.autograd.forward_ad - "enter_dual_level", - "exit_dual_level", - # torch.autograd.function - "once_differentiable", - "traceable", - # torch.autograd.gradcheck - "get_analytical_jacobian", - "get_numerical_jacobian", - "get_numerical_jacobian_wrt_specific_input", - # torch.autograd.graph - "increment_version", - # torch.autograd.profiler - "parse_nvprof_trace", # torch.backends.cudnn.rnn "get_cudnn_mode", "init_dropout_state", @@ -2530,40 +2513,6 @@ coverage_ignore_classes = [ "QuantWrapper", # torch.ao.quantization.utils "MatchAllNode", - # torch.autograd.forward_ad - "UnpackedDualTensor", - # torch.autograd.function - "BackwardCFunction", - "Function", - "FunctionCtx", - "FunctionMeta", - "InplaceFunction", - "NestedIOFunction", - # torch.autograd.grad_mode - "inference_mode", - "set_grad_enabled", - "set_multithreading_enabled", - # torch.autograd.gradcheck - "GradcheckError", - # torch.autograd.profiler - "EnforceUnique", - "KinetoStepTracker", - "profile", - "record_function", - # torch.autograd.profiler_legacy - "profile", - # torch.autograd.profiler_util - "EventList", - "FormattedTimesMixin", - "FunctionEvent", - "FunctionEventAvg", - "Interval", - "Kernel", - "MemRecordsAcc", - "StringTable", - # torch.autograd.variable - "Variable", - "VariableMeta", # torch.backends.cudnn.rnn "Unserializable", # torch.cuda.amp.grad_scaler diff --git a/docs/source/torch.rst b/docs/source/torch.rst index 0ac0d96cb1fe..71d467a9e354 100644 --- a/docs/source/torch.rst +++ b/docs/source/torch.rst @@ -274,9 +274,9 @@ Examples:: no_grad enable_grad - set_grad_enabled + autograd.grad_mode.set_grad_enabled is_grad_enabled - inference_mode + autograd.grad_mode.inference_mode is_inference_mode_enabled Math operations diff --git a/test/allowlist_for_publicAPI.json b/test/allowlist_for_publicAPI.json index 196df6175fdf..44b64324281f 100644 --- a/test/allowlist_for_publicAPI.json +++ b/test/allowlist_for_publicAPI.json @@ -35,20 +35,6 @@ "torch.nn.quantizable.modules.activation": "torch.ao.nn.quantizable.modules.activation", "torch.nn.quantizable.modules.rnn": "torch.ao.nn.quantizable.modules.rnn" }, - "torch.autograd": [ - "NestedIOFunction", - "detect_anomaly", - "enable_grad", - "grad", - "gradcheck", - "gradgradcheck", - "inference_mode", - "no_grad", - "set_detect_anomaly", - "set_grad_enabled", - "set_multithreading_enabled", - "variable" - ], "torch.backends": [ "contextmanager" ], diff --git a/torch/autograd/function.py b/torch/autograd/function.py index 7d7e94ced984..94a7dec60189 100644 --- a/torch/autograd/function.py +++ b/torch/autograd/function.py @@ -279,7 +279,14 @@ class _HookMixin: class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin): + r""" + This class is used for internal autograd work. Do not use. + """ + def apply(self, *args): + r""" + Apply method used when executing this Node during the backward + """ # _forward_cls is defined by derived class # The user should define either backward or vjp but never both. backward_fn = self._forward_cls.backward # type: ignore[attr-defined] @@ -294,6 +301,9 @@ class BackwardCFunction(_C._FunctionBase, FunctionCtx, _HookMixin): return user_fn(self, *args) def apply_jvp(self, *args): + r""" + Apply method used when executing forward mode AD during the forward + """ # _forward_cls is defined by derived class return self._forward_cls.jvp(self, *args) # type: ignore[attr-defined] @@ -378,10 +388,10 @@ class _SingleLevelFunction( Either: - 1. Override forward with the signature forward(ctx, *args, **kwargs). + 1. Override forward with the signature ``forward(ctx, *args, **kwargs)``. ``setup_context`` is not overridden. Setting up the ctx for backward happens inside the ``forward``. - 2. Override forward with the signature forward(*args, **kwargs) and + 2. Override forward with the signature ``forward(*args, **kwargs)`` and override ``setup_context``. Setting up the ctx for backward happens inside ``setup_context`` (as opposed to inside the ``forward``) @@ -639,6 +649,11 @@ def traceable(fn_cls): class InplaceFunction(Function): + r""" + This class is here only for backward compatibility reasons. + Use :class:`Function` instead of this for any new use case. + """ + def __init__(self, inplace=False): super().__init__() self.inplace = inplace @@ -754,6 +769,10 @@ _map_tensor_data = _nested_map( class NestedIOFunction(Function): + r""" + This class is here only for backward compatibility reasons. + Use :class:`Function` instead of this for any new use case. + """ # The 'type: ignore' statements are needed here because these functions are declared as '@staticmethod' in the # superclass (Function) but are instance methods here, which mypy reports as incompatible. @@ -774,6 +793,9 @@ class NestedIOFunction(Function): return result def backward(self, *gradients: Any) -> Any: # type: ignore[override] + r""" + Shared backward utility. + """ nested_gradients = _unflatten(gradients, self._nested_output) result = self.backward_extended(*nested_gradients) # type: ignore[func-returns-value] return tuple(_iter_None_tensors(result)) @@ -781,6 +803,9 @@ class NestedIOFunction(Function): __call__ = _do_forward def forward(self, *args: Any) -> Any: # type: ignore[override] + r""" + Shared forward utility. + """ nested_tensors = _map_tensor_data(self._nested_input) result = self.forward_extended(*nested_tensors) # type: ignore[func-returns-value] del self._nested_input @@ -788,22 +813,40 @@ class NestedIOFunction(Function): return tuple(_iter_tensors(result)) def save_for_backward(self, *args: Any) -> None: + r""" + See :meth:`Function.save_for_backward`. + """ self.to_save = tuple(_iter_tensors(args)) self._to_save_nested = args @property def saved_tensors(self): + r""" + See :meth:`Function.saved_tensors`. + """ flat_tensors = super().saved_tensors # type: ignore[misc] return _unflatten(flat_tensors, self._to_save_nested) def mark_dirty(self, *args: Any, **kwargs: Any) -> None: + r""" + See :meth:`Function.mark_dirty`. + """ self.dirty_tensors = tuple(_iter_tensors((args, kwargs))) def mark_non_differentiable(self, *args: Any, **kwargs: Any) -> None: + r""" + See :meth:`Function.mark_non_differentiable`. + """ self.non_differentiable = tuple(_iter_tensors((args, kwargs))) def forward_extended(self, *input: Any) -> None: + r""" + User defined forward. + """ raise NotImplementedError def backward_extended(self, *grad_output: Any) -> None: + r""" + User defined backward. + """ raise NotImplementedError diff --git a/torch/autograd/grad_mode.py b/torch/autograd/grad_mode.py index 258b0e96a4d9..4232d589e8de 100644 --- a/torch/autograd/grad_mode.py +++ b/torch/autograd/grad_mode.py @@ -196,6 +196,9 @@ class set_grad_enabled(_DecoratorContextManager): torch._C._set_grad_enabled(self.prev) def clone(self) -> "set_grad_enabled": + r""" + Create a copy of this class + """ return self.__class__(self.mode) @@ -272,6 +275,9 @@ class inference_mode(_DecoratorContextManager): self._inference_mode_context.__exit__(exc_type, exc_value, traceback) def clone(self) -> "inference_mode": + r""" + Create a copy of this class + """ return self.__class__(self.mode) @@ -315,6 +321,9 @@ class set_multithreading_enabled(_DecoratorContextManager): torch._C._set_multithreading_enabled(self.prev) def clone(self) -> "set_multithreading_enabled": + r""" + Create a copy of this class + """ return self.__class__(self.mode) diff --git a/torch/autograd/profiler.py b/torch/autograd/profiler.py index 9c5f8bd15169..90b47dda461c 100644 --- a/torch/autograd/profiler.py +++ b/torch/autograd/profiler.py @@ -876,6 +876,9 @@ class EnforceUnique: self.seen = set() def see(self, *key): + r""" + Observe a key and raise an error if it is seen multiple times. + """ if key in self.seen: raise RuntimeError("duplicate key: " + str(key)) self.seen.add(key) @@ -962,23 +965,27 @@ class KinetoStepTracker: We fix this by adding a layer of abstraction before calling step() to the kineto library. The idea is to maintain steps per requester in a dict: - ``` - { - "ProfilerStep": 100, # triggered by profiler step() call - "Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc - "Optimizer2Step": 100, - } - ``` + + .. code-block:: + + { + "ProfilerStep": 100, # triggered by profiler step() call + "Optimizer1Step": 100, # Optimizer 1 or 2 are just examples, could be SGD, Adam etc + "Optimizer2Step": 100, + } + To figure out the global step count just take the max of dict values (100). If one of the count increments the max will go up. - ``` - { - "ProfilerStep": 100, - "Optimizer1Step": 101, # Optimizer1 got incremented first say - "Optimizer2Step": 100, - } - ``` + + .. code-block:: + + { + "ProfilerStep": 100, + "Optimizer1Step": 101, # Optimizer1 got incremented first say + "Optimizer2Step": 100, + } + Then global step count is 101 We only call the kineto step() function when global count increments. @@ -991,10 +998,16 @@ class KinetoStepTracker: @classmethod def init_step_count(cls, requester: str): + r""" + Initialize for a given requester. + """ cls._step_dict[requester] = cls._current_step @classmethod def erase_step_count(cls, requester: str) -> bool: + r""" + Remove a given requester. + """ return cls._step_dict.pop(requester, None) is not None @classmethod @@ -1023,4 +1036,7 @@ class KinetoStepTracker: @classmethod def current_step(cls) -> int: + r""" + Get the latest step for any requester + """ return cls._current_step diff --git a/torch/autograd/profiler_util.py b/torch/autograd/profiler_util.py index 5e35b8604b86..331b5c77f659 100644 --- a/torch/autograd/profiler_util.py +++ b/torch/autograd/profiler_util.py @@ -423,6 +423,9 @@ class Interval: self.end = end def elapsed_us(self): + r""" + Returns the length of the interval + """ return self.end - self.start @@ -781,6 +784,9 @@ class MemRecordsAcc: self._start_uses, self._indices = zip(*tmp) # type: ignore[assignment] def in_interval(self, start_us, end_us): + r""" + Return all records in the given interval + """ start_idx = bisect.bisect_left(self._start_uses, start_us) end_idx = bisect.bisect_right(self._start_uses, end_us) for i in range(start_idx, end_idx):