mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
Current implementation hardcodes 4D input and output tensor shapes Change that by computing `output_conv_shape` for any number of input dims Replace `[.., .., .., slice]` with `[..., slice]` Pull Request resolved: https://github.com/pytorch/pytorch/pull/165241 Approved by: https://github.com/ezyang
237 lines
8.8 KiB
Python
237 lines
8.8 KiB
Python
# Copyright (c) Meta Platforms, Inc. and affiliates
|
|
# Owner(s): ["oncall: distributed"]
|
|
|
|
import copy
|
|
|
|
import torch
|
|
import torch.nn as nn
|
|
from torch.distributed import DeviceMesh
|
|
from torch.distributed.tensor import (
|
|
distribute_module,
|
|
distribute_tensor,
|
|
DTensor,
|
|
Replicate,
|
|
Shard,
|
|
)
|
|
from torch.nn import functional as F
|
|
from torch.testing._internal.common_utils import run_tests
|
|
from torch.testing._internal.distributed._tensor.common_dtensor import (
|
|
DTensorTestBase,
|
|
skip_if_lt_x_gpu,
|
|
with_comms,
|
|
)
|
|
|
|
|
|
ITER_TIME = 10
|
|
LR = 0.001
|
|
|
|
|
|
def _conv_fn(
|
|
name: str,
|
|
module: nn.Module,
|
|
device_mesh: DeviceMesh,
|
|
) -> None:
|
|
for name, param in module.named_parameters():
|
|
dist_spec = [Replicate()]
|
|
dist_param = torch.nn.Parameter(
|
|
distribute_tensor(param, device_mesh, dist_spec)
|
|
)
|
|
name = "_".join(name.split("."))
|
|
module.register_parameter(name, dist_param)
|
|
|
|
|
|
class DistConvolutionOpsTest(DTensorTestBase):
|
|
@property
|
|
def world_size(self) -> int:
|
|
# hard code world size to 2
|
|
return 2
|
|
|
|
@with_comms
|
|
def test_downsampling_convolution(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(3)]
|
|
|
|
input_list = torch.rand(ITER_TIME, 7, 3, 512, 1024)
|
|
grad_output_list = torch.rand(ITER_TIME, 7, 256, 128, 256) * 1e-3
|
|
|
|
model = nn.Conv2d(3, 256, kernel_size=4, stride=4, padding=0).to(
|
|
self.device_type
|
|
)
|
|
nn.init.ones_(model.weight)
|
|
nn.init.zeros_(model.bias)
|
|
model_gt = copy.deepcopy(model).to(self.device_type)
|
|
|
|
# training with dtensor
|
|
model = distribute_module(
|
|
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
|
|
)
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
|
for i in range(ITER_TIME):
|
|
optimizer.zero_grad()
|
|
inp = input_list[i].to(self.device_type).requires_grad_()
|
|
inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
|
|
output = model(inp_dtensor)
|
|
grad_output = grad_output_list[i].to(self.device_type)
|
|
grad_output_dtensor = distribute_tensor(
|
|
grad_output, device_mesh, shard_spec
|
|
)
|
|
output.backward(grad_output_dtensor)
|
|
optimizer.step()
|
|
|
|
# training with plain tensor
|
|
optimizer_gt = torch.optim.SGD(model_gt.parameters(), lr=LR)
|
|
for i in range(ITER_TIME):
|
|
optimizer_gt.zero_grad()
|
|
inp = input_list[i].to(self.device_type).requires_grad_()
|
|
output = model_gt(inp)
|
|
grad_output = grad_output_list[i].to(self.device_type)
|
|
output.backward(grad_output)
|
|
optimizer_gt.step()
|
|
|
|
weight_diff_abs = model.weight.to_local() - model_gt.weight
|
|
bias_diff_abs = model.bias.to_local() - model_gt.bias
|
|
weight_diff_rel = weight_diff_abs / (torch.abs(model_gt.weight) + 1e-8)
|
|
bias_diff_rel = bias_diff_abs / (torch.abs(model_gt.bias) + 1e-8)
|
|
weight_mse_abs = torch.mean(weight_diff_abs * weight_diff_abs).item()
|
|
bias_mse_abs = torch.mean(bias_diff_abs * bias_diff_abs).item()
|
|
weight_mse_rel = torch.mean(weight_diff_rel * weight_diff_rel).item()
|
|
bias_mse_rel = torch.mean(bias_diff_rel * bias_diff_rel).item()
|
|
self.assertTrue(
|
|
weight_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for weight tensor, expected less equal 1e-6, got {weight_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
bias_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for bias tensor, expected less equal 1e-6, got {bias_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
weight_mse_rel <= 1e-6,
|
|
f"Too large relative mse for weight tensor, expected less equal 1e-6, got {weight_mse_rel}",
|
|
)
|
|
self.assertTrue(
|
|
bias_mse_rel <= 1e-6,
|
|
f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}",
|
|
)
|
|
|
|
# TODO: test_depthwise_convolution is broken in CI with gloo backend.
|
|
# Temporarily disable it to unblock CI.
|
|
@with_comms
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_depthwise_convolution(self):
|
|
device_mesh = self.build_device_mesh()
|
|
shard_spec = [Shard(3)]
|
|
|
|
input_list = torch.rand(ITER_TIME, 7, 256, 128, 256)
|
|
grad_output_list = torch.rand(ITER_TIME, 7, 256, 128, 256) * 1e-3
|
|
|
|
model = nn.Conv2d(256, 256, kernel_size=7, padding=3, groups=256).to(
|
|
self.device_type
|
|
)
|
|
nn.init.ones_(model.weight)
|
|
nn.init.zeros_(model.bias)
|
|
model_gt = copy.deepcopy(model).to(self.device_type)
|
|
|
|
# training with dtensor
|
|
model = distribute_module(
|
|
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
|
|
)
|
|
optimizer = torch.optim.SGD(model.parameters(), lr=LR)
|
|
for i in range(ITER_TIME):
|
|
optimizer.zero_grad()
|
|
inp = input_list[i].to(self.device_type).requires_grad_()
|
|
inp_dtensor = distribute_tensor(inp, device_mesh, shard_spec)
|
|
output = model(inp_dtensor)
|
|
grad_output = grad_output_list[i].to(self.device_type)
|
|
grad_output_dtensor = distribute_tensor(
|
|
grad_output, device_mesh, shard_spec
|
|
)
|
|
output.backward(grad_output_dtensor)
|
|
optimizer.step()
|
|
|
|
# training with plain tensor
|
|
optimizer_gt = torch.optim.SGD(model_gt.parameters(), lr=LR)
|
|
for i in range(ITER_TIME):
|
|
optimizer_gt.zero_grad()
|
|
inp = input_list[i].to(self.device_type).requires_grad_()
|
|
output = model_gt(inp)
|
|
grad_output = grad_output_list[i].to(self.device_type)
|
|
output.backward(grad_output)
|
|
optimizer_gt.step()
|
|
|
|
weight_diff_abs = model.weight.to_local() - model_gt.weight
|
|
bias_diff_abs = model.bias.to_local() - model_gt.bias
|
|
weight_diff_rel = weight_diff_abs / (torch.abs(model_gt.weight) + 1e-8)
|
|
bias_diff_rel = bias_diff_abs / (torch.abs(model_gt.bias) + 1e-8)
|
|
weight_mse_abs = torch.mean(weight_diff_abs * weight_diff_abs).item()
|
|
bias_mse_abs = torch.mean(bias_diff_abs * bias_diff_abs).item()
|
|
weight_mse_rel = torch.mean(weight_diff_rel * weight_diff_rel).item()
|
|
bias_mse_rel = torch.mean(bias_diff_rel * bias_diff_rel).item()
|
|
self.assertTrue(
|
|
weight_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for weight tensor, expected less equal 1e-6, got {weight_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
bias_mse_abs <= 1e-6,
|
|
f"Too large absolute mse for bias tensor, expected less equal 1e-6, got {bias_mse_abs}",
|
|
)
|
|
self.assertTrue(
|
|
weight_mse_rel <= 1e-6,
|
|
f"Too large relative mse for weight tensor, expected less equal 1e-6, got {weight_mse_rel}",
|
|
)
|
|
self.assertTrue(
|
|
bias_mse_rel <= 1e-6,
|
|
f"Too large relative mse for bias tensor, expected less equal 1e-6, got {bias_mse_rel}",
|
|
)
|
|
|
|
@with_comms
|
|
@skip_if_lt_x_gpu(2)
|
|
def test_conv_backward_none_grad_inp(self):
|
|
device_mesh = self.build_device_mesh()
|
|
conv = nn.Conv2d(64, 64, 3, padding=1).train()
|
|
x = torch.randn(1, 64, 32, 32)
|
|
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
|
|
w = conv.weight
|
|
w_dt = torch.nn.Parameter(DTensor.from_local(w, device_mesh, [Replicate()]))
|
|
|
|
b = conv.bias
|
|
b_dt = torch.nn.Parameter(DTensor.from_local(b, device_mesh, [Replicate()]))
|
|
|
|
res = F.conv2d(x_dt, w_dt, b_dt, padding=1)
|
|
dres = torch.rand_like(res)
|
|
res.backward(dres)
|
|
self.assertTrue(w_dt.grad is not None)
|
|
self.assertTrue(b_dt.grad is not None)
|
|
self.assertTrue(x_dt.grad is None)
|
|
|
|
@with_comms
|
|
def test_conv1d(self):
|
|
device_mesh = self.build_device_mesh()
|
|
model = nn.Conv1d(64, 64, 3, padding=1)
|
|
model_gt = copy.deepcopy(model)
|
|
x = torch.randn(1, 64, 8)
|
|
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
|
|
model_dt = distribute_module(
|
|
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
|
|
)
|
|
out_dt = model_dt(x_dt)
|
|
out = model_gt(x)
|
|
self.assertEqual(out_dt.shape, out.shape)
|
|
|
|
@with_comms
|
|
def test_conv3d(self):
|
|
device_mesh = self.build_device_mesh()
|
|
model = nn.Conv3d(64, 64, 3, padding=1)
|
|
model_gt = copy.deepcopy(model).to(device=self.device_type)
|
|
x = torch.randn(1, 64, 8, 8, 8, device=self.device_type)
|
|
x_dt = DTensor.from_local(x, device_mesh, [Replicate()])
|
|
model_dt = distribute_module(
|
|
model, device_mesh, _conv_fn, input_fn=None, output_fn=None
|
|
)
|
|
out_dt = model_dt(x_dt)
|
|
out = model_gt(x)
|
|
self.assertEqual(out_dt.shape, out.shape)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
run_tests()
|