mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Fix missing mandatory device_type argument in autocast docstring (#97223)
Fixes #[92803](https://github.com/pytorch/pytorch/issues/92803)    Pull Request resolved: https://github.com/pytorch/pytorch/pull/97223 Approved by: https://github.com/albanD, https://github.com/malfet
This commit is contained in:
committed by
PyTorch MergeBot
parent
6d2da6106d
commit
41866a2ead
@ -41,7 +41,7 @@ class autocast:
|
||||
optimizer.zero_grad()
|
||||
|
||||
# Enables autocasting for the forward pass (model + loss)
|
||||
with autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
output = model(input)
|
||||
loss = loss_fn(output, target)
|
||||
|
||||
@ -56,7 +56,7 @@ class autocast:
|
||||
|
||||
class AutocastModel(nn.Module):
|
||||
...
|
||||
@autocast()
|
||||
@torch.autocast(device_type="cuda")
|
||||
def forward(self, input):
|
||||
...
|
||||
|
||||
@ -74,7 +74,7 @@ class autocast:
|
||||
c_float32 = torch.rand((8, 8), device="cuda")
|
||||
d_float32 = torch.rand((8, 8), device="cuda")
|
||||
|
||||
with autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
# torch.mm is on autocast's list of ops that should run in float16.
|
||||
# Inputs are float32, but the op runs in float16 and produces float16 output.
|
||||
# No manual casts are required.
|
||||
@ -153,9 +153,9 @@ class autocast:
|
||||
c_float32 = torch.rand((8, 8), device="cuda")
|
||||
d_float32 = torch.rand((8, 8), device="cuda")
|
||||
|
||||
with autocast():
|
||||
with torch.autocast(device_type="cuda"):
|
||||
e_float16 = torch.mm(a_float32, b_float32)
|
||||
with autocast(enabled=False):
|
||||
with torch.autocast(device_type="cuda", enabled=False):
|
||||
# Calls e_float16.float() to ensure float32 execution
|
||||
# (necessary because e_float16 was created in an autocasted region)
|
||||
f_float32 = torch.mm(c_float32, e_float16.float())
|
||||
|
@ -79,6 +79,9 @@ Using `@autocast` is not currently supported in script mode (a diagnostic
|
||||
will be emitted)
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.cpu.amp import autocast
|
||||
|
||||
@autocast(enabled=True)
|
||||
def helper(x):
|
||||
...
|
||||
@ -91,6 +94,9 @@ def foo(x):
|
||||
Another example
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.cpu.amp import autocast
|
||||
|
||||
@torch.jit.script
|
||||
@autocast() # not supported
|
||||
def foo(a, b, c, d):
|
||||
@ -100,6 +106,9 @@ def foo(a, b, c, d):
|
||||
#### Autocast argument must be a compile-time constant
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.cpu.amp import autocast
|
||||
|
||||
@torch.jit.script
|
||||
def fn(a, b, use_amp: bool):
|
||||
# runtime values for autocast enable argument are not supported
|
||||
@ -111,6 +120,9 @@ def fn(a, b, use_amp: bool):
|
||||
#### Uncommon autocast usage patterns may not be supported
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.cpu.amp import autocast
|
||||
|
||||
@torch.jit.script
|
||||
def fn(a, b, c, d):
|
||||
with autocast(enabled=True) as autocast_instance: # not supported
|
||||
@ -140,6 +152,9 @@ stripped from the TorchScript IR so it's effectively ignored:
|
||||
> This is one known limitation where we don't have a way to emit a diagnostic!
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.cpu.amp import autocast
|
||||
|
||||
def helper(a, b):
|
||||
with autocast(enabled=False):
|
||||
return torch.mm(a, b) * 2.0
|
||||
@ -158,6 +173,9 @@ Calling a scripted function from a trace is similar to calling the scripted
|
||||
function from eager mode:
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.cpu.amp import autocast
|
||||
|
||||
@torch.jit.script
|
||||
def fn(a, b):
|
||||
return torch.mm(a, b)
|
||||
@ -176,6 +194,9 @@ If eager-mode autocast is enabled and we try to disable autocasting from
|
||||
within a scripted function, autocasting will still occur.
|
||||
|
||||
```python
|
||||
import torch
|
||||
from torch.cuda.amp import autocast
|
||||
|
||||
@torch.jit.script
|
||||
def fn(a, b):
|
||||
with autocast(enabled=False):
|
||||
|
Reference in New Issue
Block a user