[dynamo] Support warnings.catch_warnings (#123511)

Pull Request resolved: https://github.com/pytorch/pytorch/pull/123511
Approved by: https://github.com/anijain2305
This commit is contained in:
Jason Ansel
2024-04-07 19:06:16 -07:00
committed by PyTorch MergeBot
parent 6951626735
commit d8e0c26e64
21 changed files with 87 additions and 28 deletions

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,30
detectron2_fcos_r_50_fpn,pass,21

1 name accuracy graph_breaks
86 timm_efficientnet pass 0
87 timm_regnet pass 0
88 timm_resnest pass 0
89 timm_vision_transformer pass 0
90 timm_vision_transformer_large pass_due_to_skip 0
91 timm_vovnet pass 0
92 torch_multimodal_clip pass 0

View File

@ -54,47 +54,47 @@ densenet121,pass,0
detectron2_fasterrcnn_r_101_c4,pass,54
detectron2_fasterrcnn_r_101_c4,pass,42
detectron2_fasterrcnn_r_101_dc5,pass,54
detectron2_fasterrcnn_r_101_dc5,pass,42
detectron2_fasterrcnn_r_101_fpn,pass,58
detectron2_fasterrcnn_r_101_fpn,pass,46
detectron2_fasterrcnn_r_50_c4,pass,54
detectron2_fasterrcnn_r_50_c4,pass,42
detectron2_fasterrcnn_r_50_dc5,pass,54
detectron2_fasterrcnn_r_50_dc5,pass,42
detectron2_fasterrcnn_r_50_fpn,pass,58
detectron2_fasterrcnn_r_50_fpn,pass,46
detectron2_fcos_r_50_fpn,pass,33
detectron2_fcos_r_50_fpn,pass,23
detectron2_maskrcnn_r_101_c4,fail_accuracy,75
detectron2_maskrcnn_r_101_c4,fail_accuracy,57
detectron2_maskrcnn_r_101_fpn,pass,82
detectron2_maskrcnn_r_101_fpn,pass,64
detectron2_maskrcnn_r_50_c4,pass,75
detectron2_maskrcnn_r_50_c4,pass,57
detectron2_maskrcnn_r_50_fpn,pass,82
detectron2_maskrcnn_r_50_fpn,pass,64

1 name accuracy graph_breaks
54 nvidia_deeprecommender pass 0
55 opacus_cifar10 pass 0
56 phlippe_densenet pass 0
57 phlippe_resnet pass 0
58 pyhpc_equation_of_state pass 0
59 pyhpc_isoneutral_mixing pass 0
60 pyhpc_turbulent_kinetic_energy pass 0
61 pytorch_CycleGAN_and_pix2pix pass 0
62 pytorch_stargan pass 0
63 pytorch_unet pass 0
64 resnet152 pass 0
65 resnet18 pass 0
66 resnet50 pass 0
67 resnet50_quantized_qat pass 2
68 resnext50_32x4d pass 0
69 shufflenet_v2_x1_0 pass 0
70 soft_actor_critic pass 0
71 speech_transformer pass 10
72 squeezenet1_1 pass 0
73 stable_diffusion_unet pass_due_to_skip 0
74 timm_efficientdet model_fail_to_load 0
75 timm_efficientnet pass 0
76 timm_nfnet pass 0
77 timm_regnet pass 0
78 timm_resnest pass 0
79 timm_vision_transformer pass 0
80 timm_vision_transformer_large pass_due_to_skip 0
81 timm_vovnet pass 0
82 torch_multimodal_clip pass 0
83 tts_angular pass 2
84 vgg16 pass 0
85 vision_maskrcnn pass 28
86 yolov3 pass 2
87
88
89
90
91
92
93
94
95
96
97
98
99
100

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,30
detectron2_fcos_r_50_fpn,pass,21

1 name accuracy graph_breaks
86 timm_regnet pass 0
87 timm_resnest pass 0
88 timm_vision_transformer pass 0
89 timm_vision_transformer_large pass_due_to_skip 0
90 timm_vovnet pass 0
91 torch_multimodal_clip pass 0
92 tts_angular pass 2

View File

@ -54,7 +54,7 @@ densenet121,pass,0
detectron2_fcos_r_50_fpn,pass,33
detectron2_fcos_r_50_fpn,pass,23

1 name accuracy graph_breaks
54 resnet152 pass 0
55 resnet18 pass 0
56 resnet50 pass 0
57 resnet50_quantized_qat pass 2
58 resnext50_32x4d pass 0
59 shufflenet_v2_x1_0 pass 0
60 soft_actor_critic pass 0

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,31
detectron2_fcos_r_50_fpn,pass,21

1 name accuracy graph_breaks
86 timm_regnet pass 0
87 timm_resnest pass 0
88 timm_vision_transformer pass 0
89 timm_vision_transformer_large pass_due_to_skip 0
90 timm_vovnet pass 0
91 torch_multimodal_clip pass 0
92 tts_angular pass 2

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,30
detectron2_fcos_r_50_fpn,pass,21

1 name accuracy graph_breaks
86 timm_efficientnet pass 0
87 timm_regnet pass 0
88 timm_resnest pass 0
89 timm_vision_transformer pass 0
90 timm_vision_transformer_large pass_due_to_skip 0
91 timm_vovnet pass 0
92 torch_multimodal_clip pass 0

View File

@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
detectron2_fcos_r_50_fpn,pass,31
detectron2_fcos_r_50_fpn,pass,21

1 name accuracy graph_breaks
86 timm_efficientnet pass 0
87 timm_regnet pass 0
88 timm_resnest pass 0
89 timm_vision_transformer pass 0
90 timm_vision_transformer_large pass_due_to_skip 0
91 timm_vovnet pass 0
92 torch_multimodal_clip pass 0

View File

@ -7,14 +7,12 @@ import dis
import enum
import functools
import gc
import io
import itertools
import logging
import math
import operator
import os
import random
import re
import sys
import tempfile
import threading
@ -27,8 +25,6 @@ import weakref
from unittest.mock import patch
import numpy as np
import pytest
import sympy
import torch
import torch._dynamo.test_case
@ -2109,6 +2105,33 @@ utils_device.CURRENT_DEVICE == None""".split(
self.assertEqual(cnt.frame_count, 1)
self.assertEqual(cnt.op_count, 2)
def test_catch_watchings1(self):
cnt = CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(x):
with warnings.catch_warnings(record=True):
return x.sin()
x = torch.randn(8)
self.assertEqual(fn(x), x.sin())
self.assertEqual(cnt.frame_count, 1)
def test_catch_watchings2(self):
cnt = CompileCounter()
@torch.compile(backend=cnt, fullgraph=True)
def fn(x):
return x.sin(), warnings.catch_warnings(record=True)
x = torch.randn(8)
_, a = fn(x)
_, b = fn(x)
self.assertEqual(cnt.frame_count, 1)
self.assertIsInstance(a, warnings.catch_warnings)
self.assertIsInstance(b, warnings.catch_warnings)
self.assertIsNot(a, b)
def test_tensor_build_list_unpack(self):
def fn(x):
# seen in fastNLP_Bert

View File

@ -11,6 +11,7 @@ import inspect
import itertools
import random
import unittest
import warnings
import weakref
from abc import ABC
from collections import namedtuple
@ -4485,9 +4486,8 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
raise NotImplementedError("Empty Instances does not support __len__!")
def set(self, name: str, value: Any) -> None:
# TODO(jansel): support catch_warnings
# with warnings.catch_warnings(record=True):
data_len = len(value)
with warnings.catch_warnings(record=True):
data_len = len(value)
if len(self._fields):
assert (
len(self) == data_len
@ -4499,8 +4499,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
@staticmethod
def cat(instance_lists: List["Instances"]) -> "Instances":
# TODO(jansel): support all isinstance generator
# assert all(isinstance(i, Instances) for i in instance_lists)
assert all(isinstance(i, Instances) for i in instance_lists)
assert len(instance_lists) > 0
if len(instance_lists) == 1:
return instance_lists[0]
@ -4529,7 +4528,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
return ret
instances = [
Instances((16, 16), a=[torch.randn(16, 16)], b=[torch.randn(16, 16)])
Instances((16, 16), a=torch.randn(16, 16), b=torch.randn(16, 16))
for _ in range(3)
]

View File

@ -4,6 +4,7 @@ from .base import VariableTracker
from .builtin import BuiltinVariable
from .constant import ConstantVariable, EnumVariable
from .ctx_manager import (
CatchWarningsCtxManagerVariable,
ContextWrappingVariable,
DeterministicAlgorithmsVariable,
DisabledSavedTensorsHooksVariable,
@ -101,6 +102,7 @@ __all__ = [
"BackwardHookVariable",
"BaseListVariable",
"BuiltinVariable",
"CatchWarningsCtxManagerVariable",
"ClosureVariable",
"ConstantVariable",
"ConstDictVariable",

View File

@ -1,7 +1,7 @@
# mypy: ignore-errors
import dataclasses
import inspect
import warnings
from typing import Callable, Dict, List, Optional
import torch._C
@ -347,6 +347,37 @@ class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
return variables.ConstantVariable.create(None)
class CatchWarningsCtxManagerVariable(ContextWrappingVariable):
"""Delay a call to warnings.catch_warnings"""
@staticmethod
def create(tx, catch_warnings_args):
return CatchWarningsCtxManagerVariable(
catch_warnings_args=catch_warnings_args,
target_values=None,
initial_values=None,
)
def __init__(self, catch_warnings_args, **kwargs):
assert isinstance(catch_warnings_args, dict), catch_warnings_args
super().__init__(**kwargs)
self.catch_warnings_args = catch_warnings_args
def enter(self, tx):
kwargs = {
k: v.as_python_constant() for k, v in self.catch_warnings_args.items()
}
ctx_val = warnings.catch_warnings(**kwargs)
self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None))
return variables.ConstantVariable.create(ctx_val.__enter__())
def reconstruct(self, cg):
cg.load_import_from("warnings", "catch_warnings")
cg.foreach(self.catch_warnings_args.values())
keys = tuple(self.catch_warnings_args.keys())
cg.extend_output(cg.create_call_function_kw(len(keys), keys, True))
class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
"""represents torch VMap increment/decrement nesting"""

View File

@ -10,6 +10,8 @@ import random
import sys
import threading
import types
import warnings
from typing import Dict, Generic, List
from ..bytecode_transformation import create_call_function
@ -303,6 +305,8 @@ class UserDefinedClassVariable(UserDefinedVariable):
return variables.functions.FunctoolsPartialVariable(
fn, args=rest_args, keywords=kwargs
)
elif self.value is warnings.catch_warnings and not args:
return variables.CatchWarningsCtxManagerVariable.create(tx, kwargs)
elif (
issubclass(type(self.value), type)
and hasattr(