Files
pytorch/torch/distributed/tensor/examples/convnext_example.py
Wanchao Liang cfc227ad43 [reland][dtensor] move DTensor to public namespace (#134203)
reland of https://github.com/pytorch/pytorch/pull/133113

I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :(

----

Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes:

* many path renames and path import fixes
* a dedicated doc page without too much content yet (adding in the next PRs)
* To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module

The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes

Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203
Approved by: https://github.com/tianyu-l
2024-09-08 17:08:40 +00:00

258 lines
7.6 KiB
Python

# mypy: allow-untyped-defs
"""
The following example demonstrates how to train a ConvNeXt model
with intermediate activations sharded across mutliple GPUs via DTensor
To run the example, use the following command:
torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py
"""
import os
import time
import torch
import torch.distributed as dist
import torch.nn as nn
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
init_device_mesh,
Replicate,
Shard,
)
WORLD_SIZE = 4
ITER_TIME = 20
class LayerNorm(nn.Module):
def __init__(self, normalized_shape, eps=1e-6, data_format=torch.contiguous_format):
super().__init__()
self.weight = nn.Parameter(torch.ones(normalized_shape))
self.bias = nn.Parameter(torch.zeros(normalized_shape))
self.eps = eps
self.data_format = data_format
if self.data_format not in [torch.contiguous_format]:
raise NotImplementedError
self.normalized_shape = (normalized_shape,)
def forward(self, x):
u = x.mean(1, keepdim=True)
s = (x - u).pow(2).mean(1, keepdim=True)
x = (x - u) / torch.sqrt(s + self.eps)
x = self.weight[:, None, None] * x + self.bias[:, None, None]
return x
class Block(nn.Module):
def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6):
super().__init__()
self.dwconv = nn.Conv2d(
dim, dim, kernel_size=7, padding=3, groups=dim
) # depthwise conv
self.norm = LayerNorm(dim, eps=1e-6, data_format=torch.contiguous_format)
self.pwconv1 = nn.Conv2d(
dim, 4 * dim, kernel_size=1, stride=1
) # nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
self.act = nn.GELU()
self.pwconv2 = nn.Conv2d(
4 * dim, dim, kernel_size=1, stride=1
) # nn.Linear(4 * dim, dim)
self.gamma = (
nn.Parameter(
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
)
if layer_scale_init_value > 0
else None
)
self.drop_path = nn.Identity()
def forward(self, x):
input_x = x
x = self.dwconv(x)
x = self.norm(x)
x = self.pwconv1(x)
x = self.act(x)
x = self.pwconv2(x)
if self.gamma is not None:
x = self.gamma * self.drop_path(x)
x = input_x + x
return x
class DownSampling(nn.Module):
def __init__(self, dim_in=3, dim_out=2, down_scale=4, norm_first=False):
super().__init__()
self.norm_first = norm_first
if norm_first:
self.norm = LayerNorm(dim_in, eps=1e-6, data_format=torch.contiguous_format)
self.conv = nn.Conv2d(
dim_in, dim_out, kernel_size=down_scale, stride=down_scale
)
else:
self.conv = nn.Conv2d(
dim_in, dim_out, kernel_size=down_scale, stride=down_scale
)
self.norm = LayerNorm(
dim_out, eps=1e-6, data_format=torch.contiguous_format
)
def forward(self, x):
if self.norm_first:
return self.conv(self.norm(x))
else:
return self.norm(self.conv(x))
@torch.no_grad()
def init_weights(m):
if type(m) == nn.Conv2d or type(m) == nn.Linear:
nn.init.ones_(m.weight)
if m.bias is not None:
nn.init.zeros_(m.bias)
class ConvNeXt(nn.Module):
def __init__(
self,
in_chans=3,
num_classes=10,
depths=[1, 1], # noqa: B006
dims=[2, 4], # noqa: B006
drop_path_rate=0.0,
layer_scale_init_value=1e-6,
head_init_scale=1.0,
):
super().__init__()
self.downsample_layers = nn.ModuleList()
stem = DownSampling(in_chans, dims[0], 4, norm_first=False)
self.downsample_layers.append(stem)
for i in range(len(dims) - 1):
downsample_layer = DownSampling(dims[i], dims[i + 1], 2, norm_first=True)
self.downsample_layers.append(downsample_layer)
self.stages = nn.ModuleList()
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
cur = 0
for i in range(len(dims)):
stage = nn.Sequential(
*[
Block(
dim=dims[i],
drop_path=dp_rates[cur + j],
layer_scale_init_value=layer_scale_init_value,
)
for j in range(depths[i])
]
)
self.stages.append(stage)
cur += depths[i]
self.head = nn.Linear(dims[-1], num_classes)
self.apply(init_weights)
def forward(self, x):
for i in range(len(self.stages)):
x = self.downsample_layers[i](x)
x = self.stages[i](x)
x = x.mean([-2, -1])
x = self.head(x)
return x
def _conv_fn(
name: str,
module: nn.Module,
device_mesh: DeviceMesh,
) -> None:
for name, param in module.named_parameters():
dist_spec = [Replicate()]
dist_param = torch.nn.Parameter(
distribute_tensor(param, device_mesh, dist_spec)
)
dist_param.register_hook(lambda grad: grad.redistribute(placements=dist_spec))
name = "_".join(name.split("."))
module.register_parameter(name, dist_param)
def train_convnext_example():
device_type = "cuda"
world_size = int(os.environ["WORLD_SIZE"])
mesh = init_device_mesh(device_type, (world_size,))
rank = mesh.get_rank()
in_shape = [7, 3, 512, 1024]
output_shape = [7, 1000]
torch.manual_seed(12)
model = ConvNeXt(
depths=[3, 3, 27, 3],
dims=[256, 512, 1024, 2048],
drop_path_rate=0.0,
num_classes=1000,
).to(device_type)
model = distribute_module(model, mesh, _conv_fn, input_fn=None, output_fn=None)
criterion = torch.nn.CrossEntropyLoss()
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, amsgrad=False)
x = torch.randn(*in_shape).to(device_type).requires_grad_()
y_target = (
torch.empty(output_shape[0], dtype=torch.long)
.random_(output_shape[1])
.to(device_type)
)
x = distribute_tensor(x, mesh, [Shard(3)])
y_target = distribute_tensor(y_target, mesh, [Replicate()])
# warm up
y = model(x)
loss = criterion(y, y_target)
optimizer.zero_grad()
loss.backward()
optimizer.step()
torch.cuda.synchronize()
forward_time = 0.0
backward_time = 0.0
start = time.time()
for i in range(ITER_TIME):
t1 = time.time()
y = model(x)
torch.cuda.synchronize()
t2 = time.time()
loss = criterion(y, y_target)
optimizer.zero_grad()
t3 = time.time()
loss.backward()
torch.cuda.synchronize()
t4 = time.time()
optimizer.step()
forward_time += t2 - t1
backward_time += t4 - t3
torch.cuda.synchronize()
end = time.time()
max_reserved = torch.cuda.max_memory_reserved()
max_allocated = torch.cuda.max_memory_allocated()
print(
f"rank {rank}, {ITER_TIME} iterations, average latency {(end - start)/ITER_TIME*1000:10.2f} ms"
)
print(
f"rank {rank}, forward {forward_time/ITER_TIME*1000:10.2f} ms, backward {backward_time/ITER_TIME*1000:10.2f} ms"
)
print(
f"rank {rank}, max reserved {max_reserved/1024/1024/1024:8.2f} GiB, max allocated {max_allocated/1024/1024/1024:8.2f} GiB"
)
dist.destroy_process_group()
train_convnext_example()