Files
pytorch/test/distributed/tensor/test_xla_integration.py
Yuanyuan Chen a8c528c105 [1/N] Apply UP035 rule in tests (#163947)
Apply UP035 `ruff` rule in tests, but some tests for `fx` and `dynamo` are excluded in case the old typing is the test target.

Pull Request resolved: https://github.com/pytorch/pytorch/pull/163947
Approved by: https://github.com/ezyang
2025-09-29 01:42:01 +00:00

171 lines
6.5 KiB
Python

# Copyright (c) Meta Platforms, Inc. and affiliates
# Owner(s): ["oncall: distributed"]
import os
import unittest
from collections.abc import Callable
from functools import wraps
from typing import Any
import numpy as np
import torch
from torch import nn
from torch.distributed.tensor import (
DeviceMesh,
distribute_module,
distribute_tensor,
Replicate,
Shard,
)
from torch.testing._internal.common_utils import run_tests, TestCase
# wrapper to check xla test requirements
def with_xla(func: Callable) -> Callable:
assert func is not None
@wraps(func) # pyre-ignore[6]
def wrapper(
self,
*args: tuple[object],
**kwargs: dict[str, Any], # type: ignore[misc]
) -> None:
# TODO(yeounoh) replace this with xr.use_spmd() when we deprecate the flag.
os.environ["XLA_USE_SPMD"] = "1"
try:
import torch_xla # type:ignore[import] # noqa: F401
except ImportError as exc:
raise unittest.SkipTest("torch_xla is not installed.") from exc
self.device_type = "xla"
func(self, *args, **kwargs) # type: ignore[misc]
os.environ["XLA_USE_SPMD"] = "0"
return wrapper
class DTensorXLAIntegrationTest(TestCase):
class SimpleLinear(nn.Module):
def __init__(self) -> None:
super(DTensorXLAIntegrationTest.SimpleLinear, self).__init__()
self.fc1 = nn.Linear(128, 64)
self.relu = nn.ReLU()
self.fc2 = nn.Linear(64, 1)
def forward(self, x):
y = self.relu(self.fc1(x))
z = self.fc2(y)
return z
@with_xla
def test_xla_distribute_tensor_1d_shard(self):
import torch_xla.runtime as xr # type:ignore[import]
device_count = xr.global_runtime_device_count()
if device_count > 1:
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Shard(0)]
for requires_grad in [True, False]:
tensor_to_shard = torch.randn(
3 * device_count, 3, requires_grad=requires_grad
)
dist_tensor = distribute_tensor(
tensor_to_shard, device_mesh, shard_spec
)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
self.assertEqual(
global_tensor.size(), torch.Size([3 * device_count, 3])
)
local_tensor = dist_tensor.local_shards[0].data
self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
if requires_grad:
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)
@with_xla
def test_xla_distribute_tensor_1d_replicate(self):
import torch_xla.runtime as xr # type:ignore[import]
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
shard_spec = [Replicate()]
for requires_grad in [True, False]:
tensor_to_shard = torch.randn(
3 * device_count, 3, requires_grad=requires_grad
)
dist_tensor = distribute_tensor(tensor_to_shard, device_mesh, shard_spec)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
self.assertEqual(global_tensor.size(), torch.Size([3 * device_count, 3]))
local_tensor = dist_tensor.local_shards[0].data
self.assertEqual(local_tensor.size(), torch.Size([3 * device_count, 3]))
if requires_grad:
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)
@with_xla
def test_xla_distribute_tensor_2d(self):
import torch_xla.runtime as xr # type:ignore[import]
device_count = xr.global_runtime_device_count()
if device_count > 1:
device_mesh = DeviceMesh(
"xla", np.array(range(device_count)).reshape(2, device_count // 2)
)
shard_spec = [Replicate(), Shard(0)]
for requires_grad in [True, False]:
tensor_to_shard = torch.randn(
3 * device_count // 2, 3, requires_grad=requires_grad
)
dist_tensor = distribute_tensor(
tensor_to_shard, device_mesh, shard_spec
)
# TODO(yeounoh) switch to DTensor API when XLAShardedTensor inherits DTensor
assert type(dist_tensor).__name__ == "XLAShardedTensor"
global_tensor = dist_tensor.global_tensor # type:ignore[attr-defined]
self.assertEqual(
global_tensor.size(), torch.Size([3 * device_count // 2, 3])
)
local_tensor = dist_tensor.local_shards[0].data
self.assertEqual(local_tensor.size(), torch.Size([3, 3]))
if requires_grad:
self.assertTrue(dist_tensor.global_tensor.requires_grad)
self.assertTrue(dist_tensor.is_leaf)
@with_xla
def text_xla_distribute_module(self):
import torch_xla # type:ignore[import]
import torch_xla.core.xla_model as xm # type:ignore[import]
import torch_xla.runtime as xr # type:ignore[import]
model = self.SimpleLinear().to(xm.xla_device())
device_count = xr.global_runtime_device_count()
device_mesh = DeviceMesh("xla", list(range(device_count)))
def shard_params(mod_name, mod, mesh):
shard_spec = [Shard(0)]
# annotate fc1 and fc2
if isinstance(mod, nn.Linear):
for _, param in mod.named_parameters():
# annotate the parameter tensors directly
distribute_tensor(param, mesh, shard_spec)
sharded_model = distribute_module(model, device_mesh, shard_params)
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc1.weight) != ""
)
self.assertTrue(
torch_xla._XLAC._get_xla_sharding_spec(sharded_model.fc2.weight) != ""
)
if __name__ == "__main__":
run_tests()