mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
FIXES https://github.com/pytorch/pytorch/issues/157945 ## Explanation 1. Some VTs add additional attrs e.g. NamedTupleVariable has "dynamic_attributes"a0308edb6c/torch/_dynamo/variables/lists.py (L1048-L1051)
2. VT.clone passes everything by dict, includes "dynamic_attributes"a0308edb6c/torch/_dynamo/variables/base.py (L255-L259)
3. Non-handled args become kwargs in VT's `__init__`, `super().__init__()` passes kwargs to Base VTa0308edb6c/torch/_dynamo/variables/lists.py (L1048-L1051)
4. Base VT's `__init__` gets unexpected "dynamic_attributes" kwarga0308edb6c/torch/_dynamo/variables/base.py (L609-L613)
You could also let Base VT's `__init__` ignore additional kwargs, but that seemed a bit too permissive, and I don't think many VT's add these derived class only attrs. ## After fix ```python ===== __compiled_fn_1_7f9541ed_e166_43fe_8322_c5225ce4207f ===== /home/xmfan/core/miniconda3/envs/0712/lib/python3.12/site-packages/torch/fx/_lazy_graph_module.py class GraphModule(torch.nn.Module): def forward(self, L_x_: "f32[4, 8, 6][48, 6, 1]cpu"): l_x_ = L_x_ # File: /home/xmfan/core/a/torchtitan/wtf.py:10 in forward, code: U, S = torch.linalg.svd(x)[:2] linalg_svd = torch._C._linalg.linalg_svd(l_x_); l_x_ = None U: "f32[4, 8, 8][64, 1, 8]cpu" = linalg_svd[0] S: "f32[4, 6][6, 1]cpu" = linalg_svd[1]; linalg_svd = None # File: /home/xmfan/core/a/torchtitan/wtf.py:11 in forward, code: reduced = U[:, :, :self.k] @ torch.diag_embed(S[:, :self.k]) getitem_3: "f32[4, 8, 5][64, 1, 8]cpu" = U[(slice(None, None, None), slice(None, None, None), slice(None, 5, None))]; U = None getitem_4: "f32[4, 5][6, 1]cpu" = S[(slice(None, None, None), slice(None, 5, None))]; S = None diag_embed: "f32[4, 5, 5][25, 5, 1]cpu" = torch.diag_embed(getitem_4); getitem_4 = None reduced: "f32[4, 8, 5][40, 5, 1]cpu" = getitem_3 @ diag_embed; getitem_3 = diag_embed = None return (reduced,) ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/158190 Approved by: https://github.com/StrongerXi
7600 lines
245 KiB
Python
7600 lines
245 KiB
Python
"""
|
|
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
|
|
with test_rewrite_assert_with_msg and test_rewrite_assert_without_msg)
|
|
"""
|
|
|
|
# Owner(s): ["module: dynamo"]
|
|
import collections
|
|
import contextlib
|
|
import copy
|
|
import dataclasses
|
|
import functools
|
|
import gc
|
|
import importlib
|
|
import inspect
|
|
import itertools
|
|
import logging
|
|
import os
|
|
import random
|
|
import sys
|
|
import types
|
|
import typing
|
|
import unittest
|
|
import warnings
|
|
import weakref
|
|
from abc import ABC
|
|
from collections import defaultdict, namedtuple
|
|
from collections.abc import Iterator
|
|
from copy import deepcopy
|
|
from enum import Enum, IntEnum
|
|
from functools import wraps
|
|
from typing import Any, Literal, TypedDict
|
|
from unittest import mock
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch._dynamo.utils
|
|
import torch._functorch.config
|
|
import torch.distributed as dist
|
|
import torch.library
|
|
import torch.utils._pytree as pytree
|
|
from torch import nn
|
|
from torch._dynamo.debug_utils import same_two_models
|
|
from torch._dynamo.testing import CompileCounter, rand_strided, same, skipIfPy312
|
|
from torch._inductor.utils import fresh_cache
|
|
from torch.nn import functional as F
|
|
from torch.profiler import profile, ProfilerActivity
|
|
from torch.testing._internal.common_cuda import (
|
|
PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
|
PLATFORM_SUPPORTS_FP8,
|
|
SM70OrLater,
|
|
TEST_CUDA,
|
|
)
|
|
from torch.testing._internal.common_device_type import instantiate_device_type_tests
|
|
from torch.testing._internal.common_utils import (
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
serialTest,
|
|
skipIfHpu,
|
|
skipIfWindows,
|
|
TEST_WITH_ROCM,
|
|
)
|
|
from torch.testing._internal.logging_utils import LoggingTestCase, make_logging_test
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
from torch.utils._python_dispatch import TorchDispatchMode
|
|
|
|
|
|
_orig_module_call = torch.nn.Module.__call__
|
|
|
|
# Custom operator that only supports CPU and Meta
|
|
lib = torch.library.Library("test_sample", "DEF") # noqa: TOR901
|
|
lib.define("foo(Tensor self) -> Tensor")
|
|
lib.impl("foo", torch.sin, "CPU")
|
|
|
|
|
|
requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
|
|
|
|
|
|
_GLOBAL_CPU_TENSOR = torch.randn(3)
|
|
|
|
HAS_MSGSPEC = importlib.util.find_spec("msgspec")
|
|
if HAS_MSGSPEC:
|
|
import msgspec
|
|
|
|
|
|
HAS_OMEGACONG = importlib.util.find_spec("omegaconf")
|
|
if HAS_OMEGACONG:
|
|
from omegaconf import OmegaConf
|
|
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def maybe(fn):
|
|
@wraps(fn)
|
|
def inner(x, *args, **kwargs):
|
|
if not exists(x):
|
|
return x
|
|
return fn(x, *args, **kwargs)
|
|
|
|
return inner
|
|
|
|
|
|
def is_fx_tracing_test() -> bool:
|
|
"""
|
|
Copied from the hpc trainer codebase
|
|
"""
|
|
return torch.nn.Module.__call__ is not _orig_module_call
|
|
|
|
|
|
def has_detectron2():
|
|
try:
|
|
from detectron2.layers.mask_ops import _paste_masks_tensor_shape
|
|
|
|
return _paste_masks_tensor_shape is not None
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True):
|
|
# from detectron2 mask_ops.py
|
|
|
|
device = masks.device
|
|
|
|
if skip_empty and not torch.jit.is_scripting():
|
|
x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
|
|
dtype=torch.int32
|
|
)
|
|
x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(
|
|
dtype=torch.int32
|
|
)
|
|
y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(
|
|
dtype=torch.int32
|
|
)
|
|
else:
|
|
x0_int, y0_int = 0, 0
|
|
x1_int, y1_int = img_w, img_h
|
|
x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1
|
|
|
|
N = masks.shape[0]
|
|
|
|
img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
|
|
img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
|
|
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
|
|
img_x = (img_x - x0) / (x1 - x0) * 2 - 1
|
|
# img_x, img_y have shapes (N, w), (N, h)
|
|
|
|
gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
|
|
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
|
|
grid = torch.stack([gx, gy], dim=3)
|
|
|
|
if not torch.jit.is_scripting():
|
|
if not masks.dtype.is_floating_point:
|
|
masks = masks.float()
|
|
img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)
|
|
|
|
if skip_empty and not torch.jit.is_scripting():
|
|
return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
|
|
else:
|
|
return img_masks[:, 0], ()
|
|
|
|
|
|
def global_fn(x):
|
|
return torch.sin(x)
|
|
|
|
|
|
def cat(tensors, dim=0):
|
|
# from detectron2 wrappers.py
|
|
assert isinstance(tensors, (list, tuple))
|
|
if len(tensors) == 1:
|
|
return tensors[0]
|
|
return torch.cat(tensors, dim)
|
|
|
|
|
|
def shapes_to_tensor(x, device=None):
|
|
# from detectron2 wrappers.py
|
|
if torch.jit.is_scripting():
|
|
return torch.as_tensor(x, device=device)
|
|
if torch.jit.is_tracing():
|
|
assert all(isinstance(t, torch.Tensor) for t in x), (
|
|
"Shape should be tensor during tracing!"
|
|
)
|
|
# as_tensor should not be used in tracing because it records a constant
|
|
ret = torch.stack(x)
|
|
if ret.device != device: # avoid recording a hard-coded device if not necessary
|
|
ret = ret.to(device=device)
|
|
return ret
|
|
return torch.as_tensor(x, device=device)
|
|
|
|
|
|
fw_graph = [None]
|
|
bw_graph = [None]
|
|
|
|
|
|
def aot_graph_capture_backend(gm, args):
|
|
from functorch.compile import min_cut_rematerialization_partition
|
|
from torch._functorch.aot_autograd import aot_module_simplified
|
|
|
|
def fw_compiler(gm, _):
|
|
fw_graph[0] = gm
|
|
return gm
|
|
|
|
def bw_compiler(gm, _):
|
|
bw_graph[0] = gm
|
|
return gm
|
|
|
|
return aot_module_simplified(
|
|
gm,
|
|
args,
|
|
fw_compiler,
|
|
bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
keep_inference_input_mutations=True,
|
|
)
|
|
|
|
|
|
class Boxes:
|
|
# from detectron2 poolers.py
|
|
def __init__(self, tensor: torch.Tensor):
|
|
"""
|
|
Args:
|
|
tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2).
|
|
"""
|
|
device = (
|
|
tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
|
|
)
|
|
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
|
|
if tensor.numel() == 0:
|
|
# Use reshape, so we don't end up creating a new tensor that does not depend on
|
|
# the inputs (and consequently confuses jit)
|
|
tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32, device=device)
|
|
assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
|
|
self.tensor = tensor
|
|
|
|
def __len__(self) -> int:
|
|
return self.tensor.shape[0]
|
|
|
|
@property
|
|
def device(self):
|
|
return self.tensor.device
|
|
|
|
|
|
def convert_boxes_to_pooler_format(box_lists):
|
|
# from detectron2 structures.py
|
|
boxes = torch.cat([x.tensor for x in box_lists], dim=0)
|
|
# __len__ returns Tensor in tracing.
|
|
sizes = shapes_to_tensor([x.__len__() for x in box_lists], device=boxes.device)
|
|
indices = torch.repeat_interleave(
|
|
torch.arange(len(box_lists), dtype=boxes.dtype, device=boxes.device), sizes
|
|
)
|
|
return cat([indices[:, None], boxes], dim=1)
|
|
|
|
|
|
ReformerBackwardOutput = namedtuple(
|
|
"ReformerBackwardOutput",
|
|
["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"],
|
|
)
|
|
ReformerEncoderOutput = namedtuple(
|
|
"ReformerEncoderOutput",
|
|
["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"],
|
|
)
|
|
|
|
|
|
class _ReversibleFunction(torch.autograd.Function):
|
|
# taken from modeling_reformer.py in huggingface
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
hidden_states,
|
|
layers,
|
|
attention_mask,
|
|
head_mask,
|
|
num_hashes,
|
|
all_hidden_states,
|
|
all_attentions,
|
|
past_buckets_states,
|
|
use_cache,
|
|
orig_sequence_length,
|
|
output_hidden_states,
|
|
output_attentions,
|
|
):
|
|
all_buckets = ()
|
|
|
|
# split duplicated tensor
|
|
hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1)
|
|
|
|
for layer in layers:
|
|
if output_hidden_states is True:
|
|
all_hidden_states.append(hidden_states)
|
|
|
|
attn_output = layer(attn_output)
|
|
all_buckets = all_buckets + (attn_output,)
|
|
|
|
# Add last layer
|
|
if output_hidden_states is True:
|
|
all_hidden_states.append(hidden_states)
|
|
|
|
# attach params to ctx for backward
|
|
ctx.save_for_backward(attn_output.detach(), hidden_states.detach())
|
|
ctx.layers = layers
|
|
ctx.all_buckets = all_buckets
|
|
ctx.head_mask = head_mask
|
|
ctx.attention_mask = attention_mask
|
|
|
|
# Concatenate 2 RevNet outputs
|
|
return torch.cat([attn_output, hidden_states], dim=-1)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_hidden_states):
|
|
grad_attn_output, grad_hidden_states = torch.chunk(
|
|
grad_hidden_states, 2, dim=-1
|
|
)
|
|
|
|
# free memory
|
|
del grad_attn_output
|
|
|
|
# num of return vars has to match num of forward() args
|
|
# return gradient for hidden_states arg and None for other args
|
|
return (
|
|
grad_hidden_states,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
|
|
|
|
class ReformerEncoder(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.dropout = 0.5
|
|
self.layer_norm = torch.nn.LayerNorm(512, eps=1.0e-12)
|
|
self.layers = [torch.nn.Linear(256, 256)]
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=[None] * 6,
|
|
num_hashes=None,
|
|
use_cache=False,
|
|
orig_sequence_length=64,
|
|
output_hidden_states=False,
|
|
output_attentions=False,
|
|
):
|
|
# hidden_states and attention lists to be filled if wished
|
|
all_hidden_states = []
|
|
all_attentions = []
|
|
past_buckets_states = [((None), (None)) for i in range(len(self.layers))]
|
|
|
|
# concat same tensor for reversible ResNet
|
|
hidden_states = torch.cat([hidden_states, hidden_states], dim=-1)
|
|
hidden_states = _ReversibleFunction.apply(
|
|
hidden_states,
|
|
self.layers,
|
|
attention_mask,
|
|
head_mask,
|
|
num_hashes,
|
|
all_hidden_states,
|
|
all_attentions,
|
|
past_buckets_states,
|
|
use_cache,
|
|
orig_sequence_length,
|
|
output_hidden_states,
|
|
output_attentions,
|
|
)
|
|
|
|
# Apply layer norm to concatenated hidden states
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
|
|
# Apply dropout
|
|
hidden_states = torch.nn.functional.dropout(
|
|
hidden_states, p=self.dropout, training=self.training
|
|
)
|
|
|
|
return ReformerEncoderOutput(
|
|
hidden_states=hidden_states,
|
|
all_hidden_states=all_hidden_states,
|
|
all_attentions=all_attentions,
|
|
past_buckets_states=past_buckets_states,
|
|
)
|
|
|
|
|
|
class ListConfig:
|
|
class ValueNode:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def _dereference_node(self):
|
|
return self
|
|
|
|
def _is_missing(self):
|
|
return False
|
|
|
|
def _value(self):
|
|
return self.value
|
|
|
|
# Based on an example from omegaconfig.listconfig
|
|
class ListIterator(Iterator[Any]):
|
|
def __init__(self, lst: Any, resolve: bool) -> None:
|
|
self.resolve = resolve
|
|
self.iterator = iter(lst.__dict__["_content"])
|
|
self.index = 0
|
|
|
|
def __next__(self) -> Any:
|
|
x = next(self.iterator)
|
|
if self.resolve:
|
|
x = x._dereference_node()
|
|
if x._is_missing():
|
|
raise AssertionError
|
|
|
|
self.index = self.index + 1
|
|
if isinstance(x, ListConfig.ValueNode):
|
|
return x._value()
|
|
raise AssertionError
|
|
|
|
def __iter__(self):
|
|
return self._iter_ex(True)
|
|
|
|
def _iter_ex(self, resolve: bool) -> Iterator[Any]:
|
|
try:
|
|
return ListConfig.ListIterator(self, resolve)
|
|
except Exception:
|
|
raise AssertionError from None
|
|
|
|
def __init__(self) -> None:
|
|
self._content = [
|
|
ListConfig.ValueNode(1),
|
|
ListConfig.ValueNode(3),
|
|
ListConfig.ValueNode(torch.tensor([7.0])),
|
|
]
|
|
|
|
|
|
def longformer_chunk(hidden_states, window_overlap=256):
|
|
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
|
|
|
|
# non-overlapping chunks of size = 2w
|
|
hidden_states = hidden_states.view(
|
|
hidden_states.size(0),
|
|
hidden_states.size(1) // (window_overlap * 2),
|
|
window_overlap * 2,
|
|
hidden_states.size(2),
|
|
)
|
|
|
|
# use `as_strided` to make the chunks overlap with an overlap size = window_overlap
|
|
chunk_size = list(hidden_states.size())
|
|
chunk_size[1] = chunk_size[1] * 2 - 1
|
|
|
|
chunk_stride = list(hidden_states.stride())
|
|
chunk_stride[1] = chunk_stride[1] // 2
|
|
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
|
|
|
|
|
|
class PartialT5(torch.nn.Module):
|
|
# Highly simplified T5Attention prefix
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.q = torch.nn.Linear(512, 512)
|
|
self.k = torch.nn.Linear(512, 512)
|
|
self.v = torch.nn.Linear(512, 512)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
key_value_states=None,
|
|
past_key_value=None,
|
|
query_length=None,
|
|
):
|
|
batch_size, seq_length = hidden_states.shape[:2]
|
|
|
|
real_seq_length = seq_length
|
|
|
|
if past_key_value is not None:
|
|
assert len(past_key_value) == 2, (
|
|
f"past_key_value should have 2 past states: keys and values. Got {len(past_key_value)} past states"
|
|
)
|
|
real_seq_length += (
|
|
past_key_value[0].shape[2] if query_length is None else query_length
|
|
)
|
|
|
|
def shape(states):
|
|
"""projection"""
|
|
return states.view(batch_size, -1, 8, 64).transpose(1, 2)
|
|
|
|
def project(hidden_states, proj_layer, key_value_states, past_key_value):
|
|
"""projects hidden states correctly to key/query states"""
|
|
if key_value_states is None:
|
|
# self-attn
|
|
# (batch_size, n_heads, seq_length, dim_per_head)
|
|
hidden_states = shape(proj_layer(hidden_states))
|
|
elif past_key_value is None:
|
|
# cross-attn
|
|
# (batch_size, n_heads, seq_length, dim_per_head)
|
|
hidden_states = shape(proj_layer(key_value_states))
|
|
|
|
if past_key_value is not None:
|
|
if key_value_states is None:
|
|
# self-attn
|
|
# (batch_size, n_heads, key_length, dim_per_head)
|
|
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
|
|
else:
|
|
# cross-attn
|
|
hidden_states = past_key_value
|
|
return hidden_states
|
|
|
|
# get query states
|
|
query_states = shape(
|
|
self.q(hidden_states)
|
|
) # (batch_size, n_heads, seq_length, dim_per_head)
|
|
|
|
# get key/value states
|
|
key_states = project(
|
|
hidden_states,
|
|
self.k,
|
|
key_value_states,
|
|
past_key_value[0] if past_key_value is not None else None,
|
|
)
|
|
value_states = project(
|
|
hidden_states,
|
|
self.v,
|
|
key_value_states,
|
|
past_key_value[1] if past_key_value is not None else None,
|
|
)
|
|
|
|
# compute scores
|
|
scores = torch.matmul(query_states, key_states.transpose(3, 2))
|
|
|
|
# (truncated here )
|
|
return scores, value_states
|
|
|
|
|
|
class ChunkReformerFeedForward(torch.nn.Module):
|
|
# simplified from HF modeling_reformer.py
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.layer_norm = torch.nn.LayerNorm(256, eps=1e-12)
|
|
self.dense = torch.nn.Linear(256, 256)
|
|
self.output = torch.nn.Linear(256, 256)
|
|
|
|
def forward(self, attention_output):
|
|
return apply_chunking_to_forward(
|
|
self.forward_chunk,
|
|
attention_output + 1,
|
|
)
|
|
|
|
def forward_chunk(self, hidden_states):
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
hidden_states = self.dense(hidden_states)
|
|
return self.output(hidden_states)
|
|
|
|
|
|
def apply_chunking_to_forward(forward_fn, *input_tensors):
|
|
# simplified from HF model_utils.py
|
|
assert len(input_tensors) > 0
|
|
tensor_shape = input_tensors[0].shape[1]
|
|
assert all(input_tensor.shape[1] == tensor_shape for input_tensor in input_tensors)
|
|
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
|
|
if num_args_in_forward_chunk_fn != len(input_tensors):
|
|
raise ValueError
|
|
|
|
return forward_fn(*input_tensors)
|
|
|
|
|
|
def _validate_model_kwargs(fn, model_kwargs):
|
|
# simplified from transformers.generation.utils._validate_model_kwargs
|
|
unused_model_args = []
|
|
model_args = set(inspect.signature(fn).parameters)
|
|
for key, value in model_kwargs.items():
|
|
if value is not None and key not in model_args:
|
|
unused_model_args.append(key)
|
|
if unused_model_args:
|
|
raise ValueError(
|
|
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
|
|
" generate arguments will also show up in this list)"
|
|
)
|
|
|
|
|
|
class FakeMamlInner(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(784, 5)
|
|
|
|
def forward(self, x, ignored=None, bn_training=False):
|
|
return self.linear(x.view(x.shape[0], -1))
|
|
|
|
|
|
class PartialMaml(torch.nn.Module):
|
|
# Highly simplified version of maml.meta.Meta.finetuning
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.net = FakeMamlInner()
|
|
self.update_step_test = 10
|
|
self.update_lr = 0.4
|
|
|
|
def forward(self, x_spt, y_spt, x_qry, y_qry):
|
|
querysz = x_qry.size(0)
|
|
|
|
corrects = [0 for _ in range(self.update_step_test + 1)]
|
|
|
|
# in order to not ruin the state of running_mean/variance and bn_weight/bias
|
|
# we finetuning on the copied model instead of self.net
|
|
net = deepcopy(self.net)
|
|
|
|
# 1. run the i-th task and compute loss for k=0
|
|
logits = net(x_spt)
|
|
loss = F.cross_entropy(logits, y_spt)
|
|
grad = torch.autograd.grad(loss, net.parameters())
|
|
fast_weights = [
|
|
p[1] - self.update_lr * p[0] for p in zip(grad, net.parameters())
|
|
]
|
|
|
|
# this is the loss and accuracy before first update
|
|
with torch.no_grad():
|
|
# [setsz, nway]
|
|
logits_q = net(x_qry, net.parameters(), bn_training=True)
|
|
# [setsz]
|
|
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
|
|
# scalar
|
|
correct = torch.eq(pred_q, y_qry).sum().item()
|
|
corrects[0] = corrects[0] + correct
|
|
|
|
# this is the loss and accuracy after the first update
|
|
with torch.no_grad():
|
|
# [setsz, nway]
|
|
logits_q = net(x_qry, fast_weights, bn_training=True)
|
|
# [setsz]
|
|
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
|
|
# scalar
|
|
correct = torch.eq(pred_q, y_qry).sum().item()
|
|
corrects[1] = corrects[1] + correct
|
|
|
|
del net
|
|
|
|
accs = torch.tensor(corrects) / querysz
|
|
|
|
return accs
|
|
|
|
|
|
def softmax_backward_data(parent, grad_output, output, dim, self):
|
|
from torch import _softmax_backward_data
|
|
|
|
return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
|
|
|
|
|
|
class XSoftmax(torch.autograd.Function):
|
|
# transformers.models.deberta.modeling_deberta.XSoftmax
|
|
@staticmethod
|
|
def forward(self, input, mask, dim):
|
|
self.dim = dim
|
|
rmask = ~(mask.to(torch.bool))
|
|
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
|
|
output = torch.softmax(output, self.dim)
|
|
output.masked_fill_(rmask, 0)
|
|
self.save_for_backward(output, rmask)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(self, grad_output):
|
|
output, _ = self.saved_tensors
|
|
inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
|
|
return inputGrad, None, None
|
|
|
|
|
|
class ModelOutput(collections.OrderedDict):
|
|
"""based on file_utils.py in HuggingFace"""
|
|
|
|
def __getitem__(self, k):
|
|
if isinstance(k, str):
|
|
inner_dict = dict(self.items())
|
|
return inner_dict[k]
|
|
else:
|
|
return self.to_tuple()[k]
|
|
|
|
def __setattr__(self, name, value):
|
|
if name in self.keys() and value is not None:
|
|
# Don't call self.__setitem__ to avoid recursion errors
|
|
super().__setitem__(name, value)
|
|
super().__setattr__(name, value)
|
|
|
|
def __setitem__(self, key, value):
|
|
# Will raise a KeyException if needed
|
|
super().__setitem__(key, value)
|
|
# Don't call self.__setattr__ to avoid recursion errors
|
|
super().__setattr__(key, value)
|
|
|
|
def to_tuple(self):
|
|
return tuple(self[k] for k in self.keys())
|
|
|
|
|
|
def create_rand_mask_from_inputs(
|
|
from_blocked_mask,
|
|
to_blocked_mask,
|
|
rand_attn,
|
|
num_attention_heads,
|
|
num_rand_blocks,
|
|
batch_size,
|
|
from_seq_length,
|
|
from_block_size,
|
|
):
|
|
"""taken from HF modeling_big_bird.py"""
|
|
num_windows = from_seq_length // from_block_size - 2
|
|
rand_mask = torch.stack(
|
|
[p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)]
|
|
)
|
|
rand_mask = rand_mask.view(
|
|
batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size
|
|
)
|
|
rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask)
|
|
return rand_mask
|
|
|
|
|
|
class SequentialAppendList(torch.nn.Sequential):
|
|
"""from timm/models/vovnet.py"""
|
|
|
|
def forward(self, x: torch.Tensor, concat_list: list[torch.Tensor]) -> torch.Tensor:
|
|
for i, module in enumerate(self):
|
|
if i == 0:
|
|
concat_list.append(module(x))
|
|
else:
|
|
concat_list.append(module(concat_list[-1]))
|
|
x = torch.cat(concat_list, dim=1)
|
|
return x, concat_list
|
|
|
|
|
|
class BatchNormAct2d(torch.nn.BatchNorm2d):
|
|
"""Taken from timm"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_features,
|
|
eps=1e-5,
|
|
momentum=0.1,
|
|
affine=True,
|
|
track_running_stats=True,
|
|
act_layer=torch.nn.ReLU,
|
|
inplace=True,
|
|
):
|
|
super().__init__(
|
|
num_features,
|
|
eps=eps,
|
|
momentum=momentum,
|
|
affine=affine,
|
|
track_running_stats=track_running_stats,
|
|
)
|
|
self.act = act_layer(inplace=inplace)
|
|
|
|
@torch.jit.ignore
|
|
def _forward_python(self, x):
|
|
return super().forward(x)
|
|
|
|
def forward(self, x):
|
|
if torch.jit.is_scripting():
|
|
x = self._forward_jit(x)
|
|
else:
|
|
x = self._forward_python(x)
|
|
x = self.act(x)
|
|
return x
|
|
|
|
|
|
def get_parameter_dtype(parameter):
|
|
"""from huggingface model_utils.py"""
|
|
try:
|
|
return next(parameter.parameters()).dtype
|
|
except StopIteration:
|
|
# For nn.DataParallel compatibility in PyTorch 1.5
|
|
|
|
def find_tensor_attributes(module):
|
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
|
return tuples
|
|
|
|
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
|
first_tuple = next(gen)
|
|
return first_tuple[1].dtype
|
|
|
|
|
|
class DummyConfig:
|
|
attn_layers = ["local", "lsh", "local", "lsh", "local", "lsh"]
|
|
lsh_attn_chunk_length = 64
|
|
local_attn_chunk_length = 64
|
|
|
|
|
|
def _get_min_chunk_len(config):
|
|
"""from hf_Reformer"""
|
|
attn_types = config.attn_layers
|
|
attn_types_set = set(attn_types)
|
|
if len(attn_types_set) == 1 and attn_types[0] == "lsh":
|
|
return config.lsh_attn_chunk_length
|
|
elif len(attn_types_set) == 1 and attn_types[0] == "local":
|
|
return config.local_attn_chunk_length
|
|
elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}:
|
|
return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select "
|
|
"attn layer types from ['lsh', 'local'] only."
|
|
)
|
|
|
|
|
|
def _stable_argsort(vector, dim):
|
|
"""from hf_Reformer"""
|
|
# this function scales the vector so that torch.argsort is stable.
|
|
# torch.argsort is not stable on its own
|
|
scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1)
|
|
scale_offset = scale_offset.expand(vector.shape)
|
|
scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim])
|
|
return torch.argsort(scaled_vector, dim=dim)
|
|
|
|
|
|
def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(buckets):
|
|
"""from hf_Reformer"""
|
|
# no gradients are needed
|
|
with torch.no_grad():
|
|
# hash-based sort
|
|
sorted_bucket_idx = _stable_argsort(buckets, dim=-1)
|
|
|
|
# create simple indices to scatter to, to have undo sort
|
|
indices = (
|
|
torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device)
|
|
.view(1, 1, -1)
|
|
.expand(sorted_bucket_idx.shape)
|
|
)
|
|
|
|
# get undo sort
|
|
undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size())
|
|
undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices)
|
|
|
|
return sorted_bucket_idx, undo_sorted_bucket_idx
|
|
|
|
|
|
class CustomList1(list):
|
|
def __call__(self, x):
|
|
for processor in self:
|
|
x = processor(x)
|
|
return x
|
|
|
|
def clear(self):
|
|
pass # this prevents RestrictedListSubclassVariable from kicking in
|
|
|
|
|
|
class CustomList2(list):
|
|
def __call__(self, x):
|
|
for processor in self:
|
|
x = processor(x)
|
|
return x
|
|
|
|
def length_times_10(self):
|
|
return len(self) * 10
|
|
|
|
def append_twice(self, x):
|
|
self.extend([x, x])
|
|
|
|
|
|
def _merge_criteria_processor_list(default_list, custom_list):
|
|
# simplified transformers/generation/utils.py
|
|
if len(custom_list) == 0:
|
|
return default_list
|
|
for default in default_list:
|
|
for custom in custom_list:
|
|
if type(custom) is type(default):
|
|
raise ValueError
|
|
default_list.extend(custom_list)
|
|
return default_list
|
|
|
|
|
|
class FeedForwardLayer(nn.Module):
|
|
def __init__(self, d_model, dim_feedforward, activation, dropout) -> None:
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
self.activation = activation
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
def forward(self, x):
|
|
return self.dropout2(
|
|
self.linear2(self.dropout1(self.activation(self.linear1(x))))
|
|
)
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
d_model,
|
|
nhead,
|
|
dim_feedforward=2048,
|
|
dropout=0.1,
|
|
activation=nn.ReLU(),
|
|
layer_norm_eps=1e-5,
|
|
):
|
|
super().__init__()
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.ff_block = FeedForwardLayer(d_model, dim_feedforward, activation, dropout)
|
|
|
|
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
|
x = src
|
|
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
|
|
x = self.norm2(x + self._ff_block(x))
|
|
return x
|
|
|
|
# self-attention block
|
|
def _sa_block(self, x, attn_mask, key_padding_mask):
|
|
x = self.self_attn(
|
|
x,
|
|
x,
|
|
x,
|
|
attn_mask=attn_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=False,
|
|
)[0]
|
|
return self.dropout(x)
|
|
|
|
# feed forward block
|
|
def _ff_block(self, x):
|
|
return self.ff_block(x)
|
|
|
|
|
|
class MockModule(torch.nn.Module):
|
|
def inner_fn(self, left, right):
|
|
return tuple(left) == tuple(right)
|
|
|
|
def fn(self, tensor):
|
|
if type(tensor) is int:
|
|
return False
|
|
|
|
torch.add(tensor, tensor)
|
|
return self.inner_fn(tensor.shape, (1, 2, 3))
|
|
|
|
|
|
class IncByOne:
|
|
def __init__(self, x):
|
|
self.x = x + 1
|
|
|
|
|
|
class IncByTwo:
|
|
def __init__(self, x):
|
|
self.x = x + 2
|
|
|
|
|
|
class LRUCacheWarningTests(LoggingTestCase):
|
|
@requires_cuda
|
|
@make_logging_test(dynamo=logging.DEBUG)
|
|
def test_lru_cache_warning_issued_during_tracing(self, records):
|
|
torch.set_default_device("cuda")
|
|
|
|
@torch.compile(backend="eager")
|
|
def f(x):
|
|
torch.get_device_module()
|
|
x = x.cos().sin()
|
|
return x
|
|
|
|
result = f(torch.randn(1024))
|
|
self.assertIsInstance(result, torch.Tensor)
|
|
|
|
for record in records:
|
|
if "call to a lru_cache wrapped function at:" in record.getMessage():
|
|
self.fail("lru_cache warning was incorrectly logged")
|
|
|
|
|
|
class ReproTests(torch._dynamo.test_case.TestCase):
|
|
def setUp(self) -> None:
|
|
try:
|
|
from .utils import install_guard_manager_testing_hook
|
|
except ImportError:
|
|
from utils import install_guard_manager_testing_hook
|
|
|
|
self.exit_stack = contextlib.ExitStack()
|
|
self.exit_stack.enter_context(
|
|
install_guard_manager_testing_hook(self.guard_manager_clone_hook_fn)
|
|
)
|
|
super().setUp()
|
|
|
|
def tearDown(self) -> None:
|
|
self.exit_stack.close()
|
|
super().tearDown()
|
|
|
|
def guard_manager_clone_hook_fn(self, guard_manager_wrapper, f_locals):
|
|
root = guard_manager_wrapper.root
|
|
cloned_root = root.clone_manager(lambda x: True)
|
|
cloned_wrapper = torch._dynamo.guards.GuardManagerWrapper(cloned_root)
|
|
self.assertEqual(str(guard_manager_wrapper), str(cloned_wrapper))
|
|
self.assertTrue(cloned_root.check(f_locals))
|
|
if guard_manager_wrapper.diff_guard_root:
|
|
self.assertTrue(guard_manager_wrapper.diff_guard_root.check(f_locals))
|
|
|
|
def test_do_paste_mask(self):
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt__do_paste_mask = torch.compile(_do_paste_mask, backend=cnt)
|
|
opt__do_paste_mask(
|
|
torch.randn(1, 1, 28, 28),
|
|
torch.tensor([[0.0, 1, 2, 4]]) * 1,
|
|
427,
|
|
640,
|
|
True,
|
|
)
|
|
opt__do_paste_mask(
|
|
torch.randn(1, 1, 28, 28),
|
|
torch.tensor([[0.0, 1, 2, 4]]) * 2,
|
|
427,
|
|
640,
|
|
True,
|
|
)
|
|
opt__do_paste_mask(
|
|
torch.randn(1, 1, 28, 28),
|
|
torch.tensor([[0.0, 1, 2, 4]]) * 3,
|
|
612,
|
|
612,
|
|
True,
|
|
)
|
|
opt__do_paste_mask(
|
|
torch.randn(1, 1, 28, 28),
|
|
torch.tensor([[0.0, 1, 2, 4]]) * 4,
|
|
612,
|
|
612,
|
|
True,
|
|
)
|
|
opt__do_paste_mask(
|
|
torch.randn(1, 1, 28, 28),
|
|
torch.tensor([[0.0, 1, 2, 4]]) * 2,
|
|
427,
|
|
640,
|
|
False,
|
|
)
|
|
# (dynamic shapes, static shapes)
|
|
self.assertIn(cnt.frame_count, (5, 7))
|
|
self.assertIn(cnt.op_count, (92, 106, 119))
|
|
|
|
def test_convert_boxes_to_pooler_format(self):
|
|
boxes1 = [
|
|
Boxes(torch.arange(0, 8).reshape((2, 4))),
|
|
Boxes(torch.arange(8, 16).reshape((2, 4))),
|
|
]
|
|
boxes2 = [
|
|
Boxes(torch.arange(16, 20).reshape((1, 4))),
|
|
Boxes(torch.arange(20, 24).reshape((1, 4))),
|
|
]
|
|
correct1 = convert_boxes_to_pooler_format(boxes1)
|
|
correct2 = convert_boxes_to_pooler_format(boxes2)
|
|
fn = convert_boxes_to_pooler_format
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
self.assertTrue(same(opt_fn(boxes1), correct1))
|
|
self.assertTrue(same(opt_fn(boxes2), correct2))
|
|
|
|
# repeat_interleave is a dynamic shape operator we do not execute/
|
|
# In the future, we could reduce the frame_count down to 1
|
|
# by guarding on the exact values of `Tensor repeats` arg
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """4""")
|
|
self.assertExpectedInline(cnt.op_count, """10""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """4""")
|
|
self.assertExpectedInline(cnt.op_count, """14""")
|
|
|
|
def test_boxes_len(self):
|
|
def fn(boxes):
|
|
return len(boxes) + boxes.__len__() + boxes.tensor
|
|
|
|
boxes1 = Boxes(torch.arange(0, 8).reshape((2, 4)))
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0))
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """1""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """2""")
|
|
|
|
def _reformer(self, nopython):
|
|
input = torch.randn([1, 64, 256])
|
|
model = ReformerEncoder()
|
|
torch.manual_seed(1337)
|
|
correct = copy.deepcopy(model)(input)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
torch.manual_seed(1337)
|
|
opt_model = torch.compile(model, backend=cnt, fullgraph=nopython)
|
|
self.assertTrue(same(opt_model(input), correct))
|
|
return cnt
|
|
|
|
# https://github.com/pytorch/pytorch/issues/113010
|
|
def test_out_overload_non_contiguous(self):
|
|
def f(x, y):
|
|
return torch.abs(x, out=y.T)
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
|
|
x_ref = torch.arange(4, dtype=torch.float32).reshape(2, 2)
|
|
y_ref = torch.arange(4, dtype=torch.float32).reshape(2, 2)
|
|
x_test = torch.arange(4, dtype=torch.float32).reshape(2, 2)
|
|
y_test = torch.arange(4, dtype=torch.float32).reshape(2, 2)
|
|
|
|
out_ref = f(x_ref, y_ref)
|
|
out_test = f_compiled(x_test, y_test)
|
|
self.assertEqual(out_ref, out_test)
|
|
self.assertEqual(y_ref, y_test)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/109053
|
|
def test_view_dtype_overload(self):
|
|
def f(x):
|
|
return x.view(torch.int32)
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
|
|
x1 = torch.ones(4, requires_grad=True)
|
|
out_ref = f(x1)
|
|
out_test = f_compiled(x1)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
x2 = torch.ones(4, requires_grad=False)
|
|
out_ref = f(x2)
|
|
out_test = f_compiled(x2)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/90552
|
|
def test_intermediate_leaf_requires_grad(self):
|
|
def f(x):
|
|
leaf = torch.ones(2, requires_grad=True)
|
|
return leaf, leaf * 2
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
x = torch.arange(4, dtype=torch.float32).reshape(2, 2)
|
|
|
|
leaf, out = f(x)
|
|
leaf_test, out_test = f_compiled(x)
|
|
out.sum().backward()
|
|
out_test.sum().backward()
|
|
self.assertEqual(leaf.grad, leaf_test.grad)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/113263
|
|
def test_unpack_hooks_dont_run_during_tracing(self):
|
|
def f(x, y):
|
|
return x * y
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
|
|
pack_count = 0
|
|
unpack_count = 0
|
|
|
|
def pack_hook(x):
|
|
nonlocal pack_count
|
|
pack_count += 1
|
|
return x
|
|
|
|
# unpack hook shouldn't run during compilation, while we trace the forward
|
|
def unpack_hook(x):
|
|
nonlocal unpack_count
|
|
unpack_count += 1
|
|
return x
|
|
|
|
x = torch.ones(4, requires_grad=True)
|
|
y = torch.ones(4, requires_grad=False)
|
|
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
|
|
out_test = f_compiled(x, y)
|
|
self.assertEqual(pack_count, 1)
|
|
self.assertEqual(unpack_count, 0)
|
|
out_test.sum().backward()
|
|
self.assertEqual(pack_count, 1)
|
|
self.assertEqual(unpack_count, 1)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/113263
|
|
def test_unpack_hooks_can_be_disabled(self):
|
|
def f(x, y):
|
|
return x * y
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
|
|
x = torch.ones(4, requires_grad=True)
|
|
y = torch.ones(4, requires_grad=False)
|
|
with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"):
|
|
out_test = f_compiled(x, y)
|
|
out_test.sum().backward()
|
|
|
|
# https://github.com/pytorch/pytorch/issues/113263
|
|
def test_disabling_unpack_hooks_within_compiled_region(self):
|
|
def g(z):
|
|
with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"):
|
|
return z + 5
|
|
|
|
def f(x, y):
|
|
z = x * y
|
|
return g(z)
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
|
|
x = torch.ones(4, requires_grad=True)
|
|
y = torch.ones(4, requires_grad=False)
|
|
out_test = f_compiled(x, y)
|
|
out_test.sum().backward()
|
|
|
|
# See https://github.com/pytorch/pytorch/issues/97745
|
|
def test_gan_repro_trying_to_backward_through_the_graph_a_second_time(self):
|
|
def f(a, b):
|
|
c = torch.ones(2, 2)
|
|
d = torch.ones(2, 2)
|
|
e = torch.matmul(a, c)
|
|
g_loss = torch.abs(e - d).mean()
|
|
g_loss.backward()
|
|
fake_d_pred = torch.matmul(b, e.detach())
|
|
d_loss = fake_d_pred.mean()
|
|
d_loss.backward()
|
|
|
|
a_ref = torch.randn(2, 2, requires_grad=True)
|
|
b_ref = torch.randn(2, 2, requires_grad=True)
|
|
out_ref = f(a_ref, b_ref)
|
|
|
|
a_test = a_ref.detach().clone().requires_grad_(True)
|
|
b_test = b_ref.detach().clone().requires_grad_(True)
|
|
out_test = torch.compile(f, backend="aot_eager")(a_test, b_test)
|
|
|
|
self.assertEqual(out_ref, out_test)
|
|
self.assertEqual(a_ref.grad, a_test.grad)
|
|
self.assertEqual(b_ref.grad, b_test.grad)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/111603
|
|
def test_tuple_enum_as_key_dict(self):
|
|
class MyEnum(Enum):
|
|
A = "a"
|
|
|
|
class SomeModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(1, 1)
|
|
|
|
def forward(self, x) -> torch.Tensor:
|
|
return self.linear(x[MyEnum.A])
|
|
|
|
x = {MyEnum.A: torch.rand(8, 1)}
|
|
model_pytorch = SomeModel()
|
|
model = torch.compile(model_pytorch)
|
|
# Executing twice works
|
|
model(x)
|
|
y = model(x)
|
|
self.assertEqual(y, model_pytorch(x))
|
|
|
|
def test_embedding_backward_broadcasting_decomp(self):
|
|
def f(grad_output, indices):
|
|
num_weights = 10
|
|
padding_idx = 1
|
|
scale_grad_by_freq = True
|
|
return torch.ops.aten.embedding_dense_backward(
|
|
grad_output, indices, num_weights, padding_idx, scale_grad_by_freq
|
|
)
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
|
|
grad_output = torch.ones(2, 4, 3, dtype=torch.float16)
|
|
indices = torch.ones(2, 4, dtype=torch.int64)
|
|
|
|
out_ref = f(grad_output, indices)
|
|
out_test = f_compiled(grad_output, indices)
|
|
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
def test_reformer_eval(self):
|
|
with torch.no_grad():
|
|
cnt = self._reformer(nopython=True)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 10)
|
|
|
|
def test_reformer_train(self):
|
|
with torch.enable_grad():
|
|
cnt = self._reformer(nopython=False)
|
|
expected_op_count = (
|
|
"""10""" if torch._dynamo.config.inline_inbuilt_nn_modules else """4"""
|
|
)
|
|
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, expected_op_count)
|
|
|
|
def test_longformer_chunk(self):
|
|
input1 = torch.randn([1, 4096, 1])
|
|
input2 = torch.randn([12, 4096, 64])
|
|
correct1 = longformer_chunk(input1)
|
|
correct2 = longformer_chunk(input2)
|
|
fn = longformer_chunk
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertTrue(same(opt_fn(input1), correct1))
|
|
self.assertTrue(same(opt_fn(input2), correct2))
|
|
self.assertTrue(same(opt_fn(input1), correct1))
|
|
self.assertTrue(same(opt_fn(input2), correct2))
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
if torch._dynamo.config.automatic_dynamic_shapes:
|
|
self.assertExpectedInline(cnt.frame_count, """2""")
|
|
self.assertExpectedInline(cnt.op_count, """8""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """2""")
|
|
self.assertExpectedInline(cnt.op_count, """4""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """2""")
|
|
self.assertExpectedInline(cnt.op_count, """19""")
|
|
|
|
def test_hf_t5_forward(self):
|
|
input = torch.randn([1, 2048, 512])
|
|
model = PartialT5()
|
|
correct = model(input)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch._dynamo.optimize_assert(cnt)(model)
|
|
self.assertTrue(same(opt_model(input), correct))
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """11""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """11""")
|
|
|
|
def test_module_in_skipfiles(self):
|
|
model = nn.Linear(10, 10)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
torch.compile(model, backend=cnt, fullgraph=True)(torch.randn([5, 10]))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_function_in_skipfiles(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
torch.compile(torch.sin, backend=cnt, fullgraph=True)(torch.randn([5, 10]))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_slicing_dynamic_shape(self):
|
|
def fn(y):
|
|
x = torch.ones(8)
|
|
idx = y[0]
|
|
out = x[idx:]
|
|
return (out + 3) * 5
|
|
|
|
counter = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=counter)
|
|
out = opt_fn(torch.ones(10, dtype=torch.long))
|
|
# idx should be 1 -> slicing off [1:] of 8 elem tensor
|
|
self.assertEqual(list(out.shape), [7])
|
|
|
|
self.assertEqual(counter.op_count, 2)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
self.assertEqual(list(opt_fn(torch.tensor([4])).shape), [4])
|
|
|
|
def test_slicing_dynamic_shape_setitem(self):
|
|
def fn(input_lengths: torch.Tensor, new_ones_1):
|
|
getitem_13 = input_lengths[3]
|
|
new_ones_1[(3, slice(getitem_13, None, None))] = 0
|
|
setitem_13 = new_ones_1
|
|
return (setitem_13,)
|
|
|
|
x = torch.randn(10).to(dtype=torch.int64)
|
|
y = torch.randn(10, 204)
|
|
ref = fn(x, y)
|
|
opt_fn = torch.compile(fn, backend="aot_eager")
|
|
res = opt_fn(x, y)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
@torch._dynamo.config.patch(error_on_recompile=True)
|
|
@torch.fx.experimental._config.patch(use_duck_shape=False)
|
|
def test_dynamic_shape_disable_duck_size(self):
|
|
# noqa: F841
|
|
|
|
class TestModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor, val: int) -> torch.Tensor:
|
|
return x + val
|
|
|
|
main_model = TestModel().to(memory_format=torch.channels_last)
|
|
opt_model = torch.compile(main_model, backend="eager", dynamic=True)
|
|
|
|
x1 = torch.rand(2, 5, 10, 10).to(memory_format=torch.channels_last)
|
|
x2 = torch.rand(2, 5, 4, 8).to(memory_format=torch.channels_last)
|
|
|
|
main_model(x1, 4)
|
|
opt_model(x1, 4)
|
|
|
|
main_model(x2, 20)
|
|
opt_model(x2, 20)
|
|
|
|
def test_chunk_reformer_ff(self):
|
|
input = torch.randn([1, 4096, 256])
|
|
model = ChunkReformerFeedForward()
|
|
correct = model(input)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch._dynamo.optimize_assert(cnt)(model)
|
|
self.assertTrue(same(opt_model(input), correct))
|
|
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertLessEqual(cnt.op_count, 10)
|
|
|
|
# see: https://github.com/pytorch/pytorch/issues/80067
|
|
# NB: When you remove the expectedFailure, don't forget to
|
|
# uncomment/adjust the assertEqual below
|
|
@unittest.expectedFailure
|
|
@torch._dynamo.config.patch(
|
|
fake_tensor_propagation=True, capture_scalar_outputs=True
|
|
)
|
|
def test_maml_item_capture(self):
|
|
a = torch.randn(5, 1, 28, 28)
|
|
b = torch.zeros(5, dtype=torch.int64)
|
|
c = torch.randn(75, 1, 28, 28)
|
|
d = torch.zeros(75, dtype=torch.int64)
|
|
model = PartialMaml()
|
|
correct = model(a, b, c, d)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch.compile(model, backend=cnt)
|
|
for _ in range(10):
|
|
self.assertTrue(same(opt_model(a, b, c, d), correct))
|
|
|
|
# if torch._dynamo.config.assume_static_by_default:
|
|
# self.assertExpectedInline(cnt.frame_count, """2""")
|
|
# else:
|
|
# self.assertExpectedInline(cnt.frame_count, """3""")
|
|
# TODO(jansel): figure out why op count depends on imports
|
|
self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27))
|
|
|
|
# see: https://github.com/pytorch/pytorch/issues/80067
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=False)
|
|
def test_maml_no_item_capture(self):
|
|
a = torch.randn(5, 1, 28, 28)
|
|
b = torch.zeros(5, dtype=torch.int64)
|
|
c = torch.randn(75, 1, 28, 28)
|
|
d = torch.zeros(75, dtype=torch.int64)
|
|
model = PartialMaml()
|
|
correct = model(a, b, c, d)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch.compile(model, backend=cnt)
|
|
for _ in range(10):
|
|
self.assertTrue(same(opt_model(a, b, c, d), correct))
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """2""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """3""")
|
|
|
|
def test_hf_model_output(self):
|
|
ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10))
|
|
|
|
def fn1(x):
|
|
return x["a"] + 1
|
|
|
|
def fn2(x):
|
|
return x.a + 1
|
|
|
|
def fn3(x):
|
|
return x.to_tuple()[0] + 1
|
|
|
|
def fn4(x):
|
|
return x[0] + 1
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
for fn in (fn1, fn2, fn3, fn4):
|
|
cnt.clear()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertTrue(same(opt_fn(ex), ex.a + 1))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_create_rand_mask_from_inputs(self):
|
|
args = [
|
|
torch.randn([1, 64, 64]),
|
|
torch.randn([1, 64, 64]),
|
|
torch.zeros([1, 12, 62, 3], dtype=torch.int64),
|
|
12,
|
|
3,
|
|
1,
|
|
4096,
|
|
64,
|
|
]
|
|
correct = create_rand_mask_from_inputs(*args)
|
|
fn = create_rand_mask_from_inputs
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertTrue(same(opt_fn(*args), correct))
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """8""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """11""")
|
|
|
|
def test_rng_state(self):
|
|
def fn():
|
|
state = torch.get_rng_state()
|
|
before = torch.rand(1000)
|
|
torch.set_rng_state(state)
|
|
after = torch.rand(1000)
|
|
return before, after
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
|
|
before, after = opt_fn()
|
|
self.assertTrue(same(before, after))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 2) # rand, rand
|
|
try:
|
|
_, _ = torch._dynamo.export(fn)()
|
|
# See https://github.com/pytorch/pytorch/pull/87490
|
|
self.fail("unexpected export success")
|
|
except torch._dynamo.exc.Unsupported:
|
|
pass
|
|
|
|
def test_threading_local(self):
|
|
import threading
|
|
|
|
foo = threading.local()
|
|
foo.x = torch.rand(1)
|
|
|
|
def f(x):
|
|
return torch.cat([x, foo.x])
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_f = torch.compile(f, backend=cnt, fullgraph=True)
|
|
|
|
inp = torch.ones(1)
|
|
out = f(inp)
|
|
opt_out = opt_f(inp)
|
|
self.assertEqual(opt_out, out)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_seq_append_list(self):
|
|
x = torch.randn(4, 10)
|
|
model = SequentialAppendList(
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
)
|
|
# this one is tricky because it mutates the list provided as an input
|
|
l1 = [x]
|
|
l2 = [x]
|
|
correct, _ = model(x, l1)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch._dynamo.optimize_assert(cnt)(model)
|
|
result, l3 = opt_model(x, l2)
|
|
self.assertTrue(same(result, correct))
|
|
self.assertTrue(same(l1, l2))
|
|
self.assertIs(l2, l3)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 5)
|
|
|
|
def test_batch_norm_act(self):
|
|
a = torch.randn(5, 1, 28, 28)
|
|
model = BatchNormAct2d(1).eval()
|
|
correct = model(a)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
if not torch._dynamo.config.specialize_int:
|
|
# _local_scalar_dense causes graph break w 0-dim tensor
|
|
opt_model = torch.compile(model, backend=cnt)
|
|
self.assertTrue(same(opt_model(a), correct))
|
|
return
|
|
|
|
opt_model = torch._dynamo.optimize_assert(cnt)(model)
|
|
self.assertTrue(same(opt_model(a), correct))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
def test_get_parameter_dtype(self):
|
|
model = SequentialAppendList(
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
)
|
|
|
|
def fn(model, x):
|
|
return x + torch.randn(10, dtype=get_parameter_dtype(model))
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertEqual(opt_fn(model, torch.randn(10)).dtype, torch.float32)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
def test_nn_parameter(self):
|
|
def test_fn():
|
|
a = torch.nn.Parameter(torch.randn(5, 5))
|
|
# Checks that TensorVariable stores the type information correctly
|
|
self.assertTrue(isinstance(a, torch.nn.Parameter))
|
|
return a
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_test_fn = torch.compile(test_fn, backend=cnt)
|
|
out = opt_test_fn()
|
|
self.assertTrue(isinstance(out, torch.nn.Parameter))
|
|
|
|
def test_Size(self):
|
|
def test_fn():
|
|
a = torch.randn(4)
|
|
x = torch.Size([1, 2, 3])
|
|
# Checks that SizeVariable return torch.Size object
|
|
assert isinstance(x, torch.Size)
|
|
# Causes graph breaks and checks reconstruction of SizeVariable
|
|
# object
|
|
self.assertIsInstance(x, torch.Size)
|
|
return a
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_test_fn = torch.compile(test_fn, backend=cnt)
|
|
opt_test_fn()
|
|
|
|
# See https://github.com/pytorch/pytorch/issues/100067
|
|
def test_copy_weird_strides(self):
|
|
# This test requires inductor's copy() decomp to preserve strides properly.
|
|
def test_fn(a):
|
|
b = torch.zeros(48, 4, 256, 513)
|
|
b[:, 0, 1:256, 1:256] = a
|
|
c = b.view(4, 12, 1024, 513)
|
|
d = c.transpose(2, 1)
|
|
d.add_(1)
|
|
return d
|
|
|
|
sh, st, dt, dev, rg = (
|
|
(48, 255, 255),
|
|
(787968, 513, 1),
|
|
torch.float16,
|
|
"cpu",
|
|
True,
|
|
)
|
|
a = rand_strided(sh, st, dt, dev).requires_grad_(rg)
|
|
compiled_f = torch.compile(test_fn, backend="aot_eager_decomp_partition")
|
|
out1 = test_fn(a)
|
|
out2 = compiled_f(a)
|
|
self.assertEqual(out1, out2)
|
|
|
|
def test_indexing_with_list(self):
|
|
def test_fn():
|
|
def run_test(tensor, *idx):
|
|
npt = tensor.numpy()
|
|
assert npt[idx].shape == tensor[idx].shape
|
|
|
|
x = torch.arange(0, 10)
|
|
cases = [
|
|
[None, None],
|
|
[1, None],
|
|
]
|
|
|
|
for case in cases:
|
|
run_test(x, *case)
|
|
|
|
return torch.randn(4)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_test_fn = torch.compile(test_fn, backend=cnt)
|
|
opt_test_fn()
|
|
|
|
def test_foreach_decomp_arg_names(self):
|
|
# https://github.com/pytorch/pytorch/issues/138698
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def foreach_pow(**kwargs):
|
|
return torch._foreach_pow(**kwargs)
|
|
|
|
foreach_pow(self=[torch.ones(2, 2, device="cpu")], exponent=2.7)
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def foreach_lerp_(**kwargs):
|
|
return torch._foreach_lerp_(**kwargs)
|
|
|
|
foreach_lerp_(
|
|
self=[torch.ones(2, 2, device="cpu")],
|
|
tensors1=[torch.ones(2, 2, device="cpu")],
|
|
weights=[torch.ones(2, 2, device="cpu")],
|
|
)
|
|
|
|
def test_reformer_min_chunk_len(self):
|
|
def fn(cfg):
|
|
t = torch.empty(10)
|
|
t.fill_(_get_min_chunk_len(cfg))
|
|
return t[0]
|
|
|
|
cfg = DummyConfig()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertEqual(opt_fn(cfg), 64)
|
|
# With unspec int, maximum computation is preserved
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
if torch._dynamo.config.automatic_dynamic_shapes:
|
|
if not torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.op_count, """4""")
|
|
else:
|
|
self.assertExpectedInline(cnt.op_count, """3""")
|
|
else:
|
|
self.assertExpectedInline(cnt.op_count, """3""")
|
|
|
|
def test_reformer_sorting(self):
|
|
x = torch.zeros([1, 12, 4096], dtype=torch.int64)
|
|
correct = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(x)
|
|
fn = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertTrue(same(opt_fn(x), correct))
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """14""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """16""")
|
|
|
|
def test_recursive_map(self):
|
|
# https://github.com/pytorch/torchdynamo/issues/132
|
|
def _recursive_map(struct, batch_dim=0):
|
|
for k, v in struct.items():
|
|
if v is not None:
|
|
if isinstance(v, dict):
|
|
_recursive_map(v)
|
|
else:
|
|
struct[k] = v
|
|
|
|
def toy_example(a, b, v):
|
|
x = a / (torch.abs(a) + 1)
|
|
if v is not None:
|
|
_recursive_map(v)
|
|
return x * b
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_toy_example = torch.compile(toy_example, backend=cnt)
|
|
opt_toy_example(
|
|
torch.randn(10),
|
|
torch.randn(10),
|
|
{"layer0": {"memory_keys": torch.randn(10)}},
|
|
)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 4)
|
|
|
|
def test_issue114171(self):
|
|
device = torch.device("cpu")
|
|
|
|
def fcnn(in_dim, out_dim, hidden_dim, activation=torch.nn.GELU):
|
|
layers = [
|
|
torch.nn.Linear(in_dim, hidden_dim, device=device),
|
|
activation(),
|
|
torch.nn.Linear(hidden_dim, out_dim, device=device),
|
|
]
|
|
return torch.nn.Sequential(*layers)
|
|
|
|
class testmodel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.interaction_networks = torch.nn.ModuleList(
|
|
[fcnn(262, 1174, 400) for _ in range(4)]
|
|
)
|
|
|
|
def interact(self, x, cycle):
|
|
return self.interaction_networks[cycle](x)
|
|
|
|
model = testmodel()
|
|
forward_aot = torch.compile(
|
|
model.interact, fullgraph=True, dynamic=True, backend="eager"
|
|
)
|
|
|
|
x = torch.rand([111, 262], device=device)
|
|
forward_aot(x, 2) # previously failed
|
|
|
|
def test_issue175(self):
|
|
n_heads = 2
|
|
d_model = 64
|
|
model = TransformerEncoderLayer(d_model, n_heads)
|
|
inp = torch.randn(1, d_model)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch.compile(model, backend=cnt, fullgraph=True)
|
|
opt_model(inp)
|
|
opt_model(inp)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(12, cnt.op_count)
|
|
|
|
def test_exec_import(self):
|
|
def fn1():
|
|
exec("import math")
|
|
|
|
def fn2():
|
|
try:
|
|
math.sqrt(4)
|
|
return False
|
|
except NameError:
|
|
return True
|
|
|
|
def fn3():
|
|
fn1()
|
|
return fn2()
|
|
|
|
self.assertTrue(fn3())
|
|
opt_fn3 = torch.compile(fn3, backend="eager")
|
|
self.assertTrue(opt_fn3())
|
|
|
|
def test_exec_wildcard_import(self):
|
|
# Test that globals are not carried over from frame to frame
|
|
def fn1():
|
|
exec("from torch import *")
|
|
|
|
def fn2():
|
|
x = torch.zeros(4)
|
|
for i in range(5):
|
|
x = x + i
|
|
return x
|
|
|
|
def fn3():
|
|
fn1()
|
|
return fn2()
|
|
|
|
ref = fn3()
|
|
opt_fn3 = torch.compile(fn3, backend="eager")
|
|
res = opt_fn3()
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_with_on_graph_break_inst(self):
|
|
def reversible(x):
|
|
print("Hello world") # Cause graph break so inline fails
|
|
return torch.sin(torch.cos(x))
|
|
|
|
def fn(x):
|
|
with torch.enable_grad():
|
|
a = torch.sin(x)
|
|
b = reversible(a)
|
|
c = torch.sigmoid(b)
|
|
c.sum().backward()
|
|
return x.grad
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
x.grad = None
|
|
with torch.no_grad():
|
|
ref = fn(x)
|
|
|
|
x.grad = None
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
with torch.no_grad():
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_with_on_graph_break_nested(self):
|
|
def reversible(x):
|
|
torch._dynamo.graph_break() # Cause graph break so inline fails
|
|
return torch.sin(torch.cos(x))
|
|
|
|
def fn(x):
|
|
# nested context manager failed previously
|
|
with torch.no_grad():
|
|
with torch.enable_grad():
|
|
a = torch.sin(x)
|
|
b = reversible(a)
|
|
c = torch.sigmoid(b)
|
|
c.sum().backward()
|
|
return x.grad
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
x.grad = None
|
|
with torch.no_grad():
|
|
ref = fn(x)
|
|
|
|
x.grad = None
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
with torch.no_grad():
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
# https://github.com/pytorch/torchdynamo/issues/1446
|
|
def test_grad_mode_carrying_correct_state_after_graph_break(self):
|
|
def fn(x):
|
|
with torch.no_grad():
|
|
y = x * 3
|
|
print("Break")
|
|
z = x + 2
|
|
return y, z
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
y, z = opt_fn(x)
|
|
self.assertFalse(y.requires_grad)
|
|
self.assertFalse(z.requires_grad)
|
|
|
|
def test_abc_setattr(self):
|
|
# tests that we correctly bail out of __setattr__ calls
|
|
|
|
# TODO: does not ensure ABC classes are correctly inferred as ClassVariables
|
|
# (doesn't test the fix for 'super()')
|
|
|
|
class BaseModule(torch.nn.Module, ABC):
|
|
def blah(self, x):
|
|
return x + 1
|
|
|
|
class Derived(BaseModule):
|
|
def __setattr__(self, name, value) -> None:
|
|
super().__setattr__(name, value)
|
|
|
|
def forward(self, x):
|
|
# expect a graph break on __setattr__
|
|
self.foo = 0
|
|
return self.blah(x)
|
|
|
|
def blah(self, x):
|
|
return super().blah(x)
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
mod = Derived()
|
|
opt_mod = torch.compile(mod, backend="eager")
|
|
opt_mod(x)
|
|
|
|
# Not sure what this test is testing. It was earlier graph breaking on
|
|
# __dict__, so the counter >= 2. With __dict__ support, there is no
|
|
# graph break.
|
|
self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
|
self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["total"], 1)
|
|
|
|
@torch._dynamo.config.patch("suppress_errors", True)
|
|
def test_guard_fail_tensor_bool(self):
|
|
@torch._dynamo.disable(recursive=False)
|
|
def fn():
|
|
condition_shape = (5, 5)
|
|
dtypes = (torch.bool,)
|
|
shapes = (
|
|
(),
|
|
(5,),
|
|
(1, 5),
|
|
)
|
|
|
|
tensors = [
|
|
torch.empty(shape, dtype=dtype).fill_(17)
|
|
for shape, dtype in itertools.product(shapes, dtypes)
|
|
]
|
|
|
|
x_vals = (5.0, *tensors)
|
|
y_vals = (6.0, *tensors)
|
|
|
|
@torch._dynamo.disable
|
|
def get_expected(condition, x, y):
|
|
x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x
|
|
y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y
|
|
return torch.from_numpy(
|
|
np.where(condition.cpu().numpy(), x_np, y_np)
|
|
).to(common_dtype)
|
|
|
|
for x, y in zip(x_vals, y_vals):
|
|
condition = torch.empty(*condition_shape, dtype=torch.bool).bernoulli_()
|
|
common_dtype = torch.result_type(x, y)
|
|
|
|
def check_equal(condition, x, y):
|
|
# NumPy aggressively promotes to double, hence cast to output to correct dtype
|
|
expected = get_expected(condition, x, y)
|
|
result = torch.where(condition, x, y)
|
|
assert torch.allclose(expected, result)
|
|
|
|
check_equal(condition, x, y)
|
|
check_equal(condition, y, x)
|
|
|
|
fn()
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
opt_fn()
|
|
|
|
def test_guard_fail_nested_tuple(self):
|
|
def fn(args):
|
|
return torch.ones(()), args[0] * 2
|
|
|
|
# This adds a tensor check on args[1][0] and args[1][1]
|
|
args1 = (torch.ones(1), (torch.ones(1), torch.ones(1)))
|
|
args2 = (torch.ones(1), torch.ones(1))
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
ref = opt_fn(args1)
|
|
res = opt_fn(args2)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_nullcontext1(self):
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def fn(x, ctx):
|
|
x = x.sin()
|
|
with ctx:
|
|
x = x.cos()
|
|
x = x.sin()
|
|
return x
|
|
|
|
y = torch.randn(10)
|
|
self.assertTrue(same(fn(y, contextlib.nullcontext()), y.sin().cos().sin()))
|
|
|
|
def test_nullcontext2(self):
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def fn(x, ctx):
|
|
x = x.sin()
|
|
with ctx():
|
|
x = x.cos()
|
|
x = x.sin()
|
|
return x
|
|
|
|
y = torch.randn(10)
|
|
self.assertTrue(same(fn(y, contextlib.nullcontext), y.sin().cos().sin()))
|
|
|
|
def test_no_grad_inline(self):
|
|
@torch.no_grad()
|
|
def a(x):
|
|
return x.sin()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def b(x):
|
|
return a(x).cos()
|
|
|
|
y = torch.randn(10)
|
|
self.assertTrue(same(b(y), y.sin().cos()))
|
|
|
|
@skipIfWindows(
|
|
msg="torch._dynamo.exc.TorchRuntimeError: Failed running call_function <class 'torch.LongTensor'>(*(FakeTensor(..., size=(10,), dtype=torch.int32),), **{}):" # noqa: B950
|
|
)
|
|
def test_longtensor_list(self):
|
|
for partition in [0, 5, 10]:
|
|
|
|
@torch._dynamo.disable
|
|
def rand_gen():
|
|
rand_vals = [random.randint(5, 10) for _ in range(10)]
|
|
# List of tensors mixed with np.arrays
|
|
return list(np.array(rand_vals[:partition])) + [
|
|
torch.tensor(val) for val in rand_vals[partition:]
|
|
]
|
|
|
|
def fn(x):
|
|
random_list = rand_gen()
|
|
z = torch.LongTensor(random_list)
|
|
return x * z
|
|
|
|
x = torch.ones(10) * 2
|
|
|
|
random.seed(0)
|
|
ref0 = fn(x)
|
|
ref1 = fn(x)
|
|
|
|
random.seed(0)
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
res0 = opt_fn(x)
|
|
res1 = opt_fn(x)
|
|
|
|
self.assertTrue(same(ref0, res0))
|
|
self.assertTrue(same(ref1, res1))
|
|
|
|
def test_primtorch(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
torch._refs.abs(x)
|
|
|
|
fn(torch.randn(3))
|
|
|
|
@unittest.expectedFailure
|
|
# inline_call [('inline in skipfiles: bind ...python3.10/inspect.py', 1)]
|
|
def test_primtorch_no_graph_break(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
torch._refs.abs(x)
|
|
|
|
fn(torch.randn(3))
|
|
|
|
def test_torch_tensor_ops_no_graph_break(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
torch.Tensor.abs_(x)
|
|
|
|
fn(torch.randn(3))
|
|
|
|
@unittest.skipIf(
|
|
not isinstance(torch.ops.aten.abs, torch._ops.OpOverloadPacket),
|
|
"old pt doesn't work",
|
|
)
|
|
def test_torch_ops_aten(self):
|
|
# Picked an op that doesn't show up in the default list
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return torch.ops.aten.absolute(x)
|
|
|
|
fn(torch.randn(3))
|
|
|
|
def test_hf_gelu_inline(self):
|
|
class GELUActivation(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.act = nn.functional.gelu
|
|
|
|
def forward(self, input):
|
|
return self.act(input)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return GELUActivation()(x)
|
|
|
|
y = torch.randn(10)
|
|
self.assertTrue(same(fn(y), nn.functional.gelu(y)))
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn_returns(x):
|
|
return GELUActivation(), x + 1
|
|
|
|
act, _ = fn_returns(y)
|
|
self.assertIsInstance(act, GELUActivation)
|
|
self.assertIs(act.act, nn.functional.gelu)
|
|
self.assertTrue(hasattr(act, "_buffers")) # check that __init__ got called
|
|
|
|
def test_dropout_inline(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return torch.nn.Dropout(0.1)(x)
|
|
|
|
y = torch.randn(10)
|
|
torch.manual_seed(1337)
|
|
ref = nn.functional.dropout(y, 0.1)
|
|
torch.manual_seed(1337)
|
|
res = fn(y)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_setitem_boolean_mask_diff(self):
|
|
def fn(x, b, y):
|
|
x = x.clone()
|
|
x[b] = y
|
|
return x
|
|
|
|
opt_fn = torch.compile(fn, backend="aot_eager")
|
|
x = torch.randn(4, requires_grad=True)
|
|
b = torch.tensor([True, False, True, False])
|
|
y = torch.randn(2, requires_grad=True)
|
|
opt_fn(x, b, y)
|
|
|
|
def test_setitem_tuple_boolean_mask_diff(self):
|
|
def fn(x, b, y):
|
|
x = x.clone()
|
|
x[:, b] = y
|
|
return x
|
|
|
|
opt_fn = torch.compile(fn, backend="aot_eager")
|
|
x = torch.randn(8, 4, requires_grad=True)
|
|
b = torch.tensor([True, False, True, False])
|
|
y = torch.randn(2, requires_grad=True)
|
|
opt_fn(x, b, y)
|
|
|
|
def test_torch_tensor_ops(self):
|
|
def fn(x):
|
|
return torch.Tensor.abs_(x)
|
|
|
|
x = torch.randn(3)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
y = fn(x)
|
|
y_ = opt_fn(x)
|
|
self.assertTrue(same(y, y_))
|
|
|
|
def test_guard_ordering_shape_fail(self):
|
|
# If a function which takes a tensor has an inner function which
|
|
# is compiled and generates a guard on its shape,
|
|
# they are evaluated in the wrong order. So if on a subsequent call
|
|
# an int is passed instead of a tensor, guard evaluation will crash
|
|
# with a "no attribute: shape" error
|
|
m = MockModule()
|
|
opt_m = torch.compile(m, backend="eager")
|
|
opt_m.fn(torch.ones((5, 5)))
|
|
opt_m.fn(-3)
|
|
|
|
def test_tensor_isinstance_tuple(self):
|
|
@torch.compile(backend="eager")
|
|
def fn():
|
|
t = torch.ones(5, 5)
|
|
if not isinstance(t, (int, torch.Tensor)):
|
|
msg = str.format(
|
|
"{0} is not an instance of {1}",
|
|
type(t),
|
|
(int, torch.Tensor),
|
|
)
|
|
raise ValueError(msg)
|
|
return True
|
|
|
|
fn()
|
|
|
|
def test_isinstance_dtype(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
isinstance(torch.bfloat16, torch.dtype)
|
|
return x
|
|
|
|
fn(torch.randn(3))
|
|
|
|
def test_isinstance_storage(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40])
|
|
bools = torch.BoolStorage.from_buffer(f, "big")
|
|
assert isinstance(bools, torch.BoolStorage)
|
|
return x
|
|
|
|
fn(torch.randn(3))
|
|
|
|
def test_issue111522(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f(x, y):
|
|
return x + y.a
|
|
|
|
class A:
|
|
a = 2
|
|
|
|
self.assertEqual(f(torch.zeros(2), A()), torch.full([2], 2.0))
|
|
|
|
del A.a
|
|
|
|
# graph break on missing attr
|
|
with self.assertRaises(torch._dynamo.exc.Unsupported):
|
|
f(torch.zeros(2), A())
|
|
|
|
def test_sort_out(self):
|
|
dtype = torch.float32
|
|
device = "cpu"
|
|
|
|
def fn():
|
|
tensor = torch.randn((3, 5), dtype=dtype, device=device)[:, 0]
|
|
values1 = torch.tensor(0, dtype=dtype, device=device)
|
|
indices1 = torch.tensor(0, dtype=torch.long, device=device)
|
|
torch.sort(tensor, out=(values1, indices1))
|
|
self.assertEqual(values1.stride(), (1,))
|
|
self.assertEqual(indices1.stride(), (1,))
|
|
|
|
fn()
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
opt_fn()
|
|
|
|
def test_sort_out2(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.sorted = torch.nn.Buffer(torch.ones(4, 4))
|
|
self.indices = torch.nn.Buffer(torch.ones(4, 4, dtype=torch.long))
|
|
|
|
def forward(self, x):
|
|
torch.sort(x, out=(self.sorted, self.indices))
|
|
return (x + 1, self.sorted, self.indices)
|
|
|
|
x = torch.randn(4, 4)
|
|
m = MyModule()
|
|
ref = m(x)
|
|
opt_m = torch.compile(m, backend="eager")
|
|
res = opt_m(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_sigmoid_out(self):
|
|
dtype = torch.float32
|
|
device = "cpu"
|
|
|
|
def fn():
|
|
inp = torch.randn((3, 5), dtype=dtype, device=device)
|
|
out1 = torch.tensor(0, dtype=dtype, device=device)
|
|
torch.sigmoid(inp, out=out1)
|
|
self.assertEqual(out1.numel(), 15)
|
|
|
|
fn()
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
opt_fn()
|
|
|
|
def test_sigmoid_out2(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.base = torch.nn.Buffer(torch.ones(4, 4))
|
|
|
|
def forward(self, x):
|
|
torch.sigmoid(x, out=self.base)
|
|
return x + self.base
|
|
|
|
x = torch.randn(4, 4)
|
|
m = MyModule()
|
|
ref = m(x)
|
|
opt_m = torch.compile(m, backend="eager")
|
|
res = opt_m(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_out_root_cell_shape_change(self):
|
|
@torch.compile(backend="eager")
|
|
def fn():
|
|
out = torch.empty(0)
|
|
|
|
def run():
|
|
x = torch.zeros(3, 5)
|
|
torch.sigmoid(x, out=out)
|
|
return out.size()
|
|
|
|
return run()
|
|
|
|
res = fn()
|
|
self.assertEqual((3, 5), res)
|
|
|
|
def test_out_nested_cell_shape_change(self):
|
|
@torch.compile(backend="eager")
|
|
def fn():
|
|
def run():
|
|
x = torch.zeros(3, 5)
|
|
out = torch.empty(0)
|
|
|
|
def capture():
|
|
return out # Force `out` to be a nested cell
|
|
|
|
torch.sigmoid(x, out=out)
|
|
return out.size()
|
|
|
|
return run()
|
|
|
|
res = fn()
|
|
self.assertEqual((3, 5), res)
|
|
|
|
def test_out_root_cell_tuple_shape_change(self):
|
|
@torch.compile(backend="eager")
|
|
def fn():
|
|
out1 = torch.empty(0)
|
|
out2 = torch.empty(0, dtype=torch.long)
|
|
|
|
def run():
|
|
x = torch.zeros(3, 5)
|
|
torch.sort(x, out=(out1, out2))
|
|
return out1.size(), out2.size()
|
|
|
|
return run()
|
|
|
|
res = fn()
|
|
self.assertEqual(((3, 5), (3, 5)), res)
|
|
|
|
def test_out_nested_cell_tuple_shape_change(self):
|
|
@torch.compile(backend="eager")
|
|
def fn():
|
|
def run():
|
|
x = torch.zeros(3, 5)
|
|
out1 = torch.empty(0)
|
|
out2 = torch.empty(0, dtype=torch.long)
|
|
|
|
def capture():
|
|
# Force `out1` and `out2` to be nested cells
|
|
return out1, out2
|
|
|
|
torch.sort(x, out=(out1, out2))
|
|
return out1.size(), out2.size()
|
|
|
|
return run()
|
|
|
|
res = fn()
|
|
self.assertEqual(((3, 5), (3, 5)), res)
|
|
|
|
def test_slice_into_list_mutable(self):
|
|
class Mod(torch.nn.Module):
|
|
def forward(self, listy):
|
|
x = listy[3:5]
|
|
for _ in range(10):
|
|
z = torch.abs(torch.randn(10)) + 1
|
|
x[0] = z
|
|
return x
|
|
|
|
m = Mod()
|
|
listy = [torch.randn(10)] * 10
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_m = torch.compile(m, backend=cnt, fullgraph=True)
|
|
opt_m.forward(listy)
|
|
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_issue111918(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, dynamic=True)
|
|
def fn(x):
|
|
x = x + 1
|
|
y = x.item()
|
|
if y > 2:
|
|
return x * 2
|
|
else:
|
|
return x * 3
|
|
|
|
x = torch.tensor([3.0])
|
|
fn(x)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 4)
|
|
|
|
torch._dynamo.reset()
|
|
fn = torch.compile(fn, fullgraph=True, backend="eager")
|
|
with self.assertRaises(torch._dynamo.exc.UserError):
|
|
fn(x)
|
|
|
|
def test_vdd_duplicate_error(self):
|
|
def fn(a, dt):
|
|
keys = list(dt._jt_dict.keys())
|
|
p = torch.cos(dt._jt_dict[keys[0]]._value)
|
|
q = torch.sin(a)
|
|
r = torch.sigmoid(dt._jt_dict[keys[0]]._value)
|
|
return p + q + r
|
|
|
|
class Value:
|
|
def __init__(self) -> None:
|
|
self._value = torch.randn(4)
|
|
|
|
class Sample:
|
|
def __init__(self) -> None:
|
|
self._jt_dict = {}
|
|
self._jt_dict["POSITION_ID"] = Value()
|
|
|
|
a = torch.randn(4)
|
|
sample = Sample()
|
|
|
|
ref = fn(a, sample)
|
|
|
|
optimized_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = optimized_fn(a, sample)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_specialized_stride(self):
|
|
def f():
|
|
e = torch.empty(4)
|
|
x = e[::2]
|
|
return x.stride()
|
|
|
|
self.assertEqual(f(), torch.compile(f, backend="eager")())
|
|
|
|
def test_out_none(self):
|
|
# https://github.com/pytorch/pytorch/issues/92814
|
|
def fn(input):
|
|
return torch.nn.functional.normalize(input, dim=0, out=None)
|
|
|
|
x = torch.rand([1])
|
|
self.assertEqual(fn(x), torch.compile(fn, backend="eager")(x))
|
|
|
|
def test_multi_import(self):
|
|
if not has_detectron2():
|
|
raise unittest.SkipTest("requires detectron2")
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def to_bitmasks(boxes):
|
|
from detectron2.layers.mask_ops import (
|
|
_paste_masks_tensor_shape,
|
|
paste_masks_in_image,
|
|
)
|
|
|
|
if (
|
|
paste_masks_in_image is not None
|
|
and _paste_masks_tensor_shape is not None
|
|
):
|
|
return boxes + 1
|
|
|
|
self.assertTrue((to_bitmasks(torch.zeros(10)) == torch.ones(10)).all())
|
|
|
|
def test_multi_dot_import(self):
|
|
def fn1(x):
|
|
return torch.sin(x)
|
|
|
|
def fn(x):
|
|
import torch.fx
|
|
|
|
_ = torch.fx.symbolic_trace(fn1)
|
|
return x * 2
|
|
|
|
x = torch.randn(10)
|
|
fn(x)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_relative_import(self):
|
|
try:
|
|
from . import utils as _ # noqa: F401
|
|
|
|
def fn(x):
|
|
from .utils import tensor_for_import_testing
|
|
|
|
return x * 2 * tensor_for_import_testing
|
|
|
|
except ImportError:
|
|
|
|
def fn(x):
|
|
from utils import tensor_for_import_testing
|
|
|
|
return x * 2 * tensor_for_import_testing
|
|
|
|
x = torch.randn(10)
|
|
fn(x)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_relative_import_no_modulename(self):
|
|
try:
|
|
from . import utils as _ # noqa: F401
|
|
|
|
def fn(x):
|
|
from . import utils
|
|
|
|
return x * 2 * utils.tensor_for_import_testing
|
|
|
|
except ImportError:
|
|
|
|
def fn(x):
|
|
import utils
|
|
|
|
return x * 2 * utils.tensor_for_import_testing
|
|
|
|
x = torch.randn(10)
|
|
fn(x)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_bigbird_unsqueeze_inplace(self):
|
|
def fn(reshape_2):
|
|
view_2 = reshape_2.clone()
|
|
view_2.unsqueeze_(2)
|
|
cat_11 = torch.cat([view_2], dim=2)
|
|
view_13 = cat_11.view((2, 12, 64, -1))
|
|
return (view_13,)
|
|
|
|
x = torch.randn(2, 12, 64, 64, requires_grad=True)
|
|
ref = fn(x)
|
|
opt_fn = torch.compile(fn, backend="aot_eager")
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_issue1466_size_aot_autograd(self):
|
|
def fn(x):
|
|
# do a tensor op and a size compute
|
|
y = x * 2
|
|
x_size = x.size()
|
|
# trigger a graph break
|
|
print("arf")
|
|
# use the tensor op and size compute
|
|
z = y.view(x_size) + 1
|
|
return z
|
|
|
|
x = torch.randn(2, 3, requires_grad=True)
|
|
ref = fn(x)
|
|
opt_fn = torch.compile(fn, backend="aot_eager")
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_ellipsis(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.lnorm = torch.nn.LayerNorm(
|
|
(256,), eps=1e-06, elementwise_affine=True
|
|
)
|
|
self.linear = torch.nn.Linear(
|
|
in_features=256, out_features=256, bias=True
|
|
)
|
|
|
|
def forward(self, cat_10):
|
|
lnorm = self.lnorm(cat_10)
|
|
getitem_64 = lnorm[
|
|
(slice(None, None, None), slice(0, 1, None), Ellipsis)
|
|
]
|
|
linear = self.linear(getitem_64)
|
|
return (linear,)
|
|
|
|
args = [torch.randn(2, 197, 256)]
|
|
|
|
mod = Repro()
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
|
|
self.assertTrue(same(mod(*args), opt_mod(*args)))
|
|
|
|
def test_reinplacing(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.self_layoutlm_embeddings_x_position_embeddings = (
|
|
torch.nn.Embedding(1024, 768)
|
|
)
|
|
self.self_layoutlm_embeddings_y_position_embeddings = (
|
|
torch.nn.Embedding(1024, 768)
|
|
)
|
|
|
|
def forward(self, getitem_1, getitem_2, add):
|
|
self_layoutlm_embeddings_x_position_embeddings = (
|
|
self.self_layoutlm_embeddings_x_position_embeddings(getitem_1)
|
|
)
|
|
self_layoutlm_embeddings_y_position_embeddings = (
|
|
self.self_layoutlm_embeddings_y_position_embeddings(getitem_2)
|
|
)
|
|
add_1 = add + self_layoutlm_embeddings_x_position_embeddings
|
|
add_2 = add_1 + self_layoutlm_embeddings_y_position_embeddings
|
|
return (add_2,)
|
|
|
|
mod = MockModule()
|
|
opt_mod = torch.compile(mod, backend="aot_eager_decomp_partition")
|
|
|
|
args = [
|
|
((2, 512), (2048, 4), torch.int64, "cpu", False),
|
|
((2, 512), (2048, 4), torch.int64, "cpu", False),
|
|
((2, 512, 768), (393216, 768, 1), torch.float32, "cpu", True),
|
|
]
|
|
args = [
|
|
rand_strided(sh, st, dt, dev).requires_grad_(rg)
|
|
for (sh, st, dt, dev, rg) in args
|
|
]
|
|
self.assertTrue(same_two_models(mod, opt_mod, args))
|
|
|
|
def test_optimized_deepcopy(self):
|
|
# See https://github.com/pytorch/pytorch/pull/88629
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fc = torch.nn.Linear(in_features=2, out_features=3, bias=True)
|
|
|
|
def forward(self, x):
|
|
return self.fc(x)
|
|
|
|
mod = Foo()
|
|
opt_mod = torch.compile(mod, backend="eager")
|
|
args = [torch.randn(1, 2)]
|
|
self.assertTrue(same_two_models(mod, opt_mod, args))
|
|
|
|
def test_class_member(self):
|
|
class Foo(torch.nn.Module):
|
|
a = 4
|
|
b = torch.ones(3, 4)
|
|
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.c = 4
|
|
|
|
def forward(self, x):
|
|
return x.cos() + self.a + self.b + self.c
|
|
|
|
mod = Foo()
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
args = (torch.randn(3, 4),)
|
|
self.assertTrue(same(mod(*args), opt_mod(*args)))
|
|
|
|
def test_named_buffers(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.x = torch.nn.Buffer(torch.ones(3))
|
|
self.y = torch.nn.Buffer(torch.ones(3))
|
|
|
|
def forward(self, inp):
|
|
res = 0
|
|
for _, buffer in self.named_buffers():
|
|
res += buffer.sum()
|
|
|
|
return inp.cos() + res
|
|
|
|
mod = Foo()
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
args = (torch.randn(3, 4),)
|
|
self.assertTrue(same(mod(*args), opt_mod(*args)))
|
|
|
|
def test_requires_grad_guards_with_grad_mode1(self):
|
|
def f(x):
|
|
if x.requires_grad:
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
x = torch.ones(2, requires_grad=True)
|
|
|
|
f_compiled = torch.compile(f)
|
|
with torch.no_grad():
|
|
# compile an inference graph
|
|
f_compiled(x)
|
|
|
|
# Test: we should fail guards and recompile (even though it's still an inference graph)
|
|
out_ref = f(x.detach())
|
|
out = f_compiled(x.detach())
|
|
|
|
self.assertEqual(out_ref, out)
|
|
self.assertEqual(out_ref.requires_grad, out.requires_grad)
|
|
|
|
def test_requires_grad_guards_with_grad_mode2(self):
|
|
x = torch.ones(2, requires_grad=True)
|
|
x_ref = x.detach().clone().requires_grad_(True)
|
|
|
|
m = torch.nn.Linear(2, 2)
|
|
m_compiled = torch.compile(m)
|
|
|
|
with torch.no_grad():
|
|
# compile an inference graph
|
|
m_compiled(x)
|
|
|
|
# Test: we should fail guards and recompile a training graph
|
|
out_ref = m(x_ref)
|
|
out = m_compiled(x)
|
|
self.assertEqual(out_ref, out)
|
|
self.assertEqual(out_ref.requires_grad, out.requires_grad)
|
|
|
|
def test_is_symbolic_tracing(self):
|
|
# Ensure no graph break here
|
|
def fn(x):
|
|
if is_fx_tracing_test():
|
|
return x * 2
|
|
return x * 4
|
|
|
|
a = torch.randn(4)
|
|
ref = fn(a)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(a)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_tokenization(self):
|
|
from collections import UserDict
|
|
|
|
class BatchEncoding(UserDict):
|
|
"""
|
|
Copied from tokenization
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data,
|
|
):
|
|
super().__init__(data)
|
|
|
|
def __getattr__(self, item: str):
|
|
try:
|
|
return self.data[item]
|
|
except KeyError as e:
|
|
raise AttributeError from e
|
|
|
|
def tokenization(x):
|
|
encoding = BatchEncoding({"key": x})
|
|
return encoding["key"]
|
|
|
|
opt_fn = torch.compile(tokenization, backend="eager")
|
|
x = torch.rand((1, 4))
|
|
ref = tokenization(x)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_modules(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fc = torch.nn.Linear(4, 3)
|
|
|
|
def forward(self, inp):
|
|
res = torch.zeros(3, 3)
|
|
for _ in self.modules():
|
|
res += self.fc(inp)
|
|
return res
|
|
|
|
mod = Foo()
|
|
args = (torch.ones(3, 4),)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_mod = torch.compile(mod, backend=cnt, fullgraph=True)
|
|
self.assertTrue(same(mod(*args), opt_mod(*args)))
|
|
self.assertEqual(cnt.op_count, 5)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_omegaconf_listconfig_iter(self):
|
|
obj = ListConfig()
|
|
x = torch.zeros(2)
|
|
|
|
def fn():
|
|
y = x
|
|
for i in obj:
|
|
y += i
|
|
return y
|
|
|
|
expected = fn()
|
|
actual = torch.compile(fn, fullgraph=True, backend="eager")()
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_user_defined_iter(self):
|
|
class MyIter:
|
|
def __init__(self) -> None:
|
|
self.i = 0
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.i < 3:
|
|
self.i += 1
|
|
return self.i
|
|
raise StopIteration
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
for i in MyIter():
|
|
x += i
|
|
return x
|
|
|
|
self.assertEqual(fn(torch.zeros(1)), torch.full([1], 6.0))
|
|
|
|
def test_stop_iteration_reconstruct(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return x.sin(), StopIteration(1, 2, 3)
|
|
|
|
_, res = fn(torch.ones(1))
|
|
self.assertEqual(str(res), str(StopIteration(1, 2, 3)))
|
|
|
|
def test_tensor_data_kwarg(self):
|
|
# https://github.com/pytorch/pytorch/issues/96278
|
|
def f():
|
|
return torch.tensor(data=[[1.0, -1.0]])
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(f, backend=cnt, fullgraph=True)
|
|
self.assertTrue(same(f(), opt_fn()))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_for_loop_graph_break(self):
|
|
def inner(x):
|
|
return torch.sin(x)
|
|
|
|
def fn(x):
|
|
for _ in range(100):
|
|
inner(x)
|
|
torch._dynamo.graph_break()
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
x = torch.randn(4)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_for_loop_graph_break_before(self):
|
|
# Checks that the backedge is calculated correctly
|
|
def inner(x):
|
|
return torch.sin(x)
|
|
|
|
def fn(x):
|
|
torch._dynamo.graph_break()
|
|
for _ in range(100):
|
|
inner(x)
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
x = torch.randn(4)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 100)
|
|
|
|
def test_avoid_dupe_specialization(self):
|
|
def f(x, y):
|
|
return (x + y) * 1
|
|
|
|
opt_f = torch.compile(f, backend="aot_eager")
|
|
|
|
for b in [True, False]:
|
|
x = torch.randn(4, requires_grad=b)
|
|
y = torch.randn(4, requires_grad=b)
|
|
self.assertEqual(f(x, x), opt_f(x, x))
|
|
self.assertEqual(f(x, y), opt_f(x, y))
|
|
|
|
def test_validate_model_kwargs(self):
|
|
cnt = CompileCounter()
|
|
|
|
def f1(a, b):
|
|
return torch.sin(a) + torch.cos(b)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f2(**kwargs):
|
|
_validate_model_kwargs(f1, kwargs)
|
|
return f1(**kwargs)
|
|
|
|
x = torch.randn(10)
|
|
y = torch.randn(10)
|
|
|
|
self.assertEqual(f2(a=x, b=y), f1(x, y))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 3)
|
|
|
|
def test_swin_base_tensor_attr(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
# NB: not a parameter or buffer
|
|
self.t = torch.randn(3)
|
|
|
|
def forward(self, x):
|
|
return x + torch.cat((self.t, self.t))
|
|
|
|
mod = Foo()
|
|
opt_mod = torch.compile(mod, backend="eager")
|
|
args = [torch.randn(6)]
|
|
self.assertTrue(same_two_models(mod, opt_mod, args))
|
|
opt_mod(*args)
|
|
|
|
def test_pointless_graph_removal(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def fn(x):
|
|
with torch.no_grad():
|
|
torch._dynamo.graph_break()
|
|
return x + 1
|
|
|
|
fn(torch.randn(4))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 3)
|
|
|
|
def test_output_aliases_intermediate(self):
|
|
def f(x):
|
|
intermediate = x.mul(2)
|
|
return intermediate.view(-1), intermediate
|
|
|
|
opt_f = torch.compile(f, backend="aot_eager")
|
|
|
|
for b in [True, False]:
|
|
x = torch.randn(4, requires_grad=b)
|
|
out = f(x)
|
|
out_test = opt_f(x)
|
|
self.assertEqual(out[0], out_test[0])
|
|
self.assertEqual(out[1], out_test[1])
|
|
self.assertEqual(out[0].requires_grad, out_test[0].requires_grad)
|
|
self.assertEqual(out[1].requires_grad, out_test[1].requires_grad)
|
|
# test that the aliasing relationship of outputs is preserved
|
|
out[0].mul_(2)
|
|
out_test[0].mul_(2)
|
|
self.assertEqual(out[0], out_test[0])
|
|
self.assertEqual(out[1], out_test[1])
|
|
|
|
def test_while_loop_graph_break(self):
|
|
# Repro of tacotron2 cache_size_recompilation
|
|
def inner(x):
|
|
return torch.sin(x)
|
|
|
|
def fn(x):
|
|
i = 20
|
|
while i > 10:
|
|
x = inner(x)
|
|
i -= 1
|
|
torch._dynamo.graph_break()
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
x = torch.randn(4)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_nested_while_loop_graph_break(self):
|
|
def inner_loop(x):
|
|
i = 3
|
|
while i > 0:
|
|
i -= 1
|
|
x += 1
|
|
torch._dynamo.graph_break()
|
|
return x
|
|
|
|
def inner(x):
|
|
inner_loop(x)
|
|
return torch.sin(x)
|
|
|
|
def fn(x):
|
|
i = 20
|
|
while i > 10:
|
|
x = inner(x)
|
|
i -= 1
|
|
torch._dynamo.graph_break()
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
x = torch.randn(4)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_while_loop_graph_break_inside_call_function(self):
|
|
# Repro of huggingface graph break inside loop in `get_parameter_dtype`.
|
|
# Skip only the inner frame that has loop that contains graph break.
|
|
def inner(x):
|
|
for _ in range(3):
|
|
x += 1
|
|
torch._dynamo.graph_break()
|
|
return x
|
|
|
|
def fn(x):
|
|
x += 2
|
|
inner(x)
|
|
x += 3
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
x = torch.randn(4)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
def test_exception_in_dynamo_handling(self):
|
|
hit_handler = False
|
|
|
|
# See https://github.com/pytorch/pytorch/pull/96488
|
|
@contextlib.contextmanager
|
|
def ctx():
|
|
try:
|
|
yield
|
|
except RuntimeError:
|
|
nonlocal hit_handler
|
|
hit_handler = True
|
|
|
|
@torch.compile(backend="eager")
|
|
def f():
|
|
with ctx():
|
|
h()
|
|
|
|
def h():
|
|
raise RuntimeError("boof")
|
|
|
|
# Should not error
|
|
f()
|
|
self.assertTrue(hit_handler)
|
|
|
|
def test_generator_dealloc(self):
|
|
# See https://github.com/pytorch/pytorch/pull/96488
|
|
#
|
|
# NB: yes, [(...)] is intentional, this is a list containing a
|
|
# generator
|
|
generator_box = [(x for x in [1, 2, 3])]
|
|
|
|
counter = torch._dynamo.testing.CompileCounter()
|
|
|
|
def g(x):
|
|
return x + 2
|
|
|
|
# TODO: This test is pretty delicate. To test if it's actually doing
|
|
# anything, rebuild eval_frame.c with '#define TORCHDYNAMO_DEBUG 1'
|
|
# and then look at the logs for:
|
|
#
|
|
# TRACE[_custom_eval_frame:650] begin <genexpr> test_repros.py 2276 -1 0 0
|
|
# TRACE[_custom_eval_frame:664] throw <genexpr>
|
|
#
|
|
# This means we're actually hitting the relevant codepath
|
|
|
|
# NB: Make sure we don't actually Dynamo this frame; if we do Dynamo
|
|
# this frame, Dynamo actually DOES understand list.clear and will
|
|
# arrange for the generator deallocation to happen when the eval frame
|
|
# handler is disabled, which will prevent the bug from happening (we
|
|
# specifically want to trigger the generator deallocation WHILE the
|
|
# dynamo eval frame handler is active), as that will cause the
|
|
# generator to become exhausted and trigger the throw_flag == TRUE
|
|
# case.
|
|
@torch._dynamo.disable(recursive=False)
|
|
def f(x):
|
|
generator_box.clear()
|
|
return g(x)
|
|
|
|
self.assertNoUnraisable(
|
|
lambda: torch.compile(f, backend=counter)(torch.randn(3))
|
|
)
|
|
|
|
# Make sure the x + 2 is captured (a previous incorrect implementation
|
|
# of this fix would have disabled the eval frame callback, which means
|
|
# g wouldn't get traced
|
|
self.assertEqual(counter.op_count, 1)
|
|
|
|
def test_error_return_without_exception_set(self):
|
|
# https://github.com/pytorch/pytorch/issues/93781
|
|
@torch.compile
|
|
def f():
|
|
_generator_type = type(_ for _ in ())
|
|
|
|
self.assertNoUnraisable(f)
|
|
|
|
def common_merge_criteria_processor_list(self, list_cls, fullgraph):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=fullgraph)
|
|
def f(x, left, right):
|
|
combined = _merge_criteria_processor_list(left, right)
|
|
return combined(x)
|
|
|
|
l1 = list_cls([torch.nn.ReLU(), torch.nn.Sigmoid()])
|
|
l2 = list_cls([])
|
|
input = torch.randn(16)
|
|
result = f(input, l1, l2)
|
|
self.assertEqual(result, l1(input))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
cnt.clear()
|
|
l3 = list_cls([torch.nn.SiLU()])
|
|
expected = l3(l1(input))
|
|
result = f(input, l1, l3)
|
|
self.assertEqual(len(l1), 3)
|
|
self.assertEqual(result, expected)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 3)
|
|
|
|
def test_merge_criteria_processor_list1(self):
|
|
self.common_merge_criteria_processor_list(CustomList1, False)
|
|
|
|
def test_merge_criteria_processor_list2(self):
|
|
self.common_merge_criteria_processor_list(CustomList2, True)
|
|
|
|
def test_restricted_list_subclass1(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(a, b):
|
|
l = CustomList2()
|
|
l.extend([True])
|
|
l.append(a)
|
|
l.extend([b])
|
|
l.pop(0)
|
|
l.append(l.length_times_10())
|
|
return sum(l)
|
|
|
|
x = torch.randn(10)
|
|
y = torch.randn(10)
|
|
self.assertEqual(fn(x, y), x + y + 20)
|
|
self.assertEqual(cnt.op_count, 3)
|
|
|
|
def test_restricted_list_subclass2(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(a, b):
|
|
l1 = CustomList2([a + 1])
|
|
l2 = CustomList2([b + 2])
|
|
l1.extend(l2)
|
|
return l1
|
|
|
|
x = torch.randn(10)
|
|
y = torch.randn(10)
|
|
z = fn(x, y)
|
|
self.assertEqual(type(z), CustomList2)
|
|
self.assertEqual(len(z), 2)
|
|
self.assertEqual(z.length_times_10(), 20)
|
|
self.assertEqual(list(z), [x + 1, y + 2])
|
|
|
|
def test_restricted_list_subclass3(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(a: CustomList2, b: CustomList2):
|
|
a.extend(b)
|
|
a.append_twice(b[2] + 1)
|
|
a.append(b[3] + 2)
|
|
return b
|
|
|
|
x = torch.randn(10)
|
|
y = torch.randn(10)
|
|
l = CustomList2([x, y])
|
|
self.assertIs(fn(l, l), l)
|
|
self.assertEqual(len(l), 7)
|
|
self.assertIs(l[0], x)
|
|
self.assertIs(l[1], y)
|
|
self.assertIs(l[2], x)
|
|
self.assertIs(l[3], y)
|
|
self.assertEqual(l[4], x + 1)
|
|
self.assertIs(l[5], l[4])
|
|
self.assertEqual(l[6], y + 2)
|
|
|
|
def test_rewrite_assert_with_msg(self):
|
|
def f(x):
|
|
b = x.sin()
|
|
assert x[0] == 3, "First dim need to be 3"
|
|
return x.cos() + b
|
|
|
|
args = (torch.Tensor([3, 4, 5]),)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
opt_f = torch.compile(f, backend=cnt, fullgraph=True)
|
|
self.assertTrue(same(f(*args), opt_f(*args)))
|
|
self.assertEqual(cnt.op_count, 6)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
|
|
self.assertTrue(same(exported(*args), f(*args)))
|
|
|
|
def test_list_aliasing(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(a):
|
|
a.append(torch.sin(a[0]))
|
|
return a
|
|
|
|
x = torch.randn(10)
|
|
l = [x]
|
|
self.assertIs(fn(l), l)
|
|
self.assertEqual(len(l), 2)
|
|
self.assertIs(l[0], x)
|
|
self.assertEqual(l[1], torch.sin(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_not_rewrite_assert_for_other_errors(self):
|
|
def f(x):
|
|
b = x.sin()
|
|
if not x.sum() <= 3:
|
|
raise ValueError("input sum needs to be 3")
|
|
return x.cos() + b
|
|
|
|
args = (torch.Tensor([3, 4, 5]),)
|
|
opt_fn = torch.compile(f, backend="eager")
|
|
with self.assertRaisesRegex(ValueError, "input sum needs to be 3"):
|
|
opt_fn(*args)
|
|
|
|
def test_rewrite_assert_dont_change_bytecode(self):
|
|
def fn(x):
|
|
with torch.no_grad():
|
|
assert x.max() < 5, f"invalid max {x.max()}"
|
|
x = torch.sin(x)
|
|
return x
|
|
|
|
x = torch.ones(4)
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
self.assertTrue(same(fn(x), opt_fn(x)))
|
|
|
|
def test_rewrite_assert_without_msg(self):
|
|
def f(x):
|
|
b = x.sin()
|
|
assert x[0] == 3
|
|
return x.cos() + b
|
|
|
|
args = (torch.Tensor([3, 4, 5]),)
|
|
exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
|
|
self.assertTrue(same(exported(*args), f(*args)))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "assertion error"):
|
|
exported(torch.Tensor([5, 6, 7]))
|
|
|
|
def test_rewrite_assert_with_non_string_msg(self):
|
|
def f(x):
|
|
b = x.sin()
|
|
assert x[0] == 2, x.size()
|
|
return x.cos() + b
|
|
|
|
torch._dynamo.utils.counters.clear()
|
|
args = torch.Tensor([3, 4, 5])
|
|
opt_f = torch.compile(f, backend="eager")
|
|
with self.assertRaisesRegex(AssertionError, "torch.Size"):
|
|
opt_f(args)
|
|
for gb, cnt in torch._dynamo.utils.counters["graph_break"].items():
|
|
if "assert with non-string message" in gb:
|
|
self.assertEqual(cnt, 1)
|
|
break
|
|
else:
|
|
# graph break not found
|
|
self.assertTrue(False)
|
|
|
|
def test_rewrite_assert_noop(self):
|
|
def f(x):
|
|
b = x.sin()
|
|
assert True
|
|
assert x.dtype == torch.float32
|
|
return x.cos() + b
|
|
|
|
args = (torch.Tensor([3, 4, 5]),)
|
|
exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
|
|
self.assertTrue(same(exported(*args), f(*args)))
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_f = torch.compile(f, backend=cnt, fullgraph=True)
|
|
self.assertTrue(same(f(*args), opt_f(*args)))
|
|
# torch._assert shouldn't be in the graph
|
|
self.assertEqual(cnt.op_count, 3)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5]))
|
|
self.assertTrue(same(exported(*args), f(*args)))
|
|
|
|
def test_size_typematch(self):
|
|
def f(x, y):
|
|
if isinstance(x, torch.Size):
|
|
return y + 1
|
|
else:
|
|
return y + 2
|
|
|
|
y = torch.zeros(1)
|
|
x1 = torch.Size((3,))
|
|
x2 = (3,)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_f = torch.compile(f, backend=cnt, fullgraph=True)
|
|
self.assertTrue(same(f(x1, y), opt_f(x1, y)))
|
|
self.assertTrue(same(f(x2, y), opt_f(x2, y)))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_hf_classinstantier(self):
|
|
# hf activations.py
|
|
class ClassInstantier(collections.OrderedDict):
|
|
def __getitem__(self, key):
|
|
content = super().__getitem__(key)
|
|
cls, kwargs = content if isinstance(content, tuple) else (content, {})
|
|
return cls(**kwargs)
|
|
|
|
ACT2CLS = ClassInstantier(
|
|
{
|
|
"relu": (nn.ReLU, {"inplace": False}),
|
|
"tanh": nn.Tanh,
|
|
}
|
|
)
|
|
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def f(x, act):
|
|
return ACT2CLS[act](x)
|
|
|
|
y = torch.randn(10)
|
|
self.assertTrue(same(f(y, "tanh"), torch.tanh(y)))
|
|
self.assertTrue(same(f(y, "relu"), torch.relu(y)))
|
|
|
|
def test_ephemeral_module(self):
|
|
# hf activations.py
|
|
class ReLUSquaredActivation(nn.Module):
|
|
def forward(self, input):
|
|
relu_applied = torch.nn.functional.relu(input)
|
|
squared = torch.square(relu_applied)
|
|
return squared
|
|
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def f(x):
|
|
x = x + 0.2
|
|
x = ReLUSquaredActivation()(x)
|
|
x = x + 1
|
|
return x
|
|
|
|
y = torch.randn(10)
|
|
self.assertTrue(same(f(y), ReLUSquaredActivation()(y + 0.2) + 1))
|
|
|
|
def test_inplace_unsqueeze_input(self):
|
|
def backend(gm, example_inputs):
|
|
self.assertEqual(example_inputs[-1].size(), torch.Size([1, 3, 4]))
|
|
return gm
|
|
|
|
@torch.compile(backend=backend)
|
|
def fn(x):
|
|
x.unsqueeze_(0)
|
|
return x + 1
|
|
|
|
inputs = [torch.randn(3, 4)]
|
|
self.assertEqual(fn(*inputs).size(), torch.Size([1, 3, 4]))
|
|
self.assertEqual(inputs[0].size(), torch.Size([1, 3, 4]))
|
|
|
|
def test_batchnorm_e2e(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.bn = torch.nn.BatchNorm2d(
|
|
64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
|
|
)
|
|
self.conv1 = torch.nn.Conv2d(
|
|
64,
|
|
64,
|
|
kernel_size=(3, 3),
|
|
stride=(1, 1),
|
|
padding=(1, 1),
|
|
bias=False,
|
|
)
|
|
|
|
def forward(self, x):
|
|
x1 = self.bn(x)
|
|
x2 = self.conv1(x1)
|
|
out = torch.nn.functional.relu(x2)
|
|
return (out,)
|
|
|
|
torch.manual_seed(1337)
|
|
|
|
m_ref = Repro()
|
|
m_test = deepcopy(m_ref)
|
|
|
|
@torch.compile(backend="aot_eager_decomp_partition")
|
|
def compiled_fn(x):
|
|
return m_test(x)
|
|
|
|
x_ref = torch.randn(2, 64, 32, 32, requires_grad=True)
|
|
x_test = x_ref.clone()
|
|
|
|
# Loop multiple times: each iteration the running_mean/var on batchnorm will update,
|
|
# which changes the output of the next iteration
|
|
for _ in range(3):
|
|
ref = m_ref(x_ref)
|
|
res = compiled_fn(x_test)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
|
|
for r in ref:
|
|
if r.requires_grad:
|
|
r.sum().backward()
|
|
for r in res:
|
|
if r.requires_grad:
|
|
r.sum().backward()
|
|
|
|
for param_ref, param_test in zip(m_ref.parameters(), m_test.parameters()):
|
|
self.assertTrue(same(param_ref, param_test))
|
|
# Assert running_mean/var
|
|
for buffer_ref, buffer_test in zip(m_ref.buffers(), m_test.buffers()):
|
|
self.assertTrue(same(buffer_ref, buffer_test))
|
|
|
|
@torch._dynamo.config.patch("assume_static_by_default", False)
|
|
def test_dynamic_shapes_right_side(self):
|
|
def f(x):
|
|
return torch.ones(5 * x.shape[0])
|
|
|
|
inp = torch.randn(6, 5)
|
|
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))
|
|
self.assertEqual(gm(inp).shape, f(inp).shape)
|
|
|
|
@torch._dynamo.config.patch("specialize_int", False)
|
|
def test_maybe_multiply_symint(self):
|
|
# https://github.com/pytorch/pytorch/issues/97346
|
|
from torch._functorch.aot_autograd import aot_module_simplified
|
|
|
|
def my_aot_compiler(gm, example_inputs):
|
|
def my_compiler(gm, example_inputs):
|
|
return gm.forward
|
|
|
|
# Invoke AOTAutograd
|
|
return aot_module_simplified(gm, example_inputs, fw_compiler=my_compiler)
|
|
|
|
def my_example(t1, t2, d):
|
|
out = torch.add(t1, t2, alpha=d)
|
|
return out
|
|
|
|
compiled_fn = torch.compile(backend=my_aot_compiler, dynamic=True)(my_example)
|
|
|
|
t1 = torch.arange(3, dtype=torch.float32).requires_grad_(True)
|
|
t2 = torch.arange(3, dtype=torch.float32).requires_grad_(True)
|
|
|
|
ra = compiled_fn(t1, t2, 5)
|
|
self.assertEqual(ra, torch.tensor([0.0, 6.0, 12.0]))
|
|
|
|
ra = compiled_fn(t1, t2, 6)
|
|
self.assertEqual(ra, torch.tensor([0.0, 7.0, 14.0]))
|
|
|
|
def test_build_map_unpack_with_call(self):
|
|
def forward_with_cond_scale(x, t, cond_scale, self_cond, other1, other2):
|
|
return x.sin() + t + cond_scale + self_cond + other1 + other2
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
d1 = dict(other1=5)
|
|
d2 = dict(other2=4)
|
|
text_cond = {**d1, **d2}
|
|
return forward_with_cond_scale(x, 1, cond_scale=2, self_cond=3, **text_cond)
|
|
|
|
self.assertTrue(same(fn(torch.ones(4)), torch.ones(4).sin() + 15))
|
|
|
|
@torch._dynamo.config.patch(verbose=True)
|
|
def test_graph_break_unsupported_fake(self):
|
|
counter = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=counter)
|
|
def f(x):
|
|
return torch.ops.test_sample.foo(x + 1) + 1
|
|
|
|
f(torch.randn(3))
|
|
|
|
self.assertEqual(counter.op_count, 2)
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
def test_delattr(self):
|
|
class MyObj:
|
|
def __init__(self, a, b):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x, obj):
|
|
del obj.a
|
|
obj.c = x + 1
|
|
del obj.c
|
|
tmp = MyObj(x + 2, x + 3)
|
|
del tmp.b
|
|
if hasattr(obj, "a"):
|
|
return x + 1
|
|
return tmp
|
|
|
|
x = torch.zeros([])
|
|
obj1 = MyObj(x, x)
|
|
obj2 = fn(x, obj1)
|
|
self.assertFalse(hasattr(obj1, "a"))
|
|
self.assertFalse(hasattr(obj1, "c"))
|
|
self.assertFalse(hasattr(obj2, "b"))
|
|
self.assertEqual(obj1.b.item(), 0)
|
|
self.assertEqual(obj2.a.item(), 2)
|
|
|
|
def test_delattr_return(self):
|
|
class MyObject:
|
|
def __init__(self, val):
|
|
self.val = val
|
|
self.deletion_attempted = False
|
|
|
|
def __delattr__(self, attr):
|
|
if attr == "val":
|
|
self.deletion_attempted = True
|
|
else:
|
|
super().__delattr__(attr)
|
|
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def test_delattr(input_tensor):
|
|
instance_a = MyObject(1)
|
|
instance_b = MyObject(2)
|
|
del instance_a.val
|
|
del instance_b.val
|
|
exists_a = hasattr(instance_a, "val")
|
|
exists_b = hasattr(instance_b, "val")
|
|
deletion_attempted_a = instance_a.deletion_attempted
|
|
deletion_attempted_b = instance_b.deletion_attempted
|
|
return (
|
|
input_tensor + 1,
|
|
exists_a,
|
|
exists_b,
|
|
deletion_attempted_a,
|
|
deletion_attempted_b,
|
|
)
|
|
|
|
result = test_delattr(torch.ones(1))
|
|
self.assertEqual(result[0], torch.tensor([2.0]))
|
|
self.assertEqual(result[1:], (True, True, True, True))
|
|
|
|
def test_delattr_raises(self):
|
|
class MyObj:
|
|
def __init__(self, a, b):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x, obj):
|
|
del obj.a
|
|
x = x + 1
|
|
obj.a # will raise
|
|
return x
|
|
|
|
x = torch.zeros([])
|
|
obj1 = MyObj(x, x)
|
|
self.assertRaises(AttributeError, lambda: fn(x, obj1))
|
|
|
|
def test_delsubscr(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
del x["a"]
|
|
y = x["b"] + 1
|
|
return y
|
|
|
|
x = {"a": torch.tensor([1]), "b": torch.tensor([1])}
|
|
result = fn(x)
|
|
self.assertFalse(hasattr(x, "a"))
|
|
self.assertEqual(result.item(), 2)
|
|
|
|
def test_delsubscr_raises(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
del x["a"]
|
|
y = x["a"] + 1 # should raise KeyError
|
|
return y
|
|
|
|
x = {"a": torch.tensor([1]), "b": torch.tensor([1])}
|
|
self.assertRaises(KeyError, lambda: fn(x))
|
|
|
|
def test_attached_attribute_in_dir(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(16, 16)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.linear(x))
|
|
|
|
mod = torch.compile(MyModule(), backend="eager")
|
|
mod.is_compiled = True
|
|
self.assertTrue("is_compiled" in dir(mod))
|
|
|
|
@torch._dynamo.config.patch("automatic_dynamic_shapes", False)
|
|
def test_dynamic_shapes_implicit_guard(self):
|
|
def f(x):
|
|
y = x * x.size(x.shape[0])
|
|
torch.sum(y, [y.shape[0]])
|
|
return y
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(f, backend=cnt, fullgraph=True)
|
|
opt_fn(torch.randn(3, 1, 1, 1, 1))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_dalle2_maybe(self):
|
|
def normalize(x):
|
|
return x.cos()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x, normalize_img):
|
|
lowres_cond_img = x.sin()
|
|
lowres_cond_img = maybe(normalize_img)(lowres_cond_img)
|
|
return lowres_cond_img
|
|
|
|
self.assertEqual(fn(torch.ones([]), normalize), torch.ones([]).sin().cos())
|
|
|
|
def test_functools_wraps(self):
|
|
def cool_name(x):
|
|
return x.sin()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
y = x.cos()
|
|
|
|
@functools.wraps(cool_name)
|
|
def uncool_name():
|
|
return cool_name(y)
|
|
|
|
return uncool_name
|
|
|
|
result = fn(torch.ones([]))
|
|
self.assertEqual(result.__name__, "cool_name")
|
|
self.assertEqual(result(), torch.ones([]).cos().sin())
|
|
|
|
def test_dynamic_shapes_float_guard(self):
|
|
def f(x):
|
|
return torch.nn.functional.dropout(x, x.shape[0] / 6)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(f, backend=cnt, fullgraph=True)
|
|
opt_fn(torch.randn(3))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_tensor_item(self):
|
|
def f(x, y):
|
|
val = y.item()
|
|
return x.sum() + val
|
|
|
|
gm, _ = torch._dynamo.export(
|
|
f,
|
|
aten_graph=True,
|
|
)(
|
|
torch.zeros(6, 4),
|
|
torch.tensor(1),
|
|
)
|
|
self.assertEqual(
|
|
f(torch.zeros(6, 4), torch.tensor(1)),
|
|
gm(torch.zeros(6, 4), torch.tensor(1)),
|
|
)
|
|
self.assertEqual(
|
|
f(torch.zeros(6, 4), torch.tensor(2)),
|
|
gm(torch.zeros(6, 4), torch.tensor(2)),
|
|
)
|
|
|
|
def test_dataclass_init_with_default_factory_with_inputs(self):
|
|
@dataclasses.dataclass
|
|
class DClass:
|
|
sharding_contexts: Any = dataclasses.field(default_factory=list)
|
|
a: int = 1
|
|
|
|
def fn(x, inp_list):
|
|
d = DClass(inp_list)
|
|
d.sharding_contexts.append(x.sin() + d.a)
|
|
return d
|
|
|
|
x = torch.randn(4)
|
|
inp_list1 = [1, 2, 3]
|
|
inp_list2 = [2, 3, 4]
|
|
inp_list3 = [1, 2]
|
|
ref1 = fn(x, inp_list1)
|
|
ref2 = fn(x, inp_list2)
|
|
ref3 = fn(x, inp_list3)
|
|
|
|
opt_fn = torch.compile(fn, fullgraph=True)
|
|
|
|
opt_ret1 = opt_fn(x, inp_list1)
|
|
opt_ret2 = opt_fn(x, inp_list2)
|
|
opt_ret3 = opt_fn(x, inp_list3)
|
|
self.assertEqual(ref1.sharding_contexts, opt_ret1.sharding_contexts)
|
|
self.assertEqual(ref2.sharding_contexts, opt_ret2.sharding_contexts)
|
|
self.assertEqual(ref3.sharding_contexts, opt_ret3.sharding_contexts)
|
|
|
|
def test_list_index(self):
|
|
for i, list_type in enumerate(
|
|
(
|
|
list,
|
|
tuple,
|
|
torch.Size,
|
|
collections.deque,
|
|
namedtuple("FourElems", "one two three four", defaults=[0, 0, 0, 0]),
|
|
)
|
|
):
|
|
torch._dynamo.reset()
|
|
for index in ([], [2], [0, 3]):
|
|
|
|
def f(t):
|
|
if i == 4: # namedtuple
|
|
xs = list_type(1, 2, 3, 4)
|
|
else:
|
|
xs = list_type([1, 2, 3, 4])
|
|
res = xs.index(3, *index)
|
|
return t + res
|
|
|
|
res = torch.compile(f, backend="eager", fullgraph=True)(torch.zeros(1))
|
|
|
|
self.assertEqual(res, torch.tensor([2.0]))
|
|
|
|
def test_list_index_not_found(self):
|
|
def f(t):
|
|
xs = ["bar", "foo", "baz", "buzz"]
|
|
res = xs.index("non-existent")
|
|
return t + res
|
|
|
|
# Raising ValueError from item not found is unsupported
|
|
with self.assertRaises(
|
|
torch._dynamo.exc.Unsupported,
|
|
):
|
|
torch.compile(f, backend="eager", fullgraph=True)(torch.zeros(1))
|
|
|
|
def test_list_index_tensor_unsupported(self):
|
|
for index in ([], [2], [0, 3]):
|
|
|
|
def f(t):
|
|
xs = [torch.tensor([i]) for i in range(4)]
|
|
res = xs.index(torch.tensor([2]), *index)
|
|
return t + res
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.Unsupported,
|
|
"Data-dependent branching",
|
|
):
|
|
torch.compile(f, backend="eager", fullgraph=True)(torch.zeros(1))
|
|
|
|
def test_hf_xsoftmax_inference(self):
|
|
def fn(input, mask):
|
|
return XSoftmax.apply(input + 1, mask, 1) + 2
|
|
|
|
fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
inputs = [
|
|
torch.randn(4, 10),
|
|
torch.randn(4, 10) < 0,
|
|
]
|
|
expected = fn(*inputs)
|
|
actual = fn_opt(*inputs)
|
|
self.assertTrue(same(actual, expected))
|
|
|
|
@mock.patch("torch._dynamo.config.guard_nn_modules", True)
|
|
def test_hf_xsoftmax_training(self):
|
|
from torch._dynamo.utils import counters
|
|
|
|
counters.clear()
|
|
|
|
def fn(input, mask):
|
|
return XSoftmax.apply(input, mask, 1)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
fn_opt = torch.compile(fn, backend=cnt, fullgraph=False)
|
|
|
|
torch.manual_seed(1234)
|
|
inputs1 = [
|
|
torch.randn(4, 10, requires_grad=True),
|
|
torch.randn(4, 10) < 0,
|
|
]
|
|
torch.manual_seed(1234)
|
|
inputs2 = [
|
|
torch.randn(4, 10, requires_grad=True),
|
|
torch.randn(4, 10) < 0,
|
|
]
|
|
|
|
expected = fn(*inputs1)
|
|
actual = fn_opt(*inputs2)
|
|
self.assertTrue(same(actual, expected))
|
|
self.assertEqual(cnt.op_count, 1)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
cnt.clear()
|
|
counters.clear()
|
|
|
|
expected.sum().backward()
|
|
actual.sum().backward()
|
|
self.assertTrue(same(inputs1[0].grad, inputs2[0].grad))
|
|
|
|
# currently we don't capture the backwards frame
|
|
self.assertEqual(cnt.frame_count, 0)
|
|
self.assertEqual(cnt.op_count, 0)
|
|
self.assertEqual(dict(counters["frames"]), {})
|
|
self.assertEqual(dict(counters["graph_break"]), {})
|
|
|
|
def test_autograd_function_graph_break(self):
|
|
class MySin(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
torch._dynamo.graph_break()
|
|
ctx.save_for_backward(x)
|
|
return x.sin()
|
|
|
|
@staticmethod
|
|
def backward(ctx, gx):
|
|
(x,) = ctx.saved_tensors
|
|
return gx * x.cos()
|
|
|
|
x = torch.randn([], requires_grad=True)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return MySin.apply(x)
|
|
|
|
y = fn(x)
|
|
self.assertEqual(y, x.sin())
|
|
|
|
(gx,) = torch.autograd.grad(y, x)
|
|
self.assertEqual(gx, x.cos())
|
|
|
|
def test_jit_trace_errors(self):
|
|
@torch.compile(backend="eager", dynamic=True)
|
|
def f(x):
|
|
return x + 1
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
torch.jit.trace(f, torch.randn(3))
|
|
|
|
@torch._dynamo.config.patch("assume_static_by_default", False)
|
|
def test_tensor_split(self):
|
|
def f(x):
|
|
return torch.split(x, x.shape[0] // 2, dim=0)[0]
|
|
|
|
gm, _ = torch._dynamo.export(
|
|
f,
|
|
aten_graph=True,
|
|
)(
|
|
torch.zeros(6, 4),
|
|
)
|
|
|
|
self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4)))
|
|
|
|
def test_optim_state_references_cleared(self):
|
|
model = torch.nn.Linear(2048, 2048, bias=False)
|
|
x = torch.ones(2048)
|
|
state_ref = 0
|
|
|
|
optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01)
|
|
|
|
def opt_step():
|
|
optimizer.step()
|
|
|
|
compiled_opt_step = torch.compile(opt_step, backend="eager")
|
|
|
|
def compiled_model_step(x):
|
|
optimizer.zero_grad()
|
|
y = model(x)
|
|
torch.sum(y).backward()
|
|
compiled_opt_step()
|
|
|
|
compiled_model_step(x)
|
|
|
|
# Picked "square_avg" arbitrarily to check that
|
|
# optimizer state tensors are deallocated
|
|
state_ref = weakref.ref(
|
|
optimizer.state[optimizer.param_groups[0]["params"][0]]["square_avg"]
|
|
)
|
|
optimizer = None
|
|
|
|
self.assertIsNone(state_ref())
|
|
|
|
def test_grad_references_cleared(self):
|
|
model = torch.nn.Linear(2048, 2048, bias=False)
|
|
x = torch.ones(2048)
|
|
optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01)
|
|
|
|
def opt_step():
|
|
optimizer.step()
|
|
|
|
compiled_opt_step = torch.compile(opt_step, backend="eager")
|
|
|
|
def compiled_model_step(x):
|
|
optimizer.zero_grad(True)
|
|
y = model(x)
|
|
torch.sum(y).backward()
|
|
compiled_opt_step()
|
|
|
|
compiled_model_step(x)
|
|
param_grad_ref = weakref.ref(next(iter(model.parameters())).grad)
|
|
optimizer.zero_grad(True)
|
|
self.assertIsNone(param_grad_ref())
|
|
|
|
def test_batch_encoding_clone_inputs(self):
|
|
class BatchEncoding(dict):
|
|
"""
|
|
Copied from test_tokenization
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data,
|
|
):
|
|
super().__init__(data)
|
|
|
|
def __getattr__(self, item: str):
|
|
try:
|
|
return self.data[item]
|
|
except KeyError as e:
|
|
raise AttributeError from e
|
|
|
|
encoding = BatchEncoding({"key": torch.rand((1, 4))})
|
|
cloned_encoding = torch._dynamo.utils.clone_inputs(encoding)
|
|
self.assertTrue(type(cloned_encoding) is not dict)
|
|
|
|
def test_iadd_graph_break(self):
|
|
def fn(x):
|
|
a = ()
|
|
x = torch.sin(x)
|
|
a += (x,)
|
|
return a
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_odict_get_item_index_name(self):
|
|
d = {float: torch.float32, np.float16: torch.float16}
|
|
|
|
@torch.compile(backend="eager")
|
|
def f(x, y1, y2):
|
|
return torch.zeros(5, dtype=d[y1]), torch.zeros(5, dtype=d[y2])
|
|
|
|
f(torch.zeros(4), float, np.float16)
|
|
|
|
def test_dedup_global(self):
|
|
@torch.compile()
|
|
def f():
|
|
return _GLOBAL_CPU_TENSOR + _GLOBAL_CPU_TENSOR
|
|
|
|
self.assertEqual(f(), _GLOBAL_CPU_TENSOR + _GLOBAL_CPU_TENSOR)
|
|
|
|
def test_randint_out_dynamic(self):
|
|
def randint_fn(high, size, out):
|
|
return torch.randint(high, size, out=out)
|
|
|
|
opt_model = torch.compile(randint_fn)
|
|
|
|
out1 = torch.empty(10, dtype=torch.int32)
|
|
opt_model(17, (10,), out1)
|
|
|
|
out2 = torch.empty(12, dtype=torch.int32)
|
|
opt_model(17, (12,), out2)
|
|
|
|
@requires_cuda
|
|
@serialTest()
|
|
def test_mem_leak_guards(self):
|
|
def gn(x0, x):
|
|
return x0 * x
|
|
|
|
class MyMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
@torch._dynamo.disable(recursive=False)
|
|
def forward(self, running_x):
|
|
# This line creates an temp tensor, which should not be leaked
|
|
running_x = torch.sin(running_x)
|
|
x = running_x
|
|
# This creates a TENSOR_ALIASING guard
|
|
x = gn(running_x, running_x)
|
|
# This creates a NO_TENSOR_ALIASING guard which was leaking memory
|
|
x = gn(running_x, x)
|
|
return x
|
|
|
|
mod = MyMod().cuda()
|
|
|
|
fn = torch.compile(mod, backend="eager")
|
|
x = torch.randn(10, 10, device="cuda")
|
|
torch.cuda.reset_peak_memory_stats()
|
|
|
|
fn(x)
|
|
peak_mem1 = torch.cuda.max_memory_allocated()
|
|
|
|
for _ in range(1000):
|
|
fn(x)
|
|
peak_mem2 = torch.cuda.max_memory_allocated()
|
|
self.assertTrue(peak_mem1 == peak_mem2)
|
|
|
|
@requires_cuda
|
|
def test_guard_default_device(self):
|
|
try:
|
|
torch.set_default_device("cuda")
|
|
|
|
counter = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=counter)
|
|
def f():
|
|
x = torch.randn(3)
|
|
return x * 2
|
|
|
|
self.assertEqual(f().device.type, "cuda")
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
torch.set_default_device("cpu")
|
|
|
|
self.assertEqual(f().device.type, "cpu")
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
finally:
|
|
torch.set_default_device(None)
|
|
|
|
def test_list_self_reference(self):
|
|
# Issue - https://github.com/pytorch/pytorch/issues/100150
|
|
root = []
|
|
root[:] = [root, root, None, None]
|
|
|
|
@torch.compile(fullgraph=False, backend="eager")
|
|
def test_bug():
|
|
return root[0]
|
|
|
|
test_bug()
|
|
|
|
def test_hf_bigbird_unsqueeze(self):
|
|
def torch_bmm_nd(inp_1, inp_2, ndim=None):
|
|
torch._dynamo.graph_break()
|
|
return torch.bmm(inp1, inp2)
|
|
|
|
def fn(inp1, inp2, inp3, inp4, c):
|
|
a = torch_bmm_nd(inp1, inp2, 4)
|
|
a.unsqueeze_(2)
|
|
a = a * 2
|
|
|
|
b = torch_bmm_nd(inp3, inp4, 4)
|
|
b.unsqueeze_(2)
|
|
l = a + b
|
|
|
|
out = torch.cat([a, b, c], dim=2)
|
|
return out, l
|
|
|
|
inp1 = torch.rand(1, 64, 448)
|
|
inp2 = torch.rand(1, 448, 64)
|
|
inp3 = torch.rand(1, 64, 448)
|
|
inp4 = torch.rand(1, 448, 64)
|
|
c = torch.rand(1, 64, 1, 64)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
opt_fn(inp1, inp2, inp3, inp4, c)
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
def test_torch_variable_type(self):
|
|
# from torchvision
|
|
def check_type(obj, types_or_checks):
|
|
for type_or_check in types_or_checks:
|
|
if (
|
|
isinstance(obj, type_or_check)
|
|
if isinstance(type_or_check, type)
|
|
else type_or_check(obj)
|
|
):
|
|
return True
|
|
return False
|
|
|
|
opt_check_type = torch.compile(check_type, backend="eager")
|
|
ref = check_type(torch.randn(4), [torch.Tensor])
|
|
res = opt_check_type(torch.randn(4), [torch.Tensor])
|
|
self.assertEqual(ref, res)
|
|
|
|
# Test for https://github.com/pytorch/pytorch/issues/103132
|
|
@torch._dynamo.config.patch("assume_static_by_default", False)
|
|
def test_inference_mode_dynamic_shapes(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, param):
|
|
z = torch.matmul(param, param)
|
|
return z
|
|
|
|
model = Repro()
|
|
# Need a 3d tensor to actually cause the error:
|
|
# we go down a path of the C++ matmul decomp that calls sizes().
|
|
inp = torch.randn(4, 4, 4, requires_grad=True)
|
|
model = torch.compile(model, backend="aot_eager", dynamic=True)
|
|
with torch.inference_mode():
|
|
model(inp)
|
|
|
|
def test_kwargs_out_list_variable(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, param):
|
|
z = torch.frexp(**param)
|
|
return z
|
|
|
|
model = Repro()
|
|
params = {"input": torch.tensor([[0.0, 1, 2, 4]])}
|
|
params["out"] = [
|
|
torch.empty(0, dtype=torch.float32), # mantissa
|
|
torch.empty(0, dtype=torch.int32), # exponent
|
|
]
|
|
|
|
model = torch.compile(model, backend="eager")
|
|
mantissa, exponent = model(params)
|
|
ref_mantissa = torch.tensor([[0.0000, 0.5000, 0.5000, 0.5000]])
|
|
ref_exponent = torch.tensor([[0, 1, 2, 3]], dtype=torch.int32)
|
|
self.assertEqual(ref_mantissa, mantissa)
|
|
self.assertEqual(ref_exponent, exponent)
|
|
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_split_with_sizes_aot_autograd(self):
|
|
def fn(result, split_sizes):
|
|
rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist())
|
|
return rs
|
|
|
|
example_inputs = (
|
|
torch.randn(32, requires_grad=True),
|
|
torch.tensor((7, 16, 9)),
|
|
)
|
|
actual = torch.compile(fn, fullgraph=True, backend="aot_eager")(*example_inputs)
|
|
expected = fn(*example_inputs)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_unspecialized_nn_module_with_torch_variable_attribute(self):
|
|
"""
|
|
In this case self.fn = something that should be a TorchVariable.
|
|
When it's not a TorchVariable, dynamo tries to trace through and fails.
|
|
This makes sure that the self.fn is handled as a TorchVariable.
|
|
"""
|
|
|
|
class UserModule(torch.nn.Module):
|
|
torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule
|
|
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, **inp):
|
|
return self.fn(**inp)
|
|
|
|
inputs = {
|
|
"input": torch.randn([2, 9]).uniform_(0, 1),
|
|
"target": torch.randn([2, 9]).uniform_(0, 1),
|
|
"reduction": "mean",
|
|
}
|
|
|
|
mod = UserModule(torch.nn.functional.binary_cross_entropy)
|
|
ref = mod(**inputs)
|
|
res = torch.compile(mod, backend="eager", fullgraph=True)(**inputs)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_string_format(self):
|
|
s = "temp{i}"
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
if s.format(i=4) == "temp4":
|
|
return torch.sin(x)
|
|
return torch.cos(x)
|
|
|
|
x = torch.randn(4)
|
|
self.assertEqual(fn(x), torch.sin(x))
|
|
|
|
@unittest.skip("Fails with incorrect result with fullgraph constraints")
|
|
def test_int_format(self):
|
|
def fn(num: int):
|
|
return format(num, "b")
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True, dynamic=False)
|
|
self.assertEqual(fn(10), opt_fn(10))
|
|
|
|
# Repro of torch._dynamo.exc.InternalTorchDynamoError: 'NoneType' object has no attribute 'guards'
|
|
# due to bad empty list handling
|
|
def test_empty_list_contains_with_jump(self):
|
|
def fn(x, l):
|
|
if x in l:
|
|
return x.cos()
|
|
return x.sin()
|
|
|
|
counter = CompileCounter()
|
|
torch.compile(fn, backend=counter)(torch.randn([2, 2]), [])
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
def test_graph_break_on_jit_isinstance(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
if torch.jit.isinstance(x, typing.List[str]): # noqa: UP006
|
|
return x * 2
|
|
return x
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
x = torch.rand(4)
|
|
self.assertTrue(same(fn(x), opt_fn(x)))
|
|
|
|
def test_graph_break_on_jit_isinstance_pep585(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
if torch.jit.isinstance(x, list[str]):
|
|
return x * 2
|
|
return x
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
x = torch.rand(4)
|
|
self.assertTrue(same(fn(x), opt_fn(x)))
|
|
|
|
def test_add_sub_alpha_out(self):
|
|
inp = torch.randn(2, 3, 4)
|
|
other = 1
|
|
alpha = 2
|
|
for op in [torch.add, torch.sub]:
|
|
out = torch.zeros(2, 3, 4)
|
|
compile_out = torch.zeros(2, 3, 4)
|
|
op(inp, other, alpha=alpha, out=out)
|
|
compiled_fn = torch.compile(op, dynamic=True)
|
|
compiled_fn(inp, other, alpha=alpha, out=compile_out)
|
|
self.assertTrue(same(out, compile_out))
|
|
|
|
def test_negative_shape_guard(self):
|
|
def fn(x):
|
|
if x.size() != (5, 1, 2, 3):
|
|
return x.cos()
|
|
return x.sin()
|
|
|
|
counter = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=counter, dynamic=True)
|
|
|
|
x = torch.ones(5, 1, 3, 4)
|
|
x2 = torch.ones(5, 1, 2, 3)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
self.assertEqual(fn(x2), opt_fn(x2))
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_deferred_runtime_asserts(self):
|
|
@torch.compile(fullgraph=True)
|
|
def f(x):
|
|
y = x.item()
|
|
torch._check_is_size(y)
|
|
if y >= 0:
|
|
return x * 2
|
|
else:
|
|
return x * 3
|
|
|
|
f(torch.tensor([3]))
|
|
self.assertRaises(RuntimeError, lambda: f(torch.tensor([-2])))
|
|
|
|
def test_addr_alpha_beta_out(self):
|
|
inp = torch.randn(2, 3)
|
|
vec1 = torch.randn(2)
|
|
vec2 = torch.randn(3)
|
|
alpha = 2
|
|
beta = 5
|
|
|
|
out = torch.zeros(2, 3)
|
|
compile_out = torch.zeros(2, 3)
|
|
|
|
torch.addr(inp, vec1, vec2, alpha=alpha, beta=beta, out=out)
|
|
compiled_fn = torch.compile(torch.addr, dynamic=True)
|
|
compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out)
|
|
self.assertTrue(same(out, compile_out))
|
|
|
|
def test_setattr_requires_grad_graph_breaks(self):
|
|
def fn(x):
|
|
z = x + 4
|
|
x.requires_grad = True
|
|
y = x * z
|
|
return y
|
|
|
|
for backend in ["count", "eager", "aot_eager"]:
|
|
if backend == "count":
|
|
backend = CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=backend)
|
|
|
|
eager = torch.zeros(5)
|
|
compiled = eager.clone()
|
|
|
|
out_eager = fn(eager)
|
|
out_opt = opt_fn(compiled)
|
|
|
|
self.assertEqual(out_eager, out_opt)
|
|
|
|
out_eager.sum().backward()
|
|
out_opt.sum().backward()
|
|
|
|
self.assertEqual(eager, compiled)
|
|
if isinstance(backend, CompileCounter):
|
|
self.assertEqual(backend.frame_count, 2) # graph breaks
|
|
|
|
def test_dynamic_shapes_double_not_equal(self):
|
|
# https://github.com/pytorch/pytorch/issues/113393
|
|
def fn(x):
|
|
if x.size() != (5, 1, 2, 3):
|
|
return x.cos()
|
|
return x.sin()
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
|
|
x = torch.ones(5, 1, 2, 3)
|
|
x2 = torch.ones(5, 1, 3, 4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
self.assertEqual(fn(x2), opt_fn(x2))
|
|
|
|
def test_inductor_no_recursionerror_on_for_loops(self):
|
|
def forward(x):
|
|
for _ in range(10000):
|
|
x = 1.0 * x
|
|
return x
|
|
|
|
self.assertTrue(
|
|
same(torch.compile(forward)(torch.tensor([1.0])), torch.tensor([1.0]))
|
|
)
|
|
|
|
def test_user_defined_object_callable(self):
|
|
# https://github.com/pytorch/pytorch/issues/114019
|
|
class MyCallable:
|
|
def __call__(self, x):
|
|
return x + 1
|
|
|
|
def fn(x):
|
|
# Create in graph - will not have source
|
|
return MyCallable()(x)
|
|
|
|
fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn_opt(torch.zeros(1)), fn(torch.zeros(1)))
|
|
|
|
@torch._dynamo.config.patch(log_compilation_metrics=True)
|
|
def test_many_views_with_mutation(self):
|
|
# When symbolic storage offsets were added in #113734, tensors_definitely_do_not_overlap
|
|
# began adding shape guards - a quadratic amount relative to the number of inputs.
|
|
# Test this configuration, and test that a reasonable number of guards are added.
|
|
# Note, when dynamic shapes are turned on, this test fails and we still get quadratic guards.
|
|
def fn(x):
|
|
x[0].relu_()
|
|
return torch.cat(x).sum()
|
|
|
|
AMT = 32
|
|
src = torch.rand(16 * (AMT + 1))
|
|
|
|
x = [src.as_strided((4, 4), (4, 1), 3 + 16 * i) for i in range(AMT)]
|
|
|
|
torch._dynamo.reset()
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
torch.compile(fn, backend="aot_eager")(x)
|
|
|
|
all_metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
|
|
total_guards = sum(metric.guard_count for metric in all_metrics)
|
|
self.assertLess(total_guards, AMT * 8)
|
|
|
|
total_shape_env_guards = sum(
|
|
metric.shape_env_guard_count for metric in all_metrics
|
|
)
|
|
self.assertLess(total_shape_env_guards, AMT * 8)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/118799
|
|
def test_subclass_graph_output_repro(self):
|
|
@torch._dynamo.allow_in_graph
|
|
def to_subclass(x):
|
|
return TwoTensor(x.clone(), x.clone())
|
|
|
|
def f(x):
|
|
tmp_subclass = to_subclass(x)
|
|
return tmp_subclass.view(-1)
|
|
|
|
x = torch.ones(2)
|
|
out_ref = f(x)
|
|
out_test = torch.compile(f, backend="aot_eager")(x)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
def test_numpy_tobytes_no_error(self):
|
|
def fn(x):
|
|
x += 1
|
|
z = x.tobytes()
|
|
x += 1
|
|
return z
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
opt_arg, arg = np.array([1, 2]), np.array([1, 2])
|
|
self.assertEqual(opt_fn(opt_arg), fn(arg))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_numpy_not_ndarray_recompiles(self):
|
|
import torch
|
|
|
|
def fn(x=None):
|
|
if x is None:
|
|
x = np.ones(3)
|
|
elif isinstance(x, int):
|
|
x = np.ones(6)
|
|
elif isinstance(x, str):
|
|
x = np.ones(9)
|
|
return x**2
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
|
|
x = np.zeros((2, 2))
|
|
|
|
self.assertEqual(opt_fn(x), fn(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(opt_fn(), fn())
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(opt_fn(10), fn(10))
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
self.assertEqual(opt_fn("10"), fn("10"))
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
@parametrize(
|
|
"backend",
|
|
["eager", "aot_eager", "inductor"],
|
|
)
|
|
@parametrize(
|
|
"func_name",
|
|
["func1", "func2", "func3"],
|
|
)
|
|
def test_tensor_set_data(self, backend, func_name):
|
|
# https://github.com/pytorch/pytorch/issues/113030
|
|
def func1(x, y):
|
|
x.data = y
|
|
x.add_(1)
|
|
return x
|
|
|
|
def func2(x, y):
|
|
x.data = y
|
|
y.data = torch.zeros([0])
|
|
return x
|
|
|
|
def func3(x, y):
|
|
z = x
|
|
x.data = y
|
|
y.data = torch.zeros([0])
|
|
return torch.tensor(x is z)
|
|
|
|
funcs = {"func1": func1, "func2": func2, "func3": func3}
|
|
func = funcs[func_name]
|
|
|
|
if backend != "eager" and func is func1:
|
|
# add_ not working w/ aot_autograd?
|
|
return
|
|
|
|
torch._dynamo.reset()
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
|
|
|
compiled_fn = torch.compile(func, backend=cnt, fullgraph=True)
|
|
requires_grad = func is not func1
|
|
for _ in range(0, 5):
|
|
# Inputs
|
|
eager_a = torch.ones([6], requires_grad=requires_grad)
|
|
compiled_a = torch.ones([6], requires_grad=requires_grad)
|
|
|
|
eager_b = torch.ones([6], requires_grad=requires_grad)
|
|
compiled_b = torch.ones([6], requires_grad=requires_grad)
|
|
|
|
# Eager
|
|
out_eager = func(eager_a, eager_b)
|
|
# Compiled
|
|
out_compiled = compiled_fn(compiled_a, compiled_b)
|
|
self.assertEqual(eager_a, compiled_a)
|
|
self.assertEqual(eager_b, compiled_b)
|
|
self.assertTrue(torch.equal(out_eager, out_compiled))
|
|
|
|
# func1 hits a leaf Variable that requires grad is being used in an in-place operation
|
|
if requires_grad:
|
|
bwd_inp_eager = torch.randn([6])
|
|
bwd_inp_compiled = torch.clone(bwd_inp_eager)
|
|
eager_a.backward(bwd_inp_eager)
|
|
compiled_a.backward(bwd_inp_compiled)
|
|
self.assertEqual(eager_a.grad, compiled_a.grad)
|
|
|
|
# Prove guarding works - we run the compiled_fn 5 times
|
|
# frame_count should stay at 1.
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_tensor_set_data_mismatched_dtype(self):
|
|
def func(x, y):
|
|
x.data = y.to(dtype=torch.bfloat16)
|
|
|
|
x1 = torch.tensor([], dtype=torch.float32)
|
|
x2 = torch.tensor([], dtype=torch.float32)
|
|
y1 = torch.tensor([1, 2, 3], dtype=torch.float32)
|
|
y2 = torch.tensor([1, 2, 3], dtype=torch.float32)
|
|
func(x1, y1)
|
|
torch.compile(func, backend="eager")(x2, y2)
|
|
self.assertEqual(x1, x2)
|
|
self.assertEqual(x1.data, x2.data)
|
|
self.assertEqual(y1, y2)
|
|
|
|
def test_user_ctor_ctx_manager(self):
|
|
class UserCtxManager:
|
|
def __enter__(self):
|
|
return 1
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
pass
|
|
|
|
def fn(x, y):
|
|
ucm = UserCtxManager() # noqa: F841
|
|
return x * x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
|
|
x = torch.rand([2, 2])
|
|
opt_fn(x, x)
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_unbacked_arange_in_bounds(self):
|
|
# see https://github.com/pytorch/pytorch/issues/113002
|
|
class PaddingNet(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
|
|
def forward(self, lengths):
|
|
max_seq_len = lengths.max().item()
|
|
row_vector = torch.arange(0, max_seq_len, 1)
|
|
matrix = torch.unsqueeze(lengths, dim=-1)
|
|
mask = row_vector < matrix
|
|
mask = mask.type(torch.float32)
|
|
mask_3d_btd = mask[:, :, None]
|
|
return mask_3d_btd
|
|
|
|
model = PaddingNet()
|
|
lengths = torch.tensor([5, 4, 4, 4], dtype=torch.int32)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(model, backend=cnt, fullgraph=True)
|
|
opt_fn(lengths)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_overlapping_inputs_with_dynamic_shapes_error(self):
|
|
@torch.compile(backend="aot_eager")
|
|
def fn(a, b, c, d, e, f):
|
|
a.mul_(2)
|
|
b.mul_(2)
|
|
c.mul_(2)
|
|
d.mul_(2)
|
|
e.mul_(2)
|
|
f.mul_(2)
|
|
|
|
base = torch.ones(2, 20)
|
|
a = base[:, 0:2]
|
|
b = base[:, 2:4]
|
|
c = base[:, 4:6]
|
|
d = base[:, 6:8]
|
|
e = base[:, 8:10]
|
|
f = base[:, 10:12]
|
|
f2 = base[:, 10:14]
|
|
fn(a, b, c, d, e, f)
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "is being compiled with dynamic shapes"
|
|
):
|
|
fn(a, b, c, d, e, f2)
|
|
|
|
def test_user_ctor_ctx_manager_custom_init(self):
|
|
class UserCtxManager:
|
|
def __init__(self, x):
|
|
x[0] = 10
|
|
|
|
def __enter__(self):
|
|
return 1
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
pass
|
|
|
|
def fn(x, y):
|
|
ucm = UserCtxManager(y) # noqa: F841
|
|
return x * y[0]
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt, fullgraph=True)
|
|
x = torch.rand([2, 2])
|
|
self.assertEqual(opt_fn(x, [5]), fn(x, [5]))
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
|
|
def test_user_ctor_ctx_manager_custom_init_graph_break(self):
|
|
counter = [0]
|
|
|
|
class UserCtxManager:
|
|
def __init__(self, k):
|
|
k[0] += 1
|
|
|
|
def __enter__(self):
|
|
return 1
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
pass
|
|
|
|
def fn(x, counter):
|
|
x = x * x
|
|
ucm = UserCtxManager(counter) # noqa: F841
|
|
return x * x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
x = torch.rand([2, 2])
|
|
self.assertEqual(opt_fn(x, counter), fn(x, counter))
|
|
self.assertEqual(counter[0], 2)
|
|
for _ in range(0, 10):
|
|
opt_fn(x, counter)
|
|
self.assertEqual(counter[0], 12)
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """2""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
|
|
def test_many_overlapping_inputs_does_not_explode_guards(self):
|
|
from torch._dynamo.backends.common import aot_autograd
|
|
|
|
# Before, this was (9702, 0)
|
|
num_shape_guards = None
|
|
num_aot_guards = None
|
|
num_compiles = 0
|
|
|
|
def guard_count_backend(gm, *args):
|
|
nonlocal num_shape_guards
|
|
nonlocal num_aot_guards
|
|
nonlocal num_compiles
|
|
num_shape_guards = len(
|
|
torch._guards.TracingContext.try_get().fake_mode.shape_env.guards
|
|
)
|
|
num_aot_guards = len(
|
|
torch._guards.TracingContext.try_get().guards_context.aotautograd_guards
|
|
)
|
|
num_compiles += 1
|
|
return gm
|
|
|
|
aot_guard_counter = aot_autograd(fw_compiler=guard_count_backend)
|
|
|
|
@torch.compile(backend=aot_guard_counter, dynamic=True)
|
|
def f(*args):
|
|
for a in args:
|
|
a.add_(1)
|
|
|
|
x = torch.ones(1000, requires_grad=True)
|
|
args = x.split(10)
|
|
|
|
with torch.no_grad():
|
|
f(*args)
|
|
# In this example, there were 4950 guards (roughly (# tensors) ^ 2 // 2),
|
|
# because every pair of aliased inputs needs a guard.
|
|
self.assertTrue(num_aot_guards < 5000)
|
|
# But there are no dynamic shape guards.
|
|
self.assertEqual(num_shape_guards, 0)
|
|
# don't recompile
|
|
with torch.no_grad():
|
|
f(*args)
|
|
self.assertEqual(num_compiles, 1)
|
|
|
|
def test_issue134451(self):
|
|
class BoundingBox2DIndex(IntEnum):
|
|
_X = 0
|
|
_Y = 1
|
|
_HEADING = 2
|
|
_LENGTH = 3
|
|
_WIDTH = 4
|
|
|
|
@classmethod
|
|
def size(cls):
|
|
return 5
|
|
|
|
@classmethod
|
|
@property
|
|
def X(cls):
|
|
return cls._X
|
|
|
|
@classmethod
|
|
@property
|
|
def Y(cls):
|
|
return cls._Y
|
|
|
|
@classmethod
|
|
@property
|
|
def HEADING(cls):
|
|
return cls._HEADING
|
|
|
|
@classmethod
|
|
@property
|
|
def LENGTH(cls):
|
|
return cls._LENGTH
|
|
|
|
@classmethod
|
|
@property
|
|
def WIDTH(cls):
|
|
return cls._WIDTH
|
|
|
|
@classmethod
|
|
@property
|
|
def POINT(cls):
|
|
# assumes X, Y have subsequent indices
|
|
return slice(cls._X, cls._Y + 1)
|
|
|
|
@classmethod
|
|
@property
|
|
def STATE_SE2(cls):
|
|
# assumes X, Y, HEADING have subsequent indices
|
|
return slice(cls._X, cls._HEADING + 1)
|
|
|
|
class SimpleModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._mlp_states = nn.Sequential(
|
|
nn.Linear(10, 20),
|
|
nn.ReLU(),
|
|
nn.Linear(20, BoundingBox2DIndex.size()),
|
|
)
|
|
|
|
def forward(self, x):
|
|
agent_states = self._mlp_states(x)
|
|
agent_states[..., BoundingBox2DIndex.POINT] = (
|
|
agent_states[..., BoundingBox2DIndex.POINT].tanh() * 32
|
|
)
|
|
agent_states[..., BoundingBox2DIndex.HEADING] = (
|
|
agent_states[..., BoundingBox2DIndex.HEADING].tanh() * torch.pi
|
|
)
|
|
return agent_states
|
|
|
|
model = SimpleModel().eval()
|
|
input_tensor = torch.randn(1, 10, dtype=torch.float32)
|
|
opt = torch.compile(model.eval(), backend="eager", fullgraph=True)
|
|
actual = opt(input_tensor)
|
|
try:
|
|
expected = model(input_tensor)
|
|
except Exception as e:
|
|
raise unittest.SkipTest("eager failed, requires Python>=3.12") from e
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_invalid_seq_unpack(self):
|
|
def myfn(arg):
|
|
(a, b) = arg # noqa: F841
|
|
|
|
def fn():
|
|
return myfn((1, 2, 3))
|
|
|
|
try:
|
|
torch.compile(fn)()
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
self.fail("expected exception")
|
|
|
|
def test_udf_classes_reconstruction(self):
|
|
def fn(x):
|
|
o = T(5)
|
|
return o.x + x
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
T = IncByOne
|
|
|
|
x = torch.randn(4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
# This should recompile
|
|
T = IncByTwo
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_contains_range_constprop(self):
|
|
def fn(x):
|
|
# dynamo should const prop to False
|
|
if 3 in range(0, 10):
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
x = torch.zeros(4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
# https://github.com/pytorch/pytorch/issues/104505
|
|
def test_as_strided_on_base_with_mutation_works(self):
|
|
def foo(a):
|
|
f = a.as_strided((2,), (1,), 0)
|
|
f.add_(1.0)
|
|
return a
|
|
|
|
a = torch.randn(2, 4)
|
|
a_ref = a.clone()
|
|
out_ref = foo(a_ref)
|
|
f_compiled = torch.compile(foo, backend="aot_eager")
|
|
out = f_compiled(a)
|
|
self.assertEqual(out_ref, out)
|
|
self.assertEqual(a_ref, a)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/104505
|
|
def test_as_strided_on_existing_view_banned(self):
|
|
def foo(a):
|
|
e = a.diagonal()
|
|
f = e.as_strided((2,), (1,), 0)
|
|
f.add_(1.0)
|
|
return a
|
|
|
|
a = torch.randn(2, 4)
|
|
a_ref = a.clone()
|
|
foo(a_ref)
|
|
f_compiled = torch.compile(foo, backend="aot_eager")
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"encountered a mutation on a view chain of length 2, where view 1 was an as_strided",
|
|
):
|
|
f_compiled(a)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/146598
|
|
@unittest.expectedFailure
|
|
def test_lru_cache_tracing(self):
|
|
from functools import lru_cache
|
|
|
|
counter = 0
|
|
|
|
@lru_cache
|
|
def cached_fn(x):
|
|
nonlocal counter
|
|
counter += 1
|
|
return x + 1
|
|
|
|
compiled_fn = torch.compile(cached_fn, backend="eager")
|
|
|
|
t = torch.randn(2, 2)
|
|
result1 = compiled_fn(t)
|
|
self.assertEqual(counter, 1)
|
|
|
|
result2 = compiled_fn(t)
|
|
self.assertEqual(counter, 1)
|
|
self.assertEqual(result1, result2)
|
|
|
|
def test_dont_aggressively_write_assert(self):
|
|
record_graph = torch._dynamo.testing.EagerAndRecordGraphs()
|
|
|
|
@torch.compile(dynamic=True, backend=record_graph)
|
|
def f(x):
|
|
assert x.shape[0] > 3
|
|
assert x[0].sum() > 0
|
|
assert 1 % (x.shape[0] // 2) != 0
|
|
assert 32 * (x.shape[0] // 2) ** 2 - 16 * (x.shape[0] // 2) != 0
|
|
return x.cos()
|
|
|
|
f(torch.ones(6, 4))
|
|
graph = record_graph.graphs[0]
|
|
# It is bit annoying that we generate useless statements for
|
|
# shape guards, but DCE should be able to remove them since t
|
|
# there is no backed assert on them. The reason this is ok is
|
|
# because dynamo will only skip the assert statement, but not
|
|
# the instructions before it.
|
|
self.assertExpectedInline(
|
|
str(graph.code).strip(),
|
|
"""\
|
|
def forward(self, s77 : torch.SymInt, s27 : torch.SymInt, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
getitem_2 = l_x_[0]
|
|
sum_1 = getitem_2.sum(); getitem_2 = None
|
|
gt_1 = sum_1 > 0; sum_1 = None
|
|
_assert_async = torch._assert_async(gt_1, 'assertion error'); gt_1 = _assert_async = None
|
|
cos = l_x_.cos(); l_x_ = None
|
|
return (cos,)""",
|
|
)
|
|
for node in graph.graph.nodes:
|
|
if "example_value" in node.meta and isinstance(
|
|
node.meta["example_value"], torch._subclasses.fake_tensor.FakeTensor
|
|
):
|
|
shape_env = node.meta["example_value"].fake_mode.shape_env
|
|
lower_ranges = [val.lower for val in shape_env.var_to_range.values()]
|
|
self.assertTrue(lower_ranges == [4, 2])
|
|
|
|
@torch.compile(dynamic=True, backend=record_graph)
|
|
def f_fail(x):
|
|
assert x.shape[0] < 3
|
|
|
|
# We graph-break here, so the failure should be eager
|
|
with self.assertRaisesRegex(AssertionError, ""):
|
|
f_fail(torch.ones(6, 4))
|
|
|
|
def test_detectron2_instances_cat(self):
|
|
class Instances:
|
|
def __init__(self, image_size: tuple[int, int], **kwargs: Any):
|
|
self._image_size = image_size
|
|
self._fields: dict[str, Any] = {}
|
|
for k, v in kwargs.items():
|
|
self.set(k, v)
|
|
|
|
@property
|
|
def image_size(self) -> tuple[int, int]:
|
|
return self._image_size
|
|
|
|
def __setattr__(self, name: str, val: Any) -> None:
|
|
if name.startswith("_"):
|
|
super().__setattr__(name, val)
|
|
else:
|
|
self.set(name, val)
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
if name == "_fields" or name not in self._fields:
|
|
raise AttributeError(
|
|
f"Cannot find field '{name}' in the given Instances!"
|
|
)
|
|
return self._fields[name]
|
|
|
|
def __len__(self) -> int:
|
|
for v in self._fields.values():
|
|
# use __len__ because len() has to be int and is not friendly to tracing
|
|
return v.__len__()
|
|
raise NotImplementedError("Empty Instances does not support __len__!")
|
|
|
|
def set(self, name: str, value: Any) -> None:
|
|
with warnings.catch_warnings(record=True):
|
|
data_len = len(value)
|
|
if len(self._fields):
|
|
assert len(self) == data_len, (
|
|
f"Adding a field of length {data_len} to a Instances of length {len(self)}"
|
|
)
|
|
self._fields[name] = value
|
|
|
|
def get(self, name: str) -> Any:
|
|
return self._fields[name]
|
|
|
|
@staticmethod
|
|
def cat(instance_lists: list["Instances"]) -> "Instances":
|
|
assert all(isinstance(i, Instances) for i in instance_lists)
|
|
assert len(instance_lists) > 0
|
|
if len(instance_lists) == 1:
|
|
return instance_lists[0]
|
|
|
|
image_size = instance_lists[0].image_size
|
|
if not isinstance(
|
|
image_size, torch.Tensor
|
|
): # could be a tensor in tracing
|
|
for i in instance_lists[1:]:
|
|
assert i.image_size == image_size
|
|
ret = Instances(image_size)
|
|
for k in instance_lists[0]._fields.keys():
|
|
values = [i.get(k) for i in instance_lists]
|
|
v0 = values[0]
|
|
if isinstance(v0, torch.Tensor):
|
|
values = torch.cat(values, dim=0)
|
|
elif isinstance(v0, list):
|
|
values = list(itertools.chain(*values))
|
|
elif hasattr(type(v0), "cat"):
|
|
values = type(v0).cat(values)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported type {type(v0)} for concatenation"
|
|
)
|
|
ret.set(k, values)
|
|
return ret
|
|
|
|
instances = [
|
|
Instances((16, 16), a=torch.randn(16, 16), b=torch.randn(16, 16))
|
|
for _ in range(3)
|
|
]
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(instances):
|
|
return instances[0].cat(instances)
|
|
|
|
actual = fn(instances)
|
|
expected = instances[0].cat(instances)
|
|
self.assertEqual(type(actual), type(expected))
|
|
self.assertEqual(actual.__dict__, expected.__dict__)
|
|
|
|
def test_weakref_construction(self):
|
|
def fn(x, y):
|
|
x_weak = weakref.ref(x)
|
|
return x_weak() * y
|
|
|
|
x = torch.randn(4)
|
|
y = torch.randn(4)
|
|
|
|
ref = fn(x, y)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x, y)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_weakref(self):
|
|
def fn(x_weak, weight, y):
|
|
if x_weak is not None and x_weak() is not weight:
|
|
return torch.sin(y)
|
|
return torch.cos(y)
|
|
|
|
weight = torch.randn(4)
|
|
y = torch.randn(4)
|
|
x_weak = weakref.ref(weight)
|
|
|
|
ref = fn(x_weak, weight, y)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x_weak, weight, y)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_weakref_reconstruct(self):
|
|
def fn(x_weak, weight, y):
|
|
y = torch.sin(y)
|
|
referent = x_weak()
|
|
torch._dynamo.graph_break()
|
|
if referent is not weight:
|
|
return torch.sin(y)
|
|
return torch.cos(y)
|
|
|
|
weight = torch.randn(4)
|
|
y = torch.randn(4)
|
|
x_weak = weakref.ref(weight)
|
|
|
|
ref = fn(x_weak, weight, y)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
res = opt_fn(x_weak, weight, y)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_return_weakref(self):
|
|
def f(t):
|
|
t = t * 2
|
|
wr = weakref.ref(t)
|
|
return wr, t
|
|
|
|
ref_t = torch.randn(2, 2, requires_grad=True)
|
|
ref_y = f(ref_t)
|
|
|
|
t = ref_t.detach().clone().requires_grad_()
|
|
y = torch.compile(f, backend="eager", fullgraph=True)(t)
|
|
self.assertEqual(ref_y[0](), y[0]())
|
|
|
|
def test_weakref_del(self):
|
|
def fn(x_weak, y):
|
|
x = x_weak()
|
|
if x is not None:
|
|
return torch.sin(y)
|
|
return torch.cos(y)
|
|
|
|
weight = torch.randn(4)
|
|
x_weak = weakref.ref(weight)
|
|
y = torch.randn(4)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
ref = fn(x_weak, y)
|
|
res = opt_fn(x_weak, y)
|
|
self.assertEqual(ref, res)
|
|
|
|
del weight
|
|
gc.collect()
|
|
ref = fn(x_weak, y)
|
|
res = opt_fn(x_weak, y)
|
|
self.assertEqual(ref, res)
|
|
|
|
# The programming model around (weak)references is that we DO NOT guarantee
|
|
# any behavior that depends on deallocation order. We do guarantee "eventual consistency",
|
|
# that is, after the torch.compile'd function is finished running (including any graph breaks),
|
|
# refcount semantics will match eager's.
|
|
def test_weakref_callback(self):
|
|
called1 = False
|
|
|
|
def callback1(ref):
|
|
nonlocal called1
|
|
called1 = True
|
|
if not torch.compiler.is_compiling():
|
|
raise RuntimeError("callback1 expected to be compiled")
|
|
|
|
# weakref callbacks that should be called in the compiled region will be compiled.
|
|
# But the exact place in the compiled code that the callback is made is undefined.
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
y = x + 1
|
|
ref = weakref.ref(y, callback1)
|
|
torch._dynamo.graph_break()
|
|
return ref
|
|
|
|
fn(torch.ones(3))
|
|
self.assertTrue(called1)
|
|
|
|
called2 = False
|
|
|
|
def callback2(ref):
|
|
nonlocal called2
|
|
called2 = True
|
|
if torch.compiler.is_compiling():
|
|
raise RuntimeError("callback2 expected to not be compiled")
|
|
|
|
# weakref callbacks that fire outside the compiled region work
|
|
@torch.compile(backend="eager")
|
|
def gn(x):
|
|
y = x + 1
|
|
ref = weakref.ref(y, callback2)
|
|
torch._dynamo.graph_break()
|
|
return y, ref
|
|
|
|
y, _ = gn(torch.ones(3))
|
|
del y
|
|
self.assertTrue(called2)
|
|
|
|
def callback3(ref):
|
|
raise RuntimeError("callback3 should not be called")
|
|
|
|
# The callback will NOT be called if both the weakref and the referrent are
|
|
# deleted in the same compiled region (graph breaks act like a "memory sync"
|
|
# and thus make things tricky - the callback is actually expected to be called).
|
|
# This test does NOT mean that this behavior is part of the (weak)ref programming
|
|
# model, but rather reminds us that this is an intentionally allowed weakref-Dynamo behavior.
|
|
@torch.compile(backend="eager")
|
|
def hn(x):
|
|
y = x + 1
|
|
_ = weakref.ref(y, callback3)
|
|
|
|
hn(torch.ones(3))
|
|
|
|
# @torch._functorch.config.patch(
|
|
# recompute_views=True,
|
|
# )
|
|
# def test_storage_resize_forward_full_graph(self):
|
|
# class TestModule(torch.nn.Module):
|
|
# def __init__(self) -> None:
|
|
# super().__init__()
|
|
# self.param = torch.nn.Parameter(torch.randn(4, 4))
|
|
|
|
# def forward(self, x):
|
|
# self.param.untyped_storage().resize_(
|
|
# self.param.numel() * self.param.itemsize
|
|
# )
|
|
# with torch.no_grad():
|
|
# torch._foreach_copy_([self.param], [x])
|
|
# out = torch.matmul(self.param, self.param)
|
|
# self.param.untyped_storage().resize_(0)
|
|
# return out
|
|
|
|
# def post_accumulate_grad_hook(param):
|
|
# param.untyped_storage().resize_(0)
|
|
|
|
# # Beginning of backward, resize and put data into the param
|
|
# def pre_backward_hook(module, grad) -> None:
|
|
# module.param.untyped_storage().resize_(
|
|
# self.param.numel() * self.param.itemsize
|
|
# )
|
|
# with torch.no_grad():
|
|
# # simulates loading data into param from allgather
|
|
# module.param.fill_(2)
|
|
|
|
# def post_forward_hook(module, args, output):
|
|
# output.register_hook(functools.partial(pre_backward_hook, module))
|
|
|
|
# x = torch.randn(4, 4)
|
|
|
|
# mod_ref = TestModule()
|
|
# mod_test = deepcopy(mod_ref)
|
|
|
|
# # Start the param off with zero storage size to mimic fsdp
|
|
# mod_ref.param.untyped_storage().resize_(0)
|
|
# mod_test.param.untyped_storage().resize_(0)
|
|
|
|
# # Resize storage at beginning of backward
|
|
# # Free storage at end of backward
|
|
# mod_ref.register_forward_hook(post_forward_hook, prepend=False)
|
|
# mod_ref.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
|
|
# mod_test.register_forward_hook(post_forward_hook, prepend=False)
|
|
# mod_test.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
|
|
|
|
# mod_test = torch.compile(mod_test, backend=aot_graph_capture_backend)
|
|
|
|
# out_ref = mod_ref(x)
|
|
# out_test = mod_test(x)
|
|
# self.assertExpectedInline(
|
|
# str(fw_graph[0].code.strip()),
|
|
# """\
|
|
# def forward(self, primals_1, primals_2):
|
|
# _foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_1 = primals_2 = None
|
|
# getitem = _foreach_copy[0]; _foreach_copy = None
|
|
# mm = torch.ops.aten.mm.default(getitem, getitem)
|
|
# return [mm, getitem]""",
|
|
# )
|
|
# self.assertEqual(out_ref, out_test)
|
|
|
|
def test_super_in_staticmethod(self):
|
|
class A:
|
|
@staticmethod
|
|
def foo():
|
|
return super().__init__()
|
|
|
|
def fn(obj):
|
|
return obj.foo()
|
|
|
|
obj = A()
|
|
|
|
try:
|
|
fn(obj)
|
|
except Exception as e:
|
|
orig_str = str(e)
|
|
self.assertIn("no arguments", orig_str)
|
|
|
|
try:
|
|
torch.compile(backend="eager")(fn)(obj)
|
|
except Exception as e:
|
|
compiled_str = str(e)
|
|
self.assertEqual(orig_str, compiled_str)
|
|
|
|
def test_super_staticmethod(self):
|
|
class Parent:
|
|
@staticmethod
|
|
def greet():
|
|
return 5
|
|
|
|
class Child(Parent):
|
|
@staticmethod
|
|
def greet(x):
|
|
return x * super(Child, Child).greet()
|
|
|
|
child = Child()
|
|
|
|
def fn(x):
|
|
return child.greet(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.ones(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_super_classmethod(self):
|
|
class Parent:
|
|
@classmethod
|
|
def greet(cls):
|
|
if cls == Parent:
|
|
return 4
|
|
if cls == Child:
|
|
return 3
|
|
if cls == GrandChild:
|
|
return 5
|
|
return 2
|
|
|
|
class Child(Parent):
|
|
def greet(self, x):
|
|
return x * super().greet()
|
|
|
|
class GrandChild(Child):
|
|
pass
|
|
|
|
grand_child = GrandChild()
|
|
|
|
def fn(x):
|
|
return grand_child.greet(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.ones(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_super_classmethod_inheritance(self):
|
|
class GrandParent:
|
|
@classmethod
|
|
def greet(cls, x):
|
|
return cls.A * x
|
|
|
|
class Parent(GrandParent):
|
|
@classmethod
|
|
def greet(cls, x):
|
|
return super().greet(x)
|
|
|
|
class Child(Parent):
|
|
A = 5
|
|
|
|
@classmethod
|
|
def greet(cls, x):
|
|
return super().greet(x)
|
|
|
|
child = Child()
|
|
|
|
def fn(x):
|
|
return child.greet(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.ones(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_super_diamond(self):
|
|
class A:
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.a = 5
|
|
|
|
class Nothing:
|
|
pass
|
|
|
|
class B(Nothing, A):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.b = 10
|
|
|
|
def run(self, x):
|
|
return self.a * self.b * x
|
|
|
|
def fn(x):
|
|
b = B()
|
|
return b.run(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_vc_bumped_in_inference_graph(self):
|
|
@torch.compile
|
|
def f(x):
|
|
return x.mul_(2)
|
|
|
|
x = torch.randn(4)
|
|
vc_before = x._version
|
|
f(x)
|
|
vc_after = x._version
|
|
self.assertTrue(vc_after > vc_before)
|
|
|
|
def test_nn_module_callable(self):
|
|
class M(nn.Module):
|
|
def forward(self, x):
|
|
return x.sin()
|
|
|
|
def f(m):
|
|
return callable(m)
|
|
|
|
res = torch.compile(f, fullgraph=True)(M())
|
|
self.assertTrue(res)
|
|
|
|
def test_stk_sdd_is_transposed(self):
|
|
def _is_transposed(x):
|
|
return (
|
|
not x.is_contiguous()
|
|
and x.stride()[0] == 1
|
|
and x.stride()[1] == x.size()[0]
|
|
)
|
|
|
|
class SDD(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, lhs, rhs):
|
|
ctx.save_for_backward(lhs, rhs)
|
|
out = torch.full_like(lhs, 1.0, dtype=lhs.dtype, device=lhs.device)
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, dy):
|
|
saved_tensors = ctx.saved_tensors
|
|
lhs, rhs = saved_tensors[:2]
|
|
trans_a = _is_transposed(lhs)
|
|
trans_b = _is_transposed(rhs)
|
|
dlhs = None
|
|
if ctx.needs_input_grad[0]:
|
|
dlhs = torch.full_like(lhs, 1.0 if trans_a else 2.0)
|
|
drhs = None
|
|
if ctx.needs_input_grad[1]:
|
|
drhs = torch.full_like(rhs, 1.0 if trans_b else 2.0)
|
|
return dlhs, drhs, None, None
|
|
|
|
x1 = torch.randn((8, 8), requires_grad=True)
|
|
y1 = torch.randn((8, 8)).transpose(0, 1).requires_grad_(True)
|
|
x2 = torch.randn((8, 8), requires_grad=True)
|
|
y2 = torch.randn((8, 8)).transpose(0, 1).requires_grad_(True)
|
|
|
|
SDD.apply(x1, y1).sum().backward()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn():
|
|
return SDD.apply(x2, y2)
|
|
|
|
fn().sum().backward()
|
|
|
|
self.assertEqual(x1.grad, x2.grad)
|
|
self.assertEqual(y1.grad, y2.grad)
|
|
|
|
def test_partially_initialized_module_property(self):
|
|
class Matrix(torch.nn.Module):
|
|
def __init__(self, data):
|
|
super().__init__()
|
|
self._data = data
|
|
self.foo = 10 * self.blocking
|
|
|
|
@property
|
|
def data(self):
|
|
return self._data
|
|
|
|
@property
|
|
def blocking(self):
|
|
return self.data.shape[1]
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn():
|
|
return Matrix(torch.randn(10, 20))
|
|
|
|
v = fn()
|
|
self.assertEqual(v.foo, 200)
|
|
self.assertEqual(v.data.shape, (10, 20))
|
|
self.assertEqual(type(v), Matrix)
|
|
|
|
def test_classmethod_with_slots(self):
|
|
class Mock:
|
|
__slots__ = ("_a",)
|
|
|
|
def __init__(self):
|
|
self._a = 2
|
|
|
|
@classmethod
|
|
def _m(cls):
|
|
return 3
|
|
|
|
def run(self, x):
|
|
return torch.sin(x) * self._a * self._m()
|
|
|
|
def fn(x):
|
|
mock = Mock()
|
|
return mock.run(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_nn_parametrize(self):
|
|
class Module(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.randn(10, 10))
|
|
|
|
def forward(self, x):
|
|
return self.param @ x
|
|
|
|
class Parametrization(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.sin(x)
|
|
|
|
m = Module()
|
|
torch.nn.utils.parametrize.register_parametrization(
|
|
m, "param", Parametrization()
|
|
)
|
|
|
|
sin_found = False
|
|
|
|
def backend(gm, _):
|
|
nonlocal sin_found
|
|
for node in gm.graph.nodes:
|
|
if node.target is torch.sin:
|
|
sin_found = True
|
|
return gm
|
|
|
|
opt_m = torch.compile(m, backend=backend, fullgraph=True)
|
|
inp = torch.randn(10, 10)
|
|
self.assertEqual(m(inp), opt_m(inp))
|
|
self.assertTrue(sin_found)
|
|
|
|
torch.nn.utils.parametrize.remove_parametrizations(m, "param")
|
|
sin_found = False
|
|
self.assertEqual(m(inp), opt_m(inp))
|
|
self.assertFalse(sin_found)
|
|
|
|
def test_nn_module_property_closure(self):
|
|
x = torch.randn(10, 10)
|
|
|
|
class Mod(torch.nn.Module):
|
|
@property
|
|
def y(self):
|
|
return torch.ones(10, 10) + x
|
|
|
|
def forward(self, x):
|
|
return x @ self.y
|
|
|
|
mod = Mod()
|
|
|
|
def fn(x):
|
|
return mod(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
inp = torch.randn(10, 10)
|
|
self.assertEqual(fn(inp), opt_fn(inp))
|
|
|
|
def test_global_fn_mutation(self):
|
|
def foo(x, y):
|
|
return global_fn(x) + y
|
|
|
|
x = torch.ones(1)
|
|
y = torch.ones(1)
|
|
|
|
opt = torch.compile(foo, fullgraph=True, backend="eager")
|
|
self.assertEqual(opt(x, y), foo(x, y))
|
|
|
|
# Change global_fn
|
|
global global_fn
|
|
|
|
def new_fn(x):
|
|
return torch.cos(x)
|
|
|
|
global_fn = new_fn
|
|
self.assertEqual(opt(x, y), foo(x, y))
|
|
|
|
# ref https://github.com/pytorch/pytorch/issues/123974
|
|
def test_list_reverse(self):
|
|
def ladder(x):
|
|
trail = x.size(-1)
|
|
assert trail > 2
|
|
weights = []
|
|
for s in [trail, trail - 1, trail - 2]:
|
|
weights.append(torch.ones(s, s - 1))
|
|
|
|
for w in weights:
|
|
x = x @ w
|
|
|
|
weights.reverse()
|
|
|
|
for w in weights:
|
|
x = x @ w.t()
|
|
|
|
return x
|
|
|
|
data = torch.randn(3, 4)
|
|
opt_ladder = torch.compile(ladder, fullgraph=True, backend="eager")
|
|
self.assertEqual(opt_ladder(data), ladder(data))
|
|
|
|
def test_trace_functional_tensor_with(self):
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch._subclasses.functional_tensor import (
|
|
FunctionalTensor,
|
|
FunctionalTensorMode,
|
|
)
|
|
|
|
def f(a, tmp):
|
|
a_view = a.view(-1)
|
|
with torch.no_grad():
|
|
a.set_(tmp)
|
|
a_view.mul_(2)
|
|
return a + tmp
|
|
|
|
fake_mode = FakeTensorMode()
|
|
with FunctionalTensorMode():
|
|
inp = torch.ones(3, 3, requires_grad=True)
|
|
inp = fake_mode.from_tensor(inp, static_shapes=True)
|
|
inp = FunctionalTensor.to_functional(inp)
|
|
|
|
tmp = torch.ones(3, 3, requires_grad=True)
|
|
tmp = fake_mode.from_tensor(tmp, static_shapes=True)
|
|
tmp = FunctionalTensor.to_functional(tmp)
|
|
|
|
opt_f = torch.compile(f, backend="eager")
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "cannot mutate tensors with frozen storage"
|
|
):
|
|
opt_f(inp, tmp)
|
|
|
|
def test_const_dict_keyerror(self):
|
|
d = {}
|
|
|
|
def fn(x):
|
|
try:
|
|
y = d[0]
|
|
except KeyError:
|
|
y = 1
|
|
return x + y
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
inp = torch.randn(3, 3)
|
|
self.assertEqual(fn(inp), opt_fn(inp))
|
|
|
|
def test_nonconst_issubclass(self):
|
|
def fn(x):
|
|
if issubclass(x.__class__, np.ndarray):
|
|
return 1
|
|
return 0
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
opt_fn(np.ones([3, 3]))
|
|
|
|
def test_issue126128(self):
|
|
def fn():
|
|
x = torch.randn(1, 10)
|
|
y = torch.randn(10, 1)
|
|
return torch.mm(x, y).sum()
|
|
|
|
def fn2():
|
|
x = torch.randn(10, 100)
|
|
y = torch.randn(100, 10)
|
|
return torch.mm(x, y).sum()
|
|
|
|
with fresh_cache():
|
|
torch.compile(fn)()
|
|
|
|
torch.compile(fn2)()
|
|
|
|
def test_jit_script_defaults(self):
|
|
@torch.jit.script
|
|
def fast_cos(x, c: float = 2.0):
|
|
return torch.cos(x) * c
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.fast_cos = fast_cos
|
|
|
|
def forward(self, x):
|
|
return self.fast_cos(x)
|
|
|
|
mod = Mod()
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
|
|
self.assertEqual(mod(x), opt_mod(x))
|
|
|
|
def test_enum(self):
|
|
class ExplicitEnum(str, Enum):
|
|
@classmethod
|
|
def _missing_(cls, value):
|
|
raise ValueError(
|
|
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
|
|
)
|
|
|
|
class PaddingStrategy(ExplicitEnum):
|
|
LONGEST = "longest"
|
|
MAX_LENGTH = "max_length"
|
|
DO_NOT_PAD = "do_not_pad"
|
|
|
|
def fn(x):
|
|
a = PaddingStrategy("longest")
|
|
if a == PaddingStrategy.LONGEST:
|
|
return torch.sin(x)
|
|
return torch.cos(x)
|
|
|
|
x = torch.randn(3, 3)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_hasattr_builtin(self):
|
|
class MyClass:
|
|
foo: int = 1
|
|
|
|
def func(x, m):
|
|
if getattr(type(m), "foo", 0):
|
|
return x + MyClass.foo
|
|
return x
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True)
|
|
m = MyClass()
|
|
x = torch.zeros(())
|
|
self.assertEqual(func(x, m), opt_func(x, m))
|
|
self.assertEqual(func(x, 0), opt_func(x, 0))
|
|
|
|
def test_grad(self):
|
|
# Write to `grad` or `_grad` should reflecte in reading from the other,
|
|
# and should be codegen-ed.
|
|
def fn(x, y):
|
|
x._grad = y + 1
|
|
y.grad = x + 2
|
|
return x.grad.data, y._grad.data
|
|
|
|
x0 = torch.randn(4, requires_grad=True)
|
|
y0 = torch.randn(4, requires_grad=True)
|
|
x1 = x0.clone()
|
|
y1 = y0.clone()
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
self.assertEqual(fn(x0, y0), opt_fn(x1, y1))
|
|
self.assertEqual(x0.grad, x1.grad)
|
|
self.assertEqual(y0.grad, y1.grad)
|
|
|
|
def test_nn_module_stack_bc(self):
|
|
from torch._dynamo.mutation_guard import GenerationTracker
|
|
|
|
def compiler(gm, *args):
|
|
module_stacks = [
|
|
node.meta.get("nn_module_stack", None) for node in gm.graph.nodes
|
|
]
|
|
module_stacks, _ = pytree.tree_flatten(module_stacks)
|
|
module_stacks = [x for x in module_stacks if isinstance(x, str)]
|
|
for stack in module_stacks:
|
|
self.assertTrue("_module" not in stack)
|
|
return gm.forward
|
|
|
|
class SubMod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.submod1 = SubMod()
|
|
self.submod2 = SubMod()
|
|
|
|
def forward(self, x):
|
|
return self.submod1(x) + self.submod2(x)
|
|
|
|
mod = Mod()
|
|
opt_mod = torch.compile(mod, backend=compiler)
|
|
opt_mod(torch.randn(2, 2))
|
|
|
|
with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True):
|
|
mod = Mod()
|
|
opt_mod = torch.compile(mod, backend=compiler)
|
|
opt_mod(torch.randn(2, 2))
|
|
|
|
# an example similar to Pippy usecase
|
|
mod = Mod()
|
|
GenerationTracker.tag(mod.submod1)
|
|
GenerationTracker.mark_class_dynamic(type(mod.submod1))
|
|
mod = Mod()
|
|
opt_mod = torch.compile(mod, backend=compiler)
|
|
opt_mod(torch.randn(2, 2))
|
|
|
|
def test_is_make_fx_tracing(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
torch.nn.modules.activation._is_make_fx_tracing()
|
|
return torch.sin(x)
|
|
|
|
fn(torch.rand(4))
|
|
|
|
def test_negative_floor_div_solve(self):
|
|
class CompiledClass(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.nums = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
|
self.t = 5
|
|
|
|
def forward(self):
|
|
self.num = self.nums[self.t // 12]
|
|
self.t += 1
|
|
return self.num
|
|
|
|
m = CompiledClass()
|
|
m = torch.compile(m, backend="eager")
|
|
|
|
# the first call works
|
|
m()
|
|
# the second call causes a failure
|
|
m()
|
|
|
|
# https://github.com/pytorch/pytorch/issues/121621
|
|
def test_tensor_random(self):
|
|
def random_op(tensor, args, kwargs):
|
|
res = tensor.random_(*args, **kwargs)
|
|
return res
|
|
|
|
random_op = torch.compile(random_op)
|
|
tensor = torch.randn([2, 3])
|
|
random_op(tensor, [], {"from": -10, "to": 10})
|
|
random_op(tensor, [-10], {"to": 10})
|
|
random_op(tensor, [-10, 10], {})
|
|
|
|
# https://github.com/pytorch/pytorch/issues/131019
|
|
def test_tensor_uniform(self):
|
|
def uniform_op(tensor, args, kwargs):
|
|
res = tensor.uniform_(*args, **kwargs)
|
|
return res
|
|
|
|
uniform_op = torch.compile(uniform_op)
|
|
tensor = torch.randn([2, 3])
|
|
uniform_op(tensor, [], {"from": -10, "to": 10})
|
|
uniform_op(tensor, [-10], {"to": 10})
|
|
uniform_op(tensor, [-10, 10], {})
|
|
|
|
def test_data_attr_mutation_after_saved_for_bw(self):
|
|
def f(x):
|
|
out = x.sin()
|
|
x.data.mul_(2)
|
|
return out
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
x_test = x.detach().clone().requires_grad_(True)
|
|
|
|
out = f(x)
|
|
out_test = torch.compile(f, backend="aot_eager")(x_test)
|
|
self.assertEqual(out, out_test)
|
|
|
|
out.sum().backward()
|
|
out_test.sum().backward()
|
|
self.assertEqual(x.grad, x_test.grad)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/128072
|
|
def test_map_with_multiple_args(self):
|
|
def f(a, b):
|
|
return a[0] * b[0] + a[1] * b[1]
|
|
|
|
def gen_inps(len_x, len_y):
|
|
x = [torch.randn(5) for _ in range(len_x)]
|
|
y = [torch.randn(5) for _ in range(len_y)]
|
|
return x, y
|
|
|
|
def g(x, y):
|
|
return map(f, x, y)
|
|
|
|
opt_g = torch.compile(g, fullgraph=True, backend="eager")
|
|
|
|
inps = gen_inps(3, 3)
|
|
self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
|
|
self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
|
|
|
|
inps = gen_inps(3, 5)
|
|
self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
|
|
self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
|
|
|
|
def test_staticmethod_allow_in_graph(self):
|
|
class MyClass:
|
|
i = 3
|
|
|
|
@staticmethod
|
|
def foo_inner(x):
|
|
return torch.mul(x, MyClass.i)
|
|
|
|
# if dynamo inlines with fullgraph, will error
|
|
# verify that dynamo doesn't inline
|
|
@staticmethod
|
|
@torch._dynamo.allow_in_graph
|
|
def foo1(x):
|
|
torch._dynamo.graph_break()
|
|
return MyClass.foo_inner(x)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f_bad(x):
|
|
return MyClass.foo1(x)
|
|
|
|
f_bad(torch.ones(2, 2))
|
|
|
|
def test_guard_with_tuple_mutation(self):
|
|
class Foo:
|
|
def __init__(self) -> None:
|
|
self.x = 10
|
|
|
|
foo = Foo()
|
|
d = {
|
|
"a": 2,
|
|
"b": (foo,),
|
|
}
|
|
|
|
def fn(x, d):
|
|
return x * d["a"] * d["b"][0].x
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
inp = torch.randn(3, 3)
|
|
self.assertEqual(fn(inp, d), opt_fn(inp, d))
|
|
d["b"][0].x = 12
|
|
self.assertEqual(fn(inp, d), opt_fn(inp, d))
|
|
|
|
def test_compile_complex_conj(self):
|
|
def f(x):
|
|
return torch.mul(x, 2j)
|
|
|
|
x_ref = torch.randn(4, 2, requires_grad=True)
|
|
x_test = x_ref.detach().clone().requires_grad_(True)
|
|
|
|
out_ref = f(torch.view_as_complex(x_ref))
|
|
out_test = torch.compile(f, backend="aot_eager")(torch.view_as_complex(x_test))
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
torch.view_as_real(out_ref).sum().backward()
|
|
torch.view_as_real(out_test).sum().backward()
|
|
self.assertEqual(x_ref.grad, x_test.grad)
|
|
|
|
@unittest.skipIf(
|
|
not SM70OrLater,
|
|
"Triton only supports devices of CUDA capability >= 7.0",
|
|
)
|
|
def test_add_complex_conj(self):
|
|
def f(x):
|
|
return x + x.conj()
|
|
|
|
x = torch.randn(4, dtype=torch.complex64, requires_grad=True)
|
|
out = torch.compile(f)(x)
|
|
expected_complex = (2 * x.real).to(dtype=out.dtype)
|
|
|
|
self.assertTrue(out.dtype == torch.complex64)
|
|
self.assertEqual(out, expected_complex)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/132200
|
|
def test_partitioner_cse_respects_mutation_boundaries(self):
|
|
set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_")
|
|
if not set_available:
|
|
return
|
|
|
|
@torch.compile(backend="aot_eager_decomp_partition")
|
|
def f(x, l):
|
|
# z0 and z1 can be CSEd
|
|
z0 = x.sin()
|
|
z1 = x.sin()
|
|
y = x + 1
|
|
torch.ops.fsdp.copy_.default(x, y)
|
|
# z3 and z3 can be CSEd with each other,
|
|
# but *not* with z0/z1 (they cross a mutation boundary)
|
|
z2 = x.sin()
|
|
z3 = x.sin()
|
|
return z0, z1, z2, z3, l**2
|
|
|
|
x = torch.randn(3)
|
|
x_clone = x.clone()
|
|
l = torch.randn(3, requires_grad=True)
|
|
z0, z1, z2, z3, _ = f(x, l)
|
|
|
|
# the partitioner runs CSE. We expect that of the 4 sin() ops above:
|
|
# - the first 2 are CSE'd
|
|
# - the last 2 are CSE'd
|
|
# - the set_() op in the middle is a mutation barrier, preventing CSE
|
|
self.assertEqual(z0, (x_clone).sin())
|
|
self.assertEqual(z1, (x_clone).sin())
|
|
self.assertEqual(z2, (x_clone + 1).sin())
|
|
self.assertEqual(z3, (x_clone + 1).sin())
|
|
|
|
# https://github.com/pytorch/pytorch/issues/132197
|
|
def test_fsdp_set_input_mutation_applied_when_input_gets_no_gradients(self):
|
|
set_available = hasattr(torch.ops, "fsdp") and hasattr(torch.ops.fsdp, "set_")
|
|
if not set_available:
|
|
return
|
|
|
|
@torch.compile(backend="aot_eager_decomp_partition")
|
|
def f(x, l):
|
|
z = x.sin() # noqa: F841
|
|
y = x + 1
|
|
# graph input has its storage mutated
|
|
torch.ops.fsdp.copy_.default(x, y)
|
|
z2 = x.sin()
|
|
return z2, l**2
|
|
|
|
x = torch.randn(3)
|
|
x_test = x.clone()
|
|
l = torch.randn(3, requires_grad=True)
|
|
result, _ = f(x, l)
|
|
result_test, _ = torch.compile(f, backend="aot_eager_decomp_partition")(
|
|
x_test, l
|
|
)
|
|
|
|
self.assertEqual(result, result_test)
|
|
self.assertEqual(x, x_test)
|
|
|
|
def test_aot_autograd_runtime_wrapper_prologue_profiled(self):
|
|
# Names for prologue profiling event
|
|
prologue_name = "AOTDispatcher Runtime Wrapper Prologue"
|
|
|
|
# Simple linear op to compile
|
|
mod = torch.nn.Linear(4, 4)
|
|
opt_mod = torch.compile(mod)
|
|
x = torch.randn(4, 4)
|
|
|
|
# Run this test with grad and no-grad to test both boolean cases trace_joint
|
|
for c in [contextlib.nullcontext, torch.no_grad]:
|
|
# Run compiled op with profiling
|
|
with c():
|
|
# warmup before profiling
|
|
opt_mod(x)
|
|
with profile(activities=[ProfilerActivity.CPU]) as prof:
|
|
opt_mod(x)
|
|
|
|
# Make sure events are populated then find prologue event and last start time
|
|
events = prof.events()
|
|
self.assertTrue(events is not None)
|
|
|
|
prologue_event = None
|
|
last_start_time = 0
|
|
for event in events:
|
|
if hasattr(event, "name") and prologue_name in event.name:
|
|
prologue_event = event
|
|
if event.time_range.start > last_start_time:
|
|
last_start_time = event.time_range.start
|
|
|
|
# Make sure prologue event exist
|
|
self.assertTrue(prologue_event is not None)
|
|
|
|
# Make sure there is at least one other event (compiled function) that starts
|
|
# after prologue starts
|
|
self.assertLess(prologue_event.time_range.end, last_start_time)
|
|
|
|
def test_changing_stride(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def fn(x, y):
|
|
return x * y
|
|
|
|
for i in range(1, 4):
|
|
x = torch.randn(4, i)
|
|
|
|
# create a view for i > 1
|
|
if i == 1:
|
|
x1 = x
|
|
else:
|
|
x1 = x[:, 0:1]
|
|
|
|
y = torch.randn(4, 1)
|
|
print(x1.shape, y.shape)
|
|
fn(x1, y)
|
|
|
|
self.assertTrue(cnt.frame_count <= 2)
|
|
|
|
def test_unsqueeze_mul_strides(self):
|
|
# This is a case where we had an input that was marked unbacked:
|
|
# size=[2, u0], stride=[1, 1] which is bad. We want it to actually
|
|
# be size=[2, u0], stride=[u0, 1]. See more in the issue below:
|
|
# https://github.com/pytorch/pytorch/issues/142024
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(aot6_sub_58, aot6_mul_170):
|
|
aot6_unsqueeze_14 = torch.ops.aten.unsqueeze.default(aot6_mul_170, 1)
|
|
return torch.ops.aten.mul.Tensor(aot6_sub_58, aot6_unsqueeze_14)
|
|
|
|
aot6_sub_58 = torch.randn(2, 1)
|
|
torch._dynamo.decorators.mark_unbacked(aot6_sub_58, 1)
|
|
aot6_mul_170 = torch.randn(2)
|
|
|
|
# No assert necessary since this used to crash.
|
|
fn(aot6_sub_58, aot6_mul_170)
|
|
|
|
@torch._dynamo.config.patch(guard_nn_modules=False)
|
|
@torch._dynamo.config.patch(inline_inbuilt_nn_modules=False)
|
|
def test_inlining_cornercase(self):
|
|
"""
|
|
nn.Modules can be mapped to either NNModuleVariable or UnspecializedNNModuleVariable. For NNModuleVariable, the
|
|
tensor attributes become part of the Dynamo graph. For unspecialized, they are lifted as inputs.
|
|
|
|
But there is a cornercase. Suppose you have NNModuleVariable with a submodule that is
|
|
UnspecializedNNModuleVariable. Today, Dynamo will still consider the submodule as specialized (courtesy of
|
|
guard.source().is_nn_module()). In retrospect, this is a mistake but there are dependencies of export and also
|
|
cudagraphs which make it harder to fix the corner case right away. The long term solution is
|
|
inline_inbuilt_nn_modules anyways, so we might have to live with this cornercase in the short term.
|
|
|
|
We are starting to annotate the source of each nn module more precisely - NNModuleVariable attribute is marked
|
|
as NNModuleSource, UnspecilaizedNNModuleVariable attribute is marked as UnspecializedNNModuleSource. But this
|
|
changes the behavior for the cornercase. And fails some tests which have unfortunately relied on this behavior.
|
|
|
|
|
|
To solve this, we tag the source only when inline_inbuilt_nn_module flag is turned on.
|
|
|
|
In this test, we purposely turn the flag off, testing that the tagging is disabled.
|
|
"""
|
|
|
|
class SubMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(1, 1)
|
|
self.a = torch.randn(1, 1)
|
|
self.counter = 0
|
|
self.multipliers = [2.2, 3.3]
|
|
|
|
def forward(self, x):
|
|
self.counter += 1
|
|
return (
|
|
self.linear(x) * self.a * self.multipliers[0] * self.multipliers[1]
|
|
)
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.submod = SubMod()
|
|
|
|
def forward(self, x):
|
|
return self.submod(x)
|
|
|
|
mod = Mod()
|
|
opt_mod = torch.compile(mod, backend="eager")
|
|
|
|
x = torch.randn(1, 1)
|
|
ref = mod(x) # noqa: F841
|
|
res = opt_mod(x) # noqa: F841
|
|
|
|
mod.submod.multipliers = [3.3, 4.4]
|
|
# Since guard_nn_modules is False, this will not recompile
|
|
with torch._dynamo.config.patch(error_on_recompile=True):
|
|
ref = mod(x) # noqa: F841
|
|
res = opt_mod(x) # noqa: F841
|
|
|
|
def test_optimized_module_training(self):
|
|
mod = torch.nn.Linear(3, 3)
|
|
mod.eval()
|
|
|
|
opt_mod = torch.compile(mod, backend="eager")
|
|
self.assertFalse(opt_mod.training)
|
|
|
|
opt_mod.train()
|
|
self.assertTrue(opt_mod.training)
|
|
self.assertTrue(mod.training)
|
|
|
|
mod.eval()
|
|
self.assertFalse(opt_mod.training)
|
|
|
|
def test_optimized_module_patched_init(self):
|
|
# A regression test for #138157, and the pattern acame from deepspeed.
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, x):
|
|
return x.mul(5.0)
|
|
|
|
def patch_init(init):
|
|
@functools.wraps(init)
|
|
def wrapper(module, *args, **kwargs):
|
|
if not hasattr(module, "_ds_child_entered"):
|
|
# child's __init__ was called, since parents all see the same object they can now skip post_init
|
|
module._ds_child_entered = True
|
|
init(module, *args, **kwargs)
|
|
|
|
return wrapper
|
|
|
|
def patch_init_for_class(cls):
|
|
if "__init__" in cls.__dict__:
|
|
cls._old_init = cls.__init__
|
|
cls.__init__ = patch_init(cls.__init__)
|
|
|
|
patch_init_for_class(MyModule)
|
|
mod = MyModule()
|
|
opt_mod = torch.compile(mod)
|
|
|
|
x = torch.rand(10)
|
|
ref = mod(x)
|
|
res = opt_mod(x)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_os_fspath(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
os.fspath(".")
|
|
return torch.sin(x)
|
|
|
|
fn(torch.randn(4))
|
|
|
|
@requires_cuda
|
|
# test involves custom ops that return unbacked symints
|
|
@torch._dynamo.config.patch(capture_dynamic_output_shape_ops=True)
|
|
# test requires the activation memory budget code to think
|
|
# that j() is banned from recompute
|
|
@torch._functorch.config.patch(activation_memory_budget=0.5)
|
|
def test_partitioner_activation_memory_budget_with_unbacked_symints(self):
|
|
@torch.library.custom_op("test_partitioner::f", mutates_args=[])
|
|
def f(x: torch.Tensor) -> torch.Tensor:
|
|
return x.new_zeros(512, 1)
|
|
|
|
@f.register_fake
|
|
def _(x: torch.Tensor) -> torch.Tensor:
|
|
ctx = torch.library.get_ctx()
|
|
s = ctx.new_dynamic_size()
|
|
return torch.empty(s, 1, device=x.device, dtype=x.dtype)
|
|
|
|
@torch.library.custom_op("test_partitioner::g", mutates_args=[])
|
|
def g(x: torch.Tensor) -> torch.Tensor:
|
|
return torch.cat([x, x[0].unsqueeze(-1)])
|
|
|
|
@g.register_fake
|
|
def _(x: torch.Tensor) -> torch.Tensor:
|
|
return torch.cat([x, x[0].unsqueeze(-1)])
|
|
|
|
@torch.library.custom_op("test_partitioner::i", mutates_args=[])
|
|
def i(x: torch.Tensor, sz: int) -> torch.Tensor:
|
|
return torch.ones(sz, 1, dtype=x.dtype, device=x.device)
|
|
|
|
@i.register_fake
|
|
def _(x: torch.Tensor, sz: int) -> torch.Tensor:
|
|
return torch.empty(sz, 1, dtype=x.dtype, device=x.device)
|
|
|
|
@torch.library.custom_op("test_partitioner::j", mutates_args=[])
|
|
def j(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
return x + 1
|
|
|
|
@j.register_fake
|
|
def _(x: torch.Tensor, y: torch.Tensor) -> torch.Tensor:
|
|
sz1 = x.shape[0] - 1
|
|
sz2 = y.numel()
|
|
torch._check(sz1 == sz2)
|
|
# make this a reduction so partitioner bans recompute of it
|
|
return x.sum()
|
|
|
|
def f(x, param):
|
|
y = torch.ops.test_partitioner.f(x)
|
|
z = torch.ops.test_partitioner.g(y)
|
|
z2 = torch.ops.test_partitioner.i(x, z.shape[0] - 1)
|
|
z2 = torch.ops.test_partitioner.j(z, z2)
|
|
return torch.matmul(x, param).sin() * z2.sum()
|
|
|
|
x = torch.randn(512, 512, device="cuda")
|
|
param = torch.randn(512, 512, device="cuda", requires_grad=True)
|
|
out_ref = f(x, param)
|
|
out_test = torch.compile(f, backend="aot_eager_decomp_partition")(x, param)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
@requires_cuda
|
|
# This test will fail as flip in combination with particular input lengths
|
|
# produces weird results.
|
|
# This is under investigations in
|
|
# https://github.com/pytorch/pytorch/issues/131805
|
|
@unittest.skip("Skip this flip test for the moment. It is under investigation")
|
|
def test_flip_bad_accuracy(self):
|
|
import torch
|
|
import torch._dynamo.config
|
|
import torch._functorch.config
|
|
import torch._inductor.config
|
|
import torch._inductor.inductor_prims
|
|
import torch.fx.experimental._config
|
|
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, arg0_1):
|
|
rev = torch.ops.prims.rev.default(arg0_1, [0])
|
|
arg0_1 = None
|
|
slice_1 = torch.ops.aten.slice.Tensor(rev, 0, 0, -1, 2)
|
|
slice_2 = torch.ops.aten.slice.Tensor(rev, 0, 1, 9223372036854775807, 2)
|
|
add_1 = torch.ops.aten.add.Tensor(slice_1, slice_2)
|
|
slice_1 = slice_2 = None
|
|
slice_3 = torch.ops.aten.slice.Tensor(add_1, 0, 0, -1, 2)
|
|
slice_4 = torch.ops.aten.slice.Tensor(
|
|
add_1, 0, 1, 9223372036854775807, 2
|
|
)
|
|
add_2 = torch.ops.aten.add.Tensor(slice_3, slice_4)
|
|
slice_3 = slice_4 = None
|
|
slice_5 = torch.ops.aten.slice.Tensor(add_2, 0, 0, -1, 2)
|
|
slice_6 = torch.ops.aten.slice.Tensor(
|
|
add_2, 0, 1, 9223372036854775807, 2
|
|
)
|
|
add_3 = torch.ops.aten.add.Tensor(slice_5, slice_6)
|
|
slice_5 = slice_6 = None
|
|
slice_9 = torch.ops.aten.slice.Tensor(add_2, 0, 0, 1)
|
|
add_2 = None
|
|
unsqueeze = torch.ops.aten.unsqueeze.default(slice_9, 1)
|
|
slice_9 = None
|
|
unsqueeze_1 = torch.ops.aten.unsqueeze.default(add_3, 1)
|
|
add_3 = None
|
|
cat = torch.ops.aten.cat.default([unsqueeze, unsqueeze_1], 1)
|
|
unsqueeze = unsqueeze_1 = None
|
|
view = torch.ops.aten.view.default(cat, [2])
|
|
cat = None
|
|
slice_10 = torch.ops.aten.slice.Tensor(view, 0, 0, -1)
|
|
slice_11 = torch.ops.aten.slice.Tensor(
|
|
add_1, 0, 2, 9223372036854775807, 2
|
|
)
|
|
add_5 = torch.ops.aten.add.Tensor(slice_10, slice_11)
|
|
slice_10 = slice_11 = None
|
|
slice_12 = torch.ops.aten.slice.Tensor(add_1, 0, 0, 1)
|
|
add_1 = None
|
|
cat_1 = torch.ops.aten.cat.default([slice_12, add_5])
|
|
slice_12 = add_5 = None
|
|
unsqueeze_2 = torch.ops.aten.unsqueeze.default(cat_1, 1)
|
|
cat_1 = None
|
|
unsqueeze_3 = torch.ops.aten.unsqueeze.default(view, 1)
|
|
view = None
|
|
cat_2 = torch.ops.aten.cat.default([unsqueeze_2, unsqueeze_3], 1)
|
|
unsqueeze_2 = unsqueeze_3 = None
|
|
view_1 = torch.ops.aten.view.default(cat_2, [4])
|
|
cat_2 = None
|
|
slice_13 = torch.ops.aten.slice.Tensor(
|
|
rev, 0, 2, 9223372036854775807, 2
|
|
)
|
|
add_6 = torch.ops.aten.add.Tensor(view_1, slice_13)
|
|
slice_13 = None
|
|
slice_14 = torch.ops.aten.slice.Tensor(rev, 0, 0, 1)
|
|
rev = None
|
|
cat_3 = torch.ops.aten.cat.default([slice_14, add_6])
|
|
slice_14 = add_6 = None
|
|
constant_pad_nd = torch.ops.aten.constant_pad_nd.default(
|
|
view_1, [0, 1], 0.0
|
|
)
|
|
view_1 = None
|
|
unsqueeze_4 = torch.ops.aten.unsqueeze.default(cat_3, 1)
|
|
cat_3 = None
|
|
unsqueeze_5 = torch.ops.aten.unsqueeze.default(constant_pad_nd, 1)
|
|
constant_pad_nd = None
|
|
cat_4 = torch.ops.aten.cat.default([unsqueeze_4, unsqueeze_5], 1)
|
|
unsqueeze_4 = unsqueeze_5 = None
|
|
view_2 = torch.ops.aten.view.default(cat_4, [10])
|
|
cat_4 = None
|
|
slice_15 = torch.ops.aten.slice.Tensor(view_2, 0, 0, 9)
|
|
view_2 = None
|
|
rev_1 = torch.ops.prims.rev.default(slice_15, [0])
|
|
slice_15 = None
|
|
return (rev_1,)
|
|
|
|
mod = Repro()
|
|
x = torch.arange(9, device=torch.device("cuda"))
|
|
|
|
@torch.compile
|
|
def f(x):
|
|
return mod(x)
|
|
|
|
out = f(x)
|
|
self.assertEqual(torch.flip(torch.cumsum(torch.flip(x, [0]), 0), [0]), out[0])
|
|
|
|
# https://github.com/pytorch/pytorch/issues/88813
|
|
def test_return_value_duplication_tensor(self) -> None:
|
|
def fn(val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
return val * 2, val * 2
|
|
|
|
x = torch.randn(2, requires_grad=True)
|
|
|
|
expect = fn(x)
|
|
self.assertNotEqual(
|
|
expect[0].untyped_storage().data_ptr(),
|
|
expect[1].untyped_storage().data_ptr(),
|
|
)
|
|
|
|
actual = torch.compile(fn, backend="aot_eager")(x)
|
|
self.assertNotEqual(
|
|
actual[0].untyped_storage().data_ptr(),
|
|
actual[1].untyped_storage().data_ptr(),
|
|
)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/114344
|
|
def test_return_value_duplication_mixed_grad(self) -> None:
|
|
def fn(val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
with torch.no_grad():
|
|
out0 = val + 1
|
|
out1 = val + 1
|
|
return out0, out1
|
|
|
|
x = torch.randn(2, requires_grad=True)
|
|
|
|
with torch.enable_grad():
|
|
expect = fn(x)
|
|
actual = torch.compile(fn, backend="aot_eager")(x)
|
|
|
|
self.assertEqual(expect[0].requires_grad, actual[0].requires_grad)
|
|
self.assertEqual(expect[1].requires_grad, actual[1].requires_grad)
|
|
|
|
# https://github.com/pytorch/pytorch/pull/134726#discussion_r1738774371
|
|
def test_return_value_duplication_scalar(self) -> None:
|
|
def fn(val: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
|
|
x, y = val * 2, val * 2
|
|
return x[0], y[0]
|
|
|
|
x = torch.randn(2, requires_grad=True)
|
|
|
|
expect = fn(x)
|
|
self.assertNotEqual(
|
|
expect[0].untyped_storage().data_ptr(),
|
|
expect[1].untyped_storage().data_ptr(),
|
|
)
|
|
|
|
actual = torch.compile(fn, backend="aot_eager")(x)
|
|
self.assertNotEqual(
|
|
actual[0].untyped_storage().data_ptr(),
|
|
actual[1].untyped_storage().data_ptr(),
|
|
)
|
|
|
|
def test_torch_compile_in_compile_frame(self):
|
|
def gn(x, c=None):
|
|
if c is None:
|
|
c = 2
|
|
return c * x
|
|
|
|
def outer_func(x):
|
|
return torch.compile(gn, backend="eager")(x)
|
|
|
|
compile_outer = torch.compile(outer_func, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
ref = outer_func(x)
|
|
res = compile_outer(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/136640
|
|
def test_inductor_dynamic_shapes_broadcasting(self) -> None:
|
|
def fn(x, y):
|
|
x_view = x.view(-1, 4)
|
|
y_view = y.view(-1, 4)
|
|
return x_view * y_view
|
|
|
|
x = torch.randn(4)
|
|
y = torch.randn(8)
|
|
out_ref = fn(x, y)
|
|
out_test = torch.compile(fn, dynamic=True)(x, y)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/119162
|
|
def test_inductor_rng_default_dtype(self) -> None:
|
|
@torch.compile
|
|
def fn():
|
|
tmp = torch.randn(4, 4, dtype=torch.bfloat16)
|
|
return tmp
|
|
|
|
try:
|
|
old = torch.get_default_dtype()
|
|
torch.set_default_dtype(torch.bfloat16)
|
|
out = fn()
|
|
finally:
|
|
torch.set_default_dtype(old)
|
|
# output dtype should be float32
|
|
self.assertEqual(out.dtype, torch.bfloat16)
|
|
|
|
@unittest.skipIf(not HAS_MSGSPEC, "missing msgspec package")
|
|
def test_c_defined_metaclass(self):
|
|
class User(msgspec.Struct):
|
|
"""A new type describing a User"""
|
|
|
|
name: str
|
|
value: int
|
|
|
|
def fn(x):
|
|
u = User("alice", 10)
|
|
return x * u.value
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
@unittest.skipIf(not HAS_OMEGACONG, "missing omegaconf package")
|
|
def test_omegaconf_dictconfig(self):
|
|
def fn(cfg, x):
|
|
a = cfg["foo"].a * x
|
|
b = cfg.bar["b"] * a
|
|
cfg.__dict__["baz"] = 4
|
|
return b * cfg.baz
|
|
|
|
config = OmegaConf.create({"foo": {"a": 3}, "bar": {"b": 5}})
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
fn(config, x)
|
|
cloned_config = copy.deepcopy(config)
|
|
opt_fn(cloned_config, x)
|
|
|
|
self.assertEqual(fn(config, x), opt_fn(config, x))
|
|
self.assertEqual(cloned_config.baz, 4)
|
|
|
|
@unittest.skipIf(not HAS_OMEGACONG, "missing omegaconf package")
|
|
def test_omegaconf_listconfig_contains(self):
|
|
def fn(cfg, x):
|
|
if 1 in cfg:
|
|
return torch.sin(x)
|
|
return torch.cos(x)
|
|
|
|
config = OmegaConf.create([1, 2, 3, {"key": "value"}])
|
|
|
|
x = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(config, x), opt_fn(config, x))
|
|
|
|
# https://github.com/pytorch/pytorch/issues/136257
|
|
def test_overwriting_params(self):
|
|
class M(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc1 = torch.nn.Linear(2, 2)
|
|
self.fc2 = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
x = self.fc1(x)
|
|
x = self.fc2(x)
|
|
return x
|
|
|
|
class ZeROOrderedDict(collections.OrderedDict):
|
|
def __init__(self, parent_module=None, *args, **kwargs):
|
|
"""A replacement for ``collections.OrderedDict`` to detect external ZeRO params.
|
|
|
|
Args:
|
|
parent_module (``collections.OrderedDict``): the collection to replace
|
|
"""
|
|
|
|
super().__init__(*args, **kwargs)
|
|
self._parent_module = parent_module
|
|
|
|
def __getitem__(self, key):
|
|
param = super().__getitem__(key)
|
|
|
|
# Params can be registered as None (e.g., bias)
|
|
if param is None:
|
|
return param
|
|
|
|
# do something here
|
|
return param
|
|
|
|
def inject_parameters(module, cls):
|
|
for module in module.modules(): # noqa: B020
|
|
if cls == ZeROOrderedDict:
|
|
new_param = cls(parent_module=module)
|
|
else:
|
|
new_param = cls()
|
|
|
|
for key, param in module._parameters.items():
|
|
new_param[key] = param
|
|
module._parameters = new_param
|
|
|
|
model = M()
|
|
|
|
inject_parameters(model, ZeROOrderedDict)
|
|
|
|
model = torch.compile(model, backend="eager", fullgraph=True)
|
|
|
|
x = torch.ones(2)
|
|
with torch.no_grad():
|
|
model(x)
|
|
|
|
def test_ao_fake_quantize_tracing(self):
|
|
import torch.ao.quantization.fake_quantize
|
|
|
|
q = torch.ao.quantization.FusedMovingAvgObsFakeQuantize()
|
|
|
|
def fn(x):
|
|
return q(x)
|
|
|
|
x = torch.ones(2, 2)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x)
|
|
eager_res = fn(x)
|
|
|
|
self.assertEqual(res, eager_res)
|
|
|
|
def test_typed_dict(self):
|
|
class LlavaImagePixelInputs(TypedDict):
|
|
type: Literal["pixel_values"]
|
|
data: torch.Tensor
|
|
"""Shape: `(batch_size, num_channels, height, width)`"""
|
|
|
|
def fn(x, y):
|
|
obj = LlavaImagePixelInputs(type=int, data=y)
|
|
out = x * obj["data"]
|
|
obj["data"] = 3
|
|
return out * obj["data"]
|
|
|
|
x, y = torch.randn(4), torch.randn(4)
|
|
ref = fn(x, y)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x, y)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_typed_dict_total(self):
|
|
class LlavaImagePixelInputs(TypedDict):
|
|
type: Literal["pixel_values"]
|
|
data: torch.Tensor
|
|
"""Shape: `(batch_size, num_channels, height, width)`"""
|
|
|
|
def fn(x, y):
|
|
obj = LlavaImagePixelInputs(data=y, total=False)
|
|
return x * obj["data"]
|
|
|
|
x, y = torch.randn(4), torch.randn(4)
|
|
ref = fn(x, y)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x, y)
|
|
|
|
self.assertEqual(ref, res)
|
|
|
|
@skipIfPy312 # listcomp bytecode is optimized
|
|
def test_listcomp(self):
|
|
class Module(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self._num = 4
|
|
|
|
@torch._dynamo.disable(recursive=False)
|
|
def forward(self, x):
|
|
values = [i * torch.cos(x) for i in range(self._num)]
|
|
return sum(values)
|
|
|
|
mod = Module()
|
|
|
|
def fn(x):
|
|
return mod(x)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
x = torch.randn(4)
|
|
|
|
ref = fn(x)
|
|
res = opt_fn(x)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
# Ensure that the listcomp is fully compiled
|
|
self.assertEqual(cnt.op_count, 8)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/140266
|
|
def test_distributions_subclass(self):
|
|
import torch
|
|
from torch.distributions import Categorical
|
|
|
|
class SubCateg(Categorical):
|
|
pass
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def make_dist_and_execute(t, d):
|
|
categ = d(logits=t)
|
|
a = categ.log_prob(categ.sample()) + categ.probs + categ.logits
|
|
return a
|
|
|
|
for _ in range(2):
|
|
make_dist_and_execute(torch.randn(10), SubCateg)
|
|
|
|
def test_bitwise_print_precedence(self):
|
|
import math
|
|
|
|
@torch.compile(fullgraph=True, dynamic=True)
|
|
def f(x):
|
|
torch._check(math.floor((x.size(0) | 3) * 4) == 12)
|
|
return x.sin()
|
|
|
|
f(torch.randn(2))
|
|
|
|
def test_tensor_split_within_device_cm(self):
|
|
@torch.compile(fullgraph=True)
|
|
def split(x):
|
|
return x.split(4, 0)
|
|
|
|
x = torch.zeros(12)
|
|
res = split(x)
|
|
|
|
with torch.device("cpu"):
|
|
self.assertEqual(res, split(x))
|
|
|
|
def test_method_overriding(self):
|
|
class DilateConv(torch.nn.Module):
|
|
def __init__(
|
|
self,
|
|
dilate_func=None,
|
|
):
|
|
super().__init__()
|
|
self.dilate_func = dilate_func
|
|
|
|
def forward(self, x):
|
|
return self.dilate_func() * torch.sin(x)
|
|
|
|
class MainModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.mod = DilateConv(self.dilate_func)
|
|
self.a = 4
|
|
|
|
def dilate_func(self):
|
|
return self.a
|
|
|
|
def forward(self, x):
|
|
return self.mod(x)
|
|
|
|
mod = MainModule()
|
|
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
ref = mod(x)
|
|
res = opt_mod(x)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_symnode_is_op(self):
|
|
@torch.compile(backend="eager", fullgraph=True, dynamic=True)
|
|
def f(x, xs):
|
|
if x.size(0) is xs:
|
|
return x + 1
|
|
else:
|
|
return x * 2
|
|
|
|
t = torch.randn(2)
|
|
res = f(t, [1, 2])
|
|
self.assertEqual(t * 2, res)
|
|
|
|
def test_compile_copy__int_overload(self):
|
|
@torch.compile(backend="aot_eager", fullgraph=True)
|
|
def f(x):
|
|
return x.copy_(1)
|
|
|
|
t = torch.zeros(2)
|
|
res = f(t)
|
|
self.assertEqual(torch.ones_like(t), res)
|
|
|
|
def test_symnode_is_not_op(self):
|
|
@torch.compile(backend="eager", fullgraph=True, dynamic=True)
|
|
def f(x, xs):
|
|
if x.size(0) is not xs:
|
|
return x + 1
|
|
else:
|
|
return x * 2
|
|
|
|
t = torch.randn(2)
|
|
res = f(t, [1, 2])
|
|
self.assertEqual(t + 1, res)
|
|
|
|
def test_symint_bitwise(self):
|
|
def fn(x):
|
|
z = x.shape[0]
|
|
z |= z >> 1
|
|
z |= z << 1
|
|
z &= z | (z > 1)
|
|
y = (z > 1) | (z <= 1)
|
|
# test composition with non-bitwise ops
|
|
z = (z | z) % 6
|
|
return y, z
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", dynamic=True, fullgraph=True)
|
|
inp = torch.randn(3, 3)
|
|
self.assertEqual(fn(inp), opt_fn(inp))
|
|
|
|
def test_bitwise_op_guard(self):
|
|
# attempt evaluating a guard with BitwiseFn_bitwise_[and/or]
|
|
def fn(x):
|
|
if x.shape[0] | x.shape[1] > 4:
|
|
x = x + 1
|
|
if x.shape[0] & x.shape[1] > 2:
|
|
return x + 1
|
|
return x - 1
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", dynamic=True, fullgraph=True)
|
|
inp = torch.randn(3, 3)
|
|
self.assertEqual(fn(inp), opt_fn(inp))
|
|
|
|
def test_ones_out_dynamic(self):
|
|
def ones_fn(size, out):
|
|
return torch.ones(size, out=out)
|
|
|
|
opt_model = torch.compile(ones_fn)
|
|
|
|
out1 = torch.empty(2, 3)
|
|
opt_model((2, 3), out1)
|
|
|
|
out2 = torch.empty(3, 4)
|
|
opt_model((3, 4), out2)
|
|
|
|
def test_zeros_out_dynamic(self):
|
|
def zeros_fn(size, out):
|
|
return torch.zeros(size, out=out)
|
|
|
|
opt_model = torch.compile(zeros_fn)
|
|
|
|
out1 = torch.empty(2, 3)
|
|
opt_model((2, 3), out1)
|
|
|
|
out2 = torch.empty(3, 4)
|
|
opt_model((3, 4), out2)
|
|
|
|
def test_empty_out_dynamic(self):
|
|
def empty_fn(size, out):
|
|
return torch.empty(size, out=out)
|
|
|
|
opt_model = torch.compile(empty_fn)
|
|
|
|
out1 = torch.empty(2, 3)
|
|
opt_model((2, 3), out1)
|
|
|
|
out2 = torch.empty(3, 4)
|
|
opt_model((3, 4), out2)
|
|
|
|
def test_dataclass_in_module(self):
|
|
@dataclasses.dataclass
|
|
class MyData:
|
|
value: float
|
|
|
|
class MyModel(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.my_data = MyData(value=3.14)
|
|
|
|
def forward(self, x):
|
|
# Make sure to use the scalar 'value' correctly in tensor operations
|
|
value_tensor = torch.tensor(self.my_data.value)
|
|
return x + value_tensor
|
|
|
|
model = MyModel()
|
|
inputs = torch.randn(2, 2)
|
|
expected = model(inputs)
|
|
compiled_model = torch.compile(model)
|
|
actual = compiled_model(inputs)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_no_tracing_into_eval_frame(self):
|
|
# test that dynamo doesn't trace into nested calls from eval_frame
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
orig_fn = torch._dynamo.eval_frame._maybe_set_eval_frame
|
|
|
|
def bad(*args, **kwargs):
|
|
torch._dynamo.graph_break()
|
|
return orig_fn(*args, **kwargs)
|
|
|
|
with mock.patch("torch._dynamo.eval_frame._maybe_set_eval_frame", bad):
|
|
fn(torch.ones(3))
|
|
|
|
@torch._dynamo.config.patch(raise_on_ctx_manager_usage=False)
|
|
def test_no_tracing_into_eval_frame_ctx_manager(self):
|
|
# Test that dynamo doesn't trace into nested calls from eval_frame
|
|
# when using a context manager.
|
|
# Even though we don't officially support Dynamo context managers, we still
|
|
# have tests that use them, so we should still make sure the eval_frame callback
|
|
# is set at the correct places in these cases.
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
orig_fn = torch._dynamo.eval_frame._maybe_set_eval_frame
|
|
|
|
def bad(*args, **kwargs):
|
|
torch._dynamo.graph_break()
|
|
return orig_fn(*args, **kwargs)
|
|
|
|
with mock.patch("torch._dynamo.eval_frame._maybe_set_eval_frame", bad):
|
|
with torch._dynamo.optimize_assert("eager"):
|
|
fn(torch.ones(3))
|
|
|
|
@torch._dynamo.config.patch(allow_empty_graphs=True)
|
|
@parametrize("fullgraph", [True, False])
|
|
def test_empty_graph_nested_calls(self, fullgraph):
|
|
def k(x):
|
|
return x
|
|
|
|
def g(x):
|
|
return k(x)
|
|
|
|
def f(x):
|
|
return g(x)
|
|
|
|
# TODO clear this on all tests
|
|
torch._dynamo.eval_frame.clear_dynamo_tls()
|
|
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=fullgraph, dynamic=False)
|
|
opt_f(torch.randn(3))
|
|
# we should not be compiling g or h as top-level functions
|
|
self.assertEqual(len(torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos), 1)
|
|
# no recompilation
|
|
opt_f(torch.randn(3))
|
|
self.assertEqual(len(torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos), 1)
|
|
# recompilation
|
|
opt_f(torch.randn(4))
|
|
self.assertEqual(len(torch._dynamo.eval_frame.dynamo_tls.traced_frame_infos), 2)
|
|
|
|
def test_torchname(self):
|
|
def fn(obj):
|
|
return torch.typename(obj)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
self.assertEqual(fn(typing.Any), opt_fn(typing.Any))
|
|
|
|
@unittest.skipIf(not TEST_CUDA, "test requires CUDA")
|
|
@unittest.skipIf(not dist.is_available(), "test requires distributed")
|
|
# TODO: Remoe this skip once nccl issue if fixed
|
|
@unittest.skip(
|
|
"Failing with ncc update 2.25.1 : https://github.com/pytorch/pytorch/issues/147141"
|
|
)
|
|
def test_ddp_checkpoint(self):
|
|
# https://github.com/pytorch/pytorch/issues/144035
|
|
DIM = 256
|
|
SEQ_LEN = 32
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def mlp_forward(x, w1, w2, b1, b2):
|
|
y = F.linear(x, w1, b1)
|
|
y = F.relu(y)
|
|
y = F.linear(y, w2, b2)
|
|
return y
|
|
|
|
class MLP(nn.Module):
|
|
def __init__(
|
|
self,
|
|
in_features: int,
|
|
hidden_features: int,
|
|
out_features: int,
|
|
):
|
|
super().__init__()
|
|
self.w_in = nn.Parameter(torch.randn(hidden_features, in_features))
|
|
self.w_out = nn.Parameter(torch.randn(out_features, hidden_features))
|
|
self.b_in = nn.Parameter(torch.randn(hidden_features))
|
|
self.b_out = nn.Parameter(torch.randn(out_features))
|
|
|
|
def forward(self, x):
|
|
result = torch.utils.checkpoint.checkpoint(
|
|
mlp_forward,
|
|
x,
|
|
self.w_in,
|
|
self.w_out,
|
|
self.b_in,
|
|
self.b_out,
|
|
use_reentrant=False,
|
|
)
|
|
assert isinstance(result, torch.Tensor)
|
|
return result
|
|
|
|
x = torch.randn(100, SEQ_LEN, DIM)
|
|
y = torch.zeros(100)
|
|
dataset = torch.utils.data.TensorDataset(x, y)
|
|
dataloader = torch.utils.data.DataLoader(dataset, batch_size=10)
|
|
model = MLP(DIM, 4 * DIM, DIM)
|
|
|
|
try:
|
|
# required for DDP wrapper initialization
|
|
prior_master_addr = os.environ.get("MASTER_ADDR", None)
|
|
prior_master_port = os.environ.get("MASTER_PORT", None)
|
|
os.environ["MASTER_ADDR"] = "localhost"
|
|
os.environ["MASTER_PORT"] = "12355"
|
|
dist.init_process_group(backend="nccl", world_size=1, rank=0)
|
|
model = model.to("cuda")
|
|
model = nn.parallel.DistributedDataParallel(model)
|
|
|
|
for batch in dataloader:
|
|
x, y = batch
|
|
x = x.to("cuda")
|
|
output = model(x)
|
|
loss = output.sum()
|
|
loss.backward()
|
|
finally:
|
|
dist.destroy_process_group()
|
|
if prior_master_addr:
|
|
os.environ["MASTER_ADDR"] = prior_master_addr
|
|
else:
|
|
del os.environ["MASTER_ADDR"]
|
|
|
|
if prior_master_port:
|
|
os.environ["MASTER_PORT"] = prior_master_port
|
|
else:
|
|
del os.environ["MASTER_PORT"]
|
|
|
|
@torch._dynamo.config.patch(
|
|
recompile_limit=1,
|
|
fail_on_recompile_limit_hit=True,
|
|
)
|
|
def test_compilation_metrics_on_error(self):
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
# force a recompile in a way friendly to test_dynamic_shapes
|
|
if x.numel() == 100:
|
|
return x.sum()
|
|
elif x.numel() == 10000:
|
|
return x.sum()
|
|
|
|
x = torch.randn(10, 10)
|
|
y = torch.randn(100, 100)
|
|
metrics = torch._dynamo.utils._compilation_metrics
|
|
self.assertEqual(len(metrics), 0)
|
|
|
|
fn(x)
|
|
self.assertTrue(metrics is torch._dynamo.utils._compilation_metrics)
|
|
self.assertEqual(len(metrics), 1)
|
|
latest_metrics = metrics[-1]
|
|
self.assertTrue(latest_metrics.dynamo_config is not None)
|
|
self.assertTrue(latest_metrics.recompile_reason is None)
|
|
|
|
with self.assertRaises(torch._dynamo.exc.FailOnRecompileLimitHit):
|
|
fn(y)
|
|
self.assertTrue(metrics is torch._dynamo.utils._compilation_metrics)
|
|
self.assertEqual(len(metrics), 2)
|
|
latest_metrics = metrics[-1]
|
|
self.assertTrue(latest_metrics.dynamo_config is not None)
|
|
self.assertTrue(latest_metrics.recompile_reason is not None)
|
|
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
# https://github.com/pytorch/pytorch/issues/156580
|
|
@serialTest()
|
|
def test_dont_dce_rand(self):
|
|
# https://github.com/pytorch/pytorch/issues/143431
|
|
def f(image_latent):
|
|
B = 2
|
|
num_ref = 3
|
|
num_tar = 3
|
|
x = torch.rand(B, 12)
|
|
indices = torch.argsort(torch.rand(*x.shape), dim=-1)[
|
|
:, : num_ref + num_tar
|
|
]
|
|
return image_latent[torch.arange(B).unsqueeze(-1), indices][:, :num_ref]
|
|
|
|
torch.manual_seed(54321)
|
|
torch.cuda.manual_seed_all(54321)
|
|
expected = f(torch.randn((2, 12, 16, 32, 32))).sum()
|
|
|
|
# https://github.com/pytorch/pytorch/issues/147171
|
|
torch._inductor.config.fallback_random = True
|
|
|
|
for backend in ["eager", "aot_eager"]:
|
|
torch.manual_seed(54321)
|
|
torch.cuda.manual_seed_all(54321)
|
|
actual = torch.compile(backend=backend, fullgraph=True)(f)(
|
|
torch.randn((2, 12, 16, 32, 32))
|
|
).sum()
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_incompatible_configs(self):
|
|
with torch._dynamo.config.patch(
|
|
suppress_errors=False, fail_on_recompile_limit_hit=False
|
|
):
|
|
torch.compile(lambda: None)
|
|
|
|
with torch._dynamo.config.patch(
|
|
suppress_errors=True, fail_on_recompile_limit_hit=False
|
|
):
|
|
torch.compile(lambda: None)
|
|
|
|
with torch._dynamo.config.patch(
|
|
suppress_errors=False, fail_on_recompile_limit_hit=True
|
|
):
|
|
torch.compile(lambda: None)
|
|
|
|
with (
|
|
torch._dynamo.config.patch(
|
|
suppress_errors=True, fail_on_recompile_limit_hit=True
|
|
),
|
|
self.assertRaises(AssertionError),
|
|
):
|
|
torch.compile(lambda: None)
|
|
|
|
def test_str_isalnum(self):
|
|
def f(x, c):
|
|
str.isalnum(c)
|
|
return x.sin()
|
|
|
|
opt_f = torch.compile(f, backend="eager", fullgraph=True)
|
|
x = torch.randn(3)
|
|
c = "foobar"
|
|
self.assertEqual(f(x, c), opt_f(x, c))
|
|
|
|
def test_nn_param_freevar_codegen(self):
|
|
class Model2(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.conv = nn.Conv2d(in_channels=3, out_channels=5, kernel_size=3)
|
|
self.batchnorm = nn.BatchNorm2d(num_features=5)
|
|
self.conv_weight = torch.randn(5, 3, 3, 3)
|
|
self.conv_bias = torch.randn(5)
|
|
|
|
def forward(self, x):
|
|
self.conv.weight = nn.Parameter(self.conv_weight)
|
|
self.conv.bias = nn.Parameter(self.conv_bias, requires_grad=False)
|
|
self.conv.eval()
|
|
x = self.conv(x)
|
|
x = self.batchnorm(x)
|
|
x = F.relu(x)
|
|
return x
|
|
|
|
input_tensor = torch.randn(1, 3, 10, 10)
|
|
func = Model2().to("cpu")
|
|
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
return func(*args, **kwargs)
|
|
|
|
with torch.no_grad():
|
|
func.train(False)
|
|
v1 = func(input_tensor)
|
|
jit_func = torch.compile(wrapper, backend="eager", fullgraph=True)
|
|
v2 = jit_func(input_tensor)
|
|
self.assertEqual(v1, v2)
|
|
|
|
def test_amp_foreach_fake_impl(self):
|
|
inv_scale = torch.full((1,), 0.25)
|
|
found_inf = torch.full((1,), 0.0)
|
|
grads = [torch.ones(10), torch.ones(10)]
|
|
|
|
def f():
|
|
res = torch._amp_foreach_non_finite_check_and_unscale_(
|
|
grads, found_inf, inv_scale
|
|
)
|
|
return res
|
|
|
|
ref = f()
|
|
res = torch.compile(f, backend="aot_eager")()
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_deleted_compile_wrapper_segfault(self):
|
|
def fn(x):
|
|
return x + 1
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
# This calls cached_backend.clear() which removes any strong references
|
|
# to the callback
|
|
torch._dynamo.reset()
|
|
opt_fn(torch.randn(3))
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
opt_fn(torch.randn(3)) # possible segfault due to first opt_fn deletion
|
|
|
|
def test_delete_local_error(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
y = x + 1
|
|
del y
|
|
z = y + 1 # noqa: F821
|
|
return z
|
|
|
|
with self.assertRaises(torch._dynamo.exc.Unsupported):
|
|
fn(torch.ones(3))
|
|
|
|
def test_nanmean_out(self):
|
|
def f(x, out):
|
|
torch.nanmean(x, out=out)
|
|
|
|
x = torch.randn(4)
|
|
out_ref = torch.tensor(0.0)
|
|
out_res = torch.tensor(0.0)
|
|
|
|
f(x, out_ref)
|
|
torch.compile(f, backend="eager", fullgraph=True)(x, out_res)
|
|
self.assertEqual(out_ref, out_res)
|
|
|
|
def test_unbind_copy_out(self):
|
|
def f(eye, out):
|
|
torch.unbind_copy(eye, out=out)
|
|
|
|
eye = torch.eye(3)
|
|
out_ref = (torch.zeros(3), torch.zeros(3), torch.zeros(3))
|
|
out_res = (torch.zeros(3), torch.zeros(3), torch.zeros(3))
|
|
|
|
f(eye, out_ref)
|
|
torch.compile(f, backend="eager", fullgraph=True)(eye, out_res)
|
|
self.assertEqual(out_ref, out_res)
|
|
|
|
|
|
class ReproTestsDevice(torch._dynamo.test_case.TestCase):
|
|
def test_sub_alpha_scalar_repro(self, device):
|
|
@torch.compile(backend="aot_eager")
|
|
def f(x):
|
|
return x.sub(1, alpha=2)
|
|
|
|
f(torch.ones(2, device=device, dtype=torch.float64))
|
|
|
|
@requires_cuda
|
|
def test_norm_dtype(self, device):
|
|
def foo(_stack0):
|
|
getitem = _stack0[(slice(None, None, None), -1)]
|
|
_stack0 = None
|
|
normalize = torch.nn.functional.normalize(getitem, p=2, dim=1)
|
|
getitem = None
|
|
return (normalize,)
|
|
|
|
args = [((2, 50, 256), (1, 256, 1), torch.float16, device, False)]
|
|
args = [
|
|
rand_strided(sh, st, dt, dev).requires_grad_(rg)
|
|
for (sh, st, dt, dev, rg) in args
|
|
]
|
|
|
|
torch.compile(foo, backend="aot_eager_decomp_partition")
|
|
with torch.cuda.amp.autocast(enabled=True):
|
|
ref = foo(*args)[0]
|
|
res = foo(*args)[0]
|
|
self.assertEqual(ref.dtype, res.dtype)
|
|
|
|
self.assertTrue(same(res, ref))
|
|
|
|
def test_guard_default_device(self, device):
|
|
try:
|
|
torch.set_default_device(device)
|
|
|
|
counter = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.optimize(counter)
|
|
def f():
|
|
x = torch.randn(3)
|
|
return x * 2
|
|
|
|
self.assertEqual(f().device.type + ":0", device)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
torch.set_default_device("cpu")
|
|
|
|
self.assertEqual(f().device.type, "cpu")
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
finally:
|
|
torch.set_default_device(None)
|
|
|
|
@skipIfHpu
|
|
@unittest.skipIf(
|
|
TEST_WITH_ROCM or not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
|
"flash attention not supported",
|
|
)
|
|
def test_flash_attn_backward_mixed_strides(self, device):
|
|
# in this repro, "grad_out" and "value" are transposed tensors,
|
|
# but "key" and "value" are contiguous
|
|
def gen_inputs(device):
|
|
return (
|
|
torch.randn(
|
|
2, 513, 16, 64, dtype=torch.float16, device=device
|
|
).transpose(1, 2),
|
|
torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
|
|
torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
|
|
torch.randn(
|
|
2, 513, 16, 64, dtype=torch.float16, device=device
|
|
).transpose(1, 2),
|
|
torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
|
|
torch.randn(2, 16, 513, device=device),
|
|
None,
|
|
None,
|
|
513,
|
|
513,
|
|
0.0,
|
|
False,
|
|
torch.tensor(1, dtype=torch.int64),
|
|
torch.tensor(1, dtype=torch.int64),
|
|
)
|
|
|
|
inps_device = gen_inputs(device)
|
|
inps_meta = gen_inputs("meta")
|
|
(
|
|
out1_ref,
|
|
out2_ref,
|
|
out3_ref,
|
|
) = torch.ops.aten._scaled_dot_product_flash_attention_backward(
|
|
*inps_device, scale=0.125
|
|
)
|
|
from torch._meta_registrations import meta__scaled_dot_product_flash_backward
|
|
|
|
out1_test, out2_test, out3_test = meta__scaled_dot_product_flash_backward(
|
|
*inps_meta, scale=0.125
|
|
)
|
|
|
|
self.assertEqual(out1_ref.shape, out1_test.shape)
|
|
self.assertEqual(out1_ref.stride(), out1_test.stride())
|
|
self.assertEqual(out2_ref.shape, out2_test.shape)
|
|
self.assertEqual(out2_ref.stride(), out2_test.stride())
|
|
self.assertEqual(out3_ref.shape, out3_test.shape)
|
|
self.assertEqual(out3_ref.stride(), out3_test.stride())
|
|
|
|
def test_megablocks_moe(self, device):
|
|
try:
|
|
from megablocks.layers import moe
|
|
from megablocks.layers.arguments import Arguments
|
|
except ImportError as e:
|
|
raise unittest.SkipTest("requires megablocks") from e
|
|
bs, sl, hs, num_experts, top_k = (16, 1024, 512, 1, 1)
|
|
args = Arguments(
|
|
hidden_size=hs,
|
|
ffn_hidden_size=hs * 2,
|
|
moe_num_experts=num_experts,
|
|
moe_capacity_factor=1,
|
|
moe_top_k=top_k,
|
|
)
|
|
moe_mlp = moe.MoE(args)
|
|
# moe_mlp.cuda(torch.cuda.current_device()).half()
|
|
moe_mlp.device(torch.device.current_device()).half()
|
|
x = torch.randn(sl, bs, hs).device().half()
|
|
out1, _ = moe_mlp(x)
|
|
out2, _ = torch.compile(moe_mlp, backend="eager")(x)
|
|
self.assertEqual(out1, out2)
|
|
|
|
def test_tensor_size_hasattr(self):
|
|
def fn(x):
|
|
if hasattr(x, "size"):
|
|
x = x * 2
|
|
if hasattr(x, "stride"):
|
|
x = x * 3
|
|
return x * 5
|
|
|
|
x = torch.ones(4)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
@requires_cuda
|
|
def test_memleak_when_graph_input_has_tensor_attr(self, device):
|
|
@torch.compile(backend="eager")
|
|
def f(x):
|
|
x.add_(1)
|
|
|
|
mem_before = torch.cuda.memory_allocated()
|
|
|
|
x = torch.ones(2, device=device)
|
|
x.foo = torch.zeros(2, device=device)
|
|
f(x)
|
|
del x.foo
|
|
del x
|
|
mem_after = torch.cuda.memory_allocated()
|
|
self.assertEqual(mem_before, mem_after)
|
|
|
|
# check when non-tensor data structure attribute contains a tensor
|
|
@torch.compile(backend="eager")
|
|
def f(x):
|
|
x.add_(1)
|
|
|
|
mem_before = torch.cuda.memory_allocated()
|
|
x = torch.ones(2, device=device)
|
|
x.foo = [torch.zeros(2, device=device) for _ in range(5)]
|
|
f(x)
|
|
del x.foo
|
|
del x
|
|
mem_after = torch.cuda.memory_allocated()
|
|
self.assertEqual(mem_before, mem_after)
|
|
|
|
# check with tensor refcycle
|
|
@torch.compile(backend="eager")
|
|
def g(x, y):
|
|
return x + y
|
|
|
|
mem_before = torch.cuda.memory_allocated()
|
|
x = torch.ones(2, device=device)
|
|
y = torch.zeros(2, device=device)
|
|
x.foo = [y]
|
|
y.foo = [x]
|
|
g(x, y)
|
|
del x.foo
|
|
del y.foo
|
|
del x
|
|
del y
|
|
mem_after = torch.cuda.memory_allocated()
|
|
self.assertEqual(mem_before, mem_after)
|
|
|
|
def test_udf_class_source(self):
|
|
class Foo:
|
|
pass
|
|
|
|
def fn(x):
|
|
foo = Foo()
|
|
bar = type(foo)() # noqa: F841
|
|
return torch.cos(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_truthiness_of_symints_no_recompiles(self, device):
|
|
def f(x):
|
|
numel = x.numel()
|
|
if numel:
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
f_compiled = torch.compile(f, backend=cnt, dynamic=True)
|
|
|
|
x1 = torch.randn(4)
|
|
_ = f_compiled(x1)
|
|
x2 = torch.randn(5)
|
|
_ = f_compiled(x2)
|
|
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
@requires_cuda
|
|
def test_sdpa_dynamic_shapes(self, device):
|
|
def f(x, s0, s1, s2):
|
|
q = x.view(2, s0, s2, s0)
|
|
return torch._C._nn.scaled_dot_product_attention(
|
|
q, q, q, attn_mask=None, dropout_p=0.0, is_causal=True
|
|
)
|
|
|
|
x = torch.randn(2, 32, 4096, dtype=torch.bfloat16, device=device)
|
|
x_ref = x.clone().detach().requires_grad_()
|
|
s0 = 32
|
|
s1 = 64
|
|
s2 = 128
|
|
|
|
f_compiled = torch.compile(f, dynamic=True)
|
|
|
|
with torch._dynamo.config.patch(assume_static_by_default=False):
|
|
out_ref = f(x_ref, s0, s1, s2)
|
|
out = f_compiled(x, s0, s1, s2)
|
|
self.assertEqual(out_ref, out)
|
|
|
|
@unittest.skipIf(not PLATFORM_SUPPORTS_FP8, "requires gpu with fp8 support")
|
|
@requires_cuda
|
|
def test_partitioner_saves_weights_for_bw(self):
|
|
def mul_tiled(a, *bs):
|
|
for b in bs:
|
|
a = a.unflatten(0, (b.shape[0], -1)).unflatten(-1, (b.shape[-1], -1))
|
|
a = a * b[:, None, :, None]
|
|
a = a.flatten(end_dim=1).flatten(start_dim=-2)
|
|
return a
|
|
|
|
def scale(t, amax_t):
|
|
max_v = torch.finfo(torch.float8_e4m3fn).max
|
|
scale_t = torch.clamp(amax_t.float(), min=1e-12) / max_v
|
|
t_fp8 = mul_tiled(t, scale_t.reciprocal()).to(torch.float8_e4m3fn)
|
|
return t_fp8, scale_t
|
|
|
|
def matmul(first, amax_first, second_t, amax_second_t, bias):
|
|
first_fp8, scale_first = scale(first, amax_first)
|
|
second_t_fp8, scale_second_t = scale(second_t, amax_second_t)
|
|
post_scales = []
|
|
post_bias = None
|
|
post_scales = [scale_first, scale_second_t.t()]
|
|
scale_first = scale_first.new_ones((1, 1))
|
|
scale_second_t = scale_second_t.t().new_ones((1, 1))
|
|
post_bias, bias = bias, None
|
|
res = torch._scaled_mm(
|
|
first_fp8,
|
|
second_t_fp8.t(),
|
|
scale_a=scale_first,
|
|
scale_b=scale_second_t.t(),
|
|
bias=bias,
|
|
out_dtype=torch.bfloat16,
|
|
use_fast_accum=False,
|
|
)
|
|
res = mul_tiled(res, *post_scales).to(torch.bfloat16)
|
|
if post_bias is not None:
|
|
res += post_bias
|
|
return res
|
|
|
|
@torch.compiler.allow_in_graph
|
|
class Fp8LinearFn(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, a, b_t, bias):
|
|
amax_a = a.abs().unflatten(-1, (1, -1)).amax(dim=-1)
|
|
amax_b_t = b_t.abs().unflatten(-1, (1, -1)).amax(dim=-1)
|
|
out = matmul(a, amax_a, b_t, amax_b_t, bias)
|
|
ctx.a_requires_grad = a.requires_grad
|
|
ctx.b_requires_grad = b_t.requires_grad
|
|
ctx.bias_requires_grad = (
|
|
bias.requires_grad if bias is not None else False
|
|
)
|
|
ctx.save_for_backward(a, b_t, amax_b_t)
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
a, b_t, amax_b_t = ctx.saved_tensors
|
|
# Workaround for https://github.com/pytorch/pytorch/issues/141881.
|
|
# The partitioner would pre-compute the transposed scaling of the weight
|
|
# in the forward (as it's most efficient, but it actually uses too much
|
|
# memory). We prevent that by making the scaling depend on the gradient
|
|
# in a way that has no effect and will be optimized away later.
|
|
# Care is needed to support tensor parallelism and circumvent bugs.
|
|
# b_t = b_t + grad_out[:1, :, None].squeeze(0) * 0
|
|
if ctx.a_requires_grad:
|
|
b = b_t.t().contiguous()
|
|
amax_grad_out = grad_out.abs().unflatten(-1, (1, -1)).amax(dim=-1)
|
|
amax_b = amax_b_t.t().unflatten(-1, (1, -1)).amax(dim=-1)
|
|
amax_b = amax_b.repeat_interleave(
|
|
b.shape[0] // amax_b.shape[0], dim=0, output_size=b.shape[0]
|
|
)
|
|
grad_a = matmul(grad_out, amax_grad_out, b, amax_b, None)
|
|
else:
|
|
grad_a = None
|
|
if ctx.b_requires_grad:
|
|
grad_b = grad_out.t() @ a
|
|
else:
|
|
grad_b = None
|
|
if ctx.bias_requires_grad:
|
|
grad_bias = grad_out.sum(dim=0)
|
|
else:
|
|
grad_bias = None
|
|
return grad_a, grad_b, grad_bias
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.a = torch.nn.Parameter(
|
|
torch.randn(
|
|
64, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True
|
|
)
|
|
)
|
|
self.b = torch.nn.Parameter(
|
|
torch.randn(
|
|
64, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True
|
|
)
|
|
)
|
|
self.bias = torch.nn.Parameter(
|
|
torch.randn(
|
|
64, dtype=torch.bfloat16, device="cuda", requires_grad=True
|
|
)
|
|
)
|
|
|
|
class CustomLinear(torch.nn.Linear):
|
|
def forward(self, input: torch.Tensor) -> torch.Tensor:
|
|
out = Fp8LinearFn.apply(
|
|
input.flatten(end_dim=-2), self.weight, self.bias
|
|
)
|
|
out = out.unflatten(0, input.shape[:-1])
|
|
return out
|
|
|
|
m = CustomLinear(64, 64, dtype=torch.bfloat16, device="cuda")
|
|
m = torch.compile(m, backend="aot_eager")
|
|
|
|
# simple mode to track how many collective ops we saw in the backward
|
|
class TrackingMode(TorchDispatchMode):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.ops_counter = defaultdict(int)
|
|
|
|
def __torch_dispatch__(self, func, types, args=(), kwargs=None):
|
|
if kwargs is None:
|
|
kwargs = {}
|
|
rs = func(*args, **kwargs)
|
|
self.ops_counter[func] += 1
|
|
return rs
|
|
|
|
a = torch.randn(64, 64, dtype=torch.bfloat16, device="cuda", requires_grad=True)
|
|
out = m(a)
|
|
with TrackingMode() as mode:
|
|
out.sum().backward()
|
|
# If you print out the AOT fw and bw graphs,
|
|
# the main thing to look for is that both weights (primals_1/primals_2)
|
|
# *are* saved for backward, and become back inputs.
|
|
# The easier-to-test thing I'm checking for here is that the recompute
|
|
# on primals_2 happens in the backward. With the recompute,
|
|
# there are 5 _to_copy ops in the backward. Without it, there are 4
|
|
# (aka if you set torch._functorch.config.treat_parameters_as_free_to_save = False)
|
|
self.assertEqual(mode.ops_counter[torch.ops.aten._to_copy.default], 5)
|
|
|
|
def test_getattr_return(self):
|
|
_WrapperDescriptor = type(type.__call__)
|
|
_MethodWrapper = type(all.__call__)
|
|
_ClassMethodWrapper = type(int.__dict__["from_bytes"])
|
|
|
|
_NonUserDefinedCallables = (
|
|
_WrapperDescriptor,
|
|
_MethodWrapper,
|
|
_ClassMethodWrapper,
|
|
types.BuiltinFunctionType,
|
|
)
|
|
|
|
def _signature_get_user_defined_method(cls, method_name):
|
|
try:
|
|
meth = getattr(cls, method_name)
|
|
except AttributeError:
|
|
return
|
|
else:
|
|
if not isinstance(meth, _NonUserDefinedCallables):
|
|
# Once '__signature__' will be added to 'C'-level
|
|
# callables, this check won't be necessary
|
|
return meth
|
|
|
|
def fn(x):
|
|
s = _signature_get_user_defined_method(type(torch.nn.Linear), "__call__")
|
|
if s is None:
|
|
return torch.cos(x)
|
|
|
|
return torch.sin(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_data_dependent_error_log_no_print(self):
|
|
# This is a regression test case for
|
|
# https://github.com/pytorch/pytorch/pull/149831
|
|
from io import StringIO
|
|
|
|
capturedOutput = StringIO()
|
|
sys.stderr = capturedOutput
|
|
|
|
@torch.compile(fullgraph=True)
|
|
def func(a):
|
|
if a.sum() > 0:
|
|
return a + 1
|
|
return a + 2
|
|
|
|
a = torch.rand(10, 10)
|
|
try:
|
|
func(a)
|
|
except Exception:
|
|
pass
|
|
sys.stderr = sys.__stderr__
|
|
|
|
# Make sure we don't _print_ out the graph module.
|
|
output = capturedOutput.getvalue()
|
|
self.assertNotIn("class GraphModule", output)
|
|
|
|
def test_deepcopy_constant_tensor_in_aot_bwd(self):
|
|
class Fn(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
return x + 1
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_out):
|
|
return grad_out * torch.tensor(2) * grad_out.shape[0]
|
|
|
|
def f(x):
|
|
return Fn.apply(x)
|
|
|
|
x = torch.randn(8, requires_grad=True)
|
|
out = f(x) # should not raise
|
|
c_out = torch.compile(f, backend="aot_eager", dynamic=True)(x)
|
|
expected = torch.autograd.grad(out.sum(), inputs=(x,))
|
|
actual = torch.autograd.grad(c_out.sum(), inputs=(x,))
|
|
self.assertEqual(expected, actual)
|
|
|
|
def test_module_attribute_error(self):
|
|
@torch.compile(backend="eager")
|
|
def f1(x):
|
|
return torch._bar(x)
|
|
|
|
@torch.compile(backend="eager")
|
|
def f2(x):
|
|
try:
|
|
return torch._bar(x)
|
|
except AttributeError:
|
|
return x + 1
|
|
|
|
with self.assertRaises(AttributeError):
|
|
f1(torch.ones(3))
|
|
|
|
self.assertEqual(f2(torch.ones(3)), torch.ones(3) + 1)
|
|
|
|
def test_torch_cuda_is_initialized(self):
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def f(x):
|
|
if torch.cuda.is_initialized():
|
|
return x + 1
|
|
return x + 2
|
|
|
|
inp = torch.randn(3)
|
|
self.assertEqual(f(inp), inp + 1)
|
|
|
|
with mock.patch("torch.cuda.is_initialized", lambda: False):
|
|
self.assertEqual(f(inp), inp + 2)
|
|
|
|
def test_named_tuple_vt_clone(self):
|
|
# https://github.com/pytorch/pytorch/issues/157945
|
|
class SVDCompressor(nn.Module):
|
|
def __init__(self, k=10):
|
|
super().__init__()
|
|
self.k = k
|
|
|
|
def forward(self, x):
|
|
U, S = torch.linalg.svd(x)[:2]
|
|
reduced = U[:, :, : self.k] @ torch.diag_embed(S[:, : self.k])
|
|
return reduced
|
|
|
|
input = torch.randn(4, 8, 6)
|
|
model = SVDCompressor(k=5)
|
|
|
|
out1 = model(input.clone())
|
|
out2 = torch.compile(model, backend="eager")(input.clone())
|
|
self.assertEqual(out1, out2)
|
|
|
|
|
|
instantiate_parametrized_tests(ReproTests)
|
|
|
|
devices = ["cuda", "hpu"]
|
|
instantiate_device_type_tests(ReproTestsDevice, globals(), only_for=devices)
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|