mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
reland of https://github.com/pytorch/pytorch/pull/133113 I have to create a new PR because the previous reverted PR could not either be rebased, or imported successfully :( ---- Moving DTensor to be in the public namespace, to formally add the documentation page that includes all the public APIs. This includes: * many path renames and path import fixes * a dedicated doc page without too much content yet (adding in the next PRs) * To preserve the BC for users still using the torch.distributed._tensor, I added a shim script to redirect old path calls to the new module The BC preserving is evidented by the fact that all DTensor tests are still working without changing the public imports. So it's safe to land the changes Pull Request resolved: https://github.com/pytorch/pytorch/pull/134203 Approved by: https://github.com/tianyu-l
258 lines
7.6 KiB
Python
258 lines
7.6 KiB
Python
# mypy: allow-untyped-defs
|
|
"""
|
|
The following example demonstrates how to train a ConvNeXt model
|
|
with intermediate activations sharded across mutliple GPUs via DTensor
|
|
|
|
To run the example, use the following command:
|
|
torchrun --standalone --nnodes=1 --nproc-per-node=4 convnext_example.py
|
|
"""
|
|
import os
|
|
import time
|
|
|
|
import torch
|
|
import torch.distributed as dist
|
|
import torch.nn as nn
|
|
from torch.distributed.tensor import (
|
|
DeviceMesh,
|
|
distribute_module,
|
|
distribute_tensor,
|
|
init_device_mesh,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
|
|
|
|
WORLD_SIZE = 4
|
|
ITER_TIME = 20
|
|
|
|
|
|
class LayerNorm(nn.Module):
|
|
def __init__(self, normalized_shape, eps=1e-6, data_format=torch.contiguous_format):
|
|
super().__init__()
|
|
self.weight = nn.Parameter(torch.ones(normalized_shape))
|
|
self.bias = nn.Parameter(torch.zeros(normalized_shape))
|
|
self.eps = eps
|
|
self.data_format = data_format
|
|
if self.data_format not in [torch.contiguous_format]:
|
|
raise NotImplementedError
|
|
self.normalized_shape = (normalized_shape,)
|
|
|
|
def forward(self, x):
|
|
u = x.mean(1, keepdim=True)
|
|
s = (x - u).pow(2).mean(1, keepdim=True)
|
|
x = (x - u) / torch.sqrt(s + self.eps)
|
|
x = self.weight[:, None, None] * x + self.bias[:, None, None]
|
|
return x
|
|
|
|
|
|
class Block(nn.Module):
|
|
def __init__(self, dim, drop_path=0.0, layer_scale_init_value=1e-6):
|
|
super().__init__()
|
|
self.dwconv = nn.Conv2d(
|
|
dim, dim, kernel_size=7, padding=3, groups=dim
|
|
) # depthwise conv
|
|
self.norm = LayerNorm(dim, eps=1e-6, data_format=torch.contiguous_format)
|
|
self.pwconv1 = nn.Conv2d(
|
|
dim, 4 * dim, kernel_size=1, stride=1
|
|
) # nn.Linear(dim, 4 * dim) # pointwise/1x1 convs, implemented with linear layers
|
|
self.act = nn.GELU()
|
|
self.pwconv2 = nn.Conv2d(
|
|
4 * dim, dim, kernel_size=1, stride=1
|
|
) # nn.Linear(4 * dim, dim)
|
|
self.gamma = (
|
|
nn.Parameter(
|
|
layer_scale_init_value * torch.ones((dim, 1, 1)), requires_grad=True
|
|
)
|
|
if layer_scale_init_value > 0
|
|
else None
|
|
)
|
|
self.drop_path = nn.Identity()
|
|
|
|
def forward(self, x):
|
|
input_x = x
|
|
x = self.dwconv(x)
|
|
|
|
x = self.norm(x)
|
|
x = self.pwconv1(x)
|
|
x = self.act(x)
|
|
x = self.pwconv2(x)
|
|
|
|
if self.gamma is not None:
|
|
x = self.gamma * self.drop_path(x)
|
|
x = input_x + x
|
|
return x
|
|
|
|
|
|
class DownSampling(nn.Module):
|
|
def __init__(self, dim_in=3, dim_out=2, down_scale=4, norm_first=False):
|
|
super().__init__()
|
|
self.norm_first = norm_first
|
|
if norm_first:
|
|
self.norm = LayerNorm(dim_in, eps=1e-6, data_format=torch.contiguous_format)
|
|
self.conv = nn.Conv2d(
|
|
dim_in, dim_out, kernel_size=down_scale, stride=down_scale
|
|
)
|
|
else:
|
|
self.conv = nn.Conv2d(
|
|
dim_in, dim_out, kernel_size=down_scale, stride=down_scale
|
|
)
|
|
self.norm = LayerNorm(
|
|
dim_out, eps=1e-6, data_format=torch.contiguous_format
|
|
)
|
|
|
|
def forward(self, x):
|
|
if self.norm_first:
|
|
return self.conv(self.norm(x))
|
|
else:
|
|
return self.norm(self.conv(x))
|
|
|
|
|
|
@torch.no_grad()
|
|
def init_weights(m):
|
|
if type(m) == nn.Conv2d or type(m) == nn.Linear:
|
|
nn.init.ones_(m.weight)
|
|
if m.bias is not None:
|
|
nn.init.zeros_(m.bias)
|
|
|
|
|
|
class ConvNeXt(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_chans=3,
|
|
num_classes=10,
|
|
depths=[1, 1], # noqa: B006
|
|
dims=[2, 4], # noqa: B006
|
|
drop_path_rate=0.0,
|
|
layer_scale_init_value=1e-6,
|
|
head_init_scale=1.0,
|
|
):
|
|
super().__init__()
|
|
|
|
self.downsample_layers = nn.ModuleList()
|
|
stem = DownSampling(in_chans, dims[0], 4, norm_first=False)
|
|
self.downsample_layers.append(stem)
|
|
for i in range(len(dims) - 1):
|
|
downsample_layer = DownSampling(dims[i], dims[i + 1], 2, norm_first=True)
|
|
self.downsample_layers.append(downsample_layer)
|
|
|
|
self.stages = nn.ModuleList()
|
|
dp_rates = [x.item() for x in torch.linspace(0, drop_path_rate, sum(depths))]
|
|
cur = 0
|
|
for i in range(len(dims)):
|
|
stage = nn.Sequential(
|
|
*[
|
|
Block(
|
|
dim=dims[i],
|
|
drop_path=dp_rates[cur + j],
|
|
layer_scale_init_value=layer_scale_init_value,
|
|
)
|
|
for j in range(depths[i])
|
|
]
|
|
)
|
|
self.stages.append(stage)
|
|
cur += depths[i]
|
|
|
|
self.head = nn.Linear(dims[-1], num_classes)
|
|
self.apply(init_weights)
|
|
|
|
def forward(self, x):
|
|
for i in range(len(self.stages)):
|
|
x = self.downsample_layers[i](x)
|
|
x = self.stages[i](x)
|
|
x = x.mean([-2, -1])
|
|
x = self.head(x)
|
|
return x
|
|
|
|
|
|
def _conv_fn(
|
|
name: str,
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
) -> None:
|
|
for name, param in module.named_parameters():
|
|
dist_spec = [Replicate()]
|
|
dist_param = torch.nn.Parameter(
|
|
distribute_tensor(param, device_mesh, dist_spec)
|
|
)
|
|
dist_param.register_hook(lambda grad: grad.redistribute(placements=dist_spec))
|
|
name = "_".join(name.split("."))
|
|
module.register_parameter(name, dist_param)
|
|
|
|
|
|
def train_convnext_example():
|
|
device_type = "cuda"
|
|
world_size = int(os.environ["WORLD_SIZE"])
|
|
mesh = init_device_mesh(device_type, (world_size,))
|
|
rank = mesh.get_rank()
|
|
|
|
in_shape = [7, 3, 512, 1024]
|
|
output_shape = [7, 1000]
|
|
|
|
torch.manual_seed(12)
|
|
model = ConvNeXt(
|
|
depths=[3, 3, 27, 3],
|
|
dims=[256, 512, 1024, 2048],
|
|
drop_path_rate=0.0,
|
|
num_classes=1000,
|
|
).to(device_type)
|
|
model = distribute_module(model, mesh, _conv_fn, input_fn=None, output_fn=None)
|
|
|
|
criterion = torch.nn.CrossEntropyLoss()
|
|
optimizer = torch.optim.Adam(model.parameters(), lr=1e-4, amsgrad=False)
|
|
|
|
x = torch.randn(*in_shape).to(device_type).requires_grad_()
|
|
y_target = (
|
|
torch.empty(output_shape[0], dtype=torch.long)
|
|
.random_(output_shape[1])
|
|
.to(device_type)
|
|
)
|
|
x = distribute_tensor(x, mesh, [Shard(3)])
|
|
y_target = distribute_tensor(y_target, mesh, [Replicate()])
|
|
|
|
# warm up
|
|
y = model(x)
|
|
loss = criterion(y, y_target)
|
|
optimizer.zero_grad()
|
|
loss.backward()
|
|
optimizer.step()
|
|
torch.cuda.synchronize()
|
|
|
|
forward_time = 0.0
|
|
backward_time = 0.0
|
|
start = time.time()
|
|
for i in range(ITER_TIME):
|
|
t1 = time.time()
|
|
y = model(x)
|
|
torch.cuda.synchronize()
|
|
t2 = time.time()
|
|
|
|
loss = criterion(y, y_target)
|
|
optimizer.zero_grad()
|
|
|
|
t3 = time.time()
|
|
loss.backward()
|
|
torch.cuda.synchronize()
|
|
t4 = time.time()
|
|
|
|
optimizer.step()
|
|
|
|
forward_time += t2 - t1
|
|
backward_time += t4 - t3
|
|
torch.cuda.synchronize()
|
|
end = time.time()
|
|
max_reserved = torch.cuda.max_memory_reserved()
|
|
max_allocated = torch.cuda.max_memory_allocated()
|
|
print(
|
|
f"rank {rank}, {ITER_TIME} iterations, average latency {(end - start)/ITER_TIME*1000:10.2f} ms"
|
|
)
|
|
print(
|
|
f"rank {rank}, forward {forward_time/ITER_TIME*1000:10.2f} ms, backward {backward_time/ITER_TIME*1000:10.2f} ms"
|
|
)
|
|
print(
|
|
f"rank {rank}, max reserved {max_reserved/1024/1024/1024:8.2f} GiB, max allocated {max_allocated/1024/1024/1024:8.2f} GiB"
|
|
)
|
|
dist.destroy_process_group()
|
|
|
|
|
|
train_convnext_example()
|