Files
pytorch/torch/distributed/tensor/examples/convnext_example.py
Wanchao Liang 2ee6b97464 [dtensor] move DTensor to public namespace (#133113)
Moving DTensor to be in the public namespace, to formally add the
documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next
  PRs)
* To preserve the BC for users still using the `torch.distributed._tensor`,
  I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still
working without changing the public imports. So it's safe to land the
changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/133113
Approved by: https://github.com/XilunWu
ghstack dependencies: #133305, #133306
2024-08-17 05:09:52 +00:00

258 lines
7.6 KiB
Python

# mypy: allow-untyped-defs
"""
The following example demonstrates how to train a ConvNeXt model
with intermediate activations sharded across mutliple GPUs via DTensor
To run the example, use the following command:
torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py
"""
import os
import time
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
init_device_mesh,
Replicate,
Shard,
)
WORLD_SIZE = 4
ITER_TIME = 20
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format=torch.contiguous_format):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in [torch.contiguous_format]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class Block(nn.Module):
def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6, data_format=torch.contiguous_format)
self.pwconv1 = nn.Conv2d(
dim, 4 * dim, kernel_size=1, stride=1
) # nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Conv2d(
4 * dim, dim, kernel_size=1, stride=1
) # nn.Linear(4 * dim, dim)
self.gamma = (
nn.Parameter(
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
)
if layer_scale_init_value > 0
else None
)
self.drop_path = nn.Identity()
def forward(self, x):
input_x = x
x = self.dwconv(x)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * self.drop_path(x)
x = input_x + x
return x
class DownSampling(nn.Module):
def __init__(self, dim_in=3, dim_out=2, down_scale=4, norm_first=False):
super().__init__()
self.norm_first = norm_first
if norm_first:
self.norm = LayerNorm(dim_in, eps=1e-6, data_format=torch.contiguous_format)
self.conv = nn.Conv2d(
dim_in, dim_out, kernel_size=down_scale, stride=down_scale
)
else:
self.conv = nn.Conv2d(
dim_in, dim_out, kernel_size=down_scale, stride=down_scale
)
self.norm = LayerNorm(
dim_out, eps=1e-6, data_format=torch.contiguous_format
)
def forward(self, x):
if self.norm_first:
return self.conv(self.norm(x))
else:
return self.norm(self.conv(x))
@torch.no_grad()
def init_weights(m):
if type(m) == nn.Conv2d or type(m) == nn.Linear:
nn.init.ones_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
class ConvNeXt(nn.Module):
def __init__(
self,
in_chans=3,
num_classes=10,
depths=[1, 1], # noqa: B006
dims=[2, 4], # noqa: B006
drop_path_rate=0.0,
layer_scale_init_value=1e-6,
head_init_scale=1.0,
):
super().__init__()
self.downsample_layers = nn.ModuleList()
stem = DownSampling(in_chans, dims[0], 4, norm_first=False)
self.downsample_layers.append(stem)
for i in range(len(dims) - 1):
downsample_layer = DownSampling(dims[i], dims[i + 1], 2, norm_first=True)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList()
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(len(dims)):
stage = nn.Sequential(
*[
Block(
dim=dims[i],
drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
)
for j in range(depths[i])
]
)
self.stages.append(stage)
cur += depths[i]
self.head = nn.Linear(dims[-1], num_classes)
self.apply(init_weights)
def forward(self, x):
for i in range(len(self.stages)):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
x = x.mean([-2, -1])
x = self.head(x)
return x
def _conv_fn(
name: str,
module: nn.Module,
device_mesh: DeviceMesh,
) -> None:
for name, param in module.named_parameters():
dist_spec = [Replicate()]
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, dist_spec)
)
dist_param.register_hook(lambda grad: grad.redistribute(placements=dist_spec))
name = "_".join(name.split("."))
module.register_parameter(name, dist_param)
def train_convnext_example():
device_type = "cuda"
world_size = int(os.environ["WORLD_SIZE"])
mesh = init_device_mesh(device_type, (world_size,))
rank = mesh.get_rank()
in_shape = [7, 3, 512, 1024]
output_shape = [7, 1000]
torch.manual_seed(12)
model = ConvNeXt(
depths=[3, 3, 27, 3],
dims=[256, 512, 1024, 2048],
drop_path_rate=0.0,
num_classes=1000,
).to(device_type)
model = distribute_module(model, mesh, _conv_fn, input_fn=None, output_fn=None)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, amsgrad=False)
x = torch.randn(*in_shape).to(device_type).requires_grad_()
y_target = (
torch.empty(output_shape[0], dtype=torch.long)
.random_(output_shape[1])
.to(device_type)
)
x = distribute_tensor(x, mesh, [Shard(3)])
y_target = distribute_tensor(y_target, mesh, [Replicate()])
# warm up
y = model(x)
loss = criterion(y, y_target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.cuda.synchronize()
forward_time = 0.0
backward_time = 0.0
start = time.time()
for i in range(ITER_TIME):
t1 = time.time()
y = model(x)
torch.cuda.synchronize()
t2 = time.time()
loss = criterion(y, y_target)
optimizer.zero_grad()
t3 = time.time()
loss.backward()
torch.cuda.synchronize()
t4 = time.time()
optimizer.step()
forward_time += t2 - t1
backward_time += t4 - t3
torch.cuda.synchronize()
end = time.time()
max_reserved = torch.cuda.max_memory_reserved()
max_allocated = torch.cuda.max_memory_allocated()
print(
f"rank {rank}, {ITER_TIME} iterations, average latency {(end - start)/ITER_TIME*1000:10.2f} ms"
)
print(
f"rank {rank}, forward {forward_time/ITER_TIME*1000:10.2f} ms, backward {backward_time/ITER_TIME*1000:10.2f} ms"
)
print(
f"rank {rank}, max reserved {max_reserved/1024/1024/1024:8.2f} GiB, max allocated {max_allocated/1024/1024/1024:8.2f} GiB"
)
dist.destroy_process_group()
train_convnext_example()