diff --git a/torch/_inductor/graph.py b/torch/_inductor/graph.py index db68320f3677..54c55ddb349b 100644 --- a/torch/_inductor/graph.py +++ b/torch/_inductor/graph.py @@ -243,6 +243,15 @@ class GraphLowering(torch.fx.Interpreter): if nconv == 0: return False + # Currently on ROCm we are seeing some slow downs in gcnArch that do not + # have optimal NHWC implementations. On ROCm MI200 series we will + # default to the enforced last channels behavior, but on non-MI200 series + # we will disable the forced layout. + if torch.version.hip and torch.cuda.is_available(): + gpu_name = torch.cuda.get_device_name(0) + if not re.search(r"MI2\d\d", gpu_name): + return False + # For cpu backend and mkldnn enabled, we always using channels_last for a better performance. if ( all(