From 07ca1fc88b310da5bb20143df05500a65f2a3c61 Mon Sep 17 00:00:00 2001 From: rraminen Date: Tue, 25 Jan 2022 12:19:32 -0800 Subject: [PATCH] remove hasPrimaryContext workaround on ROCm (#71146) Summary: As issue https://github.com/pytorch/pytorch/issues/59750 is fixed, this PR is to remove the workaround implemented for it on ROCm. Enabled hasPrimaryContext() related PyTorch unit tests on ROCm. cc: amathews-amd, jithunnair-amd cc jeffdaily sunway513 jithunnair-amd ROCmSupport KyleCZH Pull Request resolved: https://github.com/pytorch/pytorch/pull/71146 Reviewed By: anjali411 Differential Revision: D33754615 Pulled By: albanD fbshipit-source-id: b3a5c65a20c6d52d5f2ffc9e6f9628c819329b5d (cherry picked from commit cfdd12166cfd1365de0ebe5a75ce40ac7fde15cc) --- test/test_cuda_primary_ctx.py | 4 ++-- torch/csrc/autograd/engine.cpp | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/test/test_cuda_primary_ctx.py b/test/test_cuda_primary_ctx.py index f194deb4da56..74005515a495 100644 --- a/test/test_cuda_primary_ctx.py +++ b/test/test_cuda_primary_ctx.py @@ -1,7 +1,7 @@ # Owner(s): ["module: cuda"] import torch -from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocm +from torch.testing._internal.common_utils import TestCase, run_tests, skipIfRocmVersionLessThan import sys import unittest @@ -25,7 +25,7 @@ class TestCudaPrimaryCtx(TestCase): "where CUDA contexts are never created. Use either run_test.py or add " "--subprocess to run each test in a different subprocess.") - @skipIfRocm + @skipIfRocmVersionLessThan((4, 4, 21504)) def setUp(self): for device in range(torch.cuda.device_count()): # Ensure context has not been created beforehand diff --git a/torch/csrc/autograd/engine.cpp b/torch/csrc/autograd/engine.cpp index e231f4b527c4..22f138e2a14f 100644 --- a/torch/csrc/autograd/engine.cpp +++ b/torch/csrc/autograd/engine.cpp @@ -1262,7 +1262,7 @@ void GraphTask::stash_current_streams() { caller_current_streams_.resize(num_gpus); if (num_gpus > 0) { for (c10::DeviceIndex idx = 0; idx < num_gpus; idx++) { -#if defined(USE_ROCM) +#if defined(USE_ROCM) && (ROCM_VERSION < 50000) // If the build targets ROCM, stash streams for all visible devices unconditionally, to work around // https://github.com/pytorch/pytorch/issues/59750. // TODO: Remove ROCM-specific behavior when https://github.com/pytorch/pytorch/issues/59750 is fixed.