Add support for torch.Generator type in TorchScript (#110413)

- Add support for `torch.Generator` type in TorchScript
- Add `generator` args to all `torch.nn.init` functions that call `uniform_` or `normal_`
- Add support for `torch.Generator` in LTC's TorchScript backend (CC: @wconstab)

CC: @eellison @davidberard98 @GlebKazantaev @behzad-a
Pull Request resolved: https://github.com/pytorch/pytorch/pull/110413
Approved by: https://github.com/wconstab, https://github.com/albanD, https://github.com/glebk-cerebras, https://github.com/davidberard98
This commit is contained in:
Antonio Kim
2023-11-06 21:26:57 +00:00
committed by PyTorch MergeBot
parent 7b99b3efb1
commit 27e31ab6e8
39 changed files with 650 additions and 179 deletions

View File

@ -2,6 +2,14 @@
#include <ATen/core/Tensor.h>
#include <c10/util/Exception.h>
#include <ATen/CPUGeneratorImpl.h>
#ifdef USE_CUDA
#include <ATen/cuda/CUDAGeneratorImpl.h>
#endif
#ifdef USE_MPS
#include <ATen/mps/MPSGeneratorImpl.h>
#endif
namespace at {
void Generator::set_state(const at::Tensor& new_state) {
@ -13,4 +21,32 @@ at::Tensor Generator::get_state() const {
return at::Tensor::wrap_tensor_impl(this->impl_->get_state());
}
Generator make_generator_for_device(c10::Device device, c10::optional<int64_t> seed) {
if (device.is_cpu()) {
if (seed.has_value()) {
return at::detail::createCPUGenerator(seed.value());
} else {
return at::detail::createCPUGenerator();
}
#ifdef USE_CUDA
} else if (device.is_cuda()) {
auto generator = at::cuda::detail::createCUDAGenerator(device.index());
if (seed.has_value()) {
generator.set_current_seed(seed.value());
}
return generator;
#endif
#ifdef USE_MPS
} else if (device.is_mps()) {
if (seed.has_value()) {
return at::mps::detail::createMPSGenerator(seed.value());
} else {
return at::mps::detail::createMPSGenerator();
}
#endif
} else {
AT_ERROR("Unsupported device for at::make_generator found: ", device.str());
}
}
} // namespace at

View File

@ -145,6 +145,9 @@ Generator make_generator(Args&&... args) {
return Generator(c10::make_intrusive<Impl>(std::forward<Args>(args)...));
}
Generator make_generator_for_device(
c10::Device device, c10::optional<int64_t> seed = c10::nullopt);
/**
* Utility function to static cast input Generator* to
* the backend generator type (CPU/CUDAGeneratorImpl etc.)

View File

@ -644,6 +644,13 @@ std::ostream& IValue::repr(
c10::printQuotedString(out, device_stream.str());
return out << ")";
}
case IValue::Tag::Generator: {
auto generator = v.toGenerator();
out << "torch.Generator(device=";
c10::printQuotedString(out, generator.device().str());
out << ", seed=" << generator.current_seed() << ")";
return out;
}
case IValue::Tag::GenericDict:
return printMaybeAnnotatedDict(out, v, formatter);
case IValue::Tag::Enum: {
@ -956,6 +963,7 @@ IValue IValue::deepcopy(
case IValue::Tag::SymBool:
case IValue::Tag::Bool:
case IValue::Tag::Device:
case IValue::Tag::Generator:
case IValue::Tag::Uninitialized: {
copy = *this;
} break;

View File

@ -28,6 +28,7 @@ namespace c10 {
_(complex, ComplexType) \
_(str, StringType) \
_(Device, DeviceObjType) \
_(Generator, GeneratorType) \
_(Stream, StreamObjType) \
_(number, NumberType) \
_(None, NoneType) \

View File

@ -168,6 +168,9 @@ full_codegen:
- slice_scatter
- diagonal_scatter
- as_strided_scatter
# random ops
- normal_functional
- uniform
ir_gen:
- selu
supported:
@ -177,7 +180,6 @@ supported:
- empty.memory_format
- empty_strided
- fill_.Scalar
- normal_
- max_pool3d_with_indices
- max_pool3d_with_indices_backward
- _to_copy

View File

@ -446,7 +446,6 @@ lazy_tensor_ts_sources = [
"torch/csrc/lazy/ts_backend/dynamic_ir.cpp",
"torch/csrc/lazy/ts_backend/config.cpp",
"torch/csrc/lazy/ts_backend/ops/device_data.cpp",
"torch/csrc/lazy/ts_backend/ops/random_ops.cpp",
"torch/csrc/lazy/ts_backend/ops/generic.cpp",
"torch/csrc/lazy/ts_backend/tensor_aten_ops.cpp",
"torch/csrc/lazy/ts_backend/ts_autograd_functions.cpp",

View File

@ -88,4 +88,3 @@ we suggest using :meth:`torch.jit.trace`.
* :class:`torch.nn.AdaptiveLogSoftmaxWithLoss`
* :class:`torch.autograd.Function`
* :class:`torch.autograd.enable_grad`
* :class:`torch.Generator`

195
test/jit/test_generator.py Normal file
View File

@ -0,0 +1,195 @@
# Owner(s): ["oncall: jit"]
import io
import math
import unittest
import torch
from torch.nn import init
from torch.testing._internal.common_utils import skipIfLegacyJitExecutor
from torch.testing._internal.jit_utils import JitTestCase
if __name__ == "__main__":
raise RuntimeError(
"This test file is not meant to be run directly, use:\n\n"
"\tpython test/test_jit.py TESTNAME\n\n"
"instead."
)
class TestGenerator(JitTestCase):
# torch.jit.trace does not properly capture the generator manual seed
# and thus is non deterministic even if the generator is manually seeded
@skipIfLegacyJitExecutor("legacy JIT executor does not support Generator type")
@unittest.expectedFailure
def test_trace(self):
def f():
generator = torch.Generator()
generator.seed()
generator.manual_seed(2023)
generator.initial_seed()
tensor = torch.empty(2, 2)
tensor.uniform_(0, 1, generator=generator)
return tensor
traced_f = torch.jit.trace(f, ())
# Run this 3 times to ensure that the generator is being manually seeded
# each time the traced function is run
for i in range(3):
torch.manual_seed(1)
eager_tensor = f()
# Change the seed of the default generator to
# check that we're using the generator from the
# trace
torch.manual_seed(2)
traced_tensor = traced_f()
self.assertEqual(eager_tensor, traced_tensor)
def test_script(self):
def f():
generator = torch.Generator()
generator.seed()
generator.manual_seed(2023)
generator.initial_seed()
tensor = torch.empty(2, 2)
tensor.normal_(-1.0, 1.0, generator=generator)
return tensor
script_f = torch.jit.script(f, ())
# Run this 3 times to ensure that the generator is being manually seeded
# each time the traced function is run
for i in range(3):
torch.manual_seed(1)
eager_tensor = f()
# Change the seed of the default generator to
# check that we're using the generator from the
# trace
torch.manual_seed(2)
script_tensor = script_f()
self.assertEqual(eager_tensor, script_tensor)
def test_default_generator(self):
def f():
# check that calling manual seed for the default generator works
torch.manual_seed(2023)
tensor = torch.empty(2, 2)
tensor.normal_(-1.0, 1.0)
return tensor
torch.manual_seed(1)
eager_tensor = f()
torch.manual_seed(2)
script_f = torch.jit.script(f, ())
script_tensor = script_f()
self.assertEqual(eager_tensor, script_tensor)
def test_generator_arg(self):
def f(generator: torch.Generator):
tensor = torch.empty(2, 2)
tensor.normal_(-1.0, 1.0, generator=generator)
return tensor
generator = torch.Generator()
generator.manual_seed(2023)
script_f = torch.jit.script(f, (generator,))
for i in range(3):
generator = torch.Generator()
generator.manual_seed(2023 + i)
torch.manual_seed(1 + i)
eager_tensor = f(generator)
generator = torch.Generator()
generator.manual_seed(2023 + i)
torch.manual_seed(1 + i)
script_tensor = script_f(generator)
self.assertEqual(eager_tensor, script_tensor)
def test_save_load(self):
class Foo(torch.nn.Module):
def __init__(self):
super().__init__()
self.foo = torch.nn.Linear(2, 2, bias=False)
self.bar = torch.nn.Linear(2, 2, bias=False)
self.reset_parameters()
def reset_linear(self, module, generator):
init.kaiming_uniform_(
module.weight, a=math.sqrt(5), generator=generator
)
def reset_parameters(self):
generator = torch.Generator()
generator.manual_seed(1)
self.reset_linear(self.foo, generator)
generator = torch.Generator()
generator.manual_seed(2)
self.reset_linear(self.bar, generator)
def forward(self, x):
x = self.foo(x)
x = self.bar(x)
generator = torch.Generator()
generator.manual_seed(3)
r = torch.empty_like(x)
r.normal_(0.0, 1.0, generator=generator)
return x, r
eager_foo = Foo()
script_module = torch.jit.script(Foo())
saved_module = io.BytesIO()
torch.jit.save(script_module, saved_module)
saved_module.seek(0)
loaded_module = torch.jit.load(saved_module)
self.assertEqual(eager_foo.foo.weight, loaded_module.foo.weight)
self.assertEqual(eager_foo.bar.weight, loaded_module.bar.weight)
try:
# Run this 3 times so make sure that the generator seed is being set
# every time forward is called
for i in range(3):
x = torch.ones(2, 2)
out1, r1 = eager_foo(x)
out2, r2 = loaded_module(x)
try:
self.assertEqual(out1, out2)
except: # noqa: B001, E722
print(f"Iteration {i}:\n{out1=}\n{out2=}")
raise
try:
self.assertEqual(r1, r2)
except: # noqa: B001, E722
print(f"Iteration {i}:\n{r1=}\n{r2=}")
raise
except: # noqa: B001, E722
print(loaded_module.forward.code)
raise

103
test/lazy/test_generator.py Normal file
View File

@ -0,0 +1,103 @@
# Owner(s): ["oncall: jit"]
import torch
import torch._lazy.metrics as metrics
import torch._lazy.ts_backend
from torch.testing._internal.common_utils import run_tests, skipIfTorchDynamo, TestCase
torch._lazy.ts_backend.init()
class LazyGeneratorTest(TestCase):
def test_generator(self):
"""
Test that generators are being inserted into the TorchScript
graph by setting different seeds before each call to
generate_tensor but the resulting tensor is the same
"""
def generate_tensor():
g1 = torch.Generator()
g1.manual_seed(2023)
t1 = torch.tensor(1.0)
t1.uniform_(generator=g1)
g2 = torch.Generator()
g2.manual_seed(2024)
t2 = torch.tensor(1.0)
t2.normal_(generator=g2)
return t1, t2
torch.manual_seed(1)
with torch.device("cpu"):
cpu_t1, cpu_t2 = generate_tensor()
torch.manual_seed(2)
with torch.device("lazy"):
lazy_t1, lazy_t2 = generate_tensor()
torch._lazy.mark_step()
assert torch.allclose(
cpu_t1, lazy_t1.to("cpu")
), f"Expected {cpu_t1}, got {lazy_t1.to('cpu')}"
assert torch.allclose(
cpu_t2, lazy_t2.to("cpu")
), f"Expected {cpu_t2}, got {lazy_t2.to('cpu')}"
@skipIfTorchDynamo("Torch Dynamo does not support torch.Generator type")
def test_generator_causes_multiple_compiles(self):
"""
Test that inserting generators with different seed caused recompile
"""
def generate_tensor(seed):
t = torch.tensor(1.0)
g = torch.Generator()
g.manual_seed(seed)
t.uniform_(-1, 1, generator=g)
return t
metrics.reset()
with torch.device("lazy"):
t = generate_tensor(1)
torch._lazy.mark_step()
uncached_compile = metrics.counter_value("UncachedCompile")
assert (
uncached_compile == 1
), f"Expected 1 uncached compiles, got {uncached_compile}"
t = generate_tensor(2)
torch._lazy.mark_step()
uncached_compile = metrics.counter_value("UncachedCompile")
assert (
uncached_compile == 2
), f"Expected 2 uncached compiles, got {uncached_compile}"
t = generate_tensor(1)
torch._lazy.mark_step()
uncached_compile = metrics.counter_value("UncachedCompile")
assert (
uncached_compile == 2
), f"Expected 2 uncached compiles, got {uncached_compile}"
cached_compile = metrics.counter_value("CachedCompile")
assert (
cached_compile == 1
), f"Expected 1 cached compile, got {cached_compile}"
metrics.reset()
latest_graph = torch._C._lazy_ts_backend._get_latest_computation_graph()
assert 'torch.Generator(device="cpu", seed=1)' in latest_graph
assert "aten::uniform" in latest_graph
if __name__ == "__main__":
run_tests()

View File

@ -231,6 +231,9 @@ class TestLazyOpInfo(TestCase):
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
for sample in samples:
# Need to run mark step so that all random ops are computed in the right order
torch._lazy.mark_step()
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
copy_args = clone_to_device(args, test_device)
@ -238,6 +241,7 @@ class TestLazyOpInfo(TestCase):
r_exp = op(*copy_args, **kwargs)
r_actual = op(*args, **kwargs)
torch._lazy.mark_step()
assert_allclose_rec((r_actual, r_exp))
@ops([op for op in op_db if op.name in LAZY_OPS_LIST and op.name not in SKIP_RUNTIME_ERROR_LIST | SKIP_INCORRECT_RESULTS_LIST], allowed_dtypes=(torch.float,)) # noqa: B950
@ -263,6 +267,9 @@ class TestLazyOpInfo(TestCase):
samples = op.sample_inputs("lazy", dtype, requires_grad=False)
for sample in samples:
# Need to run mark step so that all random ops are computed in the right order
torch._lazy.mark_step()
args = [sample.input] + list(sample.args)
kwargs = sample.kwargs
copy_args = clone_to_device(args, test_device)

View File

@ -75,6 +75,7 @@ from jit.test_dce import TestDCE # noqa: F401
from jit.test_sparse import TestSparse # noqa: F401
from jit.test_tensor_methods import TestTensorMethods # noqa: F401
from jit.test_dataclasses import TestDataclasses # noqa: F401
from jit.test_generator import TestGenerator # noqa: F401
# Torch
from torch import Tensor
@ -14169,6 +14170,41 @@ dedent """
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
def test_nn_init_generator(self):
init_fns = (
'uniform_', 'normal_', 'xavier_normal_', 'xavier_uniform_',
)
for name in init_fns:
# Build test code
code = dedent('''
def test(tensor, generator):
# type: (Tensor, Generator)
return torch.nn.init.{name}(tensor, generator=generator)
''').format(name=name)
cu = torch.jit.CompilationUnit(code)
# Compare functions
init_fn = getattr(torch.nn.init, name)
torch.manual_seed(1)
g = torch.Generator()
g.manual_seed(2023)
script_out = cu.test(torch.ones(2, 2), g)
# Change the seed of the default generator to make
# sure that we're using the provided generator
torch.manual_seed(2)
g = torch.Generator()
g.manual_seed(2023)
eager_out = init_fn(torch.ones(2, 2), generator=g)
self.assertEqual(script_out, eager_out)
FileCheck().check_not("prim::PythonOp").run(cu.test.graph)
def test_early_return_rewrite(self):
def test_foo(x: bool):
if x:

View File

@ -94,6 +94,7 @@ class TestPublicBindings(TestCase):
"Future",
"FutureType",
"Generator",
"GeneratorType",
"get_autocast_cpu_dtype",
"get_autocast_ipu_dtype",
"get_default_dtype",

View File

@ -1909,6 +1909,10 @@ class DeviceObjType(JitType):
@staticmethod
def get() -> DeviceObjType: ...
class _GeneratorType(JitType):
@staticmethod
def get() -> _GeneratorType: ...
class StreamObjType(JitType):
@staticmethod
def get() -> StreamObjType: ...

View File

@ -2331,6 +2331,7 @@ def uniform(
x: Tensor,
low: Union[bool, int, float] = 0.0,
high: Union[bool, int, float] = 1.0,
generator: Optional[torch.Generator] = None,
):
return prims._uniform_helper(
x.shape,
@ -2338,13 +2339,13 @@ def uniform(
high=sym_float(high),
dtype=x.dtype,
device=x.device,
generator=generator,
)
@register_decomposition(aten.uniform_)
def uniform_(self, low=0, high=1, generator=None):
assert generator is None
return self.copy_(uniform(self, low, high))
return self.copy_(uniform(self, low, high, generator))
# aten/src/ATen/native/UpSample.cpp compute_output_size

View File

@ -2765,8 +2765,6 @@ svd = _make_prim(
#
# TODO: add generator support
# NOTE: there is currently no way of acquiring the "default" torch generator
def _normal_meta(
shape: ShapeType,
*,
@ -2775,6 +2773,7 @@ def _normal_meta(
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
generator: Optional[torch.Generator] = None,
) -> TensorLikeType:
torch._check(
std >= 0.0,
@ -2798,11 +2797,12 @@ def _normal_aten(
dtype: torch.dtype,
device: torch.device,
requires_grad: bool,
generator: Optional[torch.Generator] = None,
) -> Tensor:
a = torch.empty(shape, dtype=dtype, device=device, requires_grad=requires_grad)
with torch.no_grad():
# NOTE: normal_ is incorrectly annotated to expect mean to be a float
a.normal_(mean, std) # type: ignore[arg-type]
a.normal_(mean, std, generator=generator) # type: ignore[arg-type]
return a
@ -2815,7 +2815,7 @@ _normal_doc = """
normal = _make_prim(
schema=(
"normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad) -> Tensor"
"normal(SymInt[] shape, *, Scalar mean, Scalar std, ScalarType dtype, Device device, bool requires_grad, Generator? generator=None) -> Tensor" # noqa: B950
),
return_type=RETURN_TYPE.NEW,
meta=_normal_meta,
@ -2831,6 +2831,7 @@ def _uniform_meta(
high: float,
dtype: torch.dtype,
device: torch.device,
generator: Optional[torch.Generator] = None,
) -> TensorLikeType:
strides = utils.make_contiguous_strides_for(shape)
return TensorMeta(shape=shape, strides=strides, dtype=dtype, device=device)
@ -2843,9 +2844,10 @@ def _uniform_aten(
high: float,
dtype: torch.dtype,
device: torch.device,
generator: Optional[torch.Generator] = None,
) -> Tensor:
a = torch.empty(shape, dtype=dtype, device=device)
a.uniform_(low, high)
a.uniform_(low, high, generator=generator)
return a
@ -2856,7 +2858,7 @@ _uniform_doc = """
# TODO: we should more seriously review randomness modeling and prims
_uniform_helper = _make_prim(
schema=(
"uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device) -> Tensor"
"uniform(SymInt[] shape, *, Scalar low, Scalar high, ScalarType dtype, Device device, Generator? generator=None) -> Tensor"
),
return_type=RETURN_TYPE.NEW,
meta=_uniform_meta,

View File

@ -5931,7 +5931,6 @@ def normal(
device=None,
pin_memory=None,
):
assert generator is None
assert layout is None or layout == torch.strided
if not isinstance(std, TensorLike):
@ -5968,6 +5967,7 @@ def normal(
dtype=dtype,
device=device,
requires_grad=False,
generator=generator,
)
return std * normal_samples + mean

View File

@ -241,6 +241,17 @@ std::shared_ptr<SugaredValue> SimpleValue::attr(
return SpecialFormValue::create(aten::index);
}
if (auto generator_type = value_->type()->cast<GeneratorType>()) {
// Handle access to Generator's `manual_seed`, `initial_seed` and `seed`
// attributes.
if (field == "manual_seed" || field == "initial_seed" || field == "seed") {
if (auto builtin = BuiltinFunction::tryCreate(
Symbol::aten(field), NamedValue(loc, "self", value_))) {
return builtin;
}
}
}
ErrorReport report(loc);
report << "'" << value_->type()->repr_str()
<< "' object has no attribute or method '" << field << "'.";

View File

@ -679,12 +679,14 @@ void addInputs(
Node* n,
const char* name,
const c10::optional<at::Generator>& value) {
if (value.has_value() && value->defined()) {
detail::badArgType(*value);
}
Graph* g = n->owningGraph();
Value* undef_gen = g->insertNode(g->createNone())->output();
n->addInput(undef_gen);
if (value.has_value() && value->defined()) {
detail::genericAddInput(n, *value);
} else {
Value* undef_gen = g->insertNode(g->createNone())->output();
n->addInput(undef_gen);
}
}
void addInputs(Node* n, const char* name, at::Device value) {
detail::genericAddInput(n, value);

View File

@ -5,6 +5,7 @@
#include <torch/csrc/jit/ir/ir.h>
#include <torch/csrc/jit/runtime/custom_operator.h>
#include <torch/csrc/jit/runtime/operator.h>
#include <torch/csrc/jit/runtime/register_ops_utils.h>
namespace torch::jit {
@ -108,6 +109,10 @@ c10::optional<Value*> tryInsertConstant(
ss << val.toDevice();
n->s_(attr::value, ss.str());
n->output()->setType(DeviceObjType::get());
} else if (val.isGenerator()) {
auto generator = val.toGenerator();
n->ival_(attr::value, generator);
n->output()->setType(GeneratorType::get());
} else if (val.isStream()) {
// packing into int64_t removed
n->ival_(attr::value, val);
@ -194,6 +199,9 @@ c10::optional<IValue> toIValue(const Value* v) {
} else if (type == DeviceObjType::get()) {
auto d = c10::Device(node->s(attr::value));
return d;
} else if (type == GeneratorType::get()) {
auto generator = node->ival(attr::value).toGenerator();
return generator;
} else if (type == StreamObjType::get()) {
// int64_t packing removed
auto s = node->ival(attr::value).toStream();

View File

@ -142,6 +142,9 @@ bool ivaluesEqual(const IValue& a1, const IValue& a2) {
if (a1.isObject()) {
return &a1.toObjectRef() == &a2.toObjectRef();
}
if (a1.isGenerator()) {
return a1.toGenerator() == a2.toGenerator();
}
TORCH_INTERNAL_ASSERT(false);
}

View File

@ -397,6 +397,8 @@ inline InferredType tryToInferType(py::handle input) {
return InferredType(IntType::get());
} else if (THPDevice_Check(input.ptr())) {
return InferredType(DeviceObjType::get());
} else if (THPGenerator_Check(input.ptr())) {
return InferredType(GeneratorType::get());
} else if (THPStream_Check(input.ptr())) {
return InferredType(StreamObjType::get());
} else if (THPDtype_Check(input.ptr())) {

View File

@ -1012,6 +1012,10 @@ void initPythonIRBindings(PyObject* module_) {
.def_static("get", &StringType::get);
py::class_<DeviceObjType, Type, DeviceObjTypePtr>(m, "DeviceObjType")
.def_static("get", &DeviceObjType::get);
// TODO(antoniojkim): Add GeneratorType to the public API once its been added
// to the public documentation
py::class_<GeneratorType, Type, GeneratorTypePtr>(m, "_GeneratorType")
.def_static("get", &GeneratorType::get);
py::class_<StreamObjType, Type, StreamObjTypePtr>(m, "StreamObjType")
.def_static("get", &StreamObjType::get);
py::class_<PyObjectType, Type, PyObjectTypePtr>(m, "PyObjectType")

View File

@ -1,4 +1,5 @@
#include <ATen/autocast_mode.h>
#include <ATen/core/Generator.h>
#include <c10/util/Optional.h>
#include <c10/util/irange.h>
#include <torch/csrc/jit/mobile/promoted_prim_ops.h>
@ -2492,6 +2493,44 @@ static const std::vector<OperatorGeneratorArgs> opGenArgs1{
TORCH_SELECTIVE_SCHEMA("aten::manual_seed(int seed) -> ()"),
[](Stack& stack) { at::manual_seed(pop(stack).toInt()); },
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::Generator(*, Device? device=None, int? seed=None) -> Generator"),
[](Stack& stack) {
auto seed = pop(stack).toOptional<int64_t>();
auto device = pop(stack).toOptional<c10::Device>();
push(
stack,
at::make_generator_for_device(
device.value_or(c10::Device("cpu")), seed));
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::initial_seed(Generator self) -> int"),
[](Stack& stack) {
auto generator = pop(stack);
auto current_seed = generator.toGenerator().current_seed();
push(stack, (int64_t)current_seed);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA(
"aten::manual_seed.generator(Generator(a!) self, int seed) -> Generator(a!)"),
[](Stack& stack) {
auto seed = pop(stack).toInt();
auto generator = pop(stack);
generator.toGenerator().set_current_seed(seed);
push(stack, generator);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::seed(Generator(a!) self) -> int"),
[](Stack& stack) {
auto generator = pop(stack);
auto current_seed = generator.toGenerator().seed();
push(stack, (int64_t)current_seed);
},
aliasAnalysisFromSchema()),
OperatorGeneratorArgs(
TORCH_SELECTIVE_SCHEMA("aten::cuda(Tensor(a) self) -> Tensor(a|b)"),
[](Stack& stack) {

View File

@ -392,7 +392,7 @@ RegisterOperators reg({
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA(
"aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b) -> Tensor(a!)"),
"aten::_no_grad_uniform_(Tensor(a!) tensor, float a, float b, Generator? generator=None) -> Tensor(a!)"),
[](Stack& stack) {
// TODO: remove when script supports setting grad mode
torch::NoGradGuard no_grad;
@ -402,13 +402,16 @@ RegisterOperators reg({
double a;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
double b;
c10::optional<at::Generator> generator =
pop(stack).toOptional<at::Generator>();
pop(stack, tensor, a, b);
push(stack, tensor.uniform_(a, b));
push(stack, tensor.uniform_(a, b, generator));
},
aliasAnalysisFromSchema()),
OperatorGenerator(
TORCH_SELECTIVE_SCHEMA(
"aten::_no_grad_normal_(Tensor(a!) tensor, float mean, float std) -> Tensor(a!)"),
"aten::_no_grad_normal_(Tensor(a!) tensor, float mean, float std, Generator? generator=None) -> Tensor(a!)"),
[](Stack& stack) {
// TODO: remove when script supports setting grad mode
torch::NoGradGuard no_grad;
@ -418,8 +421,11 @@ RegisterOperators reg({
double mean;
// NOLINTNEXTLINE(cppcoreguidelines-init-variables)
double std;
c10::optional<at::Generator> generator =
pop(stack).toOptional<at::Generator>();
pop(stack, tensor, mean, std);
push(stack, tensor.normal_(mean, std));
push(stack, tensor.normal_(mean, std, generator));
},
aliasAnalysisFromSchema()),
OperatorGenerator(

View File

@ -148,6 +148,11 @@ static inline hash_t Hash(const std::string& value) {
static inline hash_t Hash(const c10::string_view& value) {
return DataHash(value.data(), value.size());
}
static inline hash_t Hash(const at::Generator& value) {
return TensorHash(value.get_state());
}
// Taken from glibc's implementation of hashing optionals,
// we want to include a contribution to the hash to distinguish
// cases where one or another option was null, but we hope it doesn't

View File

@ -1369,6 +1369,22 @@ std::vector<Shape> compute_shape_as_strided_scatter_symint(
return {Shape(out_meta.scalar_type(), out_meta.sizes().vec())};
}
std::vector<Shape> compute_shape_normal_functional(
const at::Tensor& self,
double mean,
double std,
c10::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
std::vector<Shape> compute_shape_uniform(
const at::Tensor& self,
double from,
double to,
c10::optional<at::Generator> generator) {
return {Shape(self.scalar_type(), self.sizes().vec())};
}
// Restore unused-parameters warnings
#pragma GCC diagnostic pop

View File

@ -70,6 +70,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_new_empty_strided(const
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_backward(const at::Tensor & grad_output, const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index, const at::Tensor & total_weight);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nll_loss2d_forward(const at::Tensor & self, const at::Tensor & target, const c10::optional<at::Tensor> & weight, int64_t reduction, int64_t ignore_index);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_nonzero(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_normal_functional(const at::Tensor & self, double mean, double std, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, int64_t to, c10::optional<at::Generator> generator);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_random(const at::Tensor & self, int64_t from, c10::optional<int64_t> to, c10::optional<at::Generator> generator);
@ -91,6 +92,7 @@ TORCH_API std::vector<torch::lazy::Shape> compute_shape_narrow_copy_symint(const
TORCH_API std::vector<torch::lazy::Shape> compute_shape_hardswish(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_hardswish_backward(const at::Tensor & grad_output, const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_selu(const at::Tensor & self);
TORCH_API std::vector<torch::lazy::Shape> compute_shape_uniform(const at::Tensor & self, double from, double to, c10::optional<at::Generator> generator);
// Non-Native ops
TORCH_API std::vector<Shape> compute_shape_scalar(const at::Scalar& value, const at::ScalarType& type);

View File

@ -307,6 +307,15 @@ void initLazyBindings(PyObject* module) {
#endif // !(defined(FBCODE_CAFFE2) || defined(OVRSOURCE))
return result;
});
lazy_ts_backend.def("_get_latest_computation_graph", []() {
auto computation = LazyGraphExecutor::Get()
->GetComputationCache()
->GetLatest()
->computation;
auto ts_computation = dynamic_cast<TSComputation*>(computation.get());
TORCH_CHECK(ts_computation, "Found non-TSComputation in cache");
return ts_computation->graph()->toString();
});
// GetPythonFramesFunction() has not ever worked with torchdeploy/multipy
// possibly becuase GetPythonFrames resolves to external cpython rather

View File

@ -1,47 +0,0 @@
#include <torch/csrc/lazy/core/util.h>
#include <torch/csrc/lazy/ts_backend/ops/random_ops.h>
namespace torch {
namespace lazy {
Normal::Normal(
const torch::lazy::Value& self,
const double& mean,
const double& std,
std::vector<torch::lazy::Shape>&& shapes)
: torch::lazy::TsNode(
ClassOpKind(),
{self},
std::move(shapes),
/* num_outputs */ 1,
torch::lazy::MHash(mean, std)),
mean_(mean),
std_(std) {}
std::string Normal::ToString() const {
std::stringstream ss;
ss << TsNode::ToString();
ss << ", mean=" << mean_;
ss << ", std=" << std_;
return ss.str();
}
torch::lazy::TSOpVector Normal::Lower(
std::shared_ptr<torch::jit::GraphFunction> function,
torch::lazy::TSLoweringContext* loctx) const {
std::vector<torch::jit::NamedValue> arguments;
std::vector<torch::jit::NamedValue> kwarguments;
arguments.reserve(3);
size_t i = 0;
arguments.emplace_back(loctx->GetOutputOp(operand(i++)));
arguments.emplace_back("mean", mean_);
arguments.emplace_back("std", std_);
torch::lazy::TSOpVector normal__out =
torch::lazy::LowerTSBuiltin(function, op().op, arguments, kwarguments);
TORCH_CHECK_EQ(normal__out.size(), 1);
return normal__out;
}
} // namespace lazy
} // namespace torch

