mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[PP] [BE] Remove runtime tests (#164962)
BE cleaning up dead code since we migrated the Multi-stage schedules to use schedule execution runtime Pull Request resolved: https://github.com/pytorch/pytorch/pull/164962 Approved by: https://github.com/Skylion007 ghstack dependencies: #162016
This commit is contained in:
committed by
PyTorch MergeBot
parent
1d182dd81c
commit
f0c9f3bddb
@ -2,7 +2,6 @@
|
||||
# Owner(s): ["oncall: distributed"]
|
||||
import copy
|
||||
import logging
|
||||
import tempfile
|
||||
from dataclasses import dataclass
|
||||
|
||||
from model_registry import ModelWithKwargs, MultiMLP, MultiMLPKwargs, MultiMLPWithDw
|
||||
@ -523,8 +522,7 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
ScheduleInterleavedZeroBubble,
|
||||
],
|
||||
)
|
||||
@parametrize("use_new_runtime", [False, True])
|
||||
def test_grad_with_manual_interleaved(self, ScheduleClass, use_new_runtime):
|
||||
def test_grad_with_manual_interleaved(self, ScheduleClass):
|
||||
stages_per_rank = 2
|
||||
n_stages = stages_per_rank * self.world_size
|
||||
mod, ref_mod, x, target, loss_fn = setup_models_and_data(
|
||||
@ -551,46 +549,6 @@ class ScheduleTest(MultiProcContinuousTest):
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
|
||||
# Handle new runtime testing
|
||||
if use_new_runtime:
|
||||
old_schedule = schedule
|
||||
tmp_schedule = _PipelineScheduleRuntime(
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
tmp_schedule._prepare_schedule_with_comms(old_schedule.pipeline_order)
|
||||
|
||||
# Test CSV round-trip for compute_comms schedule
|
||||
schedule = _PipelineScheduleRuntime(
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
with tempfile.NamedTemporaryFile() as f:
|
||||
tmp_schedule._dump_csv(f.name)
|
||||
f.seek(0)
|
||||
schedule._load_csv(f.name, format="compute_comms")
|
||||
|
||||
one_more_schedule = _PipelineScheduleRuntime(
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
one_more_schedule._prepare_schedule_with_comms(
|
||||
schedule.pipeline_order_with_comms, format="compute_comms"
|
||||
)
|
||||
|
||||
# Verify schedule consistency
|
||||
self.assertEqual(
|
||||
len(schedule.pipeline_order_with_comms),
|
||||
len(one_more_schedule.pipeline_order_with_comms),
|
||||
)
|
||||
for rank in schedule.pipeline_order_with_comms:
|
||||
self.assertEqual(
|
||||
len(schedule.pipeline_order_with_comms[rank]),
|
||||
len(one_more_schedule.pipeline_order_with_comms[rank]),
|
||||
)
|
||||
for a, b in zip(
|
||||
schedule.pipeline_order_with_comms[rank],
|
||||
one_more_schedule.pipeline_order_with_comms[rank],
|
||||
):
|
||||
self.assertEqual(a, b)
|
||||
|
||||
# Run pipeline with tensor leak checking
|
||||
out = None
|
||||
losses = []
|
||||
@ -1050,8 +1008,7 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
"schedule_class",
|
||||
[ScheduleVShaped, ScheduleUnbalanced],
|
||||
)
|
||||
@parametrize("use_new_runtime", [False, True])
|
||||
def test_non_symmetric_stage_ids(self, schedule_class, use_new_runtime):
|
||||
def test_non_symmetric_stage_ids(self, schedule_class):
|
||||
n_stages = schedule_class.n_stages
|
||||
rank_stages = schedule_class.rank_stages
|
||||
|
||||
@ -1074,13 +1031,6 @@ class CustomSchedulesTest(MultiProcContinuousTest):
|
||||
stages, num_microbatches, loss_fn=loss_fn, scale_grads=False
|
||||
)
|
||||
|
||||
if use_new_runtime:
|
||||
old_schedule = schedule
|
||||
schedule = _PipelineScheduleRuntime(
|
||||
stages, num_microbatches, loss_fn=loss_fn
|
||||
)
|
||||
schedule._prepare_schedule_with_comms(old_schedule.pipeline_order)
|
||||
|
||||
# Run pipeline - special case where first and last stage are on rank 0
|
||||
out = None
|
||||
losses = []
|
||||
|
Reference in New Issue
Block a user