mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
[dynamo] Support warnings.catch_warnings (#123511)
Pull Request resolved: https://github.com/pytorch/pytorch/pull/123511 Approved by: https://github.com/anijain2305
This commit is contained in:
committed by
PyTorch MergeBot
parent
6951626735
commit
d8e0c26e64
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,30
|
||||
detectron2_fcos_r_50_fpn,pass,21
|
||||
|
||||
|
||||
|
||||
|
|
@ -54,47 +54,47 @@ densenet121,pass,0
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_101_c4,pass,54
|
||||
detectron2_fasterrcnn_r_101_c4,pass,42
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_101_dc5,pass,54
|
||||
detectron2_fasterrcnn_r_101_dc5,pass,42
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_101_fpn,pass,58
|
||||
detectron2_fasterrcnn_r_101_fpn,pass,46
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_50_c4,pass,54
|
||||
detectron2_fasterrcnn_r_50_c4,pass,42
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_50_dc5,pass,54
|
||||
detectron2_fasterrcnn_r_50_dc5,pass,42
|
||||
|
||||
|
||||
|
||||
detectron2_fasterrcnn_r_50_fpn,pass,58
|
||||
detectron2_fasterrcnn_r_50_fpn,pass,46
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,33
|
||||
detectron2_fcos_r_50_fpn,pass,23
|
||||
|
||||
|
||||
|
||||
detectron2_maskrcnn_r_101_c4,fail_accuracy,75
|
||||
detectron2_maskrcnn_r_101_c4,fail_accuracy,57
|
||||
|
||||
|
||||
|
||||
detectron2_maskrcnn_r_101_fpn,pass,82
|
||||
detectron2_maskrcnn_r_101_fpn,pass,64
|
||||
|
||||
|
||||
|
||||
detectron2_maskrcnn_r_50_c4,pass,75
|
||||
detectron2_maskrcnn_r_50_c4,pass,57
|
||||
|
||||
|
||||
|
||||
detectron2_maskrcnn_r_50_fpn,pass,82
|
||||
detectron2_maskrcnn_r_50_fpn,pass,64
|
||||
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,30
|
||||
detectron2_fcos_r_50_fpn,pass,21
|
||||
|
||||
|
||||
|
||||
|
|
@ -54,7 +54,7 @@ densenet121,pass,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,33
|
||||
detectron2_fcos_r_50_fpn,pass,23
|
||||
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,31
|
||||
detectron2_fcos_r_50_fpn,pass,21
|
||||
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,30
|
||||
detectron2_fcos_r_50_fpn,pass,21
|
||||
|
||||
|
||||
|
||||
|
|
@ -86,7 +86,7 @@ detectron2_fasterrcnn_r_50_fpn,eager_fail_to_run,0
|
||||
|
||||
|
||||
|
||||
detectron2_fcos_r_50_fpn,pass,31
|
||||
detectron2_fcos_r_50_fpn,pass,21
|
||||
|
||||
|
||||
|
||||
|
|
@ -7,14 +7,12 @@ import dis
|
||||
import enum
|
||||
import functools
|
||||
import gc
|
||||
import io
|
||||
import itertools
|
||||
import logging
|
||||
import math
|
||||
import operator
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import sys
|
||||
import tempfile
|
||||
import threading
|
||||
@ -27,8 +25,6 @@ import weakref
|
||||
from unittest.mock import patch
|
||||
|
||||
import numpy as np
|
||||
import pytest
|
||||
import sympy
|
||||
import torch
|
||||
|
||||
import torch._dynamo.test_case
|
||||
@ -2109,6 +2105,33 @@ utils_device.CURRENT_DEVICE == None""".split(
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertEqual(cnt.op_count, 2)
|
||||
|
||||
def test_catch_watchings1(self):
|
||||
cnt = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnt, fullgraph=True)
|
||||
def fn(x):
|
||||
with warnings.catch_warnings(record=True):
|
||||
return x.sin()
|
||||
|
||||
x = torch.randn(8)
|
||||
self.assertEqual(fn(x), x.sin())
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
|
||||
def test_catch_watchings2(self):
|
||||
cnt = CompileCounter()
|
||||
|
||||
@torch.compile(backend=cnt, fullgraph=True)
|
||||
def fn(x):
|
||||
return x.sin(), warnings.catch_warnings(record=True)
|
||||
|
||||
x = torch.randn(8)
|
||||
_, a = fn(x)
|
||||
_, b = fn(x)
|
||||
self.assertEqual(cnt.frame_count, 1)
|
||||
self.assertIsInstance(a, warnings.catch_warnings)
|
||||
self.assertIsInstance(b, warnings.catch_warnings)
|
||||
self.assertIsNot(a, b)
|
||||
|
||||
def test_tensor_build_list_unpack(self):
|
||||
def fn(x):
|
||||
# seen in fastNLP_Bert
|
||||
|
@ -11,6 +11,7 @@ import inspect
|
||||
import itertools
|
||||
import random
|
||||
import unittest
|
||||
import warnings
|
||||
import weakref
|
||||
from abc import ABC
|
||||
from collections import namedtuple
|
||||
@ -4485,9 +4486,8 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
raise NotImplementedError("Empty Instances does not support __len__!")
|
||||
|
||||
def set(self, name: str, value: Any) -> None:
|
||||
# TODO(jansel): support catch_warnings
|
||||
# with warnings.catch_warnings(record=True):
|
||||
data_len = len(value)
|
||||
with warnings.catch_warnings(record=True):
|
||||
data_len = len(value)
|
||||
if len(self._fields):
|
||||
assert (
|
||||
len(self) == data_len
|
||||
@ -4499,8 +4499,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
|
||||
@staticmethod
|
||||
def cat(instance_lists: List["Instances"]) -> "Instances":
|
||||
# TODO(jansel): support all isinstance generator
|
||||
# assert all(isinstance(i, Instances) for i in instance_lists)
|
||||
assert all(isinstance(i, Instances) for i in instance_lists)
|
||||
assert len(instance_lists) > 0
|
||||
if len(instance_lists) == 1:
|
||||
return instance_lists[0]
|
||||
@ -4529,7 +4528,7 @@ def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
||||
return ret
|
||||
|
||||
instances = [
|
||||
Instances((16, 16), a=[torch.randn(16, 16)], b=[torch.randn(16, 16)])
|
||||
Instances((16, 16), a=torch.randn(16, 16), b=torch.randn(16, 16))
|
||||
for _ in range(3)
|
||||
]
|
||||
|
||||
|
@ -4,6 +4,7 @@ from .base import VariableTracker
|
||||
from .builtin import BuiltinVariable
|
||||
from .constant import ConstantVariable, EnumVariable
|
||||
from .ctx_manager import (
|
||||
CatchWarningsCtxManagerVariable,
|
||||
ContextWrappingVariable,
|
||||
DeterministicAlgorithmsVariable,
|
||||
DisabledSavedTensorsHooksVariable,
|
||||
@ -101,6 +102,7 @@ __all__ = [
|
||||
"BackwardHookVariable",
|
||||
"BaseListVariable",
|
||||
"BuiltinVariable",
|
||||
"CatchWarningsCtxManagerVariable",
|
||||
"ClosureVariable",
|
||||
"ConstantVariable",
|
||||
"ConstDictVariable",
|
||||
|
@ -1,7 +1,7 @@
|
||||
# mypy: ignore-errors
|
||||
|
||||
import dataclasses
|
||||
import inspect
|
||||
import warnings
|
||||
from typing import Callable, Dict, List, Optional
|
||||
|
||||
import torch._C
|
||||
@ -347,6 +347,37 @@ class GradIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
||||
return variables.ConstantVariable.create(None)
|
||||
|
||||
|
||||
class CatchWarningsCtxManagerVariable(ContextWrappingVariable):
|
||||
"""Delay a call to warnings.catch_warnings"""
|
||||
|
||||
@staticmethod
|
||||
def create(tx, catch_warnings_args):
|
||||
return CatchWarningsCtxManagerVariable(
|
||||
catch_warnings_args=catch_warnings_args,
|
||||
target_values=None,
|
||||
initial_values=None,
|
||||
)
|
||||
|
||||
def __init__(self, catch_warnings_args, **kwargs):
|
||||
assert isinstance(catch_warnings_args, dict), catch_warnings_args
|
||||
super().__init__(**kwargs)
|
||||
self.catch_warnings_args = catch_warnings_args
|
||||
|
||||
def enter(self, tx):
|
||||
kwargs = {
|
||||
k: v.as_python_constant() for k, v in self.catch_warnings_args.items()
|
||||
}
|
||||
ctx_val = warnings.catch_warnings(**kwargs)
|
||||
self.set_cleanup_hook(tx, lambda: ctx_val.__exit__(None, None, None))
|
||||
return variables.ConstantVariable.create(ctx_val.__enter__())
|
||||
|
||||
def reconstruct(self, cg):
|
||||
cg.load_import_from("warnings", "catch_warnings")
|
||||
cg.foreach(self.catch_warnings_args.values())
|
||||
keys = tuple(self.catch_warnings_args.keys())
|
||||
cg.extend_output(cg.create_call_function_kw(len(keys), keys, True))
|
||||
|
||||
|
||||
class VmapIncrementNestingCtxManagerVariable(ContextWrappingVariable):
|
||||
"""represents torch VMap increment/decrement nesting"""
|
||||
|
||||
|
@ -10,6 +10,8 @@ import random
|
||||
import sys
|
||||
import threading
|
||||
import types
|
||||
import warnings
|
||||
|
||||
from typing import Dict, Generic, List
|
||||
|
||||
from ..bytecode_transformation import create_call_function
|
||||
@ -303,6 +305,8 @@ class UserDefinedClassVariable(UserDefinedVariable):
|
||||
return variables.functions.FunctoolsPartialVariable(
|
||||
fn, args=rest_args, keywords=kwargs
|
||||
)
|
||||
elif self.value is warnings.catch_warnings and not args:
|
||||
return variables.CatchWarningsCtxManagerVariable.create(tx, kwargs)
|
||||
elif (
|
||||
issubclass(type(self.value), type)
|
||||
and hasattr(
|
||||
|
Reference in New Issue
Block a user