View File

@ -1,30 +0,0 @@
#pragma once
#include <torch/csrc/lazy/ts_backend/ts_node.h>
namespace torch {
namespace lazy {
class Normal : public torch::lazy::TsNode {
public:
static OpKind ClassOpKind() {
return OpKind::Get("aten::normal_");
}
Normal(
const torch::lazy::Value& self,
const double& mean,
const double& std,
std::vector<torch::lazy::Shape>&& shapes);
std::string ToString() const override;
torch::lazy::TSOpVector Lower(
std::shared_ptr<torch::jit::GraphFunction> function,
torch::lazy::TSLoweringContext* loctx) const override;
double mean_;
double std_;
};
} // namespace lazy
} // namespace torch

View File

@ -13,7 +13,6 @@
#include <torch/csrc/lazy/core/tensor.h>
#include <torch/csrc/lazy/core/util.h>
#include <torch/csrc/lazy/generated/LazyIr.h>
#include <torch/csrc/lazy/ts_backend/ops/random_ops.h>
#include <algorithm>
#include <functional>

View File

@ -14,7 +14,6 @@
#include <torch/csrc/lazy/core/tensor_util.h>
#include <torch/csrc/lazy/generated/LazyNativeFunctions.h>
#include <torch/csrc/lazy/ts_backend/config.h>
#include <torch/csrc/lazy/ts_backend/ops/random_ops.h>
#include <torch/csrc/lazy/ts_backend/ops/to_copy.h>
#include <torch/csrc/lazy/ts_backend/tensor_aten_ops.h>
#include <torch/csrc/lazy/ts_backend/ts_autograd_functions.h>
@ -372,36 +371,6 @@ at::Tensor LazyNativeFunctions::max_pool3d_with_indices_backward(
indices);
}
at::Tensor& LazyNativeFunctions::normal_(
at::Tensor& self,
double mean,
double std,
c10::optional<at::Generator> generator) {
// Unconditionally fall back.
// implementing normal_ via lazy tensor caused differences in results compared
// to eager.
return at::native::call_fallback_fn<&ltc_eager_fallback, ATEN_OP(normal_)>::
call(self, mean, std, generator);
// if (force_eager_fallback(c10::Symbol::fromQualString("aten::normal_"))) {
// return at::native::call_fallback_fn<&ltc_eager_fallback,
// ATEN_OP(normal_)>::call(self, mean, std, generator);
// }
// if (generator.has_value()) {
// return at::native::call_fallback_fn<&ltc_eager_fallback,
// ATEN_OP(normal_)>::call(self, mean, std, generator);
// }
// TORCH_LAZY_FN_COUNTER("lazy::");
// auto device = bridge::GetBackendDevice(self);
// LazyTensor lazy_self = GetLtcTensorOrCreateForWrappedNumber(self, *device);
// std::vector<torch::lazy::Shape> shapes =
// {torch::lazy::Shape(self.scalar_type(), self.sizes().vec())}; auto node =
// torch::lazy::MakeNode<Normal>(lazy_self.GetIrValue(), mean, std,
// std::move(shapes)); lazy_self.SetInPlaceIrValue(node); return self;
};
at::Tensor LazyNativeFunctions::_unsafe_view(
const at::Tensor& self,
at::IntArrayRef size) {

View File

@ -13,6 +13,7 @@ from typing import Type
import torch
from torch._C import (
_GeneratorType,
AnyType,
AwaitType,
BoolType,
@ -479,6 +480,8 @@ def try_ann_to_type(ann, loc, rcb=None):
return InterfaceType(ann.__torch_script_interface__)
if ann is torch.device:
return DeviceObjType.get()
if ann is torch.Generator:
return _GeneratorType.get()
if ann is torch.Stream:
return StreamObjType.get()
if ann is torch.dtype:

View File

@ -10,14 +10,14 @@ from typing import Optional as _Optional
# functions that use `with torch.no_grad()`. The JIT doesn't support context
# managers, so these need to be implemented as builtins. Using these wrappers
# lets us keep those builtins small and re-usable.
def _no_grad_uniform_(tensor, a, b):
def _no_grad_uniform_(tensor, a, b, generator=None):
with torch.no_grad():
return tensor.uniform_(a, b)
return tensor.uniform_(a, b, generator=generator)
def _no_grad_normal_(tensor, mean, std):
def _no_grad_normal_(tensor, mean, std, generator=None):
with torch.no_grad():
return tensor.normal_(mean, std)
return tensor.normal_(mean, std, generator=generator)
def _no_grad_trunc_normal_(tensor, mean, std, a, b, generator=None):
@ -121,7 +121,12 @@ def calculate_gain(nonlinearity, param=None):
raise ValueError(f"Unsupported nonlinearity {nonlinearity}")
def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
def uniform_(
tensor: Tensor,
a: float = 0.0,
b: float = 1.0,
generator: _Optional[torch.Generator] = None,
) -> Tensor:
r"""Fill the input Tensor with values drawn from the uniform distribution.
:math:`\mathcal{U}(a, b)`.
@ -130,17 +135,25 @@ def uniform_(tensor: Tensor, a: float = 0., b: float = 1.) -> Tensor:
tensor: an n-dimensional `torch.Tensor`
a: the lower bound of the uniform distribution
b: the upper bound of the uniform distribution
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.uniform_(w)
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(uniform_, (tensor,), tensor=tensor, a=a, b=b)
return _no_grad_uniform_(tensor, a, b)
return torch.overrides.handle_torch_function(
uniform_, (tensor,), tensor=tensor, a=a, b=b, generator=generator
)
return _no_grad_uniform_(tensor, a, b, generator)
def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
def normal_(
tensor: Tensor,
mean: float = 0.0,
std: float = 1.0,
generator: _Optional[torch.Generator] = None,
) -> Tensor:
r"""Fill the input Tensor with values drawn from the normal distribution.
:math:`\mathcal{N}(\text{mean}, \text{std}^2)`.
@ -149,14 +162,17 @@ def normal_(tensor: Tensor, mean: float = 0., std: float = 1.) -> Tensor:
tensor: an n-dimensional `torch.Tensor`
mean: the mean of the normal distribution
std: the standard deviation of the normal distribution
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
>>> nn.init.normal_(w)
"""
if torch.overrides.has_torch_function_variadic(tensor):
return torch.overrides.handle_torch_function(normal_, (tensor,), tensor=tensor, mean=mean, std=std)
return _no_grad_normal_(tensor, mean, std)
return torch.overrides.handle_torch_function(
normal_, (tensor,), tensor=tensor, mean=mean, std=std, generator=generator
)
return _no_grad_normal_(tensor, mean, std, generator)
def trunc_normal_(
tensor: Tensor,
@ -180,6 +196,7 @@ def trunc_normal_(
std: the standard deviation of the normal distribution
a: the minimum cutoff value
b: the maximum cutoff value
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
@ -314,7 +331,9 @@ def _calculate_fan_in_and_fan_out(tensor):
return fan_in, fan_out
def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
def xavier_uniform_(
tensor: Tensor, gain: float = 1.0, generator: _Optional[torch.Generator] = None
) -> Tensor:
r"""Fill the input `Tensor` with values using a Xavier uniform distribution.
The method is described in `Understanding the difficulty of training
@ -330,6 +349,7 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
Args:
tensor: an n-dimensional `torch.Tensor`
gain: an optional scaling factor
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
@ -339,10 +359,14 @@ def xavier_uniform_(tensor: Tensor, gain: float = 1.) -> Tensor:
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
a = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
return _no_grad_uniform_(tensor, -a, a)
return _no_grad_uniform_(tensor, -a, a, generator)
def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor:
def xavier_normal_(
tensor: Tensor,
gain: float = 1.0,
generator: _Optional[torch.Generator] = None,
) -> Tensor:
r"""Fill the input `Tensor` with values using a Xavier normal distribution.
The method is described in `Understanding the difficulty of training deep feedforward
@ -357,6 +381,7 @@ def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor:
Args:
tensor: an n-dimensional `torch.Tensor`
gain: an optional scaling factor
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
@ -365,7 +390,7 @@ def xavier_normal_(tensor: Tensor, gain: float = 1.) -> Tensor:
fan_in, fan_out = _calculate_fan_in_and_fan_out(tensor)
std = gain * math.sqrt(2.0 / float(fan_in + fan_out))
return _no_grad_normal_(tensor, 0., std)
return _no_grad_normal_(tensor, 0., std, generator)
def _calculate_correct_fan(tensor, mode):
@ -379,7 +404,11 @@ def _calculate_correct_fan(tensor, mode):
def kaiming_uniform_(
tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
tensor: Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: _Optional[torch.Generator] = None,
):
r"""Fill the input `Tensor` with values using a Kaiming uniform distribution.
@ -403,6 +432,7 @@ def kaiming_uniform_(
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
@ -415,7 +445,8 @@ def kaiming_uniform_(
tensor=tensor,
a=a,
mode=mode,
nonlinearity=nonlinearity)
nonlinearity=nonlinearity,
generator=generator)
if 0 in tensor.shape:
warnings.warn("Initializing zero-element tensors is a no-op")
@ -425,11 +456,15 @@ def kaiming_uniform_(
std = gain / math.sqrt(fan)
bound = math.sqrt(3.0) * std # Calculate uniform bounds from standard deviation
with torch.no_grad():
return tensor.uniform_(-bound, bound)
return tensor.uniform_(-bound, bound, generator=generator)
def kaiming_normal_(
tensor: Tensor, a: float = 0, mode: str = 'fan_in', nonlinearity: str = 'leaky_relu'
tensor: Tensor,
a: float = 0,
mode: str = "fan_in",
nonlinearity: str = "leaky_relu",
generator: _Optional[torch.Generator] = None,
):
r"""Fill the input `Tensor` with values using a Kaiming normal distribution.
@ -453,6 +488,7 @@ def kaiming_normal_(
backwards pass.
nonlinearity: the non-linear function (`nn.functional` name),
recommended to use only with ``'relu'`` or ``'leaky_relu'`` (default).
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
@ -465,10 +501,14 @@ def kaiming_normal_(
gain = calculate_gain(nonlinearity, a)
std = gain / math.sqrt(fan)
with torch.no_grad():
return tensor.normal_(0, std)
return tensor.normal_(0, std, generator=generator)
def orthogonal_(tensor, gain=1):
def orthogonal_(
tensor,
gain=1,
generator: _Optional[torch.Generator] = None,
):
r"""Fill the input `Tensor` with a (semi) orthogonal matrix.
Described in `Exact solutions to the nonlinear dynamics of learning in deep
@ -479,6 +519,7 @@ def orthogonal_(tensor, gain=1):
Args:
tensor: an n-dimensional `torch.Tensor`, where :math:`n \geq 2`
gain: optional scaling factor
generator: the torch Generator to sample from (default: None)
Examples:
>>> # xdoctest: +REQUIRES(env:TORCH_DOCTEST_LAPACK)
@ -493,7 +534,7 @@ def orthogonal_(tensor, gain=1):
return tensor
rows = tensor.size(0)
cols = tensor.numel() // rows
flattened = tensor.new(rows, cols).normal_(0, 1)
flattened = tensor.new(rows, cols).normal_(0, 1, generator=generator)
if rows < cols:
flattened.t_()
@ -514,7 +555,12 @@ def orthogonal_(tensor, gain=1):
return tensor
def sparse_(tensor, sparsity, std=0.01):
def sparse_(
tensor,
sparsity,
std=0.01,
generator: _Optional[torch.Generator] = None,
):
r"""Fill the 2D input `Tensor` as a sparse matrix.
The non-zero elements will be drawn from the normal distribution
@ -526,6 +572,7 @@ def sparse_(tensor, sparsity, std=0.01):
sparsity: The fraction of elements in each column to be set to zero
std: the standard deviation of the normal distribution used to generate
the non-zero values
generator: the torch Generator to sample from (default: None)
Examples:
>>> w = torch.empty(3, 5)
@ -538,7 +585,7 @@ def sparse_(tensor, sparsity, std=0.01):
num_zeros = int(math.ceil(sparsity * rows))
with torch.no_grad():
tensor.normal_(0, std)
tensor.normal_(0, std, generator=generator)
for col_idx in range(cols):
row_indices = torch.randperm(rows)
zero_indices = row_indices[:num_zeros]

View File

@ -933,10 +933,10 @@ def get_testing_overrides() -> Dict[Callable, Callable]:
distance_function=None, margin=1.0,
swap=False, reduction='mean': -1),
torch.nn.functional.unfold: lambda input, kernel_size, dilation=1, padding=0, stride=1: -1,
torch.nn.init.uniform_: lambda tensor, a=0., b=1.: -1,
torch.nn.init.normal_: lambda tensor, mean=0., std=1.: -1,
torch.nn.init.uniform_: lambda tensor, a=0., b=1., generator=None: -1,
torch.nn.init.normal_: lambda tensor, mean=0., std=1., generator=None: -1,
torch.nn.init.constant_: lambda tensor, val: -1,
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu': -1,
torch.nn.init.kaiming_uniform_: lambda tensor, a=0, mode='fan_in', nonlinearity='leaky_relu', generator=None: -1,
torch.nonzero: lambda input, as_tuple=False: -1,
torch.nonzero_static: lambda input, *, size, fill_value=-1: -1,
torch.argwhere: lambda input: -1,

View File

@ -9322,7 +9322,14 @@ def wrapper_set_seed(op, *args, **kwargs):
"""
with freeze_rng_state():
torch.manual_seed(42)
return op(*args, **kwargs)
output = op(*args, **kwargs)
if isinstance(output, torch.Tensor) and output.device.type == "lazy":
# We need to call mark step inside freeze_rng_state so that numerics
# match eager execution
torch._lazy.mark_step()
return output
def reference_layer_norm(inp: np.ndarray, normalized_shape: Tuple[int], weight=None, bias=None, eps=1e-5):
@ -17840,8 +17847,14 @@ op_db: List[OpInfo] = [
DecorateInfo(unittest.skip('Skipped!'), 'TestJit', 'test_variant_consistency_jit'),
# Lazy tensor failures
DecorateInfo(unittest.skip('Skipped!'), 'TestLazyOpInfo', 'test_dispatched_to_lazy'),
DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness'),
DecorateInfo(unittest.expectedFailure, 'TestLazyOpInfo', 'test_correctness_with_reusing_ir'),
# These tests fail only when built with ASAN
DecorateInfo(unittest.skip("Fails with ASAN"), 'TestLazyOpInfo', 'test_correctness', active_if=TEST_WITH_ASAN),
DecorateInfo(
unittest.skip("Fails with ASAN"),
'TestLazyOpInfo',
'test_correctness_with_reusing_ir',
active_if=TEST_WITH_ASAN
),
),
),
OpInfo(

View File

@ -1313,6 +1313,28 @@ def skipIfTorchInductor(msg="test doesn't currently work with torchinductor",
def skipRocmIfTorchInductor(msg="test doesn't currently work with torchinductor on the ROCm stack"):
return skipIfTorchInductor(msg=msg, condition=TEST_WITH_ROCM and TEST_WITH_TORCHINDUCTOR)
def skipIfLegacyJitExecutor(msg="test doesn't currently work with legacy JIT executor"):
def decorator(fn):
if not isinstance(fn, type):
@wraps(fn)
def wrapper(*args, **kwargs):
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
raise unittest.SkipTest(msg)
else:
fn(*args, **kwargs)
return wrapper
assert(isinstance(fn, type))
if GRAPH_EXECUTOR == ProfilingMode.LEGACY:
fn.__unittest_skip__ = True
fn.__unittest_skip_why__ = msg
return fn
return decorator
# Run PyTorch tests with translation validation on.
TEST_WITH_TV = os.getenv('PYTORCH_TEST_WITH_TV') == '1'

View File

@ -7,6 +7,7 @@ from torchgen.api.types import (
CType,
deviceT,
doubleT,
generatorT,
layoutT,
ListCType,
longT,
@ -109,6 +110,8 @@ def process_ir_type(
return BaseCType(stringT)
elif typ.name == BaseTy.Device:
return BaseCType(deviceT)
elif typ.name == BaseTy.Generator:
return BaseCType(generatorT)
elif typ.name == BaseTy.Layout:
return BaseCType(layoutT)
elif typ.name == BaseTy.MemoryFormat:
@ -218,16 +221,7 @@ class LazyArgument:
self.symint = symint
self.is_optional = isinstance(arg.type, OptionalType)
self.is_generator = isGeneratorType(arg.type)
if self.is_generator:
assert (
self.is_optional
), "We expect all generators are optional since currently they are"
# there is no handling for generators in TorchScript IR (or XLA)
# so we fall back to eager if the (optional)generator has value, and otherwise
# its null and safe to exclude from lazy IR
self.lazy_type_ = None
else:
self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
self.lazy_type_ = process_ir_type(arg.type, properties, symint=symint)
self.is_wrapped_scalar = isWrappedScalarType(arg.type)
self.is_symint_or_list = symint and (
isSymIntType(arg.type)
@ -236,9 +230,7 @@ class LazyArgument:
# or (isinstance(arg.type, ListType) and isSymIntType(arg.type.elem))
)
self.is_lazy_value = not self.is_generator and isValueType(
self.lazy_type, properties
)
self.is_lazy_value = isValueType(self.lazy_type, properties)
@property
def lazy_type(self) -> CType:
@ -419,7 +411,7 @@ class LazyIrSchema:
keyword: bool = True,
values: bool = True,
scalars: bool = True,
generator: bool = False,
generator: bool = True,
) -> List[LazyArgument]:
# This function maintains the sorted order of arguments but provides different filtered views.
# Some parts of the code care about kwargs vs args (TS lowerings),

View File

@ -122,12 +122,8 @@ def gen_fallback_code(
aten_op_str = f"ATEN_OP2({schema.aten_name}, {overload_name})"
else:
aten_op_str = f"ATEN_OP({schema.aten_name})"
or_has_generator = ""
if schema.generator_arg:
# generators are always optional and there is never more than one, at least currently
or_has_generator = f" || ({schema.generator_arg.name}.has_value() && {schema.generator_arg.name}->defined())"
return f"""
if (force_eager_fallback({aten_symbol(schema)}){or_has_generator}) {{
if (force_eager_fallback({aten_symbol(schema)})) {{
return at::native::call_fallback_fn_symint<&ltc_eager_fallback, {aten_op_str}>::call(
{fallback_args}
);
@ -290,9 +286,12 @@ class GenLazyIR(ABC):
members_to_string = []
for arg in scalar_args:
if isinstance(arg.lazy_type, OptionalCType):
value = f"{arg.name}.value()"
if arg.is_generator:
value = '"torch.Generator()"'
members_to_string.append(
f"""if ({arg.name}.has_value()) {{
ss << ", {arg.name}=" << {arg.name}.value();
ss << ", {arg.name}=" << {value};
}} else {{
ss << ", {arg.name}=null";
}}"""