mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
remove bucket_size_limit property from bucket struct
Pull Request resolved: https://github.com/pytorch/pytorch/pull/73731 during rebuilting bucket, in addition to sync bucket_indice, per_bucket_limits should be synced as well before calling initialize_buckets(). Syncing per_bucket_limits will increase communicaton volume as well increasing code complexity, after taking a further look at the codes, per_bucket_limits used inside initialize_buckets() is actually not useful, it assigns bucket_size_limit property to bucket struct, but the property is not used anywhere. So it is good to remove this property and avoid syncing per_bucket_limits. Differential Revision: [D34605513](https://our.internmc.facebook.com/intern/diff/D34605513/) Approved by: https://github.com/rohan-varma
This commit is contained in:
committed by
PyTorch MergeBot
parent
c1ced8ff72
commit
1c35f37c9f
@ -138,14 +138,6 @@ std::vector<int64_t> Logger::get_bucket_sizes() {
|
||||
return bucket_sizes;
|
||||
}
|
||||
|
||||
std::vector<int> Logger::get_bucket_size_limits() {
|
||||
std::vector<int> bucket_size_limits;
|
||||
for (const auto& bucket : reducer_->buckets_) {
|
||||
bucket_size_limits.push_back(bucket.bucket_size_limit);
|
||||
}
|
||||
return bucket_size_limits;
|
||||
}
|
||||
|
||||
// Communication hook. Empty string if not set, in which case it will not be
|
||||
// logged.
|
||||
void Logger::set_comm_hook(const std::string& hook) {
|
||||
@ -190,9 +182,6 @@ void Logger::set_construction_data_and_log(
|
||||
// A list of bucket sizes (Bytes) calculated during construction time
|
||||
ddp_logging_data_->strs_map["bucket_sizes"] =
|
||||
c10::Join(", ", get_bucket_sizes());
|
||||
// A list of bucket size limits (bytes) specified during construction time
|
||||
ddp_logging_data_->strs_map["initial_bucket_size_limits"] =
|
||||
c10::Join(", ", get_bucket_size_limits());
|
||||
set_env_variables();
|
||||
|
||||
// DistributedDataParallel constructor input parameters
|
||||
@ -299,8 +288,6 @@ void Logger::set_runtime_stats_and_log() {
|
||||
reducer_->has_rebuilt_bucket_;
|
||||
ddp_logging_data_->strs_map["rebuilt_bucket_sizes"] =
|
||||
c10::Join(", ", get_bucket_sizes());
|
||||
ddp_logging_data_->strs_map["rebuilt_bucket_size_limits"] =
|
||||
c10::Join(", ", get_bucket_size_limits());
|
||||
// Log per-bucket variable indices
|
||||
std::vector<std::string> per_bucket_variable_indices;
|
||||
auto indices = get_per_bucket_variable_indices();
|
||||
|
@ -42,8 +42,6 @@ class TORCH_API Logger {
|
||||
void set_parameter_stats();
|
||||
// Get size of each bucket (Bytes).
|
||||
std::vector<int64_t> get_bucket_sizes();
|
||||
// Get bucket size limits specified during DDP construction.
|
||||
std::vector<int> get_bucket_size_limits();
|
||||
// Get variable indices for each bucket.
|
||||
std::vector<std::vector<size_t>> get_per_bucket_variable_indices();
|
||||
// Set comm. hook, if used
|
||||
|
@ -143,8 +143,7 @@ Reducer::Reducer(
|
||||
// This can be reinitialized later after capturing runtime information.
|
||||
{
|
||||
std::lock_guard<std::mutex> lock(mutex_);
|
||||
initialize_buckets(
|
||||
std::move(bucket_indices), std::move(per_bucket_size_limits));
|
||||
initialize_buckets(std::move(bucket_indices));
|
||||
}
|
||||
|
||||
// All variables are expected to have their `grad_fn` set to the gradient
|
||||
@ -928,9 +927,7 @@ void Reducer::install_futures(c10::List<c10::intrusive_ptr<c10::ivalue::Future>>
|
||||
}
|
||||
}
|
||||
|
||||
void Reducer::initialize_buckets(
|
||||
std::vector<std::vector<size_t>> bucket_indices,
|
||||
std::vector<size_t> per_bucket_sizes) {
|
||||
void Reducer::initialize_buckets(std::vector<std::vector<size_t>> bucket_indices) {
|
||||
// If initialize_buckets is called inside DDP constructor, then
|
||||
// it does not matter rpc context ptr is nullptr or not, as grad
|
||||
// will not be mutated.
|
||||
@ -960,10 +957,8 @@ void Reducer::initialize_buckets(
|
||||
// Iterate over buckets.
|
||||
const auto bucket_count = bucket_indices.size();
|
||||
buckets_.reserve(bucket_count);
|
||||
TORCH_INTERNAL_ASSERT(bucket_count == per_bucket_sizes.size());
|
||||
for (const auto bucket_index : c10::irange(bucket_count)) {
|
||||
Bucket bucket;
|
||||
bucket.bucket_size_limit = per_bucket_sizes[bucket_index];
|
||||
|
||||
// TODO(@pietern): Validate indices.
|
||||
// Must be non-empty, unique, and unique across buckets.
|
||||
@ -1697,8 +1692,7 @@ bool Reducer::rebuild_buckets() {
|
||||
rebuilt_params_.clear();
|
||||
rebuilt_param_indices_.clear();
|
||||
|
||||
initialize_buckets(
|
||||
std::move(rebuilt_bucket_indices), std::move(per_bucket_size_limits));
|
||||
initialize_buckets(std::move(rebuilt_bucket_indices));
|
||||
|
||||
return true;
|
||||
}
|
||||
|
@ -64,10 +64,8 @@ class TORCH_API Reducer {
|
||||
// To (re-)initialize bucket assignment, pass a list of buckets, each of
|
||||
// which is specified by a list of indices in the bucket's `variables` list.
|
||||
// This function performs validation that the variables within a bucket
|
||||
// have the same dtype and device.
|
||||
void initialize_buckets(
|
||||
std::vector<std::vector<size_t>> bucket_indices,
|
||||
std::vector<size_t> per_bucket_sizes);
|
||||
// all live on the same device and have the same dimensionality.
|
||||
void initialize_buckets(std::vector<std::vector<size_t>> bucket_indices);
|
||||
|
||||
// This function is called when the forward function has produced an output,
|
||||
// and the user wishes to reduce gradients in the backwards pass.
|
||||
@ -357,11 +355,6 @@ class TORCH_API Reducer {
|
||||
// If `true`, then this implies that `bucket.variables.size() == 1`.
|
||||
bool expect_sparse_gradient = false;
|
||||
|
||||
// "Limit" of the cumulative gradient sizes that this bucket manages
|
||||
// It is actually a soft limit because we do not shard parameter across
|
||||
// buckets, so a single parameter may push the bucket size over the limit.
|
||||
size_t bucket_size_limit;
|
||||
|
||||
// TODO(@pietern)
|
||||
// Memory copies from gradient tensors into the bucket are potentially
|
||||
// done on different CUDA streams. We record an event for every copy
|
||||
|
@ -5408,16 +5408,6 @@ class DistributedTest:
|
||||
# type if it didn't exist.
|
||||
self.assertEqual(ddp_logging_data.get("unused_parameter_size", 0), 0)
|
||||
self.assertEqual(ddp_logging_data.get("has_rebuilt_buckets"), 1)
|
||||
init_bucket_lims = ddp_logging_data.get("initial_bucket_size_limits")
|
||||
rebuilt_bucket_lims = ddp_logging_data.get("rebuilt_bucket_size_limits")
|
||||
self.assertEqual(
|
||||
int(init_bucket_lims),
|
||||
-1,
|
||||
)
|
||||
self.assertEqual(
|
||||
int(rebuilt_bucket_lims),
|
||||
dist._DEFAULT_FIRST_BUCKET_BYTES,
|
||||
)
|
||||
self.assertEqual(
|
||||
ddp_logging_data.get("rebuilt_bucket_sizes"), str(param_size)
|
||||
)
|
||||
@ -8393,55 +8383,6 @@ class DistributedTest:
|
||||
self.assertEqual(opt[i]["tensor"].grad_fn, None)
|
||||
out.mean().backward()
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(
|
||||
BACKEND not in DistTestCases.backend_feature["ddp"],
|
||||
f"The {BACKEND} backend does not support DistributedDataParallel"
|
||||
)
|
||||
def test_ddp_get_bucket_sizes(self):
|
||||
torch.cuda.set_device(self.rank)
|
||||
default_bucket_cap_mb = 25 * (1024 ** 2)
|
||||
first_bucket_bytes_mb = dist._DEFAULT_FIRST_BUCKET_BYTES
|
||||
os.environ["DDP_SET_LAST_BUCKET_CAP"] = "1"
|
||||
|
||||
class MyModel(nn.Module):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.model = nn.Sequential(
|
||||
nn.Linear(2, 4000, bias=False),
|
||||
*[nn.Linear(4000, 4000, bias=False) for _ in range(10)]
|
||||
)
|
||||
|
||||
def forward(self, x):
|
||||
return self.model(x)
|
||||
|
||||
ddp = torch.nn.parallel.DistributedDataParallel(
|
||||
MyModel().cuda(),
|
||||
device_ids=[self.rank]
|
||||
)
|
||||
inp = torch.randn(10, 2)
|
||||
rebuilt_bucket_index = 2
|
||||
for i in range(6):
|
||||
out = ddp(inp).sum()
|
||||
out.backward()
|
||||
logging_data = ddp._get_ddp_logging_data()
|
||||
bucket_size_limits = [
|
||||
int(b) for b in logging_data[
|
||||
"{}_bucket_size_limits".format(
|
||||
"initial" if i < rebuilt_bucket_index else "rebuilt"
|
||||
)
|
||||
].split(", ")
|
||||
]
|
||||
# first_bucket_bytes is actually the last because we reverse
|
||||
# parameter bucket order under DDP_SET_LAST_BUCKET_CAP flag.
|
||||
if i <= 1:
|
||||
self.assertEqual(bucket_size_limits[-1], -1)
|
||||
else:
|
||||
self.assertEqual(bucket_size_limits[-1], first_bucket_bytes_mb)
|
||||
for j, bucket_size in enumerate(bucket_size_limits):
|
||||
if j != len(bucket_size_limits) - 1:
|
||||
self.assertEqual(bucket_size, default_bucket_cap_mb)
|
||||
|
||||
@skip_if_lt_x_gpu(2)
|
||||
@sandcastle_skip_if(
|
||||
BACKEND not in DistTestCases.backend_feature["ddp"],
|
||||
|
Reference in New Issue
Block a user