Ashok Kumar Kannan
2023-06-27 01:54:54 +00:00
committed by PyTorch MergeBot
parent 6d2da6106d
commit 41866a2ead
2 changed files with 26 additions and 5 deletions

View File

@ -41,7 +41,7 @@ class autocast:
optimizer.zero_grad()
# Enables autocasting for the forward pass (model + loss)
with autocast():
with torch.autocast(device_type="cuda"):
output = model(input)
loss = loss_fn(output, target)
@ -56,7 +56,7 @@ class autocast:
class AutocastModel(nn.Module):
...
@autocast()
@torch.autocast(device_type="cuda")
def forward(self, input):
...
@ -74,7 +74,7 @@ class autocast:
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with autocast():
with torch.autocast(device_type="cuda"):
# torch.mm is on autocast's list of ops that should run in float16.
# Inputs are float32, but the op runs in float16 and produces float16 output.
# No manual casts are required.
@ -153,9 +153,9 @@ class autocast:
c_float32 = torch.rand((8, 8), device="cuda")
d_float32 = torch.rand((8, 8), device="cuda")
with autocast():
with torch.autocast(device_type="cuda"):
e_float16 = torch.mm(a_float32, b_float32)
with autocast(enabled=False):
with torch.autocast(device_type="cuda", enabled=False):
# Calls e_float16.float() to ensure float32 execution
# (necessary because e_float16 was created in an autocasted region)
f_float32 = torch.mm(c_float32, e_float16.float())

View File

@ -79,6 +79,9 @@ Using `@autocast` is not currently supported in script mode (a diagnostic
will be emitted)
```python
import torch
from torch.cpu.amp import autocast
@autocast(enabled=True)
def helper(x):
...
@ -91,6 +94,9 @@ def foo(x):
Another example
```python
import torch
from torch.cpu.amp import autocast
@torch.jit.script
@autocast() # not supported
def foo(a, b, c, d):
@ -100,6 +106,9 @@ def foo(a, b, c, d):
#### Autocast argument must be a compile-time constant
```python
import torch
from torch.cpu.amp import autocast
@torch.jit.script
def fn(a, b, use_amp: bool):
# runtime values for autocast enable argument are not supported
@ -111,6 +120,9 @@ def fn(a, b, use_amp: bool):
#### Uncommon autocast usage patterns may not be supported
```python
import torch
from torch.cpu.amp import autocast
@torch.jit.script
def fn(a, b, c, d):
with autocast(enabled=True) as autocast_instance: # not supported
@ -140,6 +152,9 @@ stripped from the TorchScript IR so it's effectively ignored:
> This is one known limitation where we don't have a way to emit a diagnostic!
```python
import torch
from torch.cpu.amp import autocast
def helper(a, b):
with autocast(enabled=False):
return torch.mm(a, b) * 2.0
@ -158,6 +173,9 @@ Calling a scripted function from a trace is similar to calling the scripted
function from eager mode:
```python
import torch
from torch.cpu.amp import autocast
@torch.jit.script
def fn(a, b):
return torch.mm(a, b)
@ -176,6 +194,9 @@ If eager-mode autocast is enabled and we try to disable autocasting from
within a scripted function, autocasting will still occur.
```python
import torch
from torch.cuda.amp import autocast
@torch.jit.script
def fn(a, b):
with autocast(enabled=False):