mirror of
https://github.com/pytorch/pytorch.git
synced 2025-10-20 21:14:14 +08:00
See https://github.com/pytorch/pytorch/pull/129751#issue-2380881501. Most changes are auto-generated by linter. You can review these PRs via: ```bash git diff --ignore-all-space --ignore-blank-lines HEAD~1 ``` Pull Request resolved: https://github.com/pytorch/pytorch/pull/129762 Approved by: https://github.com/anijain2305
5509 lines
174 KiB
Python
5509 lines
174 KiB
Python
"""
|
|
PYTEST_DONT_REWRITE (prevents pytest from rewriting assertions, which interferes
|
|
with test_rewrite_assert_with_msg and test_rewrite_assert_without_msg)
|
|
"""
|
|
|
|
# Owner(s): ["module: dynamo"]
|
|
import collections
|
|
import contextlib
|
|
import copy
|
|
import dataclasses
|
|
import functools
|
|
import gc
|
|
import inspect
|
|
import itertools
|
|
import random
|
|
import unittest
|
|
import warnings
|
|
import weakref
|
|
from abc import ABC
|
|
from collections import namedtuple
|
|
from copy import deepcopy
|
|
from enum import Enum
|
|
from functools import wraps
|
|
from typing import Any, Dict, Iterator, List, Tuple
|
|
from unittest import mock
|
|
|
|
import numpy as np
|
|
|
|
import torch
|
|
import torch._dynamo.test_case
|
|
import torch._dynamo.testing
|
|
import torch._dynamo.utils
|
|
import torch._functorch.config
|
|
import torch.library
|
|
import torch.utils._pytree as pytree
|
|
from torch import nn
|
|
from torch._dynamo.debug_utils import same_two_models
|
|
from torch._dynamo.testing import CompileCounter, rand_strided, same
|
|
from torch._inductor.utils import fresh_inductor_cache
|
|
from torch.nn import functional as F
|
|
from torch.testing._internal.common_cuda import PLATFORM_SUPPORTS_FLASH_ATTENTION
|
|
from torch.testing._internal.common_utils import (
|
|
disable_translation_validation_if_dynamic_shapes,
|
|
instantiate_parametrized_tests,
|
|
parametrize,
|
|
TEST_WITH_ROCM,
|
|
)
|
|
from torch.testing._internal.two_tensor import TwoTensor
|
|
|
|
|
|
_orig_module_call = torch.nn.Module.__call__
|
|
|
|
# Custom operator that only supports CPU and Meta
|
|
lib = torch.library.Library("test_sample", "DEF") # noqa: TOR901
|
|
lib.define("foo(Tensor self) -> Tensor")
|
|
lib.impl("foo", torch.sin, "CPU")
|
|
|
|
|
|
requires_cuda = unittest.skipUnless(torch.cuda.is_available(), "requires cuda")
|
|
|
|
|
|
_GLOBAL_CPU_TENSOR = torch.randn(3)
|
|
|
|
|
|
def exists(val):
|
|
return val is not None
|
|
|
|
|
|
def maybe(fn):
|
|
@wraps(fn)
|
|
def inner(x, *args, **kwargs):
|
|
if not exists(x):
|
|
return x
|
|
return fn(x, *args, **kwargs)
|
|
|
|
return inner
|
|
|
|
|
|
def is_fx_tracing_test() -> bool:
|
|
"""
|
|
Copied from the hpc trainer codebase
|
|
"""
|
|
return torch.nn.Module.__call__ is not _orig_module_call
|
|
|
|
|
|
def has_detectron2():
|
|
try:
|
|
from detectron2.layers.mask_ops import _paste_masks_tensor_shape
|
|
|
|
return _paste_masks_tensor_shape is not None
|
|
except ImportError:
|
|
return False
|
|
|
|
|
|
def _do_paste_mask(masks, boxes, img_h: int, img_w: int, skip_empty: bool = True):
|
|
# from detectron2 mask_ops.py
|
|
|
|
device = masks.device
|
|
|
|
if skip_empty and not torch.jit.is_scripting():
|
|
x0_int, y0_int = torch.clamp(boxes.min(dim=0).values.floor()[:2] - 1, min=0).to(
|
|
dtype=torch.int32
|
|
)
|
|
x1_int = torch.clamp(boxes[:, 2].max().ceil() + 1, max=img_w).to(
|
|
dtype=torch.int32
|
|
)
|
|
y1_int = torch.clamp(boxes[:, 3].max().ceil() + 1, max=img_h).to(
|
|
dtype=torch.int32
|
|
)
|
|
else:
|
|
x0_int, y0_int = 0, 0
|
|
x1_int, y1_int = img_w, img_h
|
|
x0, y0, x1, y1 = torch.split(boxes, 1, dim=1) # each is Nx1
|
|
|
|
N = masks.shape[0]
|
|
|
|
img_y = torch.arange(y0_int, y1_int, device=device, dtype=torch.float32) + 0.5
|
|
img_x = torch.arange(x0_int, x1_int, device=device, dtype=torch.float32) + 0.5
|
|
img_y = (img_y - y0) / (y1 - y0) * 2 - 1
|
|
img_x = (img_x - x0) / (x1 - x0) * 2 - 1
|
|
# img_x, img_y have shapes (N, w), (N, h)
|
|
|
|
gx = img_x[:, None, :].expand(N, img_y.size(1), img_x.size(1))
|
|
gy = img_y[:, :, None].expand(N, img_y.size(1), img_x.size(1))
|
|
grid = torch.stack([gx, gy], dim=3)
|
|
|
|
if not torch.jit.is_scripting():
|
|
if not masks.dtype.is_floating_point:
|
|
masks = masks.float()
|
|
img_masks = F.grid_sample(masks, grid.to(masks.dtype), align_corners=False)
|
|
|
|
if skip_empty and not torch.jit.is_scripting():
|
|
return img_masks[:, 0], (slice(y0_int, y1_int), slice(x0_int, x1_int))
|
|
else:
|
|
return img_masks[:, 0], ()
|
|
|
|
|
|
def global_fn(x):
|
|
return torch.sin(x)
|
|
|
|
|
|
def cat(tensors, dim=0):
|
|
# from detectron2 wrappers.py
|
|
assert isinstance(tensors, (list, tuple))
|
|
if len(tensors) == 1:
|
|
return tensors[0]
|
|
return torch.cat(tensors, dim)
|
|
|
|
|
|
def shapes_to_tensor(x, device=None):
|
|
# from detectron2 wrappers.py
|
|
if torch.jit.is_scripting():
|
|
return torch.as_tensor(x, device=device)
|
|
if torch.jit.is_tracing():
|
|
assert all(
|
|
isinstance(t, torch.Tensor) for t in x
|
|
), "Shape should be tensor during tracing!"
|
|
# as_tensor should not be used in tracing because it records a constant
|
|
ret = torch.stack(x)
|
|
if ret.device != device: # avoid recording a hard-coded device if not necessary
|
|
ret = ret.to(device=device)
|
|
return ret
|
|
return torch.as_tensor(x, device=device)
|
|
|
|
|
|
fw_graph = [None]
|
|
bw_graph = [None]
|
|
|
|
|
|
def aot_graph_capture_backend(gm, args):
|
|
from functorch.compile import min_cut_rematerialization_partition
|
|
from torch._functorch.aot_autograd import aot_module_simplified
|
|
|
|
def fw_compiler(gm, _):
|
|
fw_graph[0] = gm
|
|
return gm
|
|
|
|
def bw_compiler(gm, _):
|
|
bw_graph[0] = gm
|
|
return gm
|
|
|
|
return aot_module_simplified(
|
|
gm,
|
|
args,
|
|
fw_compiler,
|
|
bw_compiler,
|
|
partition_fn=min_cut_rematerialization_partition,
|
|
keep_inference_input_mutations=True,
|
|
)
|
|
|
|
|
|
class Boxes:
|
|
# from detectron2 poolers.py
|
|
def __init__(self, tensor: torch.Tensor):
|
|
"""
|
|
Args:
|
|
tensor (Tensor[float]): a Nx4 matrix. Each row is (x1, y1, x2, y2).
|
|
"""
|
|
device = (
|
|
tensor.device if isinstance(tensor, torch.Tensor) else torch.device("cpu")
|
|
)
|
|
tensor = torch.as_tensor(tensor, dtype=torch.float32, device=device)
|
|
if tensor.numel() == 0:
|
|
# Use reshape, so we don't end up creating a new tensor that does not depend on
|
|
# the inputs (and consequently confuses jit)
|
|
tensor = tensor.reshape((-1, 4)).to(dtype=torch.float32, device=device)
|
|
assert tensor.dim() == 2 and tensor.size(-1) == 4, tensor.size()
|
|
self.tensor = tensor
|
|
|
|
def __len__(self) -> int:
|
|
return self.tensor.shape[0]
|
|
|
|
@property
|
|
def device(self):
|
|
return self.tensor.device
|
|
|
|
|
|
def convert_boxes_to_pooler_format(box_lists):
|
|
# from detectron2 structures.py
|
|
boxes = torch.cat([x.tensor for x in box_lists], dim=0)
|
|
# __len__ returns Tensor in tracing.
|
|
sizes = shapes_to_tensor([x.__len__() for x in box_lists], device=boxes.device)
|
|
indices = torch.repeat_interleave(
|
|
torch.arange(len(box_lists), dtype=boxes.dtype, device=boxes.device), sizes
|
|
)
|
|
return cat([indices[:, None], boxes], dim=1)
|
|
|
|
|
|
ReformerBackwardOutput = namedtuple(
|
|
"ReformerBackwardOutput",
|
|
["attn_output", "hidden_states", "grad_attn_output", "grad_hidden_states"],
|
|
)
|
|
ReformerEncoderOutput = namedtuple(
|
|
"ReformerEncoderOutput",
|
|
["hidden_states", "all_hidden_states", "all_attentions", "past_buckets_states"],
|
|
)
|
|
|
|
|
|
class _ReversibleFunction(torch.autograd.Function):
|
|
# taken from modeling_reformer.py in huggingface
|
|
@staticmethod
|
|
def forward(
|
|
ctx,
|
|
hidden_states,
|
|
layers,
|
|
attention_mask,
|
|
head_mask,
|
|
num_hashes,
|
|
all_hidden_states,
|
|
all_attentions,
|
|
past_buckets_states,
|
|
use_cache,
|
|
orig_sequence_length,
|
|
output_hidden_states,
|
|
output_attentions,
|
|
):
|
|
all_buckets = ()
|
|
|
|
# split duplicated tensor
|
|
hidden_states, attn_output = torch.chunk(hidden_states, 2, dim=-1)
|
|
|
|
for layer_id, (layer, layer_head_mask) in enumerate(zip(layers, head_mask)):
|
|
if output_hidden_states is True:
|
|
all_hidden_states.append(hidden_states)
|
|
|
|
attn_output = layer(attn_output)
|
|
all_buckets = all_buckets + (attn_output,)
|
|
|
|
# Add last layer
|
|
if output_hidden_states is True:
|
|
all_hidden_states.append(hidden_states)
|
|
|
|
# attach params to ctx for backward
|
|
ctx.save_for_backward(attn_output.detach(), hidden_states.detach())
|
|
ctx.layers = layers
|
|
ctx.all_buckets = all_buckets
|
|
ctx.head_mask = head_mask
|
|
ctx.attention_mask = attention_mask
|
|
|
|
# Concatenate 2 RevNet outputs
|
|
return torch.cat([attn_output, hidden_states], dim=-1)
|
|
|
|
@staticmethod
|
|
def backward(ctx, grad_hidden_states):
|
|
grad_attn_output, grad_hidden_states = torch.chunk(
|
|
grad_hidden_states, 2, dim=-1
|
|
)
|
|
|
|
# free memory
|
|
del grad_attn_output
|
|
|
|
# num of return vars has to match num of forward() args
|
|
# return gradient for hidden_states arg and None for other args
|
|
return (
|
|
grad_hidden_states,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
None,
|
|
)
|
|
|
|
|
|
class ReformerEncoder(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.dropout = 0.5
|
|
self.layer_norm = torch.nn.LayerNorm(512, eps=1.0e-12)
|
|
self.layers = [torch.nn.Linear(256, 256)]
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
attention_mask=None,
|
|
head_mask=[None] * 6,
|
|
num_hashes=None,
|
|
use_cache=False,
|
|
orig_sequence_length=64,
|
|
output_hidden_states=False,
|
|
output_attentions=False,
|
|
):
|
|
# hidden_states and attention lists to be filled if wished
|
|
all_hidden_states = []
|
|
all_attentions = []
|
|
past_buckets_states = [((None), (None)) for i in range(len(self.layers))]
|
|
|
|
# concat same tensor for reversible ResNet
|
|
hidden_states = torch.cat([hidden_states, hidden_states], dim=-1)
|
|
hidden_states = _ReversibleFunction.apply(
|
|
hidden_states,
|
|
self.layers,
|
|
attention_mask,
|
|
head_mask,
|
|
num_hashes,
|
|
all_hidden_states,
|
|
all_attentions,
|
|
past_buckets_states,
|
|
use_cache,
|
|
orig_sequence_length,
|
|
output_hidden_states,
|
|
output_attentions,
|
|
)
|
|
|
|
# Apply layer norm to concatenated hidden states
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
|
|
# Apply dropout
|
|
hidden_states = torch.nn.functional.dropout(
|
|
hidden_states, p=self.dropout, training=self.training
|
|
)
|
|
|
|
return ReformerEncoderOutput(
|
|
hidden_states=hidden_states,
|
|
all_hidden_states=all_hidden_states,
|
|
all_attentions=all_attentions,
|
|
past_buckets_states=past_buckets_states,
|
|
)
|
|
|
|
|
|
class ListConfig:
|
|
class ValueNode:
|
|
def __init__(self, value):
|
|
self.value = value
|
|
|
|
def _dereference_node(self):
|
|
return self
|
|
|
|
def _is_missing(self):
|
|
return False
|
|
|
|
def _value(self):
|
|
return self.value
|
|
|
|
# Based on an example from omegaconfig.listconfig
|
|
class ListIterator(Iterator[Any]):
|
|
def __init__(self, lst: Any, resolve: bool) -> None:
|
|
self.resolve = resolve
|
|
self.iterator = iter(lst.__dict__["_content"])
|
|
self.index = 0
|
|
|
|
def __next__(self) -> Any:
|
|
x = next(self.iterator)
|
|
if self.resolve:
|
|
x = x._dereference_node()
|
|
if x._is_missing():
|
|
raise AssertionError
|
|
|
|
self.index = self.index + 1
|
|
if isinstance(x, ListConfig.ValueNode):
|
|
return x._value()
|
|
raise AssertionError
|
|
|
|
def __iter__(self):
|
|
return self._iter_ex(True)
|
|
|
|
def _iter_ex(self, resolve: bool) -> Iterator[Any]:
|
|
try:
|
|
return ListConfig.ListIterator(self, resolve)
|
|
except Exception:
|
|
raise AssertionError from None
|
|
|
|
def __init__(self):
|
|
self._content = [
|
|
ListConfig.ValueNode(1),
|
|
ListConfig.ValueNode(3),
|
|
ListConfig.ValueNode(torch.tensor([7.0])),
|
|
]
|
|
|
|
|
|
def longformer_chunk(hidden_states, window_overlap=256):
|
|
"""convert into overlapping chunks. Chunk size = 2w, overlap size = w"""
|
|
|
|
# non-overlapping chunks of size = 2w
|
|
hidden_states = hidden_states.view(
|
|
hidden_states.size(0),
|
|
hidden_states.size(1) // (window_overlap * 2),
|
|
window_overlap * 2,
|
|
hidden_states.size(2),
|
|
)
|
|
|
|
# use `as_strided` to make the chunks overlap with an overlap size = window_overlap
|
|
chunk_size = list(hidden_states.size())
|
|
chunk_size[1] = chunk_size[1] * 2 - 1
|
|
|
|
chunk_stride = list(hidden_states.stride())
|
|
chunk_stride[1] = chunk_stride[1] // 2
|
|
return hidden_states.as_strided(size=chunk_size, stride=chunk_stride)
|
|
|
|
|
|
class PartialT5(torch.nn.Module):
|
|
# Highly simplified T5Attention prefix
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.q = torch.nn.Linear(512, 512)
|
|
self.k = torch.nn.Linear(512, 512)
|
|
self.v = torch.nn.Linear(512, 512)
|
|
|
|
def forward(
|
|
self,
|
|
hidden_states,
|
|
key_value_states=None,
|
|
past_key_value=None,
|
|
query_length=None,
|
|
):
|
|
batch_size, seq_length = hidden_states.shape[:2]
|
|
|
|
real_seq_length = seq_length
|
|
|
|
if past_key_value is not None:
|
|
assert (
|
|
len(past_key_value) == 2
|
|
), f"past_key_value should have 2 past states: keys and values. Got { len(past_key_value)} past states"
|
|
real_seq_length += (
|
|
past_key_value[0].shape[2] if query_length is None else query_length
|
|
)
|
|
|
|
def shape(states):
|
|
"""projection"""
|
|
return states.view(batch_size, -1, 8, 64).transpose(1, 2)
|
|
|
|
def project(hidden_states, proj_layer, key_value_states, past_key_value):
|
|
"""projects hidden states correctly to key/query states"""
|
|
if key_value_states is None:
|
|
# self-attn
|
|
# (batch_size, n_heads, seq_length, dim_per_head)
|
|
hidden_states = shape(proj_layer(hidden_states))
|
|
elif past_key_value is None:
|
|
# cross-attn
|
|
# (batch_size, n_heads, seq_length, dim_per_head)
|
|
hidden_states = shape(proj_layer(key_value_states))
|
|
|
|
if past_key_value is not None:
|
|
if key_value_states is None:
|
|
# self-attn
|
|
# (batch_size, n_heads, key_length, dim_per_head)
|
|
hidden_states = torch.cat([past_key_value, hidden_states], dim=2)
|
|
else:
|
|
# cross-attn
|
|
hidden_states = past_key_value
|
|
return hidden_states
|
|
|
|
# get query states
|
|
query_states = shape(
|
|
self.q(hidden_states)
|
|
) # (batch_size, n_heads, seq_length, dim_per_head)
|
|
|
|
# get key/value states
|
|
key_states = project(
|
|
hidden_states,
|
|
self.k,
|
|
key_value_states,
|
|
past_key_value[0] if past_key_value is not None else None,
|
|
)
|
|
value_states = project(
|
|
hidden_states,
|
|
self.v,
|
|
key_value_states,
|
|
past_key_value[1] if past_key_value is not None else None,
|
|
)
|
|
|
|
# compute scores
|
|
scores = torch.matmul(query_states, key_states.transpose(3, 2))
|
|
|
|
# (truncated here )
|
|
return scores, value_states
|
|
|
|
|
|
class ChunkReformerFeedForward(torch.nn.Module):
|
|
# simplified from HF modeling_reformer.py
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.layer_norm = torch.nn.LayerNorm(256, eps=1e-12)
|
|
self.dense = torch.nn.Linear(256, 256)
|
|
self.output = torch.nn.Linear(256, 256)
|
|
|
|
def forward(self, attention_output):
|
|
return apply_chunking_to_forward(
|
|
self.forward_chunk,
|
|
attention_output + 1,
|
|
)
|
|
|
|
def forward_chunk(self, hidden_states):
|
|
hidden_states = self.layer_norm(hidden_states)
|
|
hidden_states = self.dense(hidden_states)
|
|
return self.output(hidden_states)
|
|
|
|
|
|
def apply_chunking_to_forward(forward_fn, *input_tensors):
|
|
# simplified from HF model_utils.py
|
|
assert len(input_tensors) > 0
|
|
tensor_shape = input_tensors[0].shape[1]
|
|
assert all(input_tensor.shape[1] == tensor_shape for input_tensor in input_tensors)
|
|
num_args_in_forward_chunk_fn = len(inspect.signature(forward_fn).parameters)
|
|
if num_args_in_forward_chunk_fn != len(input_tensors):
|
|
raise ValueError
|
|
|
|
return forward_fn(*input_tensors)
|
|
|
|
|
|
def _validate_model_kwargs(fn, model_kwargs):
|
|
# simplified from transformers.generation.utils._validate_model_kwargs
|
|
unused_model_args = []
|
|
model_args = set(inspect.signature(fn).parameters)
|
|
for key, value in model_kwargs.items():
|
|
if value is not None and key not in model_args:
|
|
unused_model_args.append(key)
|
|
if unused_model_args:
|
|
raise ValueError(
|
|
f"The following `model_kwargs` are not used by the model: {unused_model_args} (note: typos in the"
|
|
" generate arguments will also show up in this list)"
|
|
)
|
|
|
|
|
|
class FakeMamlInner(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(784, 5)
|
|
|
|
def forward(self, x, ignored=None, bn_training=False):
|
|
return self.linear(x.view(x.shape[0], -1))
|
|
|
|
|
|
class PartialMaml(torch.nn.Module):
|
|
# Highly simplified version of maml.meta.Meta.finetuning
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.net = FakeMamlInner()
|
|
self.update_step_test = 10
|
|
self.update_lr = 0.4
|
|
|
|
def forward(self, x_spt, y_spt, x_qry, y_qry):
|
|
querysz = x_qry.size(0)
|
|
|
|
corrects = [0 for _ in range(self.update_step_test + 1)]
|
|
|
|
# in order to not ruin the state of running_mean/variance and bn_weight/bias
|
|
# we finetuning on the copied model instead of self.net
|
|
net = deepcopy(self.net)
|
|
|
|
# 1. run the i-th task and compute loss for k=0
|
|
logits = net(x_spt)
|
|
loss = F.cross_entropy(logits, y_spt)
|
|
grad = torch.autograd.grad(loss, net.parameters())
|
|
fast_weights = [
|
|
p[1] - self.update_lr * p[0] for p in zip(grad, net.parameters())
|
|
]
|
|
|
|
# this is the loss and accuracy before first update
|
|
with torch.no_grad():
|
|
# [setsz, nway]
|
|
logits_q = net(x_qry, net.parameters(), bn_training=True)
|
|
# [setsz]
|
|
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
|
|
# scalar
|
|
correct = torch.eq(pred_q, y_qry).sum().item()
|
|
corrects[0] = corrects[0] + correct
|
|
|
|
# this is the loss and accuracy after the first update
|
|
with torch.no_grad():
|
|
# [setsz, nway]
|
|
logits_q = net(x_qry, fast_weights, bn_training=True)
|
|
# [setsz]
|
|
pred_q = F.softmax(logits_q, dim=1).argmax(dim=1)
|
|
# scalar
|
|
correct = torch.eq(pred_q, y_qry).sum().item()
|
|
corrects[1] = corrects[1] + correct
|
|
|
|
del net
|
|
|
|
accs = torch.tensor(corrects) / querysz
|
|
|
|
return accs
|
|
|
|
|
|
def softmax_backward_data(parent, grad_output, output, dim, self):
|
|
from torch import _softmax_backward_data
|
|
|
|
return _softmax_backward_data(grad_output, output, parent.dim, self.dtype)
|
|
|
|
|
|
class XSoftmax(torch.autograd.Function):
|
|
# transformers.models.deberta.modeling_deberta.XSoftmax
|
|
@staticmethod
|
|
def forward(self, input, mask, dim):
|
|
self.dim = dim
|
|
rmask = ~(mask.to(torch.bool))
|
|
output = input.masked_fill(rmask, torch.tensor(torch.finfo(input.dtype).min))
|
|
output = torch.softmax(output, self.dim)
|
|
output.masked_fill_(rmask, 0)
|
|
self.save_for_backward(output, rmask)
|
|
return output
|
|
|
|
@staticmethod
|
|
def backward(self, grad_output):
|
|
(output, rmask) = self.saved_tensors
|
|
inputGrad = softmax_backward_data(self, grad_output, output, self.dim, output)
|
|
return inputGrad, None, None
|
|
|
|
|
|
class ModelOutput(collections.OrderedDict):
|
|
"""based on file_utils.py in HuggingFace"""
|
|
|
|
def __getitem__(self, k):
|
|
if isinstance(k, str):
|
|
inner_dict = dict(self.items())
|
|
return inner_dict[k]
|
|
else:
|
|
return self.to_tuple()[k]
|
|
|
|
def __setattr__(self, name, value):
|
|
if name in self.keys() and value is not None:
|
|
# Don't call self.__setitem__ to avoid recursion errors
|
|
super().__setitem__(name, value)
|
|
super().__setattr__(name, value)
|
|
|
|
def __setitem__(self, key, value):
|
|
# Will raise a KeyException if needed
|
|
super().__setitem__(key, value)
|
|
# Don't call self.__setattr__ to avoid recursion errors
|
|
super().__setattr__(key, value)
|
|
|
|
def to_tuple(self):
|
|
return tuple(self[k] for k in self.keys())
|
|
|
|
|
|
def create_rand_mask_from_inputs(
|
|
from_blocked_mask,
|
|
to_blocked_mask,
|
|
rand_attn,
|
|
num_attention_heads,
|
|
num_rand_blocks,
|
|
batch_size,
|
|
from_seq_length,
|
|
from_block_size,
|
|
):
|
|
"""taken from HF modeling_big_bird.py"""
|
|
num_windows = from_seq_length // from_block_size - 2
|
|
rand_mask = torch.stack(
|
|
[p1[i1.flatten()] for p1, i1 in zip(to_blocked_mask, rand_attn)]
|
|
)
|
|
rand_mask = rand_mask.view(
|
|
batch_size, num_attention_heads, num_windows, num_rand_blocks * from_block_size
|
|
)
|
|
rand_mask = torch.einsum("blq,bhlk->bhlqk", from_blocked_mask[:, 1:-1], rand_mask)
|
|
return rand_mask
|
|
|
|
|
|
class SequentialAppendList(torch.nn.Sequential):
|
|
"""from timm/models/vovnet.py"""
|
|
|
|
def forward(self, x: torch.Tensor, concat_list: List[torch.Tensor]) -> torch.Tensor:
|
|
for i, module in enumerate(self):
|
|
if i == 0:
|
|
concat_list.append(module(x))
|
|
else:
|
|
concat_list.append(module(concat_list[-1]))
|
|
x = torch.cat(concat_list, dim=1)
|
|
return x, concat_list
|
|
|
|
|
|
class BatchNormAct2d(torch.nn.BatchNorm2d):
|
|
"""Taken from timm"""
|
|
|
|
def __init__(
|
|
self,
|
|
num_features,
|
|
eps=1e-5,
|
|
momentum=0.1,
|
|
affine=True,
|
|
track_running_stats=True,
|
|
act_layer=torch.nn.ReLU,
|
|
inplace=True,
|
|
):
|
|
super().__init__(
|
|
num_features,
|
|
eps=eps,
|
|
momentum=momentum,
|
|
affine=affine,
|
|
track_running_stats=track_running_stats,
|
|
)
|
|
self.act = act_layer(inplace=inplace)
|
|
|
|
@torch.jit.ignore
|
|
def _forward_python(self, x):
|
|
return super().forward(x)
|
|
|
|
def forward(self, x):
|
|
if torch.jit.is_scripting():
|
|
x = self._forward_jit(x)
|
|
else:
|
|
x = self._forward_python(x)
|
|
x = self.act(x)
|
|
return x
|
|
|
|
|
|
def get_parameter_dtype(parameter):
|
|
"""from huggingface model_utils.py"""
|
|
try:
|
|
return next(parameter.parameters()).dtype
|
|
except StopIteration:
|
|
# For nn.DataParallel compatibility in PyTorch 1.5
|
|
|
|
def find_tensor_attributes(module):
|
|
tuples = [(k, v) for k, v in module.__dict__.items() if torch.is_tensor(v)]
|
|
return tuples
|
|
|
|
gen = parameter._named_members(get_members_fn=find_tensor_attributes)
|
|
first_tuple = next(gen)
|
|
return first_tuple[1].dtype
|
|
|
|
|
|
class DummyConfig:
|
|
attn_layers = ["local", "lsh", "local", "lsh", "local", "lsh"]
|
|
lsh_attn_chunk_length = 64
|
|
local_attn_chunk_length = 64
|
|
|
|
|
|
def _get_min_chunk_len(config):
|
|
"""from hf_Reformer"""
|
|
attn_types = config.attn_layers
|
|
attn_types_set = set(attn_types)
|
|
if len(attn_types_set) == 1 and attn_types[0] == "lsh":
|
|
return config.lsh_attn_chunk_length
|
|
elif len(attn_types_set) == 1 and attn_types[0] == "local":
|
|
return config.local_attn_chunk_length
|
|
elif len(attn_types_set) == 2 and attn_types_set == {"lsh", "local"}:
|
|
return min(config.lsh_attn_chunk_length, config.local_attn_chunk_length)
|
|
else:
|
|
raise NotImplementedError(
|
|
f"Only attn layer types 'lsh' and 'local' exist, but `config.attn_layers`: {config.attn_layers}. Select "
|
|
"attn layer types from ['lsh', 'local'] only."
|
|
)
|
|
|
|
|
|
def _stable_argsort(vector, dim):
|
|
"""from hf_Reformer"""
|
|
# this function scales the vector so that torch.argsort is stable.
|
|
# torch.argsort is not stable on its own
|
|
scale_offset = torch.arange(vector.shape[dim], device=vector.device).view(1, 1, -1)
|
|
scale_offset = scale_offset.expand(vector.shape)
|
|
scaled_vector = vector.shape[dim] * vector + (scale_offset % vector.shape[dim])
|
|
return torch.argsort(scaled_vector, dim=dim)
|
|
|
|
|
|
def _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(buckets):
|
|
"""from hf_Reformer"""
|
|
# no gradients are needed
|
|
with torch.no_grad():
|
|
# hash-based sort
|
|
sorted_bucket_idx = _stable_argsort(buckets, dim=-1)
|
|
|
|
# create simple indices to scatter to, to have undo sort
|
|
indices = (
|
|
torch.arange(sorted_bucket_idx.shape[-1], device=buckets.device)
|
|
.view(1, 1, -1)
|
|
.expand(sorted_bucket_idx.shape)
|
|
)
|
|
|
|
# get undo sort
|
|
undo_sorted_bucket_idx = sorted_bucket_idx.new(*sorted_bucket_idx.size())
|
|
undo_sorted_bucket_idx.scatter_(-1, sorted_bucket_idx, indices)
|
|
|
|
return sorted_bucket_idx, undo_sorted_bucket_idx
|
|
|
|
|
|
class CustomList1(list):
|
|
def __call__(self, x):
|
|
for processor in self:
|
|
x = processor(x)
|
|
return x
|
|
|
|
def clear(self):
|
|
pass # this prevents RestrictedListSubclassVariable from kicking in
|
|
|
|
|
|
class CustomList2(list):
|
|
def __call__(self, x):
|
|
for processor in self:
|
|
x = processor(x)
|
|
return x
|
|
|
|
def length_times_10(self):
|
|
return len(self) * 10
|
|
|
|
def append_twice(self, x):
|
|
self.extend([x, x])
|
|
|
|
|
|
def _merge_criteria_processor_list(default_list, custom_list):
|
|
# simplified transformers/generation/utils.py
|
|
if len(custom_list) == 0:
|
|
return default_list
|
|
for default in default_list:
|
|
for custom in custom_list:
|
|
if type(custom) is type(default):
|
|
raise ValueError
|
|
default_list.extend(custom_list)
|
|
return default_list
|
|
|
|
|
|
class FeedForwardLayer(nn.Module):
|
|
def __init__(self, d_model, dim_feedforward, activation, dropout) -> None:
|
|
super().__init__()
|
|
self.linear1 = nn.Linear(d_model, dim_feedforward)
|
|
self.activation = activation
|
|
self.dropout1 = nn.Dropout(dropout)
|
|
self.linear2 = nn.Linear(dim_feedforward, d_model)
|
|
self.dropout2 = nn.Dropout(dropout)
|
|
|
|
def forward(self, x):
|
|
return self.dropout2(
|
|
self.linear2(self.dropout1(self.activation(self.linear1(x))))
|
|
)
|
|
|
|
|
|
class TransformerEncoderLayer(nn.Module):
|
|
def __init__(
|
|
self,
|
|
d_model,
|
|
nhead,
|
|
dim_feedforward=2048,
|
|
dropout=0.1,
|
|
activation=nn.ReLU(),
|
|
layer_norm_eps=1e-5,
|
|
):
|
|
super().__init__()
|
|
self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=dropout)
|
|
self.norm1 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
self.norm2 = nn.LayerNorm(d_model, eps=layer_norm_eps)
|
|
self.dropout = nn.Dropout(dropout)
|
|
self.ff_block = FeedForwardLayer(d_model, dim_feedforward, activation, dropout)
|
|
|
|
def forward(self, src, src_mask=None, src_key_padding_mask=None):
|
|
x = src
|
|
x = self.norm1(x + self._sa_block(x, src_mask, src_key_padding_mask))
|
|
x = self.norm2(x + self._ff_block(x))
|
|
return x
|
|
|
|
# self-attention block
|
|
def _sa_block(self, x, attn_mask, key_padding_mask):
|
|
x = self.self_attn(
|
|
x,
|
|
x,
|
|
x,
|
|
attn_mask=attn_mask,
|
|
key_padding_mask=key_padding_mask,
|
|
need_weights=False,
|
|
)[0]
|
|
return self.dropout(x)
|
|
|
|
# feed forward block
|
|
def _ff_block(self, x):
|
|
return self.ff_block(x)
|
|
|
|
|
|
class MockModule(torch.nn.Module):
|
|
def inner_fn(self, left, right):
|
|
return tuple(left) == tuple(right)
|
|
|
|
def fn(self, tensor):
|
|
if type(tensor) is int:
|
|
return False
|
|
|
|
torch.add(tensor, tensor)
|
|
return self.inner_fn(tensor.shape, (1, 2, 3))
|
|
|
|
|
|
class IncByOne:
|
|
def __init__(self, x):
|
|
self.x = x + 1
|
|
|
|
|
|
class IncByTwo:
|
|
def __init__(self, x):
|
|
self.x = x + 2
|
|
|
|
|
|
class ReproTests(torch._dynamo.test_case.TestCase):
|
|
def test_do_paste_mask(self):
|
|
torch._dynamo.utils.counters.clear()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt__do_paste_mask = torch.compile(_do_paste_mask, backend=cnt)
|
|
opt__do_paste_mask(
|
|
torch.randn(1, 1, 28, 28),
|
|
torch.tensor([[0.0, 1, 2, 4]]) * 1,
|
|
427,
|
|
640,
|
|
True,
|
|
)
|
|
opt__do_paste_mask(
|
|
torch.randn(1, 1, 28, 28),
|
|
torch.tensor([[0.0, 1, 2, 4]]) * 2,
|
|
427,
|
|
640,
|
|
True,
|
|
)
|
|
opt__do_paste_mask(
|
|
torch.randn(1, 1, 28, 28),
|
|
torch.tensor([[0.0, 1, 2, 4]]) * 3,
|
|
612,
|
|
612,
|
|
True,
|
|
)
|
|
opt__do_paste_mask(
|
|
torch.randn(1, 1, 28, 28),
|
|
torch.tensor([[0.0, 1, 2, 4]]) * 4,
|
|
612,
|
|
612,
|
|
True,
|
|
)
|
|
opt__do_paste_mask(
|
|
torch.randn(1, 1, 28, 28),
|
|
torch.tensor([[0.0, 1, 2, 4]]) * 2,
|
|
427,
|
|
640,
|
|
False,
|
|
)
|
|
# (dynamic shapes, static shapes)
|
|
self.assertIn(cnt.frame_count, (5, 7))
|
|
self.assertIn(cnt.op_count, (94, 106, 121))
|
|
|
|
def test_convert_boxes_to_pooler_format(self):
|
|
boxes1 = [
|
|
Boxes(torch.arange(0, 8).reshape((2, 4))),
|
|
Boxes(torch.arange(8, 16).reshape((2, 4))),
|
|
]
|
|
boxes2 = [
|
|
Boxes(torch.arange(16, 20).reshape((1, 4))),
|
|
Boxes(torch.arange(20, 24).reshape((1, 4))),
|
|
]
|
|
correct1 = convert_boxes_to_pooler_format(boxes1)
|
|
correct2 = convert_boxes_to_pooler_format(boxes2)
|
|
fn = convert_boxes_to_pooler_format
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
self.assertTrue(same(opt_fn(boxes1), correct1))
|
|
self.assertTrue(same(opt_fn(boxes2), correct2))
|
|
|
|
# repeat_interleave is a dynamic shape operator we do not execute/
|
|
# In the future, we could reduce the frame_count down to 1
|
|
# by guarding on the exact values of `Tensor repeats` arg
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """4""")
|
|
self.assertExpectedInline(cnt.op_count, """10""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """4""")
|
|
self.assertExpectedInline(cnt.op_count, """14""")
|
|
|
|
def test_boxes_len(self):
|
|
def fn(boxes):
|
|
return len(boxes) + boxes.__len__() + boxes.tensor
|
|
|
|
boxes1 = Boxes(torch.arange(0, 8).reshape((2, 4)))
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertTrue(same(opt_fn(boxes1), boxes1.tensor + 4.0))
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """1""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """2""")
|
|
|
|
def _reformer(self, nopython):
|
|
input = torch.randn([1, 64, 256])
|
|
model = ReformerEncoder()
|
|
torch.manual_seed(1337)
|
|
correct = copy.deepcopy(model)(input)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
torch.manual_seed(1337)
|
|
opt_model = torch._dynamo.optimize(cnt, nopython=nopython)(model)
|
|
self.assertTrue(same(opt_model(input), correct))
|
|
return cnt
|
|
|
|
@requires_cuda
|
|
def test_sub_alpha_scalar_repro(self):
|
|
@torch.compile(backend="aot_eager")
|
|
def f(x):
|
|
return x.sub(1, alpha=2)
|
|
|
|
f(torch.ones(2, device="cuda", dtype=torch.float64))
|
|
|
|
# https://github.com/pytorch/pytorch/issues/113010
|
|
def test_out_overload_non_contiguous(self):
|
|
def f(x, y):
|
|
return torch.abs(x, out=y.T)
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
|
|
x_ref = torch.arange(4, dtype=torch.float32).reshape(2, 2)
|
|
y_ref = torch.arange(4, dtype=torch.float32).reshape(2, 2)
|
|
x_test = torch.arange(4, dtype=torch.float32).reshape(2, 2)
|
|
y_test = torch.arange(4, dtype=torch.float32).reshape(2, 2)
|
|
|
|
out_ref = f(x_ref, y_ref)
|
|
out_test = f_compiled(x_test, y_test)
|
|
self.assertEqual(out_ref, out_test)
|
|
self.assertEqual(y_ref, y_test)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/109053
|
|
def test_view_dtype_overload(self):
|
|
def f(x):
|
|
return x.view(torch.int32)
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
|
|
x1 = torch.ones(4, requires_grad=True)
|
|
out_ref = f(x1)
|
|
out_test = f_compiled(x1)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
x2 = torch.ones(4, requires_grad=False)
|
|
out_ref = f(x2)
|
|
out_test = f_compiled(x2)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/90552
|
|
def test_intermediate_leaf_requires_grad(self):
|
|
def f(x):
|
|
leaf = torch.ones(2, requires_grad=True)
|
|
return leaf, leaf * 2
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
x = torch.arange(4, dtype=torch.float32).reshape(2, 2)
|
|
|
|
leaf, out = f(x)
|
|
leaf_test, out_test = f_compiled(x)
|
|
out.sum().backward()
|
|
out_test.sum().backward()
|
|
self.assertEqual(leaf.grad, leaf_test.grad)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/113263
|
|
def test_unpack_hooks_dont_run_during_tracing(self):
|
|
def f(x, y):
|
|
return x * y
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
|
|
pack_count = 0
|
|
unpack_count = 0
|
|
|
|
def pack_hook(x):
|
|
nonlocal pack_count
|
|
pack_count += 1
|
|
return x
|
|
|
|
# unpack hook shouldn't run during compilation, while we trace the forward
|
|
def unpack_hook(x):
|
|
nonlocal unpack_count
|
|
unpack_count += 1
|
|
return x
|
|
|
|
x = torch.ones(4, requires_grad=True)
|
|
y = torch.ones(4, requires_grad=False)
|
|
with torch.autograd.graph.saved_tensors_hooks(pack_hook, unpack_hook):
|
|
out_test = f_compiled(x, y)
|
|
self.assertEqual(pack_count, 1)
|
|
self.assertEqual(unpack_count, 0)
|
|
out_test.sum().backward()
|
|
self.assertEqual(pack_count, 1)
|
|
self.assertEqual(unpack_count, 1)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/113263
|
|
def test_unpack_hooks_can_be_disabled(self):
|
|
def f(x, y):
|
|
return x * y
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
|
|
x = torch.ones(4, requires_grad=True)
|
|
y = torch.ones(4, requires_grad=False)
|
|
with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"):
|
|
out_test = f_compiled(x, y)
|
|
out_test.sum().backward()
|
|
|
|
# https://github.com/pytorch/pytorch/issues/113263
|
|
def test_disabling_unpack_hooks_within_compiled_region(self):
|
|
def g(z):
|
|
with torch.autograd.graph.disable_saved_tensors_hooks("hooks are disabled"):
|
|
return z + 5
|
|
|
|
def f(x, y):
|
|
z = x * y
|
|
return g(z)
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
|
|
x = torch.ones(4, requires_grad=True)
|
|
y = torch.ones(4, requires_grad=False)
|
|
out_test = f_compiled(x, y)
|
|
out_test.sum().backward()
|
|
|
|
# See https://github.com/pytorch/pytorch/issues/97745
|
|
def test_gan_repro_trying_to_backward_through_the_graph_a_second_time(self):
|
|
def f(a, b):
|
|
c = torch.ones(2, 2)
|
|
d = torch.ones(2, 2)
|
|
e = torch.matmul(a, c)
|
|
g_loss = torch.abs(e - d).mean()
|
|
g_loss.backward()
|
|
fake_d_pred = torch.matmul(b, e.detach())
|
|
d_loss = fake_d_pred.mean()
|
|
d_loss.backward()
|
|
|
|
a_ref = torch.randn(2, 2, requires_grad=True)
|
|
b_ref = torch.randn(2, 2, requires_grad=True)
|
|
out_ref = f(a_ref, b_ref)
|
|
|
|
a_test = a_ref.clone().detach().requires_grad_(True)
|
|
b_test = b_ref.clone().detach().requires_grad_(True)
|
|
out_test = torch.compile(f, backend="aot_eager")(a_test, b_test)
|
|
|
|
self.assertEqual(out_ref, out_test)
|
|
self.assertEqual(a_ref.grad, a_test.grad)
|
|
self.assertEqual(b_ref.grad, b_test.grad)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/111603
|
|
def test_tuple_enum_as_key_dict(self):
|
|
class MyEnum(Enum):
|
|
A = "a"
|
|
|
|
class SomeModel(torch.nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(1, 1)
|
|
|
|
def forward(self, x) -> torch.Tensor:
|
|
return self.linear(x[MyEnum.A])
|
|
|
|
x = {MyEnum.A: torch.rand(8, 1)}
|
|
model_pytorch = SomeModel()
|
|
model = torch.compile(model_pytorch)
|
|
# Executing twice works
|
|
model(x)
|
|
y = model(x)
|
|
self.assertEqual(y, model_pytorch(x))
|
|
|
|
def test_embedding_backward_broadcasting_decomp(self):
|
|
def f(grad_output, indices):
|
|
num_weights = 10
|
|
padding_idx = 1
|
|
scale_grad_by_freq = True
|
|
return torch.ops.aten.embedding_dense_backward(
|
|
grad_output, indices, num_weights, padding_idx, scale_grad_by_freq
|
|
)
|
|
|
|
f_compiled = torch.compile(f, backend="aot_eager")
|
|
|
|
grad_output = torch.ones(2, 4, 3, dtype=torch.float16)
|
|
indices = torch.ones(2, 4, dtype=torch.int64)
|
|
|
|
out_ref = f(grad_output, indices)
|
|
out_test = f_compiled(grad_output, indices)
|
|
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
def test_reformer_eval(self):
|
|
with torch.no_grad():
|
|
cnt = self._reformer(nopython=True)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 11)
|
|
|
|
def test_reformer_train(self):
|
|
with torch.enable_grad():
|
|
cnt = self._reformer(nopython=False)
|
|
expected_op_count = (
|
|
"""11""" if torch._dynamo.config.inline_inbuilt_nn_modules else """5"""
|
|
)
|
|
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, expected_op_count)
|
|
|
|
@disable_translation_validation_if_dynamic_shapes
|
|
def test_longformer_chunk(self):
|
|
input1 = torch.randn([1, 4096, 1])
|
|
input2 = torch.randn([12, 4096, 64])
|
|
correct1 = longformer_chunk(input1)
|
|
correct2 = longformer_chunk(input2)
|
|
fn = longformer_chunk
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertTrue(same(opt_fn(input1), correct1))
|
|
self.assertTrue(same(opt_fn(input2), correct2))
|
|
self.assertTrue(same(opt_fn(input1), correct1))
|
|
self.assertTrue(same(opt_fn(input2), correct2))
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
if torch._dynamo.config.automatic_dynamic_shapes:
|
|
self.assertExpectedInline(cnt.frame_count, """2""")
|
|
self.assertExpectedInline(cnt.op_count, """8""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """2""")
|
|
self.assertExpectedInline(cnt.op_count, """4""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """2""")
|
|
self.assertExpectedInline(cnt.op_count, """21""")
|
|
|
|
def test_hf_t5_forward(self):
|
|
input = torch.randn([1, 2048, 512])
|
|
model = PartialT5()
|
|
correct = model(input)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch._dynamo.optimize_assert(cnt)(model)
|
|
self.assertTrue(same(opt_model(input), correct))
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """11""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """11""")
|
|
|
|
def test_module_in_skipfiles(self):
|
|
model = nn.Linear(10, 10)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
torch.compile(model, backend=cnt, fullgraph=True)(torch.randn([5, 10]))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_function_in_skipfiles(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
torch.compile(torch.sin, backend=cnt, fullgraph=True)(torch.randn([5, 10]))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_slicing_dynamic_shape(self):
|
|
def fn(y):
|
|
x = torch.ones(8)
|
|
idx = y[0]
|
|
out = x[idx:]
|
|
return (out + 3) * 5
|
|
|
|
counter = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(counter)(fn)
|
|
out = opt_fn(torch.ones(10, dtype=torch.long))
|
|
# idx should be 1 -> slicing off [1:] of 8 elem tensor
|
|
self.assertEqual(list(out.shape), [7])
|
|
|
|
self.assertEqual(counter.op_count, 2)
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
self.assertEqual(list(opt_fn(torch.tensor([4])).shape), [4])
|
|
|
|
def test_slicing_dynamic_shape_setitem(self):
|
|
def fn(input_lengths: torch.Tensor, new_ones_1):
|
|
getitem_13 = input_lengths[3]
|
|
new_ones_1[(3, slice(getitem_13, None, None))] = 0
|
|
setitem_13 = new_ones_1
|
|
return (setitem_13,)
|
|
|
|
x = torch.randn(10).to(dtype=torch.int64)
|
|
y = torch.randn(10, 204)
|
|
ref = fn(x, y)
|
|
opt_fn = torch._dynamo.optimize("aot_eager")(fn)
|
|
res = opt_fn(x, y)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
@torch._dynamo.config.patch(error_on_recompile=True)
|
|
@torch.fx.experimental._config.patch(use_duck_shape=False)
|
|
def test_dynamic_shape_disable_duck_size(self):
|
|
class TestModel(nn.Module):
|
|
def __init__(
|
|
self,
|
|
):
|
|
super().__init__()
|
|
|
|
def forward(self, x: torch.Tensor, val: int) -> torch.Tensor:
|
|
return x + val
|
|
|
|
main_model = TestModel().to(memory_format=torch.channels_last)
|
|
opt_model = torch.compile(main_model, backend="eager", dynamic=True)
|
|
|
|
x1 = torch.rand(2, 5, 10, 10).to(memory_format=torch.channels_last)
|
|
x2 = torch.rand(2, 5, 4, 8).to(memory_format=torch.channels_last)
|
|
|
|
o1_ref = main_model(x1, 4)
|
|
o1 = opt_model(x1, 4)
|
|
|
|
o2_ref = main_model(x2, 20)
|
|
o2 = opt_model(x2, 20)
|
|
|
|
def test_chunk_reformer_ff(self):
|
|
input = torch.randn([1, 4096, 256])
|
|
model = ChunkReformerFeedForward()
|
|
correct = model(input)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch._dynamo.optimize_assert(cnt)(model)
|
|
self.assertTrue(same(opt_model(input), correct))
|
|
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertLessEqual(cnt.op_count, 10)
|
|
|
|
# see: https://github.com/pytorch/pytorch/issues/80067
|
|
# NB: When you remove the expectedFailure, don't forget to
|
|
# uncomment/adjust the assertEqual below
|
|
@unittest.expectedFailure
|
|
@torch._dynamo.config.patch(
|
|
fake_tensor_propagation=True, capture_scalar_outputs=True
|
|
)
|
|
def test_maml_item_capture(self):
|
|
a = torch.randn(5, 1, 28, 28)
|
|
b = torch.zeros(5, dtype=torch.int64)
|
|
c = torch.randn(75, 1, 28, 28)
|
|
d = torch.zeros(75, dtype=torch.int64)
|
|
model = PartialMaml()
|
|
correct = model(a, b, c, d)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch._dynamo.optimize(cnt)(model)
|
|
for _ in range(10):
|
|
self.assertTrue(same(opt_model(a, b, c, d), correct))
|
|
|
|
# if torch._dynamo.config.assume_static_by_default:
|
|
# self.assertExpectedInline(cnt.frame_count, """2""")
|
|
# else:
|
|
# self.assertExpectedInline(cnt.frame_count, """3""")
|
|
# TODO(jansel): figure out why op count depends on imports
|
|
self.assertIn(cnt.op_count, (36, 35, 34, 29, 28, 27))
|
|
|
|
# see: https://github.com/pytorch/pytorch/issues/80067
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=False)
|
|
def test_maml_no_item_capture(self):
|
|
a = torch.randn(5, 1, 28, 28)
|
|
b = torch.zeros(5, dtype=torch.int64)
|
|
c = torch.randn(75, 1, 28, 28)
|
|
d = torch.zeros(75, dtype=torch.int64)
|
|
model = PartialMaml()
|
|
correct = model(a, b, c, d)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch._dynamo.optimize(cnt)(model)
|
|
for _ in range(10):
|
|
self.assertTrue(same(opt_model(a, b, c, d), correct))
|
|
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """4""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """5""")
|
|
|
|
def test_hf_model_output(self):
|
|
ex = ModelOutput(a=torch.randn(10), b=torch.randn(10), c=torch.randn(10))
|
|
|
|
def fn1(x):
|
|
return x["a"] + 1
|
|
|
|
def fn2(x):
|
|
return x.a + 1
|
|
|
|
def fn3(x):
|
|
return x.to_tuple()[0] + 1
|
|
|
|
def fn4(x):
|
|
return x[0] + 1
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
for fn in (fn1, fn2, fn3, fn4):
|
|
cnt.clear()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertTrue(same(opt_fn(ex), ex.a + 1))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
@disable_translation_validation_if_dynamic_shapes
|
|
def test_create_rand_mask_from_inputs(self):
|
|
args = [
|
|
torch.randn([1, 64, 64]),
|
|
torch.randn([1, 64, 64]),
|
|
torch.zeros([1, 12, 62, 3], dtype=torch.int64),
|
|
12,
|
|
3,
|
|
1,
|
|
4096,
|
|
64,
|
|
]
|
|
correct = create_rand_mask_from_inputs(*args)
|
|
fn = create_rand_mask_from_inputs
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertTrue(same(opt_fn(*args), correct))
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """8""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """11""")
|
|
|
|
def test_rng_state(self):
|
|
def fn():
|
|
state = torch.get_rng_state()
|
|
before = torch.rand(1000)
|
|
torch.set_rng_state(state)
|
|
after = torch.rand(1000)
|
|
return before, after
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
|
|
before, after = opt_fn()
|
|
self.assertTrue(same(before, after))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 2) # rand, rand
|
|
try:
|
|
graph, _ = torch._dynamo.export(fn)()
|
|
# See https://github.com/pytorch/pytorch/pull/87490
|
|
self.fail("unexpected export success")
|
|
except torch._dynamo.exc.Unsupported:
|
|
pass
|
|
|
|
def test_threading_local(self):
|
|
import threading
|
|
|
|
foo = threading.local()
|
|
foo.x = torch.rand(1)
|
|
|
|
def f(x):
|
|
return torch.cat([x, foo.x])
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
|
|
|
|
inp = torch.ones(1)
|
|
out = f(inp)
|
|
opt_out = opt_f(inp)
|
|
self.assertEqual(opt_out, out)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_seq_append_list(self):
|
|
x = torch.randn(4, 10)
|
|
model = SequentialAppendList(
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
)
|
|
# this one is tricky because it mutates the list provided as an input
|
|
l1 = [x]
|
|
l2 = [x]
|
|
correct, _ = model(x, l1)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch._dynamo.optimize_assert(cnt)(model)
|
|
result, l3 = opt_model(x, l2)
|
|
self.assertTrue(same(result, correct))
|
|
self.assertTrue(same(l1, l2))
|
|
self.assertIs(l2, l3)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 5)
|
|
|
|
def test_batch_norm_act(self):
|
|
a = torch.randn(5, 1, 28, 28)
|
|
model = BatchNormAct2d(1).eval()
|
|
correct = model(a)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
if not torch._dynamo.config.specialize_int:
|
|
# _local_scalar_dense causes graph break w 0-dim tensor
|
|
opt_model = torch._dynamo.optimize(cnt)(model)
|
|
self.assertTrue(same(opt_model(a), correct))
|
|
return
|
|
|
|
opt_model = torch._dynamo.optimize_assert(cnt)(model)
|
|
self.assertTrue(same(opt_model(a), correct))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
def test_get_parameter_dtype(self):
|
|
model = SequentialAppendList(
|
|
torch.nn.Linear(10, 10),
|
|
torch.nn.ReLU(),
|
|
)
|
|
|
|
def fn(model, x):
|
|
return x + torch.randn(10, dtype=get_parameter_dtype(model))
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertEqual(opt_fn(model, torch.randn(10)).dtype, torch.float32)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
def test_nn_parameter(self):
|
|
def test_fn():
|
|
a = torch.nn.Parameter(torch.randn(5, 5))
|
|
# Checks that TensorVariable stores the type information correctly
|
|
self.assertTrue(isinstance(a, torch.nn.Parameter))
|
|
return a
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_test_fn = torch._dynamo.optimize(cnt)(test_fn)
|
|
out = opt_test_fn()
|
|
self.assertTrue(isinstance(out, torch.nn.Parameter))
|
|
|
|
def test_Size(self):
|
|
def test_fn():
|
|
a = torch.randn(4)
|
|
x = torch.Size([1, 2, 3])
|
|
# Checks that SizeVariable return torch.Size object
|
|
assert isinstance(x, torch.Size)
|
|
# Causes graph breaks and checks reconstruction of SizeVariable
|
|
# object
|
|
self.assertIsInstance(x, torch.Size)
|
|
return a
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_test_fn = torch._dynamo.optimize(cnt)(test_fn)
|
|
opt_test_fn()
|
|
|
|
# See https://github.com/pytorch/pytorch/issues/100067
|
|
def test_copy_weird_strides(self):
|
|
# This test requires inductor's copy() decomp to preserve strides properly.
|
|
def test_fn(a):
|
|
b = torch.zeros(48, 4, 256, 513)
|
|
b[:, 0, 1:256, 1:256] = a
|
|
c = b.view(4, 12, 1024, 513)
|
|
d = c.transpose(2, 1)
|
|
d.add_(1)
|
|
return d
|
|
|
|
sh, st, dt, dev, rg = (
|
|
(48, 255, 255),
|
|
(787968, 513, 1),
|
|
torch.float16,
|
|
"cpu",
|
|
True,
|
|
)
|
|
a = rand_strided(sh, st, dt, dev).requires_grad_(rg)
|
|
compiled_f = torch.compile(test_fn, backend="aot_eager_decomp_partition")
|
|
out1 = test_fn(a)
|
|
out2 = compiled_f(a)
|
|
self.assertEqual(out1, out2)
|
|
|
|
def test_indexing_with_list(self):
|
|
def test_fn():
|
|
def run_test(tensor, *idx):
|
|
npt = tensor.numpy()
|
|
assert npt[idx].shape == tensor[idx].shape
|
|
|
|
x = torch.arange(0, 10)
|
|
cases = [
|
|
[None, None],
|
|
[1, None],
|
|
]
|
|
|
|
for case in cases:
|
|
run_test(x, *case)
|
|
|
|
return torch.randn(4)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_test_fn = torch._dynamo.optimize(cnt)(test_fn)
|
|
opt_test_fn()
|
|
|
|
def test_reformer_min_chunk_len(self):
|
|
def fn(cfg):
|
|
t = torch.empty(10)
|
|
t.fill_(_get_min_chunk_len(cfg))
|
|
return t[0]
|
|
|
|
cfg = DummyConfig()
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertEqual(opt_fn(cfg), 64)
|
|
# With unspec int, maximum computation is preserved
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """3""")
|
|
|
|
def test_reformer_sorting(self):
|
|
x = torch.zeros([1, 12, 4096], dtype=torch.int64)
|
|
correct = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx(x)
|
|
fn = _get_sorted_bucket_idx_and_undo_sorted_bucket_idx
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize_assert(cnt)(fn)
|
|
self.assertTrue(same(opt_fn(x), correct))
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """14""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
self.assertExpectedInline(cnt.op_count, """16""")
|
|
|
|
def test_recursive_map(self):
|
|
# https://github.com/pytorch/torchdynamo/issues/132
|
|
def _recursive_map(struct, batch_dim=0):
|
|
for k, v in struct.items():
|
|
if v is not None:
|
|
if isinstance(v, dict):
|
|
_recursive_map(v)
|
|
else:
|
|
struct[k] = v
|
|
|
|
def toy_example(a, b, v):
|
|
x = a / (torch.abs(a) + 1)
|
|
if v is not None:
|
|
_recursive_map(v)
|
|
return x * b
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_toy_example = torch._dynamo.optimize(cnt)(toy_example)
|
|
opt_toy_example(
|
|
torch.randn(10),
|
|
torch.randn(10),
|
|
{"layer0": {"memory_keys": torch.randn(10)}},
|
|
)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 4)
|
|
|
|
def test_issue114171(self):
|
|
device = torch.device("cpu")
|
|
|
|
def fcnn(in_dim, out_dim, hidden_dim, activation=torch.nn.GELU):
|
|
layers = [
|
|
torch.nn.Linear(in_dim, hidden_dim, device=device),
|
|
activation(),
|
|
torch.nn.Linear(hidden_dim, out_dim, device=device),
|
|
]
|
|
return torch.nn.Sequential(*layers)
|
|
|
|
class testmodel(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.interaction_networks = torch.nn.ModuleList(
|
|
[fcnn(262, 1174, 400) for _ in range(4)]
|
|
)
|
|
|
|
def interact(self, x, cycle):
|
|
return self.interaction_networks[cycle](x)
|
|
|
|
model = testmodel()
|
|
forward_aot = torch.compile(
|
|
model.interact, fullgraph=True, dynamic=True, backend="eager"
|
|
)
|
|
|
|
x = torch.rand([111, 262], device=device)
|
|
y2 = forward_aot(x, 2) # previously failed
|
|
|
|
def test_issue175(self):
|
|
n_heads = 2
|
|
d_model = 64
|
|
model = TransformerEncoderLayer(d_model, n_heads)
|
|
inp = torch.randn(1, d_model)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_model = torch._dynamo.optimize(cnt, nopython=True)(model)
|
|
opt_model(inp)
|
|
opt_model(inp)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
self.assertEqual(
|
|
15 if torch._dynamo.config.inline_inbuilt_nn_modules else 12, cnt.op_count
|
|
)
|
|
|
|
def test_exec_import(self):
|
|
def fn1():
|
|
exec("import math")
|
|
|
|
def fn2():
|
|
try:
|
|
math.sqrt(4)
|
|
return False
|
|
except NameError:
|
|
return True
|
|
|
|
def fn3():
|
|
fn1()
|
|
return fn2()
|
|
|
|
self.assertTrue(fn3())
|
|
opt_fn3 = torch._dynamo.optimize("eager")(fn3)
|
|
self.assertTrue(opt_fn3())
|
|
|
|
def test_exec_wildcard_import(self):
|
|
# Test that globals are not carried over from frame to frame
|
|
def fn1():
|
|
exec("from torch import *")
|
|
|
|
def fn2():
|
|
x = torch.zeros(4)
|
|
for i in range(5):
|
|
x = x + i
|
|
return x
|
|
|
|
def fn3():
|
|
fn1()
|
|
return fn2()
|
|
|
|
ref = fn3()
|
|
opt_fn3 = torch._dynamo.optimize("eager")(fn3)
|
|
res = opt_fn3()
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_with_on_graph_break_inst(self):
|
|
def reversible(x):
|
|
print("Hello world") # Cause graph break so inline fails
|
|
return torch.sin(torch.cos(x))
|
|
|
|
def fn(x):
|
|
with torch.enable_grad():
|
|
a = torch.sin(x)
|
|
b = reversible(a)
|
|
c = torch.sigmoid(b)
|
|
c.sum().backward()
|
|
return x.grad
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
x.grad = None
|
|
with torch.no_grad():
|
|
ref = fn(x)
|
|
|
|
x.grad = None
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
with torch.no_grad():
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_with_on_graph_break_nested(self):
|
|
def reversible(x):
|
|
torch._dynamo.graph_break() # Cause graph break so inline fails
|
|
return torch.sin(torch.cos(x))
|
|
|
|
def fn(x):
|
|
# nested context manager failed previously
|
|
with torch.no_grad():
|
|
with torch.enable_grad():
|
|
a = torch.sin(x)
|
|
b = reversible(a)
|
|
c = torch.sigmoid(b)
|
|
c.sum().backward()
|
|
return x.grad
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
x.grad = None
|
|
with torch.no_grad():
|
|
ref = fn(x)
|
|
|
|
x.grad = None
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
with torch.no_grad():
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
# https://github.com/pytorch/torchdynamo/issues/1446
|
|
def test_grad_mode_carrying_correct_state_after_graph_break(self):
|
|
def fn(x):
|
|
with torch.no_grad():
|
|
y = x * 3
|
|
print("Break")
|
|
z = x + 2
|
|
return y, z
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
y, z = opt_fn(x)
|
|
self.assertFalse(y.requires_grad)
|
|
self.assertFalse(z.requires_grad)
|
|
|
|
def test_abc_setattr(self):
|
|
# tests that we correctly bail out of __setattr__ calls
|
|
|
|
# TODO: does not ensure ABC classes are correctly inferred as ClassVariables
|
|
# (doesn't test the fix for 'super()')
|
|
|
|
class BaseModule(torch.nn.Module, ABC):
|
|
def blah(self, x):
|
|
return x + 1
|
|
|
|
class Derived(BaseModule):
|
|
def __setattr__(self, name, value) -> None:
|
|
super().__setattr__(name, value)
|
|
|
|
def forward(self, x):
|
|
# expect a graph break on __setattr__
|
|
self.foo = 0
|
|
return self.blah(x)
|
|
|
|
def blah(self, x):
|
|
return super().blah(x)
|
|
|
|
x = torch.randn(3, requires_grad=True)
|
|
mod = Derived()
|
|
opt_mod = torch._dynamo.optimize("eager")(mod)
|
|
opt_mod(x)
|
|
|
|
# Not sure what this test is testing. It was earlier graph breaking on
|
|
# __dict__, so the counter >= 2. With __dict__ support, there is no
|
|
# graph break.
|
|
self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["ok"], 1)
|
|
self.assertGreaterEqual(torch._dynamo.utils.counters["frames"]["total"], 1)
|
|
|
|
@torch._dynamo.config.patch("suppress_errors", True)
|
|
def test_guard_fail_tensor_bool(self):
|
|
@torch._dynamo.disable(recursive=False)
|
|
def fn():
|
|
condition_shape = (5, 5)
|
|
dtypes = (torch.bool,)
|
|
shapes = (
|
|
(),
|
|
(5,),
|
|
(1, 5),
|
|
)
|
|
|
|
tensors = [
|
|
torch.empty(shape, dtype=dtype).fill_(17)
|
|
for shape, dtype in itertools.product(shapes, dtypes)
|
|
]
|
|
|
|
x_vals = (5.0, *tensors)
|
|
y_vals = (6.0, *tensors)
|
|
|
|
@torch._dynamo.disable
|
|
def get_expected(condition, x, y):
|
|
x_np = x.cpu().numpy() if isinstance(x, torch.Tensor) else x
|
|
y_np = y.cpu().numpy() if isinstance(y, torch.Tensor) else y
|
|
return torch.from_numpy(
|
|
np.where(condition.cpu().numpy(), x_np, y_np)
|
|
).to(common_dtype)
|
|
|
|
for x, y in zip(x_vals, y_vals):
|
|
condition = torch.empty(*condition_shape, dtype=torch.bool).bernoulli_()
|
|
common_dtype = torch.result_type(x, y)
|
|
|
|
def check_equal(condition, x, y):
|
|
# NumPy aggressively promotes to double, hence cast to output to correct dtype
|
|
expected = get_expected(condition, x, y)
|
|
result = torch.where(condition, x, y)
|
|
assert torch.allclose(expected, result)
|
|
|
|
check_equal(condition, x, y)
|
|
check_equal(condition, y, x)
|
|
|
|
fn()
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
opt_fn()
|
|
|
|
def test_guard_fail_nested_tuple(self):
|
|
def fn(args):
|
|
return torch.ones(()), args[0] * 2
|
|
|
|
# This adds a tensor check on args[1][0] and args[1][1]
|
|
args1 = (torch.ones(1), (torch.ones(1), torch.ones(1)))
|
|
args2 = (torch.ones(1), torch.ones(1))
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
ref = opt_fn(args1)
|
|
res = opt_fn(args2)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_nullcontext1(self):
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def fn(x, ctx):
|
|
x = x.sin()
|
|
with ctx:
|
|
x = x.cos()
|
|
x = x.sin()
|
|
return x
|
|
|
|
y = torch.randn(10)
|
|
self.assertTrue(same(fn(y, contextlib.nullcontext()), y.sin().cos().sin()))
|
|
|
|
def test_nullcontext2(self):
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def fn(x, ctx):
|
|
x = x.sin()
|
|
with ctx():
|
|
x = x.cos()
|
|
x = x.sin()
|
|
return x
|
|
|
|
y = torch.randn(10)
|
|
self.assertTrue(same(fn(y, contextlib.nullcontext), y.sin().cos().sin()))
|
|
|
|
def test_no_grad_inline(self):
|
|
@torch.no_grad()
|
|
def a(x):
|
|
return x.sin()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def b(x):
|
|
return a(x).cos()
|
|
|
|
y = torch.randn(10)
|
|
self.assertTrue(same(b(y), y.sin().cos()))
|
|
|
|
def test_longtensor_list(self):
|
|
for partition in [0, 5, 10]:
|
|
|
|
@torch._dynamo.disable
|
|
def rand_gen():
|
|
rand_vals = [random.randint(5, 10) for _ in range(10)]
|
|
# List of tensors mixed with np.arrays
|
|
return list(np.array(rand_vals[:partition])) + [
|
|
torch.tensor(val) for val in rand_vals[partition:]
|
|
]
|
|
|
|
def fn(x):
|
|
random_list = rand_gen()
|
|
z = torch.LongTensor(random_list)
|
|
return x * z
|
|
|
|
x = torch.ones(10) * 2
|
|
|
|
random.seed(0)
|
|
ref0 = fn(x)
|
|
ref1 = fn(x)
|
|
|
|
random.seed(0)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
res0 = opt_fn(x)
|
|
res1 = opt_fn(x)
|
|
|
|
self.assertTrue(same(ref0, res0))
|
|
self.assertTrue(same(ref1, res1))
|
|
|
|
def test_primtorch(self):
|
|
@torch._dynamo.optimize("eager")
|
|
def fn(x):
|
|
torch._refs.abs(x)
|
|
|
|
fn(torch.randn(3))
|
|
|
|
@unittest.expectedFailure
|
|
# inline_call [('inline in skipfiles: bind ...python3.10/inspect.py', 1)]
|
|
def test_primtorch_no_graph_break(self):
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def fn(x):
|
|
torch._refs.abs(x)
|
|
|
|
fn(torch.randn(3))
|
|
|
|
def test_torch_tensor_ops_no_graph_break(self):
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def fn(x):
|
|
torch.Tensor.abs_(x)
|
|
|
|
fn(torch.randn(3))
|
|
|
|
@unittest.skipIf(
|
|
not isinstance(torch.ops.aten.abs, torch._ops.OpOverloadPacket),
|
|
"old pt doesn't work",
|
|
)
|
|
def test_torch_ops_aten(self):
|
|
# Picked an op that doesn't show up in the default list
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def fn(x):
|
|
return torch.ops.aten.absolute(x)
|
|
|
|
fn(torch.randn(3))
|
|
|
|
def test_hf_gelu_inline(self):
|
|
class GELUActivation(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.act = nn.functional.gelu
|
|
|
|
def forward(self, input):
|
|
return self.act(input)
|
|
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def fn(x):
|
|
return GELUActivation()(x)
|
|
|
|
y = torch.randn(10)
|
|
self.assertTrue(same(fn(y), nn.functional.gelu(y)))
|
|
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def fn_returns(x):
|
|
return GELUActivation(), x + 1
|
|
|
|
act, _ = fn_returns(y)
|
|
self.assertIsInstance(act, GELUActivation)
|
|
self.assertIs(act.act, nn.functional.gelu)
|
|
self.assertTrue(hasattr(act, "_buffers")) # check that __init__ got called
|
|
|
|
def test_dropout_inline(self):
|
|
@torch._dynamo.optimize("eager")
|
|
def fn(x):
|
|
return torch.nn.Dropout(0.1)(x)
|
|
|
|
y = torch.randn(10)
|
|
torch.manual_seed(1337)
|
|
ref = nn.functional.dropout(y, 0.1)
|
|
torch.manual_seed(1337)
|
|
res = fn(y)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_setitem_boolean_mask_diff(self):
|
|
def fn(x, b, y):
|
|
x = x.clone()
|
|
x[b] = y
|
|
return x
|
|
|
|
opt_fn = torch._dynamo.optimize("aot_eager")(fn)
|
|
x = torch.randn(4, requires_grad=True)
|
|
b = torch.tensor([True, False, True, False])
|
|
y = torch.randn(2, requires_grad=True)
|
|
opt_fn(x, b, y)
|
|
|
|
def test_setitem_tuple_boolean_mask_diff(self):
|
|
def fn(x, b, y):
|
|
x = x.clone()
|
|
x[:, b] = y
|
|
return x
|
|
|
|
opt_fn = torch._dynamo.optimize("aot_eager")(fn)
|
|
x = torch.randn(8, 4, requires_grad=True)
|
|
b = torch.tensor([True, False, True, False])
|
|
y = torch.randn(2, requires_grad=True)
|
|
opt_fn(x, b, y)
|
|
|
|
def test_torch_tensor_ops(self):
|
|
def fn(x):
|
|
return torch.Tensor.abs_(x)
|
|
|
|
x = torch.randn(3)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
y = fn(x)
|
|
y_ = opt_fn(x)
|
|
self.assertTrue(same(y, y_))
|
|
|
|
def test_guard_ordering_shape_fail(self):
|
|
# If a function which takes a tensor has an inner function which
|
|
# is compiled and generates a guard on its shape,
|
|
# they are evaluated in the wrong order. So if on a subsequent call
|
|
# an int is passed instead of a tensor, guard evaluation will crash
|
|
# with a "no attribute: shape" error
|
|
m = MockModule()
|
|
opt_m = torch._dynamo.optimize("eager")(m)
|
|
opt_m.fn(torch.ones((5, 5)))
|
|
opt_m.fn(-3)
|
|
|
|
def test_tensor_isinstance_tuple(self):
|
|
@torch._dynamo.optimize("eager")
|
|
def fn():
|
|
t = torch.ones(5, 5)
|
|
if not isinstance(t, (int, torch.Tensor)):
|
|
msg = str.format(
|
|
"{0} is not an instance of {1}",
|
|
type(t),
|
|
(int, torch.Tensor),
|
|
)
|
|
raise ValueError(msg)
|
|
return True
|
|
|
|
fn()
|
|
|
|
def test_isinstance_dtype(self):
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def fn(x):
|
|
isinstance(torch.bfloat16, torch.dtype)
|
|
return x
|
|
|
|
fn(torch.randn(3))
|
|
|
|
def test_isinstance_storage(self):
|
|
@torch._dynamo.optimize("eager")
|
|
def fn(x):
|
|
f = bytearray([0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x10, 0x40])
|
|
bools = torch.BoolStorage.from_buffer(f, "big")
|
|
assert isinstance(bools, torch.BoolStorage)
|
|
return x
|
|
|
|
fn(torch.randn(3))
|
|
|
|
def test_issue111522(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f(x, y):
|
|
return x + y.a
|
|
|
|
class A:
|
|
a = 2
|
|
|
|
self.assertEqual(f(torch.zeros(2), A()), torch.full([2], 2.0))
|
|
|
|
del A.a
|
|
|
|
# graph break on missing attr
|
|
with self.assertRaises(torch._dynamo.exc.Unsupported):
|
|
f(torch.zeros(2), A())
|
|
|
|
def test_dict_list_values(self):
|
|
def inner_fn(args):
|
|
return [x[1].shape for x in args]
|
|
|
|
@torch._dynamo.optimize("eager")
|
|
def fn(tensors):
|
|
return inner_fn(zip(itertools.count(), tensors["args"]))
|
|
|
|
fn({"args": [torch.ones(5, 5), torch.ones(5, 6), torch.ones(5, 7)]})
|
|
fn({"args": [torch.ones(5, 5)]})
|
|
|
|
def test_dict_iter(self):
|
|
class MyMod(torch.nn.Module):
|
|
def forward(self, x):
|
|
z = {"my": 1, "const": 2, "dict": 3, "variable": 4}
|
|
tot = 0
|
|
for key in z:
|
|
tot += z[key]
|
|
|
|
return tot
|
|
|
|
x = torch.tensor([0])
|
|
model = MyMod()
|
|
opt_model = torch._dynamo.optimize("eager", nopython=True)(model)
|
|
y = opt_model(x)
|
|
|
|
self.assertEqual(y, 10)
|
|
|
|
def test_sort_out(self):
|
|
dtype = torch.float32
|
|
device = "cpu"
|
|
|
|
def fn():
|
|
tensor = torch.randn((3, 5), dtype=dtype, device=device)[:, 0]
|
|
values1 = torch.tensor(0, dtype=dtype, device=device)
|
|
indices1 = torch.tensor(0, dtype=torch.long, device=device)
|
|
torch.sort(tensor, out=(values1, indices1))
|
|
self.assertEqual(values1.stride(), (1,))
|
|
self.assertEqual(indices1.stride(), (1,))
|
|
|
|
fn()
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
opt_fn()
|
|
|
|
def test_sort_out2(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("sorted", torch.ones(4, 4))
|
|
self.register_buffer("indices", torch.ones(4, 4, dtype=torch.long))
|
|
|
|
def forward(self, x):
|
|
torch.sort(x, out=(self.sorted, self.indices))
|
|
return (x + 1, self.sorted, self.indices)
|
|
|
|
x = torch.randn(4, 4)
|
|
m = MyModule()
|
|
ref = m(x)
|
|
opt_m = torch._dynamo.optimize("eager")(m)
|
|
res = opt_m(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_sigmoid_out(self):
|
|
dtype = torch.float32
|
|
device = "cpu"
|
|
|
|
def fn():
|
|
inp = torch.randn((3, 5), dtype=dtype, device=device)
|
|
out1 = torch.tensor(0, dtype=dtype, device=device)
|
|
torch.sigmoid(inp, out=out1)
|
|
self.assertEqual(out1.numel(), 15)
|
|
|
|
fn()
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
opt_fn()
|
|
|
|
def test_sigmoid_out2(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("base", torch.ones(4, 4))
|
|
|
|
def forward(self, x):
|
|
torch.sigmoid(x, out=self.base)
|
|
return x + self.base
|
|
|
|
x = torch.randn(4, 4)
|
|
m = MyModule()
|
|
ref = m(x)
|
|
opt_m = torch._dynamo.optimize("eager")(m)
|
|
res = opt_m(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_slice_into_list_mutable(self):
|
|
class Mod(torch.nn.Module):
|
|
def forward(self, listy):
|
|
x = listy[3:5]
|
|
for i in range(10):
|
|
z = torch.abs(torch.randn(10)) + 1
|
|
x[0] = z
|
|
return x
|
|
|
|
m = Mod()
|
|
listy = [torch.randn(10)] * 10
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_m = torch._dynamo.optimize(cnt, nopython=True)(m)
|
|
opt_m.forward(listy)
|
|
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_issue111918(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, dynamic=True)
|
|
def fn(x):
|
|
x = x + 1
|
|
y = x.item()
|
|
if y > 2:
|
|
return x * 2
|
|
else:
|
|
return x * 3
|
|
|
|
x = torch.tensor([3.0])
|
|
fn(x)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 4)
|
|
|
|
torch._dynamo.reset()
|
|
fn = torch.compile(fn, fullgraph=True, backend="eager")
|
|
with self.assertRaises(torch._dynamo.exc.UserError):
|
|
fn(x)
|
|
|
|
def test_vdd_duplicate_error(self):
|
|
def fn(a, dt):
|
|
keys = list(dt._jt_dict.keys())
|
|
p = torch.cos(dt._jt_dict[keys[0]]._value)
|
|
q = torch.sin(a)
|
|
r = torch.sigmoid(dt._jt_dict[keys[0]]._value)
|
|
return p + q + r
|
|
|
|
class Value:
|
|
def __init__(self):
|
|
self._value = torch.randn(4)
|
|
|
|
class Sample:
|
|
def __init__(self):
|
|
self._jt_dict = {}
|
|
self._jt_dict["POSITION_ID"] = Value()
|
|
|
|
a = torch.randn(4)
|
|
sample = Sample()
|
|
|
|
ref = fn(a, sample)
|
|
|
|
optimized_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = optimized_fn(a, sample)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_specialized_stride(self):
|
|
def f():
|
|
e = torch.empty(4)
|
|
x = e[::2]
|
|
return x.stride()
|
|
|
|
self.assertEqual(f(), torch._dynamo.optimize("eager")(f)())
|
|
|
|
def test_out_none(self):
|
|
# https://github.com/pytorch/pytorch/issues/92814
|
|
def fn(input):
|
|
return torch.nn.functional.normalize(input, dim=0, out=None)
|
|
|
|
x = torch.rand([1])
|
|
self.assertEqual(fn(x), torch._dynamo.optimize("eager")(fn)(x))
|
|
|
|
def test_multi_import(self):
|
|
if not has_detectron2():
|
|
raise unittest.SkipTest("requires detectron2")
|
|
|
|
@torch._dynamo.optimize("eager", nopython=True)
|
|
def to_bitmasks(boxes):
|
|
from detectron2.layers.mask_ops import (
|
|
_paste_masks_tensor_shape,
|
|
paste_masks_in_image,
|
|
)
|
|
|
|
if (
|
|
paste_masks_in_image is not None
|
|
and _paste_masks_tensor_shape is not None
|
|
):
|
|
return boxes + 1
|
|
|
|
self.assertTrue((to_bitmasks(torch.zeros(10)) == torch.ones(10)).all())
|
|
|
|
def test_multi_dot_import(self):
|
|
def fn1(x):
|
|
return torch.sin(x)
|
|
|
|
def fn(x):
|
|
import torch.fx
|
|
|
|
_ = torch.fx.symbolic_trace(fn1)
|
|
return x * 2
|
|
|
|
x = torch.randn(10)
|
|
fn(x)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_relative_import(self):
|
|
try:
|
|
from . import utils as _ # noqa: F401
|
|
|
|
def fn(x):
|
|
from .utils import tensor_for_import_testing
|
|
|
|
return x * 2 * tensor_for_import_testing
|
|
|
|
except ImportError:
|
|
|
|
def fn(x):
|
|
from utils import tensor_for_import_testing
|
|
|
|
return x * 2 * tensor_for_import_testing
|
|
|
|
x = torch.randn(10)
|
|
fn(x)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_relative_import_no_modulename(self):
|
|
try:
|
|
from . import utils as _ # noqa: F401
|
|
|
|
def fn(x):
|
|
from . import utils
|
|
|
|
return x * 2 * utils.tensor_for_import_testing
|
|
|
|
except ImportError:
|
|
|
|
def fn(x):
|
|
import utils
|
|
|
|
return x * 2 * utils.tensor_for_import_testing
|
|
|
|
x = torch.randn(10)
|
|
fn(x)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_bigbird_unsqueeze_inplace(self):
|
|
def fn(reshape_2):
|
|
view_2 = reshape_2.clone()
|
|
view_2.unsqueeze_(2)
|
|
cat_11 = torch.cat([view_2], dim=2)
|
|
view_13 = cat_11.view((2, 12, 64, -1))
|
|
return (view_13,)
|
|
|
|
x = torch.randn(2, 12, 64, 64, requires_grad=True)
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("aot_eager")(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_issue1466_size_aot_autograd(self):
|
|
def fn(x):
|
|
# do a tensor op and a size compute
|
|
y = x * 2
|
|
x_size = x.size()
|
|
# trigger a graph break
|
|
print("arf")
|
|
# use the tensor op and size compute
|
|
z = y.view(x_size) + 1
|
|
return z
|
|
|
|
x = torch.randn(2, 3, requires_grad=True)
|
|
ref = fn(x)
|
|
opt_fn = torch._dynamo.optimize("aot_eager")(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_ellipsis(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.lnorm = torch.nn.LayerNorm(
|
|
(256,), eps=1e-06, elementwise_affine=True
|
|
)
|
|
self.linear = torch.nn.Linear(
|
|
in_features=256, out_features=256, bias=True
|
|
)
|
|
|
|
def forward(self, cat_10):
|
|
lnorm = self.lnorm(cat_10)
|
|
getitem_64 = lnorm[
|
|
(slice(None, None, None), slice(0, 1, None), Ellipsis)
|
|
]
|
|
linear = self.linear(getitem_64)
|
|
return (linear,)
|
|
|
|
args = [torch.randn(2, 197, 256)]
|
|
|
|
mod = Repro()
|
|
opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
|
|
|
|
self.assertTrue(same(mod(*args), opt_mod(*args)))
|
|
|
|
def test_reinplacing(self):
|
|
class MockModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.self_layoutlm_embeddings_x_position_embeddings = (
|
|
torch.nn.Embedding(1024, 768)
|
|
)
|
|
self.self_layoutlm_embeddings_y_position_embeddings = (
|
|
torch.nn.Embedding(1024, 768)
|
|
)
|
|
|
|
def forward(self, getitem_1, getitem_2, add):
|
|
self_layoutlm_embeddings_x_position_embeddings = (
|
|
self.self_layoutlm_embeddings_x_position_embeddings(getitem_1)
|
|
)
|
|
self_layoutlm_embeddings_y_position_embeddings = (
|
|
self.self_layoutlm_embeddings_y_position_embeddings(getitem_2)
|
|
)
|
|
add_1 = add + self_layoutlm_embeddings_x_position_embeddings
|
|
add_2 = add_1 + self_layoutlm_embeddings_y_position_embeddings
|
|
return (add_2,)
|
|
|
|
mod = MockModule()
|
|
opt_mod = torch._dynamo.optimize("aot_eager_decomp_partition")(mod)
|
|
|
|
args = [
|
|
((2, 512), (2048, 4), torch.int64, "cpu", False),
|
|
((2, 512), (2048, 4), torch.int64, "cpu", False),
|
|
((2, 512, 768), (393216, 768, 1), torch.float32, "cpu", True),
|
|
]
|
|
args = [
|
|
rand_strided(sh, st, dt, dev).requires_grad_(rg)
|
|
for (sh, st, dt, dev, rg) in args
|
|
]
|
|
self.assertTrue(same_two_models(mod, opt_mod, args))
|
|
|
|
def test_optimized_deepcopy(self):
|
|
# See https://github.com/pytorch/pytorch/pull/88629
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc = torch.nn.Linear(in_features=2, out_features=3, bias=True)
|
|
|
|
def forward(self, x):
|
|
return self.fc(x)
|
|
|
|
mod = Foo()
|
|
opt_mod = torch._dynamo.optimize("eager")(mod)
|
|
args = [torch.randn(1, 2)]
|
|
self.assertTrue(same_two_models(mod, opt_mod, args))
|
|
|
|
def test_class_member(self):
|
|
class Foo(torch.nn.Module):
|
|
a = 4
|
|
b = torch.ones(3, 4)
|
|
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.c = 4
|
|
|
|
def forward(self, x):
|
|
return x.cos() + self.a + self.b + self.c
|
|
|
|
mod = Foo()
|
|
opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
|
|
args = (torch.randn(3, 4),)
|
|
self.assertTrue(same(mod(*args), opt_mod(*args)))
|
|
|
|
def test_named_buffers(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.register_buffer("x", torch.ones(3))
|
|
self.register_buffer("y", torch.ones(3))
|
|
|
|
def forward(self, inp):
|
|
res = 0
|
|
for name, buffer in self.named_buffers():
|
|
res += buffer.sum()
|
|
|
|
return inp.cos() + res
|
|
|
|
mod = Foo()
|
|
opt_mod = torch._dynamo.optimize("eager", nopython=True)(mod)
|
|
args = (torch.randn(3, 4),)
|
|
self.assertTrue(same(mod(*args), opt_mod(*args)))
|
|
|
|
def test_requires_grad_guards_with_grad_mode1(self):
|
|
def f(x):
|
|
if x.requires_grad:
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
x = torch.ones(2, requires_grad=True)
|
|
|
|
f_compiled = torch.compile(f)
|
|
with torch.no_grad():
|
|
# compile an inference graph
|
|
f_compiled(x)
|
|
|
|
# Test: we should fail guards and recompile (even though it's still an inference graph)
|
|
out_ref = f(x.detach())
|
|
out = f_compiled(x.detach())
|
|
|
|
self.assertEqual(out_ref, out)
|
|
self.assertEqual(out_ref.requires_grad, out.requires_grad)
|
|
|
|
def test_requires_grad_guards_with_grad_mode2(self):
|
|
x = torch.ones(2, requires_grad=True)
|
|
x_ref = x.clone().detach().requires_grad_(True)
|
|
|
|
m = torch.nn.Linear(2, 2)
|
|
m_compiled = torch.compile(m)
|
|
|
|
with torch.no_grad():
|
|
# compile an inference graph
|
|
m_compiled(x)
|
|
|
|
# Test: we should fail guards and recompile a training graph
|
|
out_ref = m(x_ref)
|
|
out = m_compiled(x)
|
|
self.assertEqual(out_ref, out)
|
|
self.assertEqual(out_ref.requires_grad, out.requires_grad)
|
|
|
|
def test_is_symbolic_tracing(self):
|
|
# Ensure no graph break here
|
|
def fn(x):
|
|
if is_fx_tracing_test():
|
|
return x * 2
|
|
return x * 4
|
|
|
|
a = torch.randn(4)
|
|
ref = fn(a)
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(a)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_tokenization(self):
|
|
from collections import UserDict
|
|
|
|
class BatchEncoding(UserDict):
|
|
"""
|
|
Copied from tokenization
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data,
|
|
):
|
|
super().__init__(data)
|
|
|
|
def __getattr__(self, item: str):
|
|
try:
|
|
return self.data[item]
|
|
except KeyError as e:
|
|
raise AttributeError from e
|
|
|
|
def tokenization(x):
|
|
encoding = BatchEncoding({"key": x})
|
|
return encoding["key"]
|
|
|
|
opt_fn = torch._dynamo.optimize("eager")(tokenization)
|
|
x = torch.rand((1, 4))
|
|
ref = tokenization(x)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_modules(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fc = torch.nn.Linear(4, 3)
|
|
|
|
def forward(self, inp):
|
|
res = torch.zeros(3, 3)
|
|
for mod in self.modules():
|
|
res += self.fc(inp)
|
|
return res
|
|
|
|
mod = Foo()
|
|
args = (torch.ones(3, 4),)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_mod = torch._dynamo.optimize(cnt, nopython=True)(mod)
|
|
self.assertTrue(same(mod(*args), opt_mod(*args)))
|
|
self.assertEqual(cnt.op_count, 5)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_omegaconf_listconfig_iter(self):
|
|
obj = ListConfig()
|
|
x = torch.zeros(2)
|
|
|
|
def fn():
|
|
y = x
|
|
for i in obj:
|
|
y += i
|
|
return y
|
|
|
|
expected = fn()
|
|
actual = torch.compile(fn, fullgraph=True, backend="eager")()
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_user_defined_iter(self):
|
|
class MyIter:
|
|
def __init__(self):
|
|
self.i = 0
|
|
|
|
def __iter__(self):
|
|
return self
|
|
|
|
def __next__(self):
|
|
if self.i < 3:
|
|
self.i += 1
|
|
return self.i
|
|
raise StopIteration
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
for i in MyIter():
|
|
x += i
|
|
return x
|
|
|
|
self.assertEqual(fn(torch.zeros(1)), torch.full([1], 6.0))
|
|
|
|
def test_stop_iteration_reconstruct(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
return x.sin(), StopIteration(1, 2, 3)
|
|
|
|
_, res = fn(torch.ones(1))
|
|
self.assertEqual(str(res), str(StopIteration(1, 2, 3)))
|
|
|
|
def test_tensor_data_kwarg(self):
|
|
# https://github.com/pytorch/pytorch/issues/96278
|
|
def f():
|
|
return torch.tensor(data=[[1.0, -1.0]])
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f)
|
|
self.assertTrue(same(f(), opt_fn()))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
@requires_cuda
|
|
def test_norm_dtype(self):
|
|
def foo(_stack0):
|
|
getitem = _stack0[(slice(None, None, None), -1)]
|
|
_stack0 = None
|
|
normalize = torch.nn.functional.normalize(getitem, p=2, dim=1)
|
|
getitem = None
|
|
return (normalize,)
|
|
|
|
args = [((2, 50, 256), (1, 256, 1), torch.float16, "cuda", False)]
|
|
args = [
|
|
rand_strided(sh, st, dt, dev).requires_grad_(rg)
|
|
for (sh, st, dt, dev, rg) in args
|
|
]
|
|
|
|
opt_foo = torch._dynamo.optimize("aot_eager_decomp_partition")(foo)
|
|
with torch.cuda.amp.autocast(enabled=True):
|
|
ref = foo(*args)[0]
|
|
res = foo(*args)[0]
|
|
self.assertEqual(ref.dtype, res.dtype)
|
|
|
|
self.assertTrue(same(res, ref))
|
|
|
|
def test_for_loop_graph_break(self):
|
|
def inner(x):
|
|
return torch.sin(x)
|
|
|
|
def fn(x):
|
|
for _ in range(100):
|
|
inner(x)
|
|
torch._dynamo.graph_break()
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
x = torch.randn(4)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_for_loop_graph_break_before(self):
|
|
# Checks that the backedge is calculated correctly
|
|
def inner(x):
|
|
return torch.sin(x)
|
|
|
|
def fn(x):
|
|
torch._dynamo.graph_break()
|
|
for _ in range(100):
|
|
inner(x)
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
x = torch.randn(4)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 100)
|
|
|
|
def test_avoid_dupe_specialization(self):
|
|
def f(x, y):
|
|
return (x + y) * 1
|
|
|
|
opt_f = torch._dynamo.optimize("aot_eager")(f)
|
|
|
|
for b in [True, False]:
|
|
x = torch.randn(4, requires_grad=b)
|
|
y = torch.randn(4, requires_grad=b)
|
|
self.assertEqual(f(x, x), opt_f(x, x))
|
|
self.assertEqual(f(x, y), opt_f(x, y))
|
|
|
|
def test_validate_model_kwargs(self):
|
|
cnt = CompileCounter()
|
|
|
|
def f1(a, b):
|
|
return torch.sin(a) + torch.cos(b)
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def f2(**kwargs):
|
|
_validate_model_kwargs(f1, kwargs)
|
|
return f1(**kwargs)
|
|
|
|
x = torch.randn(10)
|
|
y = torch.randn(10)
|
|
|
|
self.assertEqual(f2(a=x, b=y), f1(x, y))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 3)
|
|
|
|
def test_swin_base_tensor_attr(self):
|
|
class Foo(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
# NB: not a parameter or buffer
|
|
self.t = torch.randn(3)
|
|
|
|
def forward(self, x):
|
|
return x + torch.cat((self.t, self.t))
|
|
|
|
mod = Foo()
|
|
opt_mod = torch._dynamo.optimize("eager")(mod)
|
|
args = [torch.randn(6)]
|
|
self.assertTrue(same_two_models(mod, opt_mod, args))
|
|
opt_mod(*args)
|
|
|
|
def test_pointless_graph_removal(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def fn(x):
|
|
with torch.no_grad():
|
|
torch._dynamo.graph_break()
|
|
return x + 1
|
|
|
|
fn(torch.randn(4))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 3)
|
|
|
|
def test_output_aliases_intermediate(self):
|
|
def f(x):
|
|
intermediate = x.mul(2)
|
|
return intermediate.view(-1), intermediate
|
|
|
|
opt_f = torch._dynamo.optimize("aot_eager")(f)
|
|
|
|
for b in [True, False]:
|
|
x = torch.randn(4, requires_grad=b)
|
|
out = f(x)
|
|
out_test = opt_f(x)
|
|
self.assertEqual(out[0], out_test[0])
|
|
self.assertEqual(out[1], out_test[1])
|
|
self.assertEqual(out[0].requires_grad, out_test[0].requires_grad)
|
|
self.assertEqual(out[1].requires_grad, out_test[1].requires_grad)
|
|
# test that the aliasing relationship of outputs is preserved
|
|
out[0].mul_(2)
|
|
out_test[0].mul_(2)
|
|
self.assertEqual(out[0], out_test[0])
|
|
self.assertEqual(out[1], out_test[1])
|
|
|
|
def test_while_loop_graph_break(self):
|
|
# Repro of tacotron2 cache_size_recompilation
|
|
def inner(x):
|
|
return torch.sin(x)
|
|
|
|
def fn(x):
|
|
i = 20
|
|
while i > 10:
|
|
x = inner(x)
|
|
i -= 1
|
|
torch._dynamo.graph_break()
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
x = torch.randn(4)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_nested_while_loop_graph_break(self):
|
|
def inner_loop(x):
|
|
i = 3
|
|
while i > 0:
|
|
i -= 1
|
|
x += 1
|
|
torch._dynamo.graph_break()
|
|
return x
|
|
|
|
def inner(x):
|
|
inner_loop(x)
|
|
return torch.sin(x)
|
|
|
|
def fn(x):
|
|
i = 20
|
|
while i > 10:
|
|
x = inner(x)
|
|
i -= 1
|
|
torch._dynamo.graph_break()
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
x = torch.randn(4)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_while_loop_graph_break_inside_call_function(self):
|
|
# Repro of huggingface graph break inside loop in `get_parameter_dtype`.
|
|
# Skip only the inner frame that has loop that contains graph break.
|
|
def inner(x):
|
|
for i in range(3):
|
|
x += 1
|
|
torch._dynamo.graph_break()
|
|
return x
|
|
|
|
def fn(x):
|
|
x += 2
|
|
inner(x)
|
|
x += 3
|
|
return x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
x = torch.randn(4)
|
|
opt_fn(x)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
def test_exception_in_dynamo_handling(self):
|
|
hit_handler = False
|
|
|
|
# See https://github.com/pytorch/pytorch/pull/96488
|
|
@contextlib.contextmanager
|
|
def ctx():
|
|
try:
|
|
yield
|
|
except RuntimeError:
|
|
nonlocal hit_handler
|
|
hit_handler = True
|
|
|
|
@torch._dynamo.optimize("eager")
|
|
def f():
|
|
with ctx():
|
|
h()
|
|
|
|
def h():
|
|
raise RuntimeError("boof")
|
|
|
|
# Should not error
|
|
f()
|
|
self.assertTrue(hit_handler)
|
|
|
|
def test_generator_dealloc(self):
|
|
# See https://github.com/pytorch/pytorch/pull/96488
|
|
#
|
|
# NB: yes, [(...)] is intentional, this is a list containing a
|
|
# generator
|
|
generator_box = [(x for x in [1, 2, 3])]
|
|
|
|
counter = torch._dynamo.testing.CompileCounter()
|
|
|
|
def g(x):
|
|
return x + 2
|
|
|
|
# TODO: This test is pretty delicate. To test if it's actually doing
|
|
# anything, rebuild eval_frame.c with '#define TORCHDYNAMO_DEBUG 1'
|
|
# and then look at the logs for:
|
|
#
|
|
# TRACE[_custom_eval_frame:650] begin <genexpr> test_repros.py 2276 -1 0 0
|
|
# TRACE[_custom_eval_frame:664] throw <genexpr>
|
|
#
|
|
# This means we're actually hitting the relevant codepath
|
|
|
|
# NB: Make sure we don't actually Dynamo this frame; if we do Dynamo
|
|
# this frame, Dynamo actually DOES understand list.clear and will
|
|
# arrange for the generator deallocation to happen when the eval frame
|
|
# handler is disabled, which will prevent the bug from happening (we
|
|
# specifically want to trigger the generator deallocation WHILE the
|
|
# dynamo eval frame handler is active), as that will cause the
|
|
# generator to become exhausted and trigger the throw_flag == TRUE
|
|
# case.
|
|
@torch._dynamo.disable(recursive=False)
|
|
def f(x):
|
|
generator_box.clear()
|
|
return g(x)
|
|
|
|
self.assertNoUnraisable(
|
|
lambda: torch._dynamo.optimize(counter)(f)(torch.randn(3))
|
|
)
|
|
|
|
# Make sure the x + 2 is captured (a previous incorrect implementation
|
|
# of this fix would have disabled the eval frame callback, which means
|
|
# g wouldn't get traced
|
|
self.assertEqual(counter.op_count, 1)
|
|
|
|
def test_error_return_without_exception_set(self):
|
|
# https://github.com/pytorch/pytorch/issues/93781
|
|
@torch.compile
|
|
def f():
|
|
_generator_type = type(_ for _ in ())
|
|
|
|
self.assertNoUnraisable(f)
|
|
|
|
def common_merge_criteria_processor_list(self, list_cls, fullgraph):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=fullgraph)
|
|
def f(x, left, right):
|
|
combined = _merge_criteria_processor_list(left, right)
|
|
return combined(x)
|
|
|
|
l1 = list_cls([torch.nn.ReLU(), torch.nn.Sigmoid()])
|
|
l2 = list_cls([])
|
|
input = torch.randn(16)
|
|
result = f(input, l1, l2)
|
|
self.assertEqual(result, l1(input))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 2)
|
|
|
|
cnt.clear()
|
|
l3 = list_cls([torch.nn.SiLU()])
|
|
expected = l3(l1(input))
|
|
result = f(input, l1, l3)
|
|
self.assertEqual(len(l1), 3)
|
|
self.assertEqual(result, expected)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 3)
|
|
|
|
def test_merge_criteria_processor_list1(self):
|
|
self.common_merge_criteria_processor_list(CustomList1, False)
|
|
|
|
def test_merge_criteria_processor_list2(self):
|
|
self.common_merge_criteria_processor_list(CustomList2, True)
|
|
|
|
def test_restricted_list_subclass1(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(a, b):
|
|
l = CustomList2()
|
|
l.extend([True])
|
|
l.append(a)
|
|
l.extend([b])
|
|
l.pop(0)
|
|
l.append(l.length_times_10())
|
|
return sum(l)
|
|
|
|
x = torch.randn(10)
|
|
y = torch.randn(10)
|
|
self.assertEqual(fn(x, y), x + y + 20)
|
|
self.assertEqual(cnt.op_count, 3)
|
|
|
|
def test_restricted_list_subclass2(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(a, b):
|
|
l1 = CustomList2([a + 1])
|
|
l2 = CustomList2([b + 2])
|
|
l1.extend(l2)
|
|
return l1
|
|
|
|
x = torch.randn(10)
|
|
y = torch.randn(10)
|
|
z = fn(x, y)
|
|
self.assertEqual(type(z), CustomList2)
|
|
self.assertEqual(len(z), 2)
|
|
self.assertEqual(z.length_times_10(), 20)
|
|
self.assertEqual(list(z), [x + 1, y + 2])
|
|
|
|
def test_restricted_list_subclass3(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(a: CustomList2, b: CustomList2):
|
|
a.extend(b)
|
|
a.append_twice(b[2] + 1)
|
|
a.append(b[3] + 2)
|
|
return b
|
|
|
|
x = torch.randn(10)
|
|
y = torch.randn(10)
|
|
l = CustomList2([x, y])
|
|
self.assertIs(fn(l, l), l)
|
|
self.assertEqual(len(l), 7)
|
|
self.assertIs(l[0], x)
|
|
self.assertIs(l[1], y)
|
|
self.assertIs(l[2], x)
|
|
self.assertIs(l[3], y)
|
|
self.assertEqual(l[4], x + 1)
|
|
self.assertIs(l[5], l[4])
|
|
self.assertEqual(l[6], y + 2)
|
|
|
|
def test_rewrite_assert_with_msg(self):
|
|
def f(x):
|
|
b = x.sin()
|
|
assert x[0] == 3, "First dim need to be 3"
|
|
return x.cos() + b
|
|
|
|
args = (torch.Tensor([3, 4, 5]),)
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
|
|
self.assertTrue(same(f(*args), opt_f(*args)))
|
|
self.assertEqual(cnt.op_count, 6)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
|
|
self.assertTrue(same(exported(*args), f(*args)))
|
|
|
|
def test_list_aliasing(self):
|
|
cnt = CompileCounter()
|
|
|
|
@torch.compile(backend=cnt, fullgraph=True)
|
|
def fn(a):
|
|
a.append(torch.sin(a[0]))
|
|
return a
|
|
|
|
x = torch.randn(10)
|
|
l = [x]
|
|
self.assertIs(fn(l), l)
|
|
self.assertEqual(len(l), 2)
|
|
self.assertIs(l[0], x)
|
|
self.assertEqual(l[1], torch.sin(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(cnt.op_count, 1)
|
|
|
|
def test_not_rewrite_assert_for_other_errors(self):
|
|
def f(x):
|
|
b = x.sin()
|
|
if not x.sum() <= 3:
|
|
raise ValueError("input sum needs to be 3")
|
|
return x.cos() + b
|
|
|
|
args = (torch.Tensor([3, 4, 5]),)
|
|
opt_fn = torch._dynamo.optimize("eager")(f)
|
|
with self.assertRaisesRegex(ValueError, "input sum needs to be 3"):
|
|
opt_fn(*args)
|
|
|
|
def test_rewrite_assert_dont_change_bytecode(self):
|
|
def fn(x):
|
|
with torch.no_grad():
|
|
assert x.max() < 5, f"invalid max {x.max()}"
|
|
x = torch.sin(x)
|
|
return x
|
|
|
|
x = torch.ones(4)
|
|
opt_fn = torch._dynamo.optimize("eager")(fn)
|
|
self.assertTrue(same(fn(x), opt_fn(x)))
|
|
|
|
def test_rewrite_assert_without_msg(self):
|
|
def f(x):
|
|
b = x.sin()
|
|
assert x[0] == 3
|
|
return x.cos() + b
|
|
|
|
args = (torch.Tensor([3, 4, 5]),)
|
|
exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
|
|
self.assertTrue(same(exported(*args), f(*args)))
|
|
|
|
with self.assertRaisesRegex(RuntimeError, "assertion error"):
|
|
exported(torch.Tensor([5, 6, 7]))
|
|
|
|
def test_rewrite_assert_with_non_string_msg(self):
|
|
def f(x):
|
|
b = x.sin()
|
|
assert x[0] == 2, x.size()
|
|
return x.cos() + b
|
|
|
|
torch._dynamo.utils.counters.clear()
|
|
args = torch.Tensor([3, 4, 5])
|
|
opt_f = torch._dynamo.optimize("eager")(f)
|
|
with self.assertRaisesRegex(AssertionError, "torch.Size"):
|
|
opt_f(args)
|
|
self.assertEqual(
|
|
torch._dynamo.utils.counters["graph_break"][
|
|
"assert with non-string message"
|
|
],
|
|
1,
|
|
)
|
|
|
|
def test_rewrite_assert_noop(self):
|
|
def f(x):
|
|
b = x.sin()
|
|
assert True
|
|
assert x.dtype == torch.float32
|
|
return x.cos() + b
|
|
|
|
args = (torch.Tensor([3, 4, 5]),)
|
|
exported, _ = torch._dynamo.export(f)(torch.Tensor([3, 4, 5]))
|
|
self.assertTrue(same(exported(*args), f(*args)))
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
|
|
self.assertTrue(same(f(*args), opt_f(*args)))
|
|
# torch._assert shouldn't be in the graph
|
|
self.assertEqual(cnt.op_count, 3)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
exported, _ = torch._dynamo.export(f)(torch.Tensor([4, 4, 5]))
|
|
self.assertTrue(same(exported(*args), f(*args)))
|
|
|
|
def test_size_typematch(self):
|
|
def f(x, y):
|
|
if isinstance(x, torch.Size):
|
|
return y + 1
|
|
else:
|
|
return y + 2
|
|
|
|
y = torch.zeros(1)
|
|
x1 = torch.Size((3,))
|
|
x2 = (3,)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_f = torch._dynamo.optimize(cnt, nopython=True)(f)
|
|
self.assertTrue(same(f(x1, y), opt_f(x1, y)))
|
|
self.assertTrue(same(f(x2, y), opt_f(x2, y)))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_dict_subclass_contains(self):
|
|
# pattern from huggingface
|
|
class ClassInstantier(collections.OrderedDict):
|
|
pass
|
|
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def f(x, d):
|
|
if "key1" in d:
|
|
x = x + 2
|
|
if "key2" in d:
|
|
x = x + 4
|
|
x = x + 8
|
|
return x
|
|
|
|
result = f(torch.ones(8), ClassInstantier({"key1": torch.ones(8)}))
|
|
self.assertTrue(same(result, torch.full([8], 11.0)))
|
|
|
|
result = f(torch.ones(8), ClassInstantier({"key2": torch.ones(8)}))
|
|
self.assertTrue(same(result, torch.full([8], 13.0)))
|
|
|
|
def test_hf_classinstantier(self):
|
|
# hf activations.py
|
|
class ClassInstantier(collections.OrderedDict):
|
|
def __getitem__(self, key):
|
|
content = super().__getitem__(key)
|
|
cls, kwargs = content if isinstance(content, tuple) else (content, {})
|
|
return cls(**kwargs)
|
|
|
|
ACT2CLS = ClassInstantier(
|
|
{
|
|
"relu": (nn.ReLU, {"inplace": False}),
|
|
"tanh": nn.Tanh,
|
|
}
|
|
)
|
|
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def f(x, act):
|
|
return ACT2CLS[act](x)
|
|
|
|
y = torch.randn(10)
|
|
self.assertTrue(same(f(y, "tanh"), torch.tanh(y)))
|
|
self.assertTrue(same(f(y, "relu"), torch.relu(y)))
|
|
|
|
def test_ephemeral_module(self):
|
|
# hf activations.py
|
|
class ReLUSquaredActivation(nn.Module):
|
|
def forward(self, input):
|
|
relu_applied = torch.nn.functional.relu(input)
|
|
squared = torch.square(relu_applied)
|
|
return squared
|
|
|
|
@torch.compile(fullgraph=True, backend="eager")
|
|
def f(x):
|
|
x = x + 0.2
|
|
x = ReLUSquaredActivation()(x)
|
|
x = x + 1
|
|
return x
|
|
|
|
y = torch.randn(10)
|
|
self.assertTrue(same(f(y), ReLUSquaredActivation()(y + 0.2) + 1))
|
|
|
|
def test_inplace_unsqueeze_input(self):
|
|
def backend(gm, example_inputs):
|
|
self.assertEqual(example_inputs[-1].size(), torch.Size([1, 3, 4]))
|
|
return gm
|
|
|
|
@torch.compile(backend=backend)
|
|
def fn(x):
|
|
x.unsqueeze_(0)
|
|
return x + 1
|
|
|
|
inputs = [torch.randn(3, 4)]
|
|
self.assertEqual(fn(*inputs).size(), torch.Size([1, 3, 4]))
|
|
self.assertEqual(inputs[0].size(), torch.Size([1, 3, 4]))
|
|
|
|
def test_batchnorm_e2e(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.bn = torch.nn.BatchNorm2d(
|
|
64, eps=1e-05, momentum=0.1, affine=True, track_running_stats=True
|
|
)
|
|
self.conv1 = torch.nn.Conv2d(
|
|
64,
|
|
64,
|
|
kernel_size=(3, 3),
|
|
stride=(1, 1),
|
|
padding=(1, 1),
|
|
bias=False,
|
|
)
|
|
|
|
def forward(self, x):
|
|
x1 = self.bn(x)
|
|
x2 = self.conv1(x1)
|
|
out = torch.nn.functional.relu(x2)
|
|
return (out,)
|
|
|
|
torch.manual_seed(1337)
|
|
|
|
m_ref = Repro()
|
|
m_test = deepcopy(m_ref)
|
|
|
|
@torch._dynamo.optimize("aot_eager_decomp_partition")
|
|
def compiled_fn(x):
|
|
return m_test(x)
|
|
|
|
x_ref = torch.randn(2, 64, 32, 32, requires_grad=True)
|
|
x_test = x_ref.clone()
|
|
|
|
# Loop multiple times: each iteration the running_mean/var on batchnorm will update,
|
|
# which changes the output of the next iteration
|
|
for _ in range(3):
|
|
ref = m_ref(x_ref)
|
|
res = compiled_fn(x_test)
|
|
|
|
self.assertTrue(same(ref, res))
|
|
|
|
for r in ref:
|
|
if r.requires_grad:
|
|
r.sum().backward()
|
|
for r in res:
|
|
if r.requires_grad:
|
|
r.sum().backward()
|
|
|
|
for param_ref, param_test in zip(m_ref.parameters(), m_test.parameters()):
|
|
self.assertTrue(same(param_ref, param_test))
|
|
# Assert running_mean/var
|
|
for buffer_ref, buffer_test in zip(m_ref.buffers(), m_test.buffers()):
|
|
self.assertTrue(same(buffer_ref, buffer_test))
|
|
|
|
@torch._dynamo.config.patch("assume_static_by_default", False)
|
|
def test_dynamic_shapes_right_side(self):
|
|
def f(x):
|
|
return torch.ones(5 * x.shape[0])
|
|
|
|
inp = torch.randn(6, 5)
|
|
|
|
gm, _ = torch._dynamo.export(f, aten_graph=True)(torch.randn(4, 5))
|
|
self.assertEqual(gm(inp).shape, f(inp).shape)
|
|
|
|
@torch._dynamo.config.patch("specialize_int", False)
|
|
def test_maybe_multiply_symint(self):
|
|
# https://github.com/pytorch/pytorch/issues/97346
|
|
from torch._functorch.aot_autograd import aot_module_simplified
|
|
|
|
def my_aot_compiler(gm, example_inputs):
|
|
def my_compiler(gm, example_inputs):
|
|
return gm.forward
|
|
|
|
# Invoke AOTAutograd
|
|
return aot_module_simplified(gm, example_inputs, fw_compiler=my_compiler)
|
|
|
|
def my_example(t1, t2, d):
|
|
out = torch.add(t1, t2, alpha=d)
|
|
return out
|
|
|
|
compiled_fn = torch.compile(backend=my_aot_compiler, dynamic=True)(my_example)
|
|
|
|
t1 = torch.arange(3, dtype=torch.float32).requires_grad_(True)
|
|
t2 = torch.arange(3, dtype=torch.float32).requires_grad_(True)
|
|
|
|
ra = compiled_fn(t1, t2, 5)
|
|
self.assertEqual(ra, torch.tensor([0.0, 6.0, 12.0]))
|
|
|
|
ra = compiled_fn(t1, t2, 6)
|
|
self.assertEqual(ra, torch.tensor([0.0, 7.0, 14.0]))
|
|
|
|
def test_build_map_unpack_with_call(self):
|
|
def forward_with_cond_scale(x, t, cond_scale, self_cond, other1, other2):
|
|
return x.sin() + t + cond_scale + self_cond + other1 + other2
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
d1 = dict(other1=5)
|
|
d2 = dict(other2=4)
|
|
text_cond = {**d1, **d2}
|
|
return forward_with_cond_scale(x, 1, cond_scale=2, self_cond=3, **text_cond)
|
|
|
|
self.assertTrue(same(fn(torch.ones(4)), torch.ones(4).sin() + 15))
|
|
|
|
@torch._dynamo.config.patch(verbose=True)
|
|
def test_graph_break_unsupported_fake(self):
|
|
counter = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.optimize(counter)
|
|
def f(x):
|
|
return torch.ops.test_sample.foo(x + 1) + 1
|
|
|
|
f(torch.randn(3))
|
|
|
|
self.assertEqual(counter.op_count, 2)
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
def test_delattr(self):
|
|
class MyObj:
|
|
def __init__(self, a, b):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x, obj):
|
|
del obj.a
|
|
obj.c = x + 1
|
|
del obj.c
|
|
tmp = MyObj(x + 2, x + 3)
|
|
del tmp.b
|
|
if hasattr(obj, "a"):
|
|
return x + 1
|
|
return tmp
|
|
|
|
x = torch.zeros([])
|
|
obj1 = MyObj(x, x)
|
|
obj2 = fn(x, obj1)
|
|
self.assertFalse(hasattr(obj1, "a"))
|
|
self.assertFalse(hasattr(obj1, "c"))
|
|
self.assertFalse(hasattr(obj2, "b"))
|
|
self.assertEqual(obj1.b.item(), 0)
|
|
self.assertEqual(obj2.a.item(), 2)
|
|
|
|
def test_delattr_raises(self):
|
|
class MyObj:
|
|
def __init__(self, a, b):
|
|
self.a = a
|
|
self.b = b
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x, obj):
|
|
del obj.a
|
|
x = x + 1
|
|
obj.a # will raise
|
|
return x
|
|
|
|
x = torch.zeros([])
|
|
obj1 = MyObj(x, x)
|
|
self.assertRaises(AttributeError, lambda: fn(x, obj1))
|
|
|
|
def test_delsubscr(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
del x["a"]
|
|
y = x["b"] + 1
|
|
return y
|
|
|
|
x = {"a": torch.tensor([1]), "b": torch.tensor([1])}
|
|
result = fn(x)
|
|
self.assertFalse(hasattr(x, "a"))
|
|
self.assertEqual(result.item(), 2)
|
|
|
|
def test_delsubscr_raises(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
del x["a"]
|
|
y = x["a"] + 1 # should raise KeyError
|
|
return y
|
|
|
|
x = {"a": torch.tensor([1]), "b": torch.tensor([1])}
|
|
self.assertRaises(KeyError, lambda: fn(x))
|
|
|
|
def test_attached_attribute_in_dir(self):
|
|
class MyModule(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(16, 16)
|
|
self.relu = torch.nn.ReLU()
|
|
|
|
def forward(self, x):
|
|
return self.relu(self.linear(x))
|
|
|
|
mod = torch.compile(MyModule(), backend="eager")
|
|
mod.is_compiled = True
|
|
self.assertTrue("is_compiled" in dir(mod))
|
|
|
|
@torch._dynamo.config.patch("automatic_dynamic_shapes", False)
|
|
def test_dynamic_shapes_implicit_guard(self):
|
|
def f(x):
|
|
y = x * x.size(x.shape[0])
|
|
torch.sum(y, [y.shape[0]])
|
|
return y
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f)
|
|
opt_fn(torch.randn(3, 1, 1, 1, 1))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_dalle2_maybe(self):
|
|
def normalize(x):
|
|
return x.cos()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x, normalize_img):
|
|
lowres_cond_img = x.sin()
|
|
lowres_cond_img = maybe(normalize_img)(lowres_cond_img)
|
|
return lowres_cond_img
|
|
|
|
self.assertEqual(fn(torch.ones([]), normalize), torch.ones([]).sin().cos())
|
|
|
|
def test_functools_wraps(self):
|
|
def cool_name(x):
|
|
return x.sin()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
y = x.cos()
|
|
|
|
@functools.wraps(cool_name)
|
|
def uncool_name():
|
|
return cool_name(y)
|
|
|
|
return uncool_name
|
|
|
|
result = fn(torch.ones([]))
|
|
self.assertEqual(result.__name__, "cool_name")
|
|
self.assertEqual(result(), torch.ones([]).cos().sin())
|
|
|
|
def test_dynamic_shapes_float_guard(self):
|
|
def f(x):
|
|
return torch.nn.functional.dropout(x, x.shape[0] / 6)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt, nopython=True)(f)
|
|
opt_fn(torch.randn(3))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_tensor_item(self):
|
|
def f(x, y):
|
|
val = y.item()
|
|
return x.sum() + val
|
|
|
|
gm, _ = torch._dynamo.export(
|
|
f,
|
|
aten_graph=True,
|
|
)(
|
|
torch.zeros(6, 4),
|
|
torch.tensor(1),
|
|
)
|
|
self.assertEqual(
|
|
f(torch.zeros(6, 4), torch.tensor(1)),
|
|
gm(torch.zeros(6, 4), torch.tensor(1)),
|
|
)
|
|
self.assertEqual(
|
|
f(torch.zeros(6, 4), torch.tensor(2)),
|
|
gm(torch.zeros(6, 4), torch.tensor(2)),
|
|
)
|
|
|
|
def test_dataclass_init_with_default_factory_with_inputs(self):
|
|
@dataclasses.dataclass
|
|
class DClass:
|
|
sharding_contexts: Any = dataclasses.field(default_factory=list)
|
|
a: int = 1
|
|
|
|
def fn(x, inp_list):
|
|
d = DClass(inp_list)
|
|
d.sharding_contexts.append(x.sin() + d.a)
|
|
return d
|
|
|
|
x = torch.randn(4)
|
|
inp_list1 = [1, 2, 3]
|
|
inp_list2 = [2, 3, 4]
|
|
inp_list3 = [1, 2]
|
|
ref1 = fn(x, inp_list1)
|
|
ref2 = fn(x, inp_list2)
|
|
ref3 = fn(x, inp_list3)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, fullgraph=True)
|
|
|
|
opt_ret1 = opt_fn(x, inp_list1)
|
|
opt_ret2 = opt_fn(x, inp_list2)
|
|
opt_ret3 = opt_fn(x, inp_list3)
|
|
self.assertEqual(ref1.sharding_contexts, opt_ret1.sharding_contexts)
|
|
self.assertEqual(ref2.sharding_contexts, opt_ret2.sharding_contexts)
|
|
self.assertEqual(ref3.sharding_contexts, opt_ret3.sharding_contexts)
|
|
|
|
def test_list_index(self):
|
|
for i, list_type in enumerate(
|
|
(
|
|
list,
|
|
tuple,
|
|
torch.Size,
|
|
collections.deque,
|
|
namedtuple("FourElems", "one two three four", defaults=[0, 0, 0, 0]),
|
|
)
|
|
):
|
|
torch._dynamo.reset()
|
|
for index in ([], [2], [0, 3]):
|
|
|
|
def f(t):
|
|
if i == 4: # namedtuple
|
|
xs = list_type(1, 2, 3, 4)
|
|
else:
|
|
xs = list_type([1, 2, 3, 4])
|
|
res = xs.index(3, *index)
|
|
return t + res
|
|
|
|
res = torch._dynamo.optimize(backend="eager", nopython=True)(f)(
|
|
torch.zeros(1)
|
|
)
|
|
|
|
self.assertEqual(res, torch.tensor([2.0]))
|
|
|
|
def test_list_index_not_found(self):
|
|
def f(t):
|
|
xs = ["bar", "foo", "baz", "buzz"]
|
|
res = xs.index("non-existent")
|
|
return t + res
|
|
|
|
# Raising ValueError from item not found is unsupported
|
|
with self.assertRaises(
|
|
torch._dynamo.exc.Unsupported,
|
|
):
|
|
torch._dynamo.optimize(backend="eager", nopython=True)(f)(torch.zeros(1))
|
|
|
|
def test_list_index_tensor_unsupported(self):
|
|
for index in ([], [2], [0, 3]):
|
|
|
|
def f(t):
|
|
xs = [torch.tensor([i]) for i in range(4)]
|
|
res = xs.index(torch.tensor([2]), *index)
|
|
return t + res
|
|
|
|
with self.assertRaisesRegex(
|
|
torch._dynamo.exc.UserError, "Dynamic control flow is not supported"
|
|
):
|
|
torch._dynamo.optimize(backend="eager", nopython=True)(f)(
|
|
torch.zeros(1)
|
|
)
|
|
|
|
def test_hf_xsoftmax_inference(self):
|
|
def fn(input, mask):
|
|
return XSoftmax.apply(input + 1, mask, 1) + 2
|
|
|
|
fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
inputs = [
|
|
torch.randn(4, 10),
|
|
torch.randn(4, 10) < 0,
|
|
]
|
|
expected = fn(*inputs)
|
|
actual = fn_opt(*inputs)
|
|
self.assertTrue(same(actual, expected))
|
|
|
|
@mock.patch("torch._dynamo.config.guard_nn_modules", True)
|
|
def test_hf_xsoftmax_training(self):
|
|
from torch._dynamo.utils import counters
|
|
|
|
counters.clear()
|
|
|
|
def fn(input, mask):
|
|
return XSoftmax.apply(input, mask, 1)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
fn_opt = torch.compile(fn, backend=cnt, fullgraph=False)
|
|
|
|
torch.manual_seed(1234)
|
|
inputs1 = [
|
|
torch.randn(4, 10, requires_grad=True),
|
|
torch.randn(4, 10) < 0,
|
|
]
|
|
torch.manual_seed(1234)
|
|
inputs2 = [
|
|
torch.randn(4, 10, requires_grad=True),
|
|
torch.randn(4, 10) < 0,
|
|
]
|
|
|
|
expected = fn(*inputs1)
|
|
actual = fn_opt(*inputs2)
|
|
self.assertTrue(same(actual, expected))
|
|
self.assertEqual(dict(counters["frames"]), {"total": 1, "ok": 1})
|
|
self.assertEqual(cnt.op_count, 2)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
cnt.clear()
|
|
counters.clear()
|
|
|
|
expected.sum().backward()
|
|
actual.sum().backward()
|
|
self.assertTrue(same(inputs1[0].grad, inputs2[0].grad))
|
|
|
|
# currently we don't capture the backwards frame
|
|
self.assertEqual(cnt.frame_count, 0)
|
|
self.assertEqual(cnt.op_count, 0)
|
|
self.assertEqual(dict(counters["frames"]), {})
|
|
self.assertEqual(dict(counters["graph_break"]), {})
|
|
|
|
def test_autograd_function_graph_break(self):
|
|
class MySin(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, x):
|
|
torch._dynamo.graph_break()
|
|
ctx.save_for_backward(x)
|
|
return x.sin()
|
|
|
|
@staticmethod
|
|
def backward(ctx, gx):
|
|
(x,) = ctx.saved_tensors
|
|
return gx * x.cos()
|
|
|
|
x = torch.randn([], requires_grad=True)
|
|
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
return MySin.apply(x)
|
|
|
|
y = fn(x)
|
|
self.assertEqual(y, x.sin())
|
|
|
|
(gx,) = torch.autograd.grad(y, x)
|
|
self.assertEqual(gx, x.cos())
|
|
|
|
def test_jit_trace_errors(self):
|
|
@torch.compile(backend="eager", dynamic=True)
|
|
def f(x):
|
|
return x + 1
|
|
|
|
with self.assertRaises(RuntimeError):
|
|
torch.jit.trace(f, torch.randn(3))
|
|
|
|
with torch._dynamo.config.patch(error_on_nested_jit_trace=False):
|
|
torch.jit.trace(f, torch.randn(3))
|
|
|
|
@torch._dynamo.config.patch("assume_static_by_default", False)
|
|
def test_tensor_split(self):
|
|
def f(x):
|
|
return torch.split(x, x.shape[0] // 2, dim=0)[0]
|
|
|
|
gm, _ = torch._dynamo.export(
|
|
f,
|
|
aten_graph=True,
|
|
)(
|
|
torch.zeros(6, 4),
|
|
)
|
|
|
|
self.assertEqual(f(torch.ones(8, 4)), gm(torch.ones(8, 4)))
|
|
|
|
def test_optim_state_references_cleared(self):
|
|
model = torch.nn.Linear(2048, 2048, bias=False)
|
|
x = torch.ones(2048)
|
|
state_ref = 0
|
|
|
|
optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01)
|
|
|
|
def opt_step():
|
|
optimizer.step()
|
|
|
|
compiled_opt_step = torch._dynamo.optimize("eager")(opt_step)
|
|
|
|
def compiled_model_step(x):
|
|
optimizer.zero_grad()
|
|
y = model(x)
|
|
torch.sum(y).backward()
|
|
compiled_opt_step()
|
|
|
|
compiled_model_step(x)
|
|
|
|
# Picked "square_avg" arbitrarily to check that
|
|
# optimizer state tensors are deallocated
|
|
state_ref = weakref.ref(
|
|
optimizer.state[optimizer.param_groups[0]["params"][0]]["square_avg"]
|
|
)
|
|
optimizer = None
|
|
|
|
self.assertIsNone(state_ref())
|
|
|
|
def test_grad_references_cleared(self):
|
|
model = torch.nn.Linear(2048, 2048, bias=False)
|
|
x = torch.ones(2048)
|
|
optimizer = torch.optim.Adadelta(model.parameters(), lr=0.01)
|
|
|
|
def opt_step():
|
|
optimizer.step()
|
|
|
|
compiled_opt_step = torch._dynamo.optimize("eager")(opt_step)
|
|
|
|
def compiled_model_step(x):
|
|
optimizer.zero_grad(True)
|
|
y = model(x)
|
|
torch.sum(y).backward()
|
|
compiled_opt_step()
|
|
|
|
compiled_model_step(x)
|
|
param_grad_ref = weakref.ref(next(iter(model.parameters())).grad)
|
|
optimizer.zero_grad(True)
|
|
self.assertIsNone(param_grad_ref())
|
|
|
|
def test_batch_encoding_clone_inputs(self):
|
|
class BatchEncoding(dict):
|
|
"""
|
|
Copied from test_tokenization
|
|
"""
|
|
|
|
def __init__(
|
|
self,
|
|
data,
|
|
):
|
|
super().__init__(data)
|
|
|
|
def __getattr__(self, item: str):
|
|
try:
|
|
return self.data[item]
|
|
except KeyError as e:
|
|
raise AttributeError from e
|
|
|
|
encoding = BatchEncoding({"key": torch.rand((1, 4))})
|
|
cloned_encoding = torch._dynamo.utils.clone_inputs(encoding)
|
|
self.assertTrue(type(cloned_encoding) is not dict)
|
|
|
|
def test_iadd_graph_break(self):
|
|
def fn(x):
|
|
a = ()
|
|
x = torch.sin(x)
|
|
a += (x,)
|
|
return a
|
|
|
|
x = torch.randn(4)
|
|
ref = fn(x)
|
|
|
|
opt_fn = torch._dynamo.optimize("eager", nopython=True)(fn)
|
|
res = opt_fn(x)
|
|
self.assertTrue(same(ref, res))
|
|
|
|
def test_odict_get_item_index_name(self):
|
|
d = {float: torch.float32, np.float16: torch.float16}
|
|
|
|
@torch.compile(backend="eager")
|
|
def f(x, y1, y2):
|
|
return torch.zeros(5, dtype=d[y1]), torch.zeros(5, dtype=d[y2])
|
|
|
|
f(torch.zeros(4), float, np.float16)
|
|
|
|
def test_dedup_global(self):
|
|
@torch.compile()
|
|
def f():
|
|
return _GLOBAL_CPU_TENSOR + _GLOBAL_CPU_TENSOR
|
|
|
|
self.assertEqual(f(), _GLOBAL_CPU_TENSOR + _GLOBAL_CPU_TENSOR)
|
|
|
|
def test_randint_out_dynamic(self):
|
|
def randint_fn(high, size, out):
|
|
return torch.randint(high, size, out=out)
|
|
|
|
opt_model = torch.compile(randint_fn)
|
|
|
|
out1 = torch.empty(10, dtype=torch.int32)
|
|
opt_model(17, (10,), out1)
|
|
|
|
out2 = torch.empty(12, dtype=torch.int32)
|
|
opt_model(17, (12,), out2)
|
|
|
|
@requires_cuda
|
|
def test_guard_default_device(self):
|
|
try:
|
|
torch.set_default_device("cuda")
|
|
|
|
counter = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch._dynamo.optimize(counter)
|
|
def f():
|
|
x = torch.randn(3)
|
|
return x * 2
|
|
|
|
self.assertEqual(f().device.type, "cuda")
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
torch.set_default_device("cpu")
|
|
|
|
self.assertEqual(f().device.type, "cpu")
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
finally:
|
|
torch.set_default_device(None)
|
|
|
|
def test_list_self_reference(self):
|
|
# Issue - https://github.com/pytorch/pytorch/issues/100150
|
|
root = []
|
|
root[:] = [root, root, None, None]
|
|
|
|
@torch._dynamo.optimize("eager")
|
|
def test_bug():
|
|
return root
|
|
|
|
test_bug()
|
|
|
|
def test_hf_bigbird_unsqueeze(self):
|
|
def torch_bmm_nd(inp_1, inp_2, ndim=None):
|
|
torch._dynamo.graph_break()
|
|
return torch.bmm(inp1, inp2)
|
|
|
|
def fn(inp1, inp2, inp3, inp4, c):
|
|
a = torch_bmm_nd(inp1, inp2, 4)
|
|
a.unsqueeze_(2)
|
|
a = a * 2
|
|
|
|
b = torch_bmm_nd(inp3, inp4, 4)
|
|
b.unsqueeze_(2)
|
|
l = a + b
|
|
|
|
out = torch.cat([a, b, c], dim=2)
|
|
return out, l
|
|
|
|
inp1 = torch.rand(1, 64, 448)
|
|
inp2 = torch.rand(1, 448, 64)
|
|
inp3 = torch.rand(1, 64, 448)
|
|
inp4 = torch.rand(1, 448, 64)
|
|
c = torch.rand(1, 64, 1, 64)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
opt_fn(inp1, inp2, inp3, inp4, c)
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
|
|
def test_torch_variable_type(self):
|
|
# from torchvision
|
|
def check_type(obj, types_or_checks):
|
|
for type_or_check in types_or_checks:
|
|
if (
|
|
isinstance(obj, type_or_check)
|
|
if isinstance(type_or_check, type)
|
|
else type_or_check(obj)
|
|
):
|
|
return True
|
|
return False
|
|
|
|
opt_check_type = torch._dynamo.optimize("eager")(check_type)
|
|
ref = check_type(torch.randn(4), [torch.Tensor])
|
|
res = opt_check_type(torch.randn(4), [torch.Tensor])
|
|
self.assertEqual(ref, res)
|
|
|
|
# Test for https://github.com/pytorch/pytorch/issues/103132
|
|
@torch._dynamo.config.patch("assume_static_by_default", False)
|
|
def test_inference_mode_dynamic_shapes(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, param):
|
|
z = torch.matmul(param, param)
|
|
return z
|
|
|
|
model = Repro()
|
|
# Need a 3d tensor to actually cause the error:
|
|
# we go down a path of the C++ matmul decomp that calls sizes().
|
|
inp = torch.randn(4, 4, 4, requires_grad=True)
|
|
model = torch.compile(model, backend="aot_eager", dynamic=True)
|
|
with torch.inference_mode():
|
|
model(inp)
|
|
|
|
def test_kwargs_out_list_variable(self):
|
|
class Repro(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, param):
|
|
z = torch.frexp(**param)
|
|
return z
|
|
|
|
model = Repro()
|
|
params = {"input": torch.tensor([[0.0, 1, 2, 4]])}
|
|
params["out"] = [
|
|
torch.empty(0, dtype=torch.float32), # mantissa
|
|
torch.empty(0, dtype=torch.int32), # exponent
|
|
]
|
|
|
|
model = torch.compile(model, backend="eager")
|
|
mantissa, exponent = model(params)
|
|
ref_mantissa = torch.tensor([[0.0000, 0.5000, 0.5000, 0.5000]])
|
|
ref_exponent = torch.tensor([[0, 1, 2, 3]], dtype=torch.int32)
|
|
self.assertEqual(ref_mantissa, mantissa)
|
|
self.assertEqual(ref_exponent, exponent)
|
|
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_split_with_sizes_aot_autograd(self):
|
|
def fn(result, split_sizes):
|
|
rs = torch.ops.aten.split_with_sizes(result, split_sizes.tolist())
|
|
return rs
|
|
|
|
example_inputs = (
|
|
torch.randn(32, requires_grad=True),
|
|
torch.tensor((7, 16, 9)),
|
|
)
|
|
actual = torch.compile(fn, fullgraph=True, backend="aot_eager")(*example_inputs)
|
|
expected = fn(*example_inputs)
|
|
self.assertEqual(actual, expected)
|
|
|
|
def test_unspecialized_nn_module_with_torch_variable_attribute(self):
|
|
"""
|
|
In this case self.fn = something that should be a TorchVariable.
|
|
When it's not a TorchVariable, dynamo tries to trace through and fails.
|
|
This makes sure that the self.fn is handled as a TorchVariable.
|
|
"""
|
|
|
|
class UserModule(torch.nn.Module):
|
|
torchdynamo_force_dynamic = True # forced to be a UnspecializedNNModule
|
|
|
|
def __init__(self, fn):
|
|
super().__init__()
|
|
self.fn = fn
|
|
|
|
def forward(self, **inp):
|
|
return self.fn(**inp)
|
|
|
|
inputs = {
|
|
"input": torch.randn([2, 9]).uniform_(0, 1),
|
|
"target": torch.randn([2, 9]).uniform_(0, 1),
|
|
"reduction": "mean",
|
|
}
|
|
|
|
mod = UserModule(torch.nn.functional.binary_cross_entropy)
|
|
ref = mod(**inputs)
|
|
res = torch._dynamo.optimize("eager", nopython=True)(mod)(**inputs)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_call_finally_python_3_8(self):
|
|
# Issue - https://github.com/pytorch/pytorch/issues/97811
|
|
def make_fn(g):
|
|
def fn():
|
|
while True:
|
|
try:
|
|
print(g)
|
|
break
|
|
except Exception as _:
|
|
break
|
|
|
|
return torch.compile(fn, backend="eager")
|
|
|
|
make_fn(None)()
|
|
|
|
def test_call_finally_python_3_8_2(self):
|
|
def f(x):
|
|
while x:
|
|
try:
|
|
pass
|
|
except Exception as _:
|
|
continue
|
|
|
|
torch.compile(f, backend="eager")(0)
|
|
|
|
def test_call_finally_opcode_python_3_8(self):
|
|
def fn():
|
|
try:
|
|
return torch.zeros(4)
|
|
finally:
|
|
return torch.ones(4) # noqa: SIM107, B012
|
|
|
|
result = torch.compile(fn, backend="aot_eager")()
|
|
self.assertEqual(result, torch.ones(4))
|
|
|
|
def test_string_format(self):
|
|
s = "temp{i}"
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
if s.format(i=4) == "temp4":
|
|
return torch.sin(x)
|
|
return torch.cos(x)
|
|
|
|
x = torch.randn(4)
|
|
self.assertEqual(fn(x), torch.sin(x))
|
|
|
|
# Repro of torch._dynamo.exc.InternalTorchDynamoError: 'NoneType' object has no attribute 'guards'
|
|
# due to bad empty list handling
|
|
def test_empty_list_contains_with_jump(self):
|
|
def fn(x, l):
|
|
if x in l:
|
|
return x.cos()
|
|
return x.sin()
|
|
|
|
counter = CompileCounter()
|
|
compiled_fn = torch._dynamo.optimize(counter)(fn)(torch.randn([2, 2]), [])
|
|
self.assertEqual(counter.frame_count, 1)
|
|
|
|
def test_graph_break_on_jit_isinstance(self):
|
|
@torch.compile(backend="eager")
|
|
def fn(x):
|
|
if torch.jit.isinstance(x, List[str]):
|
|
return x * 2
|
|
return x
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
x = torch.rand(4)
|
|
self.assertTrue(same(fn(x), opt_fn(x)))
|
|
|
|
def test_add_sub_alpha_out(self):
|
|
inp = torch.randn(2, 3, 4)
|
|
other = 1
|
|
alpha = 2
|
|
for op in [torch.add, torch.sub]:
|
|
out = torch.zeros(2, 3, 4)
|
|
compile_out = torch.zeros(2, 3, 4)
|
|
op(inp, other, alpha=alpha, out=out)
|
|
compiled_fn = torch.compile(op, dynamic=True)
|
|
compiled_fn(inp, other, alpha=alpha, out=compile_out)
|
|
self.assertTrue(same(out, compile_out))
|
|
|
|
def test_negative_shape_guard(self):
|
|
def fn(x):
|
|
if x.size() != (5, 1, 2, 3):
|
|
return x.cos()
|
|
return x.sin()
|
|
|
|
counter = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=counter, dynamic=True)
|
|
|
|
x = torch.ones(5, 1, 3, 4)
|
|
x2 = torch.ones(5, 1, 2, 3)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
self.assertEqual(fn(x2), opt_fn(x2))
|
|
self.assertEqual(counter.frame_count, 2)
|
|
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_deferred_runtime_asserts(self):
|
|
@torch.compile(fullgraph=True)
|
|
def f(x):
|
|
y = x.item()
|
|
torch._check_is_size(y)
|
|
if y >= 0:
|
|
return x * 2
|
|
else:
|
|
return x * 3
|
|
|
|
f(torch.tensor([3]))
|
|
self.assertRaises(RuntimeError, lambda: f(torch.tensor([-2])))
|
|
|
|
def test_addr_alpha_beta_out(self):
|
|
inp = torch.randn(2, 3)
|
|
vec1 = torch.randn(2)
|
|
vec2 = torch.randn(3)
|
|
alpha = 2
|
|
beta = 5
|
|
|
|
out = torch.zeros(2, 3)
|
|
compile_out = torch.zeros(2, 3)
|
|
|
|
torch.addr(inp, vec1, vec2, alpha=alpha, beta=beta, out=out)
|
|
compiled_fn = torch.compile(torch.addr, dynamic=True)
|
|
compiled_fn(inp, vec1, vec2, alpha=alpha, beta=beta, out=compile_out)
|
|
self.assertTrue(same(out, compile_out))
|
|
|
|
def test_setattr_requires_grad_graph_breaks(self):
|
|
def fn(x):
|
|
z = x + 4
|
|
x.requires_grad = True
|
|
y = x * z
|
|
return y
|
|
|
|
for backend in ["count", "eager", "aot_eager"]:
|
|
if backend == "count":
|
|
backend = CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=backend)
|
|
|
|
eager = torch.zeros(5)
|
|
compiled = eager.clone()
|
|
|
|
out_eager = fn(eager)
|
|
out_opt = opt_fn(compiled)
|
|
|
|
self.assertEqual(out_eager, out_opt)
|
|
|
|
out_eager.sum().backward()
|
|
out_opt.sum().backward()
|
|
|
|
self.assertEqual(eager, compiled)
|
|
if isinstance(backend, CompileCounter):
|
|
self.assertEqual(backend.frame_count, 2) # graph breaks
|
|
|
|
def test_dynamic_shapes_double_not_equal(self):
|
|
# https://github.com/pytorch/pytorch/issues/113393
|
|
def fn(x):
|
|
if x.size() != (5, 1, 2, 3):
|
|
return x.cos()
|
|
return x.sin()
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
|
|
x = torch.ones(5, 1, 2, 3)
|
|
x2 = torch.ones(5, 1, 3, 4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
self.assertEqual(fn(x2), opt_fn(x2))
|
|
|
|
def test_inductor_no_recursionerror_on_for_loops(self):
|
|
def forward(x):
|
|
for _ in range(1000):
|
|
x = 1.0 * x
|
|
return x
|
|
|
|
self.assertTrue(
|
|
same(torch.compile(forward)(torch.tensor([1.0])), torch.tensor([1.0]))
|
|
)
|
|
|
|
def test_user_defined_object_callable(self):
|
|
# https://github.com/pytorch/pytorch/issues/114019
|
|
class MyCallable:
|
|
def __call__(self, x):
|
|
return x + 1
|
|
|
|
def fn(x):
|
|
# Create in graph - will not have source
|
|
return MyCallable()(x)
|
|
|
|
fn_opt = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn_opt(torch.zeros(1)), fn(torch.zeros(1)))
|
|
|
|
@torch._dynamo.config.patch(log_compilation_metrics=True)
|
|
def test_many_views_with_mutation(self):
|
|
# When symbolic storage offsets were added in #113734, tensors_definitely_do_not_overlap
|
|
# began adding shape guards - a quadratic amount relative to the number of inputs.
|
|
# Test this configuration, and test that a reasonable number of guards are added.
|
|
# Note, when dynamic shapes are turned on, this test fails and we still get quadratic guards.
|
|
def fn(x):
|
|
x[0].relu_()
|
|
return torch.cat(x).sum()
|
|
|
|
AMT = 32
|
|
src = torch.rand(16 * (AMT + 1))
|
|
|
|
x = [src.as_strided((4, 4), (4, 1), 3 + 16 * i) for i in range(AMT)]
|
|
|
|
torch._dynamo.reset()
|
|
torch._dynamo.utils.clear_compilation_metrics()
|
|
|
|
res = torch.compile(fn, backend="aot_eager")(x)
|
|
|
|
all_metrics = torch._dynamo.utils.get_compilation_metrics()
|
|
|
|
total_guards = sum(metric.guard_count for metric in all_metrics)
|
|
self.assertLess(total_guards, AMT * 8)
|
|
|
|
total_shape_env_guards = sum(
|
|
metric.shape_env_guard_count for metric in all_metrics
|
|
)
|
|
self.assertLess(total_shape_env_guards, AMT * 8)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/118799
|
|
def test_subclass_graph_output_repro(self):
|
|
@torch._dynamo.allow_in_graph
|
|
def to_subclass(x):
|
|
return TwoTensor(x.clone(), x.clone())
|
|
|
|
def f(x):
|
|
tmp_subclass = to_subclass(x)
|
|
return tmp_subclass.view(-1)
|
|
|
|
x = torch.ones(2)
|
|
out_ref = f(x)
|
|
out_test = torch.compile(f, backend="aot_eager")(x)
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
def test_numpy_tobytes_no_error(self):
|
|
def fn(x):
|
|
x += 1
|
|
z = x.tobytes()
|
|
x += 1
|
|
return z
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
opt_arg, arg = np.array([1, 2]), np.array([1, 2])
|
|
self.assertEqual(opt_fn(opt_arg), fn(arg))
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_numpy_not_ndarray_recompiles(self):
|
|
import torch
|
|
|
|
def fn(x=None):
|
|
if x is None:
|
|
x = np.ones(3)
|
|
elif isinstance(x, int):
|
|
x = np.ones(6)
|
|
elif isinstance(x, str):
|
|
x = np.ones(9)
|
|
return x**2
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
|
|
x = np.zeros((2, 2))
|
|
|
|
self.assertEqual(opt_fn(x), fn(x))
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
self.assertEqual(opt_fn(), fn())
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
self.assertEqual(opt_fn(10), fn(10))
|
|
self.assertEqual(cnt.frame_count, 3)
|
|
self.assertEqual(opt_fn("10"), fn("10"))
|
|
self.assertEqual(cnt.frame_count, 4)
|
|
|
|
@parametrize(
|
|
"backend",
|
|
["eager", "aot_eager", "inductor"],
|
|
)
|
|
@parametrize(
|
|
"func_name",
|
|
["func1", "func2", "func3"],
|
|
)
|
|
def test_tensor_set_data(self, backend, func_name):
|
|
# https://github.com/pytorch/pytorch/issues/113030
|
|
def func1(x, y):
|
|
x.data = y
|
|
x.add_(1)
|
|
return x
|
|
|
|
def func2(x, y):
|
|
x.data = y
|
|
y.data = torch.zeros([0])
|
|
return x
|
|
|
|
def func3(x, y):
|
|
z = x
|
|
x.data = y
|
|
y.data = torch.zeros([0])
|
|
return torch.tensor(x is z)
|
|
|
|
funcs = {"func1": func1, "func2": func2, "func3": func3}
|
|
func = funcs[func_name]
|
|
|
|
if backend != "eager" and func is func1:
|
|
# add_ not working w/ aot_autograd?
|
|
return
|
|
|
|
torch._dynamo.reset()
|
|
cnt = torch._dynamo.testing.CompileCounterWithBackend(backend)
|
|
|
|
compiled_fn = torch.compile(func, backend=cnt, fullgraph=True)
|
|
requires_grad = func is not func1
|
|
for i in range(0, 5):
|
|
# Inputs
|
|
eager_a = torch.ones([6], requires_grad=requires_grad)
|
|
compiled_a = torch.ones([6], requires_grad=requires_grad)
|
|
|
|
eager_b = torch.ones([6], requires_grad=requires_grad)
|
|
compiled_b = torch.ones([6], requires_grad=requires_grad)
|
|
|
|
# Eager
|
|
out_eager = func(eager_a, eager_b)
|
|
# Compiled
|
|
out_compiled = compiled_fn(compiled_a, compiled_b)
|
|
self.assertEqual(eager_a, compiled_a)
|
|
self.assertEqual(eager_b, compiled_b)
|
|
self.assertTrue(torch.equal(out_eager, out_compiled))
|
|
|
|
# func1 hits a leaf Variable that requires grad is being used in an in-place operation
|
|
if requires_grad:
|
|
bwd_inp_eager = torch.randn([6])
|
|
bwd_inp_compiled = torch.clone(bwd_inp_eager)
|
|
eager_a.backward(bwd_inp_eager)
|
|
compiled_a.backward(bwd_inp_compiled)
|
|
self.assertEqual(eager_a.grad, compiled_a.grad)
|
|
|
|
# Prove guarding works - we run the compiled_fn 5 times
|
|
# frame_count should stay at 1.
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
@unittest.skipIf(
|
|
TEST_WITH_ROCM or not PLATFORM_SUPPORTS_FLASH_ATTENTION,
|
|
"flash attention not supported",
|
|
)
|
|
def test_flash_attn_backward_mixed_strides(self):
|
|
# in this repro, "grad_out" and "value" are transposed tensors,
|
|
# but "key" and "value" are contiguous
|
|
def gen_inputs(device):
|
|
return (
|
|
torch.randn(
|
|
2, 513, 16, 64, dtype=torch.float16, device=device
|
|
).transpose(1, 2),
|
|
torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
|
|
torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
|
|
torch.randn(
|
|
2, 513, 16, 64, dtype=torch.float16, device=device
|
|
).transpose(1, 2),
|
|
torch.randn(2, 16, 513, 64, dtype=torch.float16, device=device),
|
|
torch.randn(2, 16, 513, device=device),
|
|
None,
|
|
None,
|
|
513,
|
|
513,
|
|
0.0,
|
|
False,
|
|
torch.tensor(1, dtype=torch.int64),
|
|
torch.tensor(1, dtype=torch.int64),
|
|
)
|
|
|
|
inps_cuda = gen_inputs("cuda")
|
|
inps_meta = gen_inputs("meta")
|
|
(
|
|
out1_ref,
|
|
out2_ref,
|
|
out3_ref,
|
|
) = torch.ops.aten._scaled_dot_product_flash_attention_backward(
|
|
*inps_cuda, scale=0.125
|
|
)
|
|
from torch._meta_registrations import meta__scaled_dot_product_flash_backward
|
|
|
|
out1_test, out2_test, out3_test = meta__scaled_dot_product_flash_backward(
|
|
*inps_meta, scale=0.125
|
|
)
|
|
|
|
self.assertEqual(out1_ref.shape, out1_test.shape)
|
|
self.assertEqual(out1_ref.stride(), out1_test.stride())
|
|
self.assertEqual(out2_ref.shape, out2_test.shape)
|
|
self.assertEqual(out2_ref.stride(), out2_test.stride())
|
|
self.assertEqual(out3_ref.shape, out3_test.shape)
|
|
self.assertEqual(out3_ref.stride(), out3_test.stride())
|
|
|
|
def test_user_ctor_ctx_manager(self):
|
|
class UserCtxManager:
|
|
def __enter__(self):
|
|
return 1
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
pass
|
|
|
|
def fn(x, y):
|
|
ucm = UserCtxManager()
|
|
return x * x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
|
|
x = torch.rand([2, 2])
|
|
opt_fn(x, x)
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
|
|
@torch._dynamo.config.patch(capture_scalar_outputs=True)
|
|
def test_unbacked_arange_in_bounds(self):
|
|
# see https://github.com/pytorch/pytorch/issues/113002
|
|
class PaddingNet(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
|
|
def forward(self, lengths):
|
|
max_seq_len = lengths.max().item()
|
|
row_vector = torch.arange(0, max_seq_len, 1)
|
|
matrix = torch.unsqueeze(lengths, dim=-1)
|
|
mask = row_vector < matrix
|
|
mask = mask.type(torch.float32)
|
|
mask_3d_btd = mask[:, :, None]
|
|
return mask_3d_btd
|
|
|
|
model = PaddingNet()
|
|
lengths = torch.tensor([5, 4, 4, 4], dtype=torch.int32)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt, nopython=True)(model)
|
|
opt_fn(lengths)
|
|
self.assertEqual(cnt.frame_count, 1)
|
|
|
|
def test_overlapping_inputs_with_dynamic_shapes_error(self):
|
|
@torch.compile(backend="aot_eager")
|
|
def fn(a, b, c, d, e, f):
|
|
a.mul_(2)
|
|
b.mul_(2)
|
|
c.mul_(2)
|
|
d.mul_(2)
|
|
e.mul_(2)
|
|
f.mul_(2)
|
|
|
|
base = torch.ones(2, 20)
|
|
a = base[:, 0:2]
|
|
b = base[:, 2:4]
|
|
c = base[:, 4:6]
|
|
d = base[:, 6:8]
|
|
e = base[:, 8:10]
|
|
f = base[:, 10:12]
|
|
f2 = base[:, 10:14]
|
|
out = fn(a, b, c, d, e, f)
|
|
with self.assertRaisesRegex(
|
|
AssertionError, "is being compiled with dynamic shapes"
|
|
):
|
|
out2 = fn(a, b, c, d, e, f2)
|
|
|
|
def test_user_ctor_ctx_manager_custom_init(self):
|
|
class UserCtxManager:
|
|
def __init__(self, x):
|
|
x[0] = 10
|
|
|
|
def __enter__(self):
|
|
return 1
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
pass
|
|
|
|
def fn(x, y):
|
|
ucm = UserCtxManager(y)
|
|
return x * y[0]
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt, nopython=True)(fn)
|
|
x = torch.rand([2, 2])
|
|
self.assertEqual(opt_fn(x, [5]), fn(x, [5]))
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
|
|
def test_user_ctor_ctx_manager_custom_init_graph_break(self):
|
|
counter = [0]
|
|
|
|
class UserCtxManager:
|
|
def __init__(self, k):
|
|
k[0] += 1
|
|
|
|
def __enter__(self):
|
|
return 1
|
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb):
|
|
pass
|
|
|
|
def fn(x, counter):
|
|
x = x * x
|
|
ucm = UserCtxManager(counter)
|
|
return x * x
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch._dynamo.optimize(cnt)(fn)
|
|
x = torch.rand([2, 2])
|
|
self.assertEqual(opt_fn(x, counter), fn(x, counter))
|
|
self.assertEqual(counter[0], 2)
|
|
for i in range(0, 10):
|
|
opt_fn(x, counter)
|
|
self.assertEqual(counter[0], 12)
|
|
if torch._dynamo.config.assume_static_by_default:
|
|
self.assertExpectedInline(cnt.frame_count, """2""")
|
|
else:
|
|
self.assertExpectedInline(cnt.frame_count, """1""")
|
|
|
|
@unittest.expectedFailure
|
|
def test_many_overlapping_inputs_does_not_explode_guards(self):
|
|
from torch._dynamo.backends.common import aot_autograd
|
|
|
|
# Before, this was (9702, 0)
|
|
num_shape_guards = None
|
|
num_aot_guards = None
|
|
num_compiles = 0
|
|
|
|
def guard_count_backend(gm, *args):
|
|
nonlocal num_shape_guards
|
|
nonlocal num_aot_guards
|
|
nonlocal num_compiles
|
|
num_shape_guards = len(
|
|
torch._guards.TracingContext.try_get().fake_mode.shape_env.guards
|
|
)
|
|
num_aot_guards = len(
|
|
torch._guards.TracingContext.try_get().guards_context.aotautograd_guards
|
|
)
|
|
num_compiles += 1
|
|
return gm
|
|
|
|
aot_guard_counter = aot_autograd(fw_compiler=guard_count_backend)
|
|
|
|
@torch.compile(backend=aot_guard_counter, dynamic=True)
|
|
def f(*args):
|
|
for a in args:
|
|
a.add_(1)
|
|
|
|
x = torch.ones(1000, requires_grad=True)
|
|
args = x.split(10)
|
|
|
|
with torch.no_grad():
|
|
f(*args)
|
|
# In this example, there were 4950 guards (roughly (# tensors) ^ 2 // 2),
|
|
# because every pair of aliased inputs needs a guard.
|
|
self.assertTrue(num_aot_guards < 5000)
|
|
# But there are no dynamic shape guards.
|
|
self.assertEqual(num_shape_guards, 0)
|
|
# don't recompile
|
|
with torch.no_grad():
|
|
f(*args)
|
|
self.assertEqual(num_compiles, 1)
|
|
|
|
def test_invalid_seq_unpack(self):
|
|
def myfn(arg):
|
|
(a, b) = arg
|
|
|
|
def fn():
|
|
return myfn((1, 2, 3))
|
|
|
|
try:
|
|
torch.compile(fn)()
|
|
except ValueError:
|
|
pass
|
|
else:
|
|
self.fail("expected exception")
|
|
|
|
def test_megablocks_moe(self):
|
|
try:
|
|
from megablocks.layers import moe
|
|
from megablocks.layers.arguments import Arguments
|
|
except ImportError as e:
|
|
raise unittest.SkipTest("requires megablocks") from e
|
|
bs, sl, hs, num_experts, top_k = (16, 1024, 512, 1, 1)
|
|
args = Arguments(
|
|
hidden_size=hs,
|
|
ffn_hidden_size=hs * 2,
|
|
moe_num_experts=num_experts,
|
|
moe_capacity_factor=1,
|
|
moe_top_k=top_k,
|
|
)
|
|
moe_mlp = moe.MoE(args)
|
|
moe_mlp.cuda(torch.cuda.current_device()).half()
|
|
x = torch.randn(sl, bs, hs).cuda().half()
|
|
out1, _ = moe_mlp(x)
|
|
out2, _ = torch.compile(moe_mlp, backend="eager")(x)
|
|
self.assertEqual(out1, out2)
|
|
|
|
def test_udf_classes_reconstruction(self):
|
|
def fn(x):
|
|
o = T(5)
|
|
return o.x + x
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
T = IncByOne
|
|
|
|
x = torch.randn(4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
# This should recompile
|
|
T = IncByTwo
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_contains_range_constprop(self):
|
|
def fn(x):
|
|
# dynamo should const prop to False
|
|
if 3 in range(0, 10):
|
|
return x + 1
|
|
else:
|
|
return x + 2
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
x = torch.zeros(4)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
# https://github.com/pytorch/pytorch/issues/104505
|
|
def test_as_strided_on_base_with_mutation_works(self):
|
|
def foo(a):
|
|
f = a.as_strided((2,), (1,), 0)
|
|
f.add_(1.0)
|
|
return a
|
|
|
|
a = torch.randn(2, 4)
|
|
a_ref = a.clone()
|
|
out_ref = foo(a_ref)
|
|
f_compiled = torch.compile(foo, backend="aot_eager")
|
|
out = f_compiled(a)
|
|
self.assertEqual(out_ref, out)
|
|
self.assertEqual(a_ref, a)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/104505
|
|
def test_as_strided_on_existing_view_banned(self):
|
|
def foo(a):
|
|
e = a.diagonal()
|
|
f = e.as_strided((2,), (1,), 0)
|
|
f.add_(1.0)
|
|
return a
|
|
|
|
a = torch.randn(2, 4)
|
|
a_ref = a.clone()
|
|
out_ref = foo(a_ref)
|
|
f_compiled = torch.compile(foo, backend="aot_eager")
|
|
with self.assertRaisesRegex(
|
|
RuntimeError,
|
|
"encountered a mutation on a view chain of length 2, where view 1 was an as_strided",
|
|
):
|
|
out = f_compiled(a)
|
|
|
|
def test_dont_aggressively_write_assert(self):
|
|
record_graph = torch._dynamo.testing.EagerAndRecordGraphs()
|
|
|
|
@torch.compile(dynamic=True, backend=record_graph)
|
|
def f(x):
|
|
assert x.shape[0] > 3
|
|
assert x[0].sum() > 0
|
|
assert 1 % (x.shape[0] // 2) != 0
|
|
assert 32 * (x.shape[0] // 2) ** 2 - 16 * (x.shape[0] // 2) != 0
|
|
return x.cos()
|
|
|
|
f(torch.ones(6, 4))
|
|
graph = record_graph.graphs[0]
|
|
# It is bit annoying that we generate useless statements for
|
|
# shape guards, but DCE should be able to remove them since t
|
|
# there is no backed assert on them. The reason this is ok is
|
|
# because dynamo will only skip the assert statement, but not
|
|
# the instructions before it.
|
|
self.assertExpectedInline(
|
|
str(graph.code).strip(),
|
|
"""\
|
|
def forward(self, s0 : torch.SymInt, s1 : torch.SymInt, L_x_ : torch.Tensor):
|
|
l_x_ = L_x_
|
|
getitem_2 = l_x_[0]
|
|
sum_1 = getitem_2.sum(); getitem_2 = None
|
|
gt_1 = sum_1 > 0; sum_1 = None
|
|
_assert_async = torch._assert_async(gt_1, 'assertion error'); gt_1 = None
|
|
cos = l_x_.cos(); l_x_ = None
|
|
return (cos,)""",
|
|
)
|
|
for node in graph.graph.nodes:
|
|
if "example_value" in node.meta and isinstance(
|
|
node.meta["example_value"], torch._subclasses.fake_tensor.FakeTensor
|
|
):
|
|
shape_env = node.meta["example_value"].fake_mode.shape_env
|
|
lower_ranges = [val.lower for val in shape_env.var_to_range.values()]
|
|
self.assertTrue(lower_ranges == [4, 2])
|
|
|
|
@torch.compile(dynamic=True, backend=record_graph)
|
|
def f_fail(x):
|
|
assert x.shape[0] < 3
|
|
|
|
# We graph-break here, so the failure should be eager
|
|
with self.assertRaisesRegex(AssertionError, ""):
|
|
f_fail(torch.ones(6, 4))
|
|
|
|
def test_detectron2_instances_cat(self):
|
|
class Instances:
|
|
def __init__(self, image_size: Tuple[int, int], **kwargs: Any):
|
|
self._image_size = image_size
|
|
self._fields: Dict[str, Any] = {}
|
|
for k, v in kwargs.items():
|
|
self.set(k, v)
|
|
|
|
@property
|
|
def image_size(self) -> Tuple[int, int]:
|
|
return self._image_size
|
|
|
|
def __setattr__(self, name: str, val: Any) -> None:
|
|
if name.startswith("_"):
|
|
super().__setattr__(name, val)
|
|
else:
|
|
self.set(name, val)
|
|
|
|
def __getattr__(self, name: str) -> Any:
|
|
if name == "_fields" or name not in self._fields:
|
|
raise AttributeError(
|
|
f"Cannot find field '{name}' in the given Instances!"
|
|
)
|
|
return self._fields[name]
|
|
|
|
def __len__(self) -> int:
|
|
for v in self._fields.values():
|
|
# use __len__ because len() has to be int and is not friendly to tracing
|
|
return v.__len__()
|
|
raise NotImplementedError("Empty Instances does not support __len__!")
|
|
|
|
def set(self, name: str, value: Any) -> None:
|
|
with warnings.catch_warnings(record=True):
|
|
data_len = len(value)
|
|
if len(self._fields):
|
|
assert (
|
|
len(self) == data_len
|
|
), f"Adding a field of length {data_len} to a Instances of length {len(self)}"
|
|
self._fields[name] = value
|
|
|
|
def get(self, name: str) -> Any:
|
|
return self._fields[name]
|
|
|
|
@staticmethod
|
|
def cat(instance_lists: List["Instances"]) -> "Instances":
|
|
assert all(isinstance(i, Instances) for i in instance_lists)
|
|
assert len(instance_lists) > 0
|
|
if len(instance_lists) == 1:
|
|
return instance_lists[0]
|
|
|
|
image_size = instance_lists[0].image_size
|
|
if not isinstance(
|
|
image_size, torch.Tensor
|
|
): # could be a tensor in tracing
|
|
for i in instance_lists[1:]:
|
|
assert i.image_size == image_size
|
|
ret = Instances(image_size)
|
|
for k in instance_lists[0]._fields.keys():
|
|
values = [i.get(k) for i in instance_lists]
|
|
v0 = values[0]
|
|
if isinstance(v0, torch.Tensor):
|
|
values = torch.cat(values, dim=0)
|
|
elif isinstance(v0, list):
|
|
values = list(itertools.chain(*values))
|
|
elif hasattr(type(v0), "cat"):
|
|
values = type(v0).cat(values)
|
|
else:
|
|
raise ValueError(
|
|
f"Unsupported type {type(v0)} for concatenation"
|
|
)
|
|
ret.set(k, values)
|
|
return ret
|
|
|
|
instances = [
|
|
Instances((16, 16), a=torch.randn(16, 16), b=torch.randn(16, 16))
|
|
for _ in range(3)
|
|
]
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(instances):
|
|
return instances[0].cat(instances)
|
|
|
|
actual = fn(instances)
|
|
expected = instances[0].cat(instances)
|
|
self.assertEqual(type(actual), type(expected))
|
|
self.assertEqual(actual.__dict__, expected.__dict__)
|
|
|
|
def test_weakref(self):
|
|
def fn(x_weak, weight, y):
|
|
if x_weak is not None and x_weak() is not weight:
|
|
return torch.sin(y)
|
|
return torch.cos(y)
|
|
|
|
weight = torch.randn(4)
|
|
y = torch.randn(4)
|
|
x_weak = weakref.ref(weight)
|
|
|
|
ref = fn(x_weak, weight, y)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
res = opt_fn(x_weak, weight, y)
|
|
self.assertEqual(ref, res)
|
|
|
|
def test_weakref_reconstruct(self):
|
|
def fn(x_weak, weight, y):
|
|
y = torch.sin(y)
|
|
referent = x_weak()
|
|
torch._dynamo.graph_break()
|
|
if referent is not weight:
|
|
return torch.sin(y)
|
|
return torch.cos(y)
|
|
|
|
weight = torch.randn(4)
|
|
y = torch.randn(4)
|
|
x_weak = weakref.ref(weight)
|
|
|
|
ref = fn(x_weak, weight, y)
|
|
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
opt_fn = torch.compile(fn, backend=cnt)
|
|
res = opt_fn(x_weak, weight, y)
|
|
self.assertEqual(ref, res)
|
|
self.assertEqual(cnt.frame_count, 2)
|
|
|
|
def test_weakref_del(self):
|
|
def fn(x_weak, y):
|
|
x = x_weak()
|
|
if x is not None:
|
|
return torch.sin(y)
|
|
return torch.cos(y)
|
|
|
|
weight = torch.randn(4)
|
|
x_weak = weakref.ref(weight)
|
|
y = torch.randn(4)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
ref = fn(x_weak, y)
|
|
res = opt_fn(x_weak, y)
|
|
self.assertEqual(ref, res)
|
|
|
|
del weight
|
|
gc.collect()
|
|
ref = fn(x_weak, y)
|
|
res = opt_fn(x_weak, y)
|
|
self.assertEqual(ref, res)
|
|
|
|
# @torch._functorch.config.patch(
|
|
# recompute_views=True,
|
|
# )
|
|
# def test_storage_resize_forward_full_graph(self):
|
|
# class TestModule(torch.nn.Module):
|
|
# def __init__(self):
|
|
# super().__init__()
|
|
# self.param = torch.nn.Parameter(torch.randn(4, 4))
|
|
|
|
# def forward(self, x):
|
|
# self.param.untyped_storage().resize_(
|
|
# self.param.numel() * self.param.itemsize
|
|
# )
|
|
# with torch.no_grad():
|
|
# torch._foreach_copy_([self.param], [x])
|
|
# out = torch.matmul(self.param, self.param)
|
|
# self.param.untyped_storage().resize_(0)
|
|
# return out
|
|
|
|
# def post_accumulate_grad_hook(param):
|
|
# param.untyped_storage().resize_(0)
|
|
|
|
# # Beginning of backward, resize and put data into the param
|
|
# def pre_backward_hook(module, grad) -> None:
|
|
# module.param.untyped_storage().resize_(
|
|
# self.param.numel() * self.param.itemsize
|
|
# )
|
|
# with torch.no_grad():
|
|
# # simulates loading data into param from allgather
|
|
# module.param.fill_(2)
|
|
|
|
# def post_forward_hook(module, args, output):
|
|
# output.register_hook(functools.partial(pre_backward_hook, module))
|
|
|
|
# x = torch.randn(4, 4)
|
|
|
|
# mod_ref = TestModule()
|
|
# mod_test = deepcopy(mod_ref)
|
|
|
|
# # Start the param off with zero storage size to mimic fsdp
|
|
# mod_ref.param.untyped_storage().resize_(0)
|
|
# mod_test.param.untyped_storage().resize_(0)
|
|
|
|
# # Resize storage at beginning of backward
|
|
# # Free storage at end of backward
|
|
# mod_ref.register_forward_hook(post_forward_hook, prepend=False)
|
|
# mod_ref.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
|
|
# mod_test.register_forward_hook(post_forward_hook, prepend=False)
|
|
# mod_test.param.register_post_accumulate_grad_hook(post_accumulate_grad_hook)
|
|
|
|
# mod_test = torch.compile(mod_test, backend=aot_graph_capture_backend)
|
|
|
|
# out_ref = mod_ref(x)
|
|
# out_test = mod_test(x)
|
|
# self.assertExpectedInline(
|
|
# str(fw_graph[0].code.strip()),
|
|
# """\
|
|
# def forward(self, primals_1, primals_2):
|
|
# _foreach_copy = torch.ops.aten._foreach_copy.default([primals_1], [primals_2]); primals_1 = primals_2 = None
|
|
# getitem = _foreach_copy[0]; _foreach_copy = None
|
|
# mm = torch.ops.aten.mm.default(getitem, getitem)
|
|
# return [mm, getitem]""",
|
|
# )
|
|
# self.assertEqual(out_ref, out_test)
|
|
|
|
def test_super_in_staticmethod(self):
|
|
class A:
|
|
@staticmethod
|
|
def foo():
|
|
return super().__init__()
|
|
|
|
def fn(obj):
|
|
return obj.foo()
|
|
|
|
obj = A()
|
|
|
|
try:
|
|
fn(obj)
|
|
except Exception as e:
|
|
orig_str = str(e)
|
|
self.assertIn("no arguments", orig_str)
|
|
|
|
try:
|
|
torch.compile(backend="eager")(fn)(obj)
|
|
except Exception as e:
|
|
compiled_str = str(e)
|
|
self.assertEqual(orig_str, compiled_str)
|
|
|
|
def test_vc_bumped_in_inference_graph(self):
|
|
@torch.compile
|
|
def f(x):
|
|
return x.mul_(2)
|
|
|
|
x = torch.randn(4)
|
|
vc_before = x._version
|
|
f(x)
|
|
vc_after = x._version
|
|
self.assertTrue(vc_after > vc_before)
|
|
|
|
def test_nn_module_callable(self):
|
|
class M(nn.Module):
|
|
def forward(self, x):
|
|
return x.sin()
|
|
|
|
def f(m):
|
|
return callable(m)
|
|
|
|
res = torch.compile(f, fullgraph=True)(M())
|
|
self.assertTrue(res)
|
|
|
|
def test_stk_sdd_is_transposed(self):
|
|
trigger_graph_break = False
|
|
|
|
def _is_transposed(x):
|
|
return (
|
|
not x.is_contiguous()
|
|
and x.stride()[0] == 1
|
|
and x.stride()[1] == x.size()[0]
|
|
)
|
|
|
|
class SDD(torch.autograd.Function):
|
|
@staticmethod
|
|
def forward(ctx, lhs, rhs):
|
|
ctx.save_for_backward(lhs, rhs)
|
|
out = torch.full_like(lhs, 1.0, dtype=lhs.dtype, device=lhs.device)
|
|
return out
|
|
|
|
@staticmethod
|
|
def backward(ctx, dy):
|
|
saved_tensors = ctx.saved_tensors
|
|
lhs, rhs = saved_tensors[:2]
|
|
trans_a = _is_transposed(lhs)
|
|
trans_b = _is_transposed(rhs)
|
|
dlhs = None
|
|
if ctx.needs_input_grad[0]:
|
|
dlhs = torch.full_like(lhs, 1.0 if trans_a else 2.0)
|
|
drhs = None
|
|
if ctx.needs_input_grad[1]:
|
|
drhs = torch.full_like(rhs, 1.0 if trans_b else 2.0)
|
|
if trigger_graph_break:
|
|
if _is_transposed(dy):
|
|
return dlhs + 1, drhs + 1, None, None
|
|
return dlhs, drhs, None, None
|
|
|
|
x1 = torch.randn((8, 8), requires_grad=True)
|
|
y1 = torch.randn((8, 8)).transpose(0, 1).requires_grad_(True)
|
|
x2 = torch.randn((8, 8), requires_grad=True)
|
|
y2 = torch.randn((8, 8)).transpose(0, 1).requires_grad_(True)
|
|
|
|
SDD.apply(x1, y1).sum().backward()
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn():
|
|
return SDD.apply(x2, y2)
|
|
|
|
fn().sum().backward()
|
|
|
|
self.assertEqual(x1.grad, x2.grad)
|
|
self.assertEqual(y1.grad, y2.grad)
|
|
|
|
trigger_graph_break = True
|
|
with self.assertRaises(torch._dynamo.exc.Unsupported):
|
|
fn().sum().backward()
|
|
|
|
def test_partially_initialized_module_property(self):
|
|
class Matrix(torch.nn.Module):
|
|
def __init__(self, data):
|
|
super().__init__()
|
|
self._data = data
|
|
self.foo = 10 * self.blocking
|
|
|
|
@property
|
|
def data(self):
|
|
return self._data
|
|
|
|
@property
|
|
def blocking(self):
|
|
return self.data.shape[1]
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn():
|
|
return Matrix(torch.randn(10, 20))
|
|
|
|
v = fn()
|
|
self.assertEqual(v.foo, 200)
|
|
self.assertEqual(v.data.shape, (10, 20))
|
|
self.assertEqual(type(v), Matrix)
|
|
|
|
def test_nn_parametrize(self):
|
|
class Module(nn.Module):
|
|
def __init__(self) -> None:
|
|
super().__init__()
|
|
self.param = torch.nn.Parameter(torch.randn(10, 10))
|
|
|
|
def forward(self, x):
|
|
return self.param @ x
|
|
|
|
class Parametrization(torch.nn.Module):
|
|
def forward(self, x):
|
|
return torch.sin(x)
|
|
|
|
m = Module()
|
|
torch.nn.utils.parametrize.register_parametrization(
|
|
m, "param", Parametrization()
|
|
)
|
|
|
|
sin_found = False
|
|
|
|
def backend(gm, _):
|
|
nonlocal sin_found
|
|
for node in gm.graph.nodes:
|
|
if node.target is torch.sin:
|
|
sin_found = True
|
|
return gm
|
|
|
|
opt_m = torch.compile(m, backend=backend, fullgraph=True)
|
|
inp = torch.randn(10, 10)
|
|
self.assertEqual(m(inp), opt_m(inp))
|
|
self.assertTrue(sin_found)
|
|
|
|
torch.nn.utils.parametrize.remove_parametrizations(m, "param")
|
|
sin_found = False
|
|
self.assertEqual(m(inp), opt_m(inp))
|
|
self.assertFalse(sin_found)
|
|
|
|
def test_nn_module_property_closure(self):
|
|
x = torch.randn(10, 10)
|
|
|
|
class Mod(torch.nn.Module):
|
|
@property
|
|
def y(self):
|
|
return torch.ones(10, 10) + x
|
|
|
|
def forward(self, x):
|
|
return x @ self.y
|
|
|
|
mod = Mod()
|
|
|
|
def fn(x):
|
|
return mod(x)
|
|
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
|
|
inp = torch.randn(10, 10)
|
|
self.assertEqual(fn(inp), opt_fn(inp))
|
|
|
|
def test_global_fn_mutation(self):
|
|
def foo(x, y):
|
|
return global_fn(x) + y
|
|
|
|
x = torch.ones(1)
|
|
y = torch.ones(1)
|
|
|
|
opt = torch.compile(foo, fullgraph=True, backend="eager")
|
|
self.assertEqual(opt(x, y), foo(x, y))
|
|
|
|
# Change global_fn
|
|
global global_fn
|
|
|
|
def new_fn(x):
|
|
return torch.cos(x)
|
|
|
|
global_fn = new_fn
|
|
self.assertEqual(opt(x, y), foo(x, y))
|
|
|
|
# ref https://github.com/pytorch/pytorch/issues/123974
|
|
def test_list_reverse(self):
|
|
def ladder(x):
|
|
trail = x.size(-1)
|
|
assert trail > 2
|
|
weights = []
|
|
for s in [trail, trail - 1, trail - 2]:
|
|
weights.append(torch.ones(s, s - 1))
|
|
|
|
for w in weights:
|
|
x = x @ w
|
|
|
|
weights.reverse()
|
|
|
|
for w in weights:
|
|
x = x @ w.t()
|
|
|
|
return x
|
|
|
|
data = torch.randn(3, 4)
|
|
opt_ladder = torch.compile(ladder, fullgraph=True, backend="eager")
|
|
self.assertEqual(opt_ladder(data), ladder(data))
|
|
|
|
@unittest.expectedFailure
|
|
def test_trace_functional_tensor_with_error(self):
|
|
from torch._subclasses.fake_tensor import FakeTensorMode
|
|
from torch._subclasses.functional_tensor import (
|
|
FunctionalTensor,
|
|
FunctionalTensorMode,
|
|
)
|
|
|
|
def f(a, tmp):
|
|
a_view = a.view(-1)
|
|
with torch.no_grad():
|
|
a.set_(tmp)
|
|
a_view.mul_(2)
|
|
return a + tmp
|
|
|
|
fake_mode = FakeTensorMode()
|
|
with FunctionalTensorMode():
|
|
inp = torch.ones(3, 3, requires_grad=True)
|
|
inp = fake_mode.from_tensor(inp, static_shapes=True)
|
|
inp = FunctionalTensor.to_functional(inp)
|
|
|
|
tmp = torch.ones(3, 3, requires_grad=True)
|
|
tmp = fake_mode.from_tensor(tmp, static_shapes=True)
|
|
tmp = FunctionalTensor.to_functional(tmp)
|
|
|
|
opt_f = torch.compile(f, backend="eager")
|
|
with self.assertRaisesRegex(
|
|
RuntimeError, "cannot mutate tensors with frozen storage"
|
|
):
|
|
opt_f(inp, tmp)
|
|
|
|
# grad state may not be properly reset after the error
|
|
self.assertTrue(torch.is_grad_enabled())
|
|
|
|
def test_const_dict_keyerror(self):
|
|
d = {}
|
|
|
|
def fn(x):
|
|
try:
|
|
y = d[0]
|
|
except KeyError:
|
|
y = 1
|
|
return x + y
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
inp = torch.randn(3, 3)
|
|
self.assertEqual(fn(inp), opt_fn(inp))
|
|
|
|
def test_dict_tag_guard(self):
|
|
class Foo:
|
|
def __init__(self):
|
|
self.scalar = 10
|
|
|
|
def fn(d, x):
|
|
return d["a"] * d["b"] * d["c"].scalar * x
|
|
|
|
foo = Foo()
|
|
|
|
d = {"a": 2, "b": 3, "c": foo}
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
inp = torch.randn(3, 3)
|
|
self.assertEqual(fn(d, inp), opt_fn(d, inp))
|
|
|
|
d["a"] = 4
|
|
self.assertEqual(fn(d, inp), opt_fn(d, inp))
|
|
|
|
# Check that recompilation happens
|
|
foo.scalar = 12
|
|
self.assertEqual(fn(d, inp), opt_fn(d, inp))
|
|
|
|
def test_nonconst_issubclass(self):
|
|
def fn(x):
|
|
if issubclass(x.__class__, np.ndarray):
|
|
return 1
|
|
return 0
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
opt_fn(np.ones([3, 3]))
|
|
|
|
def test_issue126128(self):
|
|
def fn():
|
|
x = torch.randn(1, 10)
|
|
y = torch.randn(10, 1)
|
|
return torch.mm(x, y).sum()
|
|
|
|
def fn2():
|
|
x = torch.randn(10, 100)
|
|
y = torch.randn(100, 10)
|
|
return torch.mm(x, y).sum()
|
|
|
|
with fresh_inductor_cache():
|
|
torch.compile(fn)()
|
|
|
|
torch.compile(fn2)()
|
|
|
|
def test_jit_script_defaults(self):
|
|
@torch.jit.script
|
|
def fast_cos(x, c: float = 2.0):
|
|
return torch.cos(x) * c
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.fast_cos = fast_cos
|
|
|
|
def forward(self, x):
|
|
return self.fast_cos(x)
|
|
|
|
mod = Mod()
|
|
opt_mod = torch.compile(mod, backend="eager", fullgraph=True)
|
|
x = torch.randn(4)
|
|
self.assertEqual(mod(x), opt_mod(x))
|
|
|
|
def test_enum(self):
|
|
class ExplicitEnum(str, Enum):
|
|
@classmethod
|
|
def _missing_(cls, value):
|
|
raise ValueError(
|
|
f"{value} is not a valid {cls.__name__}, please select one of {list(cls._value2member_map_.keys())}"
|
|
)
|
|
|
|
class PaddingStrategy(ExplicitEnum):
|
|
LONGEST = "longest"
|
|
MAX_LENGTH = "max_length"
|
|
DO_NOT_PAD = "do_not_pad"
|
|
|
|
def fn(x):
|
|
a = PaddingStrategy("longest")
|
|
if a == PaddingStrategy.LONGEST:
|
|
return torch.sin(x)
|
|
return torch.cos(x)
|
|
|
|
x = torch.randn(3, 3)
|
|
opt_fn = torch.compile(fn, backend="eager", fullgraph=True)
|
|
self.assertEqual(fn(x), opt_fn(x))
|
|
|
|
def test_hasattr_builtin(self):
|
|
class MyClass:
|
|
foo: int = 1
|
|
|
|
def func(x, m):
|
|
if getattr(type(m), "foo", 0):
|
|
return x + MyClass.foo
|
|
return x
|
|
|
|
opt_func = torch.compile(func, backend="eager", fullgraph=True)
|
|
m = MyClass()
|
|
x = torch.zeros(())
|
|
self.assertEqual(func(x, m), opt_func(x, m))
|
|
self.assertEqual(func(x, 0), opt_func(x, 0))
|
|
|
|
def test_grad(self):
|
|
def fn(x, y):
|
|
x._grad = y
|
|
return x.grad.data
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
y = torch.randn(4)
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
self.assertEqual(fn(x, y), opt_fn(x, y))
|
|
|
|
def test_nn_module_stack_bc(self):
|
|
from torch._dynamo.mutation_guard import GenerationTracker
|
|
|
|
def compiler(gm, *args):
|
|
module_stacks = [
|
|
node.meta.get("nn_module_stack", None) for node in gm.graph.nodes
|
|
]
|
|
module_stacks, _ = pytree.tree_flatten(module_stacks)
|
|
module_stacks = [x for x in module_stacks if isinstance(x, str)]
|
|
for stack in module_stacks:
|
|
self.assertTrue("_module" not in stack)
|
|
return gm.forward
|
|
|
|
class SubMod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.linear = torch.nn.Linear(2, 2)
|
|
|
|
def forward(self, x):
|
|
return self.linear(x)
|
|
|
|
class Mod(torch.nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.submod1 = SubMod()
|
|
self.submod2 = SubMod()
|
|
|
|
def forward(self, x):
|
|
return self.submod1(x) + self.submod2(x)
|
|
|
|
mod = Mod()
|
|
opt_mod = torch.compile(mod, backend=compiler)
|
|
opt_mod(torch.randn(2, 2))
|
|
|
|
with torch._dynamo.config.patch(inline_inbuilt_nn_modules=True):
|
|
mod = Mod()
|
|
opt_mod = torch.compile(mod, backend=compiler)
|
|
opt_mod(torch.randn(2, 2))
|
|
|
|
# an example similar to Pippy usecase
|
|
mod = Mod()
|
|
GenerationTracker.tag(mod.submod1)
|
|
GenerationTracker.mark_class_dynamic(type(mod.submod1))
|
|
mod = Mod()
|
|
opt_mod = torch.compile(mod, backend=compiler)
|
|
opt_mod(torch.randn(2, 2))
|
|
|
|
def test_is_make_fx_tracing(self):
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def fn(x):
|
|
torch.nn.modules.activation._is_make_fx_tracing()
|
|
return torch.sin(x)
|
|
|
|
fn(torch.rand(4))
|
|
|
|
def test_negative_floor_div_solve(self):
|
|
class CompiledClass(nn.Module):
|
|
def __init__(self):
|
|
super().__init__()
|
|
self.nums = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])
|
|
self.t = 5
|
|
|
|
def forward(self):
|
|
self.num = self.nums[self.t // 12]
|
|
self.t += 1
|
|
return self.num
|
|
|
|
m = CompiledClass()
|
|
m = torch.compile(m, backend="eager")
|
|
|
|
# the first call works
|
|
m()
|
|
# the second call causes a failure
|
|
m()
|
|
|
|
# https://github.com/pytorch/pytorch/issues/121621
|
|
def test_tensor_random(self):
|
|
def random_op(tensor, params):
|
|
res = tensor.random_(**params)
|
|
return res
|
|
|
|
random_op = torch.compile(random_op)
|
|
params = {"from": -10, "to": 10}
|
|
tensor = torch.randn([2, 3])
|
|
res = random_op(tensor, params)
|
|
|
|
def test_data_attr_mutation_after_saved_for_bw(self):
|
|
def f(x):
|
|
out = x.sin()
|
|
x.data.mul_(2)
|
|
return out
|
|
|
|
x = torch.randn(4, requires_grad=True)
|
|
x_test = x.clone().detach().requires_grad_(True)
|
|
|
|
out = f(x)
|
|
out_test = torch.compile(f, backend="aot_eager")(x_test)
|
|
self.assertEqual(out, out_test)
|
|
|
|
out.sum().backward()
|
|
out_test.sum().backward()
|
|
self.assertEqual(x.grad, x_test.grad)
|
|
|
|
# https://github.com/pytorch/pytorch/issues/128072
|
|
def test_map_with_multiple_args(self):
|
|
def f(a, b):
|
|
return a[0] * b[0] + a[1] * b[1]
|
|
|
|
def gen_inps(len_x, len_y):
|
|
x = [torch.randn(5) for _ in range(len_x)]
|
|
y = [torch.randn(5) for _ in range(len_y)]
|
|
return x, y
|
|
|
|
def g(x, y):
|
|
return map(f, x, y)
|
|
|
|
opt_g = torch.compile(g, fullgraph=True, backend="eager")
|
|
|
|
inps = gen_inps(3, 3)
|
|
self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
|
|
self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
|
|
|
|
inps = gen_inps(3, 5)
|
|
self.assertEqual(type(g(*inps)), type(opt_g(*inps)))
|
|
self.assertEqual(tuple(g(*inps)), tuple(opt_g(*inps)))
|
|
|
|
def test_staticmethod_allow_in_graph(self):
|
|
class MyClass:
|
|
i = 3
|
|
|
|
@staticmethod
|
|
def foo_inner(x):
|
|
return torch.mul(x, MyClass.i)
|
|
|
|
# if dynamo inlines with fullgraph, will error
|
|
# verify that dynamo doesn't inline
|
|
@staticmethod
|
|
@torch._dynamo.allow_in_graph
|
|
def foo1(x):
|
|
torch._dynamo.graph_break()
|
|
return MyClass.foo_inner(x)
|
|
|
|
@torch.compile(backend="eager", fullgraph=True)
|
|
def f_bad(x):
|
|
return MyClass.foo1(x)
|
|
|
|
f_bad(torch.ones(2, 2))
|
|
|
|
def test_guard_with_tuple_mutation(self):
|
|
class Foo:
|
|
def __init__(self):
|
|
self.x = 10
|
|
|
|
foo = Foo()
|
|
d = {
|
|
"a": 2,
|
|
"b": (foo,),
|
|
}
|
|
|
|
def fn(x, d):
|
|
return x * d["a"] * d["b"][0].x
|
|
|
|
opt_fn = torch.compile(fn, backend="eager")
|
|
inp = torch.randn(3, 3)
|
|
self.assertEqual(fn(inp, d), opt_fn(inp, d))
|
|
d["b"][0].x = 12
|
|
self.assertEqual(fn(inp, d), opt_fn(inp, d))
|
|
|
|
def test_compile_complex_conj(self):
|
|
def f(x):
|
|
return torch.mul(x, 2j)
|
|
|
|
x_ref = torch.randn(4, 2, requires_grad=True)
|
|
x_test = x_ref.clone().detach().requires_grad_(True)
|
|
|
|
out_ref = f(torch.view_as_complex(x_ref))
|
|
out_test = torch.compile(f, backend="aot_eager")(torch.view_as_complex(x_test))
|
|
self.assertEqual(out_ref, out_test)
|
|
|
|
torch.view_as_real(out_ref).sum().backward()
|
|
torch.view_as_real(out_test).sum().backward()
|
|
self.assertEqual(x_ref.grad, x_test.grad)
|
|
|
|
def test_changing_stride(self):
|
|
cnt = torch._dynamo.testing.CompileCounter()
|
|
|
|
@torch.compile(backend=cnt)
|
|
def fn(x, y):
|
|
return x * y
|
|
|
|
for i in range(1, 4):
|
|
x = torch.randn(4, i)
|
|
|
|
# create a view for i > 1
|
|
if i == 1:
|
|
x1 = x
|
|
else:
|
|
x1 = x[:, 0:1]
|
|
|
|
y = torch.randn(4, 1)
|
|
print(x1.shape, y.shape)
|
|
fn(x1, y)
|
|
|
|
self.assertTrue(cnt.frame_count <= 2)
|
|
|
|
def test_optimized_module_training(self):
|
|
mod = torch.nn.Linear(3, 3)
|
|
mod.eval()
|
|
|
|
opt_mod = torch.compile(mod, backend="eager")
|
|
self.assertFalse(opt_mod.training)
|
|
|
|
opt_mod.train()
|
|
self.assertTrue(opt_mod.training)
|
|
self.assertTrue(mod.training)
|
|
|
|
mod.eval()
|
|
self.assertFalse(opt_mod.training)
|
|
|
|
|
|
instantiate_parametrized_tests(ReproTests)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
from torch._dynamo.test_case import run_tests
|
|
|
|
run_tests()
